diff --git a/src/hotspot/cpu/aarch64/assembler_aarch64.hpp b/src/hotspot/cpu/aarch64/assembler_aarch64.hpp index 4eb2f6010c0..ae2b9ac9bf7 100644 --- a/src/hotspot/cpu/aarch64/assembler_aarch64.hpp +++ b/src/hotspot/cpu/aarch64/assembler_aarch64.hpp @@ -3151,6 +3151,34 @@ public: _pmull(Vd, Ta, Vn, Vm, Tb); } + //Vector by element variant of UMULL + void _umullv(FloatRegister Vd, SIMD_Arrangement Ta, FloatRegister Vn, + SIMD_Arrangement Tb, FloatRegister Vm, SIMD_RegVariant Ts, int lane) { + starti; + int size = (Ta == T4S) ? 0b01 : 0b10; + int q = (Tb == T4H || Tb == T2S) ? 0 : 1; + int h = (size == 0b01) ? ((lane >> 2) & 1) : ((lane >> 1) & 1); + int l = (size == 0b01) ? ((lane >> 1) & 1) : (lane & 1); + assert(Ta == T4S || Ta == T2D, "umull{2}v destination register must have arrangement T4S or T2D"); + assert(size == 0b10 ? lane < 4 : lane < 8, "umull{2}v assumes lane < 4 when using half-words and lane < 8 otherwise"); + assert(Ts == H ? Vm->encoding() < 16 : Vm->encoding() < 32, "umull{2}v requires Vm to be in range V0..V15 when Ts is H"); + f(0, 31), f(q, 30), f(0b101111, 29, 24), f(size, 23, 22), f(l, 21); //f(m, 20); + rf(Vm, 16), f(0b1010, 15, 12), f(h, 11), f(0, 10), rf(Vn, 5), rf(Vd, 0); + } + + //Vector by element variant of UMULL + void umullv(FloatRegister Vd, SIMD_Arrangement Ta, FloatRegister Vn, + SIMD_Arrangement Tb, FloatRegister Vm, SIMD_RegVariant Ts, int lane) { + assert(Ta == T4S ? (Tb == T4H && Ts == H) : (Tb == T2S && Ts == S), "umullv register arrangements must adhere to spec"); + _umullv(Vd, Ta, Vn, Tb, Vm, Ts, lane); + } + + void umull2v(FloatRegister Vd, SIMD_Arrangement Ta, FloatRegister Vn, + SIMD_Arrangement Tb, FloatRegister Vm, SIMD_RegVariant Ts, int lane) { + assert(Ta == T4S ? (Tb == T8H && Ts == H) : (Tb == T4S && Ts == S), "umull2v register arrangements must adhere to spec"); + _umullv(Vd, Ta, Vn, Tb, Vm, Ts, lane); + } + void uqxtn(FloatRegister Vd, SIMD_Arrangement Tb, FloatRegister Vn, SIMD_Arrangement Ta) { starti; int size_b = (int)Tb >> 1; diff --git a/src/hotspot/cpu/aarch64/register_aarch64.hpp b/src/hotspot/cpu/aarch64/register_aarch64.hpp index d1e0632c80b..ab83307d526 100644 --- a/src/hotspot/cpu/aarch64/register_aarch64.hpp +++ b/src/hotspot/cpu/aarch64/register_aarch64.hpp @@ -535,6 +535,17 @@ VSeq vs_odd(const VSeq& v) { return VSeq(v.base() + v.delta(), v.delta() * 2); } +template +FloatRegister vs_head(const VSeq& v) { + static_assert(N > 1, "sequence length must be greater than 1"); + return v.base(); +} + +template +VSeq vs_tail(const VSeq& v) { + return VSeq(v.base() + v.delta(), v.delta()); +} + // convenience method to construct a vector register sequence that // indexes its elements in reverse order to the original diff --git a/src/hotspot/cpu/aarch64/stubDeclarations_aarch64.hpp b/src/hotspot/cpu/aarch64/stubDeclarations_aarch64.hpp index d1f59e479db..d1e0621f6a9 100644 --- a/src/hotspot/cpu/aarch64/stubDeclarations_aarch64.hpp +++ b/src/hotspot/cpu/aarch64/stubDeclarations_aarch64.hpp @@ -57,7 +57,7 @@ do_arch_entry, \ do_arch_entry_init, \ do_arch_entry_array) \ - do_arch_blob(compiler, 70000) \ + do_arch_blob(compiler, 75000) \ do_stub(compiler, vector_iota_indices) \ do_arch_entry_array(aarch64, compiler, vector_iota_indices, \ vector_iota_indices, vector_iota_indices, \ diff --git a/src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp b/src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp index f41a54e9d26..f89b6e2d579 100644 --- a/src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp +++ b/src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp @@ -152,6 +152,12 @@ static const uint64_t _double_keccak_round_consts[24] = { 0x8000000000008080L, 0x0000000080000001L, 0x8000000080008008L }; +//Omit 3rd limb of modulus since it is 0 +static const int64_t _modulus_P256[5] = { + 0x000fffffffffffffL, 0x00000fffffffffffL, + 0x0000001000000000L, 0x0000ffffffff0000L +}; + static const char _encodeBlock_toBase64[64] = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', @@ -5311,6 +5317,32 @@ class StubGenerator: public StubCodeGenerator { } } + template + void vs_shl(const VSeq& v, Assembler::SIMD_Arrangement T, + const VSeq& v1, int shift) { + // output must not be constant + assert(N == 1 || !v.is_constant(), "cannot output multiple values to a constant vector"); + // output cannot overwrite pending inputs + assert(!vs_write_before_read(v, v1), "output overwrites input"); + + for (int i = 0; i < N; i++) { + __ shl(v[i], T, v1[i], shift); + } + } + + template + void vs_ushr(const VSeq& v, Assembler::SIMD_Arrangement T, + const VSeq& v1, int shift) { + // output must not be constant + assert(N == 1 || !v.is_constant(), "cannot output multiple values to a constant vector"); + // output cannot overwrite pending inputs + assert(!vs_write_before_read(v, v1), "output overwrites input"); + + for (int i = 0; i < N; i++) { + __ ushr(v[i], T, v1[i], shift); + } + } + template void vs_sshr(const VSeq& v, Assembler::SIMD_Arrangement T, const VSeq& v1, int shift) { @@ -5335,6 +5367,29 @@ class StubGenerator: public StubCodeGenerator { } } + template + void vs_andr(const VSeq& v, const VSeq& v1, const FloatRegister v2) { + // output must not be constant + assert(N == 1 || !v.is_constant(), "cannot output multiple values to a constant vector"); + // output cannot overwrite pending inputs + assert(!vs_write_before_read(v, v1), "output overwrites input"); + for (int i = 0; i < N; i++) { + __ andr(v[i], __ T16B, v1[i], v2); + } + } + + template + void vs_eor(const VSeq& v, const VSeq& v1, const VSeq& v2) { + // output must not be constant + assert(N == 1 || !v.is_constant(), "cannot output multiple values to a constant vector"); + // output cannot overwrite pending inputs + assert(!vs_write_before_read(v, v1), "output overwrites input"); + assert(!vs_write_before_read(v, v2), "output overwrites input"); + for (int i = 0; i < N; i++) { + __ eor(v[i], __ T16B, v1[i], v2[i]); + } + } + template void vs_orr(const VSeq& v, const VSeq& v1, const VSeq& v2) { // output must not be constant @@ -5388,7 +5443,7 @@ class StubGenerator: public StubCodeGenerator { template void vs_ldpq(const VSeq& v, Register base) { for (int i = 0; i < N; i += 2) { - __ ldpq(v[i], v[i+1], Address(base, 32 * i)); + __ ldpq(v[i], v[i+1], Address(base, 16 * i)); } } @@ -5436,6 +5491,18 @@ class StubGenerator: public StubCodeGenerator { } } + // store two vector register sequences of length N + // interleaved into N pairs of quadword memory locations + // starting at the address supplied in dest using + // post-increment addressing. + template + void vs_st1_interleaved(VSeq A, VSeq B, Register dest) { + for (int i = 0; i < N; i++) { + __ st1(A[i], __ T2D, __ post(dest, 16)); + __ st1(B[i], __ T2D, __ post(dest, 16)); + } + } + // load N quadword values from memory de-interleaved into N vector // registers 3 elements at a time via the address supplied in base. template @@ -7674,6 +7741,919 @@ class StubGenerator: public StubCodeGenerator { return start; } + static constexpr int montMulP256Shift1 = 12; // 64 - bits per limb + static constexpr int montMulP256Shift2 = 52; // bits per limb + // stack space needed for carry computation + static constexpr int cDataSize = 6 * BytesPerLong; + // stack space needed for data computed by the neon side + static constexpr int mulDataSize = 16 * BytesPerLong; + + + // Subroutine used by the 52 x 52 bit multiplication algorithm in + // generate_intpoly_montgomeryMult_P256(). + // This function computes partial results of eight 52 x 52 bit multiplications, + // where the multiplicands are stored as 64-bit values, specifically + // (b_0, b_1, b_2, b_3) * (a_3, a_4). (The 4 calls to this function + // together provide the results of these limb-multiplications.) + // Calls to this function accept either the low 32 bits or high 20 bits + // of each b_i packed into bs in ascending order. a_3 and a_4 are packed + // into successive 64 bit elements of as. lane selects the low 32 or high + // 20 bits of each a_j value. So four calls with the appropriate parameters + // will produce the 64-bit low32 * low32, low32 * high20, high20 * low32, + // high20 * high20 values in the output register sequences vs. The + // 64-bit partial products are returned in vs in ascending order: + // vs[0] = (b_0*a_3, b_1*a_3) . . . vs[3] = (b_2*a_4, b_3*a_4) + + void neon_partial_mult_64(const VSeq<4>& vs, FloatRegister bs, FloatRegister as, int lane_lo) { + __ umullv(vs[0], __ T2D, bs, __ T2S, as, __ S, lane_lo); + __ umull2v(vs[1], __ T2D, bs, __ T4S, as, __ S, lane_lo); + __ umullv(vs[2], __ T2D, bs, __ T2S, as, __ S, lane_lo + 2); + __ umull2v(vs[3], __ T2D, bs, __ T4S, as, __ S, lane_lo + 2); + } + + // Subroutine used by the generate_intpoly_montgomeryMult_P256() function + // to compute the result of a 52 x 52 bit multiplications where the + // multiplicands, a and b are available as 64-bit values. + // The result is going to two 64-bit registers lo (least significant 52 bits) + // and hi (most significant 52 bits). + void gpr_partial_mult_52(Register a, Register b, Register hi, Register lo, + Register mask) { + // compute 104-bit (40 + 64) full product + __ umulh(hi, a, b); + __ mul(lo, a, b); + // combine 40 + 12 bits into hi result + // on certain implementations of aarch64 (e.g. apple M1) replacing extr() + // with the following equivalent instruction sequence the performance + // improves slightly (despite it is two instructions longer and needs + // an additional register) + // __ lsl(hi, hi, montMulP256Shift1); + // __ lsr(tmp, lo, montMulP256Shift2); + // __ orr(hi, hi, tmp); + __ extr(hi, hi, lo, montMulP256Shift2); + // mask off 52 bits of lo result + __ andr(lo, lo, mask); + } + + // This assembly follows the Java code in MontgomeryIntegerPolynomial256.mult() + // quite closely. The main difference is that the computations done with the + // last two limbs of `a` are done using Neon registers. This allows us to take + // advantage of both the Neon registers and GPRs simultaneously. + // It is also worth noting that since Neon does not support 64 bit + // multiplication, we split each 64 bit value into lower and upper halves + // and use the "schoolbook" multiplication algorithm. + address generate_intpoly_montgomeryMult_P256() { + assert(UseIntPolyIntrinsics, "what are we doing here?"); + StubId stub_id = StubId::stubgen_intpoly_montgomeryMult_P256_id; + int entry_count = StubInfo::entry_count(stub_id); + assert(entry_count == 1, "sanity check"); + address start = load_archive_data(stub_id); + if (start != nullptr) { + return start; + } + __ align(CodeEntryAlignment); + StubCodeMark mark(this, stub_id); + start = __ pc(); + __ enter(); + + // Registers that are used throughout entire routine + const Register a = c_rarg0; + const Register b = c_rarg1; + const Register result = c_rarg2; + + RegSet regs = RegSet::range(r0, r28) - rscratch1 - rscratch2 + - r16 - r17 - r18_tls - a - b - result; + + auto common_regs = regs.begin(); + Register limb_mask = *common_regs++, + c_ptr = *common_regs++, + mod_0 = *common_regs++, + mod_1 = *common_regs++, + mod_3 = *common_regs++, + mod_4 = *common_regs++, + b_0 = *common_regs++, + b_1 = *common_regs++, + b_2 = *common_regs++, + b_3 = *common_regs++, + b_4 = *common_regs++; + + FloatRegSet floatRegs = FloatRegSet::range(v0, v31) + - FloatRegSet::range(v8, v15) // Caller saved vectors + - FloatRegSet::range(v16, v31); // Manually-allocated vectors + + auto common_vectors = floatRegs.begin(); + FloatRegister limb_mask_vec = *common_vectors++, + b_lows = *common_vectors++, + b_highs = *common_vectors++, + a_vals = *common_vectors++; + + // Push callee saved registers on to the stack + RegSet callee_saved = RegSet::range(r19, r28); + __ push(callee_saved, sp); + + // Allocate space on the stack for carry values + __ sub(sp, sp, cDataSize); + __ mov(c_ptr, sp); + + // Calculate (52-bit) limb masks for both gpr and vector registers + __ mov(limb_mask, -UCONST64(1) >> montMulP256Shift1); + __ dup(limb_mask_vec, __ T2D, limb_mask); + + //Load input arrays and modulus + Register a_ptr = *common_regs++, mod_ptr = *common_regs++; + // skip 3 limbs so a_ptr addresses trailing pair {a3, a4} + __ add(a_ptr, a, 3 * BytesPerLong); + __ lea(mod_ptr, ExternalAddress((address)_modulus_P256)); + __ ldr(b_0, Address(b)); + __ ldr(b_1, Address(b, BytesPerLong)); + __ ldr(b_2, Address(b, 2 * BytesPerLong)); + __ ldr(b_3, Address(b, 3 * BytesPerLong)); + __ ldr(b_4, Address(b, 4 * BytesPerLong)); + __ ldr(mod_0, __ post(mod_ptr, BytesPerLong)); + __ ldr(mod_1, __ post(mod_ptr, BytesPerLong)); + __ ldr(mod_3, __ post(mod_ptr, BytesPerLong)); + __ ldr(mod_4, mod_ptr); + __ ld1(a_vals, __ T2D, a_ptr); + // use an interleaved load to group low 32 bits and high 20 bits + // of 4 successive b values into two vector registers + // n.b. these are the same inputs as the ones in b_0 ... b4 + __ ld2(b_lows, b_highs, __ T4S, b); + common_regs = common_regs.remaining() + + a_ptr + mod_ptr; + a_ptr = mod_ptr = noreg; + + //Regs used throughout the main "loop", which is partially unrolled here + Register high = *common_regs++, + low = *common_regs++, + mul_ptr = *common_regs++, + mod_high = *common_regs++, + mod_low = *common_regs++, + a_i = *common_regs++, + c_i = *common_regs++, + tmp = *common_regs++, + n = *common_regs++; + + // vector sequences used to compute and combine partial products of + // b_i * a_j for i = {0,1,2,3} j = {3,4} + VSeq<4> A(16); + VSeq<4> B(20); + VSeq<4> C(24); + VSeq<4> D(28); + + + // neon and gpr computations are interleaved to maximize parallelism + + // allocate stack space for the neon results + __ sub(sp, sp, mulDataSize); + __ mov(mul_ptr, sp); + + // cross-multiply low * low for limbs b0-b3 and a3-a4 in parallel + neon_partial_mult_64(A, b_lows, a_vals, 0); + + // Limb 0 + __ ldr(a_i, __ post(a, BytesPerLong)); + gpr_partial_mult_52(a_i, b_0, high, low, limb_mask); + __ mov(n, low); + // __ andr(n, low, limb_mask); + + // cross-multiply high * low for limbs b0-b3 and a3-a4 in parallel + neon_partial_mult_64(B, b_highs, a_vals, 0); + + // Limb 0 modulus computation + // n.b. modulus computation requires multiplying successive + // limbs of the product by corresponding limbs of the p256 + // prime adding the result to the limb and folding this + // partial result into a running 256-bit sum in c_i. Limbs + // of c_i are stored via c_ptr once carries are included. + // n.b. the mul + add is omitted for limb 2 since the + // corresponding prime bits are zero. + gpr_partial_mult_52(n, mod_0, mod_high, mod_low, limb_mask); + __ add(low, low, mod_low); + __ add(high, high, mod_high); + __ lsr(c_i, low, montMulP256Shift2); + __ add(c_i, c_i, high); + + // cross-multiply low * high for limbs b0-b3 and a3-a4 in parallel + neon_partial_mult_64(C, b_lows, a_vals, 1); + + // Limb 1 + gpr_partial_mult_52(a_i, b_1, high, low, limb_mask); + + // cross-multiply high * high for limbs b0-b3 and a3-a4 in parallel + neon_partial_mult_64(D, b_highs, a_vals, 1); + + gpr_partial_mult_52(n, mod_1, mod_high, mod_low, limb_mask); + __ add(low, low, mod_low); + __ add(high, high, mod_high); + __ add(c_i, c_i, low); + __ str(c_i, c_ptr); + __ mov(c_i, high); + + // combine neon 32-bit partial products, regrouping to produce + // 8*52-bit low products in A and 8*52-bit high products in D + + // add low*high/high*low intermediate products before regrouping + vs_addv(B, __ T2D, B, C); // Store (B+C) in B + + // Limb 2 + gpr_partial_mult_52(a_i, b_2, high, low, limb_mask); + __ add(c_i, c_i, low); + __ str(c_i, Address(c_ptr, 8)); + __ mov(c_i, high); + + // shift high*high (40-bit) product up into 52-bits of output + vs_shl(D, __ T2D, D, montMulP256Shift1); + + // Limb 3 + gpr_partial_mult_52(a_i, b_3, high, low, limb_mask); + + // shift high 32 (or 33) bits of intermediate products for addition to D + vs_ushr(C, __ T2D, B, 32 - montMulP256Shift1); // Use C for ((B+C) >>> 20) + + gpr_partial_mult_52(n, mod_3, mod_high, mod_low, limb_mask); + __ add(low, low, mod_low); + __ add(high, high, mod_high); + __ add(c_i, c_i, low); + __ str(c_i, Address(c_ptr, 2 * BytesPerLong)); + __ mov(c_i, high); + + // shift low 32 bits of intermediate product up for masking and addition to A + vs_shl(B, __ T2D, B, 32); + + // Limb 4 + gpr_partial_mult_52(a_i, b_4, high, low, limb_mask); + + // add high bits of intermediate product into D + vs_addv(D, __ T2D, D, C); + + gpr_partial_mult_52(n, mod_4, mod_high, mod_low, limb_mask); + __ add(low, low, mod_low); + __ add(high, high, mod_high); + __ add(c_i, c_i, low); + __ str(c_i, Address(c_ptr, 3 * BytesPerLong)); + __ str(high, Address(c_ptr, 4 * BytesPerLong)); + + // top 12 bits of 32*32 bit product in A need adding into high 52-bit output + vs_ushr(C, __ T2D, A, 52); // C now holds (A >>> 52) + // Only 20 of the 32 bits now in the top of B should be added into A + vs_andr(B, B, limb_mask_vec); + // reduce original 64-bit product to 52-bits + vs_andr(A, A, limb_mask_vec); + // add intermediate products to high 52-bit result in D + vs_addv(D, __ T2D, D, C); + // add 20/21 bits of intermediate product in top of B into low 52-bit result + vs_addv(A, __ T2D, A, B); + // save and then mask off any overflow bit from computing low 52-bit result + vs_ushr(B, __ T2D, A, montMulP256Shift2); + vs_andr(A, A, limb_mask_vec); + // add any remaining carry into the high 52-bit result + vs_addv(D, __ T2D, D, B); + + // the write interleaves the 4 successive pairs of low and + // high results: (l0, l1), (h0, h1), ... (l6, l7), (h6, h7) + vs_st1_interleaved(A, D, mul_ptr); + + // Free mul_ptr + common_regs = common_regs.remaining() + mul_ptr; + mul_ptr = noreg; + + ///////////////////////// + // Loop 2 & 3 + ///////////////////////// + + for (int i = 0; i < 2; i++) { + // Load a_i and increment by 8 bytes + __ ldr(a_i, __ post(a, BytesPerLong)); + __ ldr(c_i, c_ptr); //Load prior c_i + + // Limb 0 + gpr_partial_mult_52(a_i, b_0, high, low, limb_mask); + __ add(low, low, c_i); + __ ldr(c_i, Address(c_ptr, BytesPerLong)); + __ andr(n, low, limb_mask); + gpr_partial_mult_52(n, mod_0, mod_high, mod_low, limb_mask); + __ add(low, low, mod_low); + __ add(high, high, mod_high); + __ lsr(tmp, low, montMulP256Shift2); + __ add(c_i, c_i, tmp); + __ add(c_i, c_i, high); + + // Limb 1 + gpr_partial_mult_52(a_i, b_1, high, low, limb_mask); + gpr_partial_mult_52(n, mod_1, mod_high, mod_low, limb_mask); + __ ldr(tmp, Address(c_ptr, 2 * BytesPerLong)); + __ add(low, low, mod_low); + __ add(high, high, mod_high); + __ add(c_i, c_i, low); + __ str(c_i, c_ptr); + __ add(c_i, tmp, high); + + // Limb 2 + gpr_partial_mult_52(a_i, b_2, high, low, limb_mask); + __ ldr(tmp, Address(c_ptr, 3 * BytesPerLong)); + __ add(c_i, c_i, low); + __ str(c_i, Address(c_ptr, BytesPerLong)); + __ add(c_i, tmp, high); + + // Limb 3 + gpr_partial_mult_52(a_i, b_3, high, low, limb_mask); + gpr_partial_mult_52(n, mod_3, mod_high, mod_low, limb_mask); + __ ldr(tmp, Address(c_ptr, 4 * BytesPerLong)); + __ add(low, low, mod_low); + __ add(high, high, mod_high); + __ add(c_i, c_i, low); + __ str(c_i, Address(c_ptr, 2 * BytesPerLong)); + __ add(c_i, tmp, high); + + // Limb 4 + gpr_partial_mult_52(a_i, b_4, high, low, limb_mask); + gpr_partial_mult_52(n, mod_4, mod_high, mod_low, limb_mask); + __ add(low, low, mod_low); + __ add(high, high, mod_high); + __ add(c_i, c_i, low); + __ str(c_i, Address(c_ptr, 3 * BytesPerLong)); + __ str(high, Address(c_ptr, 4 * BytesPerLong)); + } + // Reallocate regs b_0, b_1, b_2 and b_3 + common_regs = common_regs.remaining() + + b_0 + b_1 + b_2 + b_3; + b_0 = b_1 = b_2 = b_3 = noreg; + + Register low_1 = *common_regs++; + Register high_1 = *common_regs++; + + ////////////////////////////// + // a[3] + ////////////////////////////// + + // For a_3 and a_4 we have already computed the cross-products + // with b_0 ... b_3 and stored them on the stack relative to + // `mul_ptr` i.e. the current `sp`in the order + // l(a_3 * b_0), l(a_3 * b_1), h(a_3 * b_0), h(a_3 * b_1), + // l(a_3 * b_2), l(a_3 * b_3), h(a_3 * b_2), h(a_3 * b_3), + // l(a_4 * b_0), l(a_4 * b_1), h(a_4 * b_0), h(a_4 * b_1), + // l(a_4 * b_2), l(a_4 * b_3), h(a_4 * b_2), h(a_4 * b_3), + // where l(x) is the low 52 bits of x and h(x) is the high 52 bits + + __ ldr(low_1, Address(sp)); + __ ldr(high_1, Address(sp, 2 * BytesPerLong)); + + __ ldr(low, Address(sp, BytesPerLong)); + __ ldr(high, Address(sp, 3 * BytesPerLong)); + __ ldr(a_i, __ post(a, BytesPerLong)); + __ ldr(c_i, c_ptr); + + // Limb 0 + __ add(low_1, low_1, c_i); + __ ldr(c_i, Address(c_ptr, BytesPerLong)); + __ andr(n, low_1, limb_mask); + gpr_partial_mult_52(n, mod_0, mod_high, mod_low, limb_mask); + __ add(low_1, low_1, mod_low); + __ add(high_1, high_1, mod_high); + __ lsr(tmp, low_1, montMulP256Shift2); + __ add(c_i, c_i, tmp); + __ add(c_i, c_i, high_1); + + // Limb 1 + __ ldr(low_1, Address(sp, 4 * BytesPerLong)); + __ ldr(high_1, Address(sp, 6 * BytesPerLong)); + gpr_partial_mult_52(n, mod_1, mod_high, mod_low, limb_mask); + __ ldr(tmp, Address(c_ptr, 2 * BytesPerLong)); + __ andr(mod_low, mod_low, limb_mask); + __ add(low, low, mod_low); + __ add(high, high, mod_high); + __ add(c_i, c_i, low); + __ str(c_i, c_ptr); + __ add(c_i, tmp, high); + + // Limb 2 + __ ldr(low, Address(sp, 5 * BytesPerLong)); + __ ldr(high, Address(sp, 7 * BytesPerLong)); + __ ldr(tmp, Address(c_ptr, 3 * BytesPerLong)); + __ add(c_i, c_i, low_1); + __ str(c_i, Address(c_ptr, BytesPerLong)); + __ add(c_i, tmp, high_1); + + // Limb 3 + gpr_partial_mult_52(n, mod_3, mod_high, mod_low, limb_mask); + __ ldr(tmp, Address(c_ptr, 4 * BytesPerLong)); + __ add(low, low, mod_low); + __ add(high, high, mod_high); + __ add(c_i, c_i, low); + __ str(c_i, Address(c_ptr, 2 * BytesPerLong)); + __ add(c_i, tmp, high); + + // Limb 4 + __ ldr(low, Address(sp, 8 * BytesPerLong)); + __ ldr(high, Address(sp, 10 * BytesPerLong)); + gpr_partial_mult_52(a_i, b_4, high_1, low_1, limb_mask); + gpr_partial_mult_52(n, mod_4, mod_high, mod_low, limb_mask); + __ add(low_1, low_1, mod_low); + __ add(high_1, high_1, mod_high); + __ add(c_i, c_i, low_1); + __ str(c_i, Address(c_ptr, 3 * BytesPerLong)); + __ str(high_1, Address(c_ptr, 4 * BytesPerLong)); + + ////////////////////////////// + // a[4] + ////////////////////////////// + + Register c5 = *common_regs++, + c6 = *common_regs++, + c7 = *common_regs++; + + __ ldr(a_i, a); + __ ldr(c_i, c_ptr); + + // Limb 0 + __ ldr(low_1, Address(sp, 9 * BytesPerLong)); + __ ldr(high_1, Address(sp, 11 * BytesPerLong)); + + __ add(low, low, c_i); + __ ldr(c_i, Address(c_ptr, BytesPerLong)); + __ andr(n, low, limb_mask); + gpr_partial_mult_52(n, mod_0, mod_high, mod_low, limb_mask); + __ add(low, low, mod_low); + __ add(high, high, mod_high); + __ lsr(tmp, low, montMulP256Shift2); + __ add(c_i, c_i, tmp); + __ add(c_i, c_i, high); + + __ ldr(low, Address(sp, 12 * BytesPerLong)); + __ ldr(high, Address(sp, 14 * BytesPerLong)); + gpr_partial_mult_52(n, mod_1, mod_high, mod_low, limb_mask); + __ add(low_1, low_1, mod_low); + __ add(high_1, high_1, mod_high); + __ add(c5, c_i, low_1); + __ ldr(c_i, Address(c_ptr, 2 * BytesPerLong)); + __ lsr(tmp, c5, montMulP256Shift2); + __ add(c_i, c_i, tmp); + __ add(c_i, c_i, high_1); + + // Limb 2 + __ ldr(low_1, Address(sp, 13 * BytesPerLong)); + __ ldr(high_1, Address(sp, 15 * BytesPerLong)); + __ add(c6, c_i, low); + __ ldr(c_i, Address(c_ptr, 3 * BytesPerLong)); + __ lsr(tmp, c6, montMulP256Shift2); + __ add(c_i, c_i, tmp); + __ add(c_i, c_i, high); + + // Limb 3 + gpr_partial_mult_52(n, mod_3, mod_high, mod_low, limb_mask); + __ add(low_1, low_1, mod_low); + __ add(high_1, high_1, mod_high); + __ add(c7, c_i, low_1); + __ ldr(c_i, Address(c_ptr, 4 * BytesPerLong)); + __ lsr(tmp, c7, montMulP256Shift2); + __ add(c_i, c_i, tmp); + __ add(c_i, c_i, high_1); + + // Limb 4 + gpr_partial_mult_52(a_i, b_4, high, low, limb_mask); + gpr_partial_mult_52(n, mod_4, mod_high, mod_low, limb_mask); + __ add(low, low, mod_low); + __ add(high, high, mod_high); + + // Reallocate b_4 + common_regs = common_regs.remaining() + b_4; + b_4 = noreg; + + Register c8 = *common_regs++, + c9 = *common_regs++; + + __ add(c8, c_i, low); + __ lsr(c9, c8, montMulP256Shift2); + __ add(c9, c9, high); + + __ andr(c5, c5, limb_mask); + __ andr(c6, c6, limb_mask); + __ andr(c7, c7, limb_mask); + __ andr(c8, c8, limb_mask); + + ///////////////////////////// + // Final carry propagate + ///////////////////////////// + + // c0 = c5 - modulus[0]; + // c1 = c6 - modulus[1] + (c0 >> BITS_PER_LIMB); + // c0 &= LIMB_MASK; + // c2 = c7 + (c1 >> BITS_PER_LIMB); + // c1 &= LIMB_MASK; + // c3 = c8 - modulus[3] + (c2 >> BITS_PER_LIMB); + // c2 &= LIMB_MASK; + // c4 = c9 - modulus4] + (c3 >> BITS_PER_LIMB); + // c3 &= LIMB_MASK; + + // Free up all unused regs + common_regs = common_regs.remaining() + + c_ptr + low + high + mod_high + + mod_low + a_i + c_i + n + low_1 + high_1; + c_ptr = low = high = mod_high + = mod_low = a_i = c_i = n = low_1 = high_1 = noreg; + + Register c0 = *common_regs++, + c1 = *common_regs++, + c2 = *common_regs++, + c3 = *common_regs++, + c4 = *common_regs++; + + __ sub(c0, c5, mod_0); + __ sub(c1, c6, mod_1); + __ sub(c3, c8, mod_3); + __ sub(c4, c9, mod_4); + __ add(c1, c1, c0, Assembler::ASR, montMulP256Shift2); + __ andr(c0, c0, limb_mask); + __ add(c2, c7, c1, Assembler::ASR, montMulP256Shift2); + __ andr(c1, c1, limb_mask); + __ add(c3, c3, c2, Assembler::ASR, montMulP256Shift2); + __ andr(c2, c2, limb_mask); + __ add(c4, c4, c3, Assembler::ASR, montMulP256Shift2); + __ andr(c3, c3, limb_mask); + + // Final write back + // mask = c4 >> 63 + // r[0] = ((c5 & mask) | (c0 & ~mask)); + // r[1] = ((c6 & mask) | (c1 & ~mask)); + // r[2] = ((c7 & mask) | (c2 & ~mask)); + // r[3] = ((c8 & mask) | (c3 & ~mask)); + // r[4] = ((c9 & mask) | (c4 & ~mask)); + + common_regs = common_regs.remaining() + + mod_0 + mod_1 + mod_3 + mod_4; + mod_0 = mod_1 = mod_3 = mod_4 = noreg; + + Register mask = *common_regs++; + Register nmask = *common_regs++; + + __ asr(mask, c4, 63); + __ mvn(nmask, mask); + __ andr(c5, c5, mask); + __ andr(tmp, c0, nmask); + __ orr(c5, c5, tmp); + __ andr(c6, c6, mask); + __ andr(tmp, c1, nmask); + __ orr(c6, c6, tmp); + __ andr(c7, c7, mask); + __ andr(tmp, c2, nmask); + __ orr(c7, c7, tmp); + __ andr(c8, c8, mask); + __ andr(tmp, c3, nmask); + __ orr(c8, c8, tmp); + __ andr(c9, c9, mask); + __ andr(tmp, c4, nmask); + __ orr(c9, c9, tmp); + + __ str(c5, result); + __ str(c6, Address(result, BytesPerLong)); + __ str(c7, Address(result, 2 * BytesPerLong)); + __ str(c8, Address(result, 3 * BytesPerLong)); + __ str(c9, Address(result, 4 * BytesPerLong)); + + // End intrinsic call + __ add(sp, sp, cDataSize + mulDataSize); + __ pop(callee_saved, sp); + __ leave(); + __ mov(r0, zr); // return 0 + __ ret(lr); + + // record the stub entry and end + store_archive_data(stub_id, start, __ pc()); + + return start; + } + + address generate_intpoly_assign() { + // KNOWN Lengths: + // MontgomeryIntPolynP256: 5 = 4 + 1 + // IntegerPolynomial1305: 5 = 4 + 1 + // IntegerPolynomial25519: 10 = 8 + 2 + // IntegerPolynomialP256: 10 = 8 + 2 + // Curve25519OrderField: 10 = 8 + 2 + // Curve25519OrderField: 10 = 8 + 2 + // P256OrderField: 10 = 8 + 2 + // IntegerPolynomialP384: 14 = 8 + 4 + 2 + // P384OrderField: 14 = 8 + 4 + 2 + // IntegerPolynomial448: 16 = 8 + 8 + // Curve448OrderField: 16 = 8 + 8 + // Curve448OrderField: 16 = 8 + 8 + // IntegerPolynomialP521: 19 = 8 + 8 + 2 + 1 + // P521OrderField: 19 = 8 + 8 + 2 + 1 + // Special Cases 5, 10, 14, 16, 19 + assert(UseIntPolyIntrinsics, "what are we doing here?"); + StubId stub_id = StubId::stubgen_intpoly_assign_id; + int entry_count = StubInfo::entry_count(stub_id); + assert(entry_count == 1, "sanity check"); + address start = load_archive_data(stub_id); + if (start != nullptr) { + return start; + } + + __ align(CodeEntryAlignment); + StubCodeMark mark(this, stub_id); + start = __ pc(); + __ enter(); + + // Inputs + const Register set = c_rarg0; + const Register aLimbs = c_rarg1; + const Register bLimbs = c_rarg2; + const Register length = c_rarg3; + + Label L_Length5, L_Length10, L_Length14, L_Length16, L_Length19, L_Default, L_Done; + + /* + int maskValue = -set; + for (int i = 0; i < a.length; i++) { + long dummyLimbs = maskValue & (a[i] ^ b[i]); + a[i] = dummyLimbs ^ a[i]; + } + */ + Register mask_scalar = r4; + FloatRegister mask_vec = v0; + + __ neg(mask_scalar, set); + __ dup(mask_vec, __ T2D, mask_scalar); + + __ cmp(length, (u1)5); + __ br(Assembler::EQ, L_Length5); + __ cmp(length, (u1)10); + __ br(Assembler::EQ, L_Length10); + __ cmp(length, (u1)14); + __ br(Assembler::EQ, L_Length14); + __ cmp(length, (u1)16); + __ br(Assembler::EQ, L_Length16); + __ cmp(length, (u1)19); + __ br(Assembler::EQ, L_Length19); + __ b(L_Default); + + + // Length = 5 + // Use 5 GPRs (neon not faster with this few limbs) + __ BIND(L_Length5); + { + Register a0 = r5; + Register a1 = r6; + Register a2 = r7; + Register a3 = r10; + Register a4 = r11; + Register b0 = r12; + Register b1 = r13; + Register b2 = r14; + Register b3 = r15; + Register b4 = r19; + + __ push(r19, sp); + + __ ldr(a0, aLimbs); + __ ldr(a1, Address(aLimbs, 1 * BytesPerLong)); + __ ldr(a2, Address(aLimbs, 2 * BytesPerLong)); + __ ldr(a3, Address(aLimbs, 3 * BytesPerLong)); + __ ldr(a4, Address(aLimbs, 4 * BytesPerLong)); + + __ ldr(b0, bLimbs); + __ ldr(b1, Address(bLimbs, 1 * BytesPerLong)); + __ ldr(b2, Address(bLimbs, 2 * BytesPerLong)); + __ ldr(b3, Address(bLimbs, 3 * BytesPerLong)); + __ ldr(b4, Address(bLimbs, 4 * BytesPerLong)); + + __ eor(b0, b0, a0); + __ eor(b1, b1, a1); + __ eor(b2, b2, a2); + __ eor(b3, b3, a3); + __ eor(b4, b4, a4); + + __ andr(b0, b0, mask_scalar); + __ andr(b1, b1, mask_scalar); + __ andr(b2, b2, mask_scalar); + __ andr(b3, b3, mask_scalar); + __ andr(b4, b4, mask_scalar); + + __ eor(a0, a0, b0); + __ eor(a1, a1, b1); + __ eor(a2, a2, b2); + __ eor(a3, a3, b3); + __ eor(a4, a4, b4); + + __ str(a0, aLimbs); + __ str(a1, Address(aLimbs, 1 * BytesPerLong)); + __ str(a2, Address(aLimbs, 2 * BytesPerLong)); + __ str(a3, Address(aLimbs, 3 * BytesPerLong)); + __ str(a4, Address(aLimbs, 4 * BytesPerLong)); + + __ pop(r19, sp); + __ b(L_Done); + } + + // Length = 10 + // Split into 4 neon regs and 2 GPRs + __ BIND(L_Length10); + { + Register a9 = r10; + Register a10 = r11; + Register b9 = r12; + Register b10 = r13; + + VSeq<4> a_vec(16); + VSeq<4> b_vec(20); + + __ ldr(a9, Address(aLimbs, 8 * BytesPerLong)); + __ ldr(a10, Address(aLimbs, 9 * BytesPerLong)); + __ ldr(b9, Address(bLimbs, 8 * BytesPerLong)); + __ ldr(b10, Address(bLimbs, 9 * BytesPerLong)); + + vs_ldpq(a_vec, aLimbs); + + __ eor(b9, b9, a9); + __ eor(b10, b10, a10); + + vs_ldpq(b_vec, bLimbs); + + __ andr(b9, b9, mask_scalar); + __ andr(b10, b10, mask_scalar); + + vs_eor(b_vec, b_vec, a_vec); + + __ eor(a9, a9, b9); + __ eor(a10, a10, b10); + + vs_andr(b_vec, b_vec, mask_vec); + + __ str(a9, Address(aLimbs, 8 * BytesPerLong)); + __ str(a10, Address(aLimbs, 9 * BytesPerLong)); + + vs_eor(a_vec, a_vec, b_vec); + vs_stpq_post(a_vec, aLimbs); + + __ b(L_Done); + } + + // Length = 14 + // Split into 5 neon regs and 4 GPRs + __ BIND(L_Length14); + { + Register a10 = r5; + Register a11 = r6; + Register a12 = r7; + Register a13 = r8; + Register b10 = r9; + Register b11 = r10; + Register b12 = r11; + Register b13 = r12; + + VSeq<5> a_vec(16); + VSeq<5> b_vec(22); + + int offsets[2] = { 0, 32 }; + + __ ldr(a10, Address(aLimbs, 10 * BytesPerLong)); + __ ldr(a11, Address(aLimbs, 11 * BytesPerLong)); + __ ldr(a12, Address(aLimbs, 12 * BytesPerLong)); + __ ldr(a13, Address(aLimbs, 13 * BytesPerLong)); + + __ ldr(b10, Address(bLimbs, 10 * BytesPerLong)); + __ ldr(b11, Address(bLimbs, 11 * BytesPerLong)); + __ ldr(b12, Address(bLimbs, 12 * BytesPerLong)); + __ ldr(b13, Address(bLimbs, 13 * BytesPerLong)); + + __ ld1(a_vec[0], __ T2D, aLimbs); + vs_ldpq_indexed(vs_tail(a_vec), aLimbs, 16, offsets); + + __ eor(b10, b10, a10); + __ eor(b11, b11, a11); + __ eor(b12, b12, a12); + __ eor(b13, b13, a13); + + __ ld1(b_vec[0], __ T2D, bLimbs); + vs_ldpq_indexed(vs_tail(b_vec), bLimbs, 16, offsets); + + __ andr(b10, b10, mask_scalar); + __ andr(b11, b11, mask_scalar); + __ andr(b12, b12, mask_scalar); + __ andr(b13, b13, mask_scalar); + + vs_eor(b_vec, b_vec, a_vec); + + __ eor(a10, a10, b10); + __ eor(a11, a11, b11); + __ eor(a12, a12, b12); + __ eor(a13, a13, b13); + + vs_andr(b_vec, b_vec, mask_vec); + + __ str(a10, Address(aLimbs, 10 * BytesPerLong)); + __ str(a11, Address(aLimbs, 11 * BytesPerLong)); + __ str(a12, Address(aLimbs, 12 * BytesPerLong)); + __ str(a13, Address(aLimbs, 13 * BytesPerLong)); + + vs_eor(a_vec, a_vec, b_vec); + + __ st1(a_vec[0], __ T2D, aLimbs); + vs_stpq_indexed(vs_tail(a_vec), aLimbs, 16, offsets); + + __ b(L_Done); + } + + // Length = 16 + // Use 8 neon regs + __ BIND(L_Length16); + { + VSeq<8> a_vec(16); + VSeq<8> b_vec(24); + + vs_ldpq(a_vec, aLimbs); + vs_ldpq(b_vec, bLimbs); + vs_eor(b_vec, b_vec, a_vec); + vs_andr(b_vec, b_vec, mask_vec); + vs_eor(a_vec, a_vec, b_vec); + vs_stpq_post(a_vec, aLimbs); + + __ b(L_Done); + } + + // Length = 19 + // Split into 8 neon regs and 3 GPRs + __ BIND(L_Length19); + { + Register a17 = r10; + Register a18 = r11; + Register a19 = r12; + Register b17 = r13; + Register b18 = r14; + Register b19 = r15; + + VSeq<8> a_vec(16); + VSeq<8> b_vec(24); + + __ ldr(a17, Address(aLimbs, 16 * BytesPerLong)); + __ ldr(a18, Address(aLimbs, 17 * BytesPerLong)); + __ ldr(a19, Address(aLimbs, 18 * BytesPerLong)); + __ ldr(b17, Address(bLimbs, 16 * BytesPerLong)); + __ ldr(b18, Address(bLimbs, 17 * BytesPerLong)); + __ ldr(b19, Address(bLimbs, 18 * BytesPerLong)); + + vs_ldpq(a_vec, aLimbs); + + __ eor(b17, b17, a17); + __ eor(b18, b18, a18); + __ eor(b19, b19, a19); + + vs_ldpq(b_vec, bLimbs); + + __ andr(b17, b17, mask_scalar); + __ andr(b18, b18, mask_scalar); + __ andr(b19, b19, mask_scalar); + + vs_eor(b_vec, b_vec, a_vec); + + __ eor(a17, a17, b17); + __ eor(a18, a18, b18); + __ eor(a19, a19, b19); + + vs_andr(b_vec, b_vec, mask_vec); + + __ str(a17, Address(aLimbs, 16 * BytesPerLong)); + __ str(a18, Address(aLimbs, 17 * BytesPerLong)); + __ str(a19, Address(aLimbs, 18 * BytesPerLong)); + + vs_eor(a_vec, a_vec, b_vec); + vs_stpq_post(a_vec, aLimbs); + + __ b(L_Done); + } + + __ BIND(L_Default); + { + Register ctr = r5; + Register a_val = r6; + Register b_val = r7; + + __ mov(ctr, length); // length (the number of limbs) is never 0 + + Label default_loop; + __ BIND(default_loop); + + __ ldr(a_val, aLimbs); + __ ldr(b_val, __ post(bLimbs, 8)); + __ eor(b_val, b_val, a_val); + __ andr(b_val, b_val, mask_scalar); + __ eor(a_val, a_val, b_val); + __ str(a_val, __ post(aLimbs, 8)); + __ sub(ctr, ctr, 1); + __ cmp(ctr, (u1)0); + __ br(Assembler::NE, default_loop); + } + + __ BIND(L_Done); + __ leave(); // required for proper stackwalking of RuntimeStub frame + __ mov(r0, zr); // return 0 + __ ret(lr); + + // record the stub entry and end + store_archive_data(stub_id, start, __ pc()); + + return start; + } + void bcax5(Register a0, Register a1, Register a2, Register a3, Register a4, Register tmp0, Register tmp1, Register tmp2) { __ bic(tmp0, a2, a1); // for a0 @@ -12734,6 +13714,11 @@ class StubGenerator: public StubCodeGenerator { StubRoutines::_chacha20Block = generate_chacha20Block_blockpar(); } + if (UseIntPolyIntrinsics) { + StubRoutines::_intpoly_montgomeryMult_P256 = generate_intpoly_montgomeryMult_P256(); + StubRoutines::_intpoly_assign = generate_intpoly_assign(); + } + if (UseKyberIntrinsics) { StubRoutines::_kyberNtt = generate_kyberNtt(); StubRoutines::_kyberInverseNtt = generate_kyberInverseNtt(); @@ -12846,6 +13831,7 @@ class StubGenerator: public StubCodeGenerator { ADD(_sha512_round_consts); ADD(_sha3_round_consts); ADD(_double_keccak_round_consts); + ADD(_modulus_P256); ADD(_encodeBlock_toBase64); ADD(_encodeBlock_toBase64URL); ADD(_decodeBlock_fromBase64ForNoSIMD); diff --git a/src/hotspot/cpu/aarch64/vm_version_aarch64.cpp b/src/hotspot/cpu/aarch64/vm_version_aarch64.cpp index d1cf8b6feed..e746447e013 100644 --- a/src/hotspot/cpu/aarch64/vm_version_aarch64.cpp +++ b/src/hotspot/cpu/aarch64/vm_version_aarch64.cpp @@ -454,6 +454,10 @@ void VM_Version::initialize() { FLAG_SET_DEFAULT(UseChaCha20Intrinsics, false); } + if (FLAG_IS_DEFAULT(UseIntPolyIntrinsics)) { + UseIntPolyIntrinsics = true; + } + if (supports_feature(CPU_ASIMD)) { if (FLAG_IS_DEFAULT(UseKyberIntrinsics)) { UseKyberIntrinsics = true; diff --git a/src/hotspot/share/code/aotCodeCache.hpp b/src/hotspot/share/code/aotCodeCache.hpp index 777ada59a0b..448bab6fbc2 100644 --- a/src/hotspot/share/code/aotCodeCache.hpp +++ b/src/hotspot/share/code/aotCodeCache.hpp @@ -299,6 +299,7 @@ public: do_var(bool, UseSHA256Intrinsics) \ do_var(bool, UseSHA3Intrinsics) \ do_var(bool, UseSHA512Intrinsics) \ + do_var(bool, UseIntPolyIntrinsics) \ do_var(bool, UseVectorizedMismatchIntrinsic) \ do_fun(int, CompressedKlassPointers_shift, CompressedKlassPointers::shift()) \ do_fun(bool, JavaAssertions_systemClassDefault, JavaAssertions::systemClassDefault()) \ @@ -342,7 +343,6 @@ public: do_var(int, AVX3Threshold) /* array copy stubs and nmethods */ \ do_var(bool, EnableX86ECoreOpts) /* nmethods */ \ do_var(bool, UseLibmIntrinsic) \ - do_var(bool, UseIntPolyIntrinsics) \ // END #else #define AOTCODECACHE_CONFIGS_X86_DO(do_var, do_fun)