mirror of
https://github.com/openjdk/jdk.git
synced 2026-01-28 03:58:21 +00:00
aarch64 intrinsics for MontgomeryIntegerPolynomialP256.mult()
This commit is contained in:
parent
3e20a9392f
commit
c0c1493026
@ -51,7 +51,7 @@
|
||||
do_arch_blob, \
|
||||
do_arch_entry, \
|
||||
do_arch_entry_init) \
|
||||
do_arch_blob(compiler, 70000) \
|
||||
do_arch_blob(compiler, 75000) \
|
||||
do_stub(compiler, vector_iota_indices) \
|
||||
do_arch_entry(aarch64, compiler, vector_iota_indices, \
|
||||
vector_iota_indices, vector_iota_indices) \
|
||||
|
||||
@ -7140,6 +7140,362 @@ class StubGenerator: public StubCodeGenerator {
|
||||
return start;
|
||||
}
|
||||
|
||||
address generate_intpoly_montgomeryMult_P256() {
|
||||
|
||||
__ align(CodeEntryAlignment);
|
||||
StubId stub_id = StubId::stubgen_intpoly_montgomeryMult_P256_id;
|
||||
StubCodeMark mark(this, stub_id);
|
||||
address start = __ pc();
|
||||
__ enter();
|
||||
|
||||
const Register a = c_rarg0;
|
||||
const Register b = c_rarg1;
|
||||
const Register result = c_rarg2;
|
||||
|
||||
static const int64_t modulus[5] = {
|
||||
0x000fffffffffffffL, 0x00000fffffffffffL, 0x0000000000000000L,
|
||||
0x0000001000000000L, 0x0000ffffffff0000L
|
||||
};
|
||||
|
||||
Register c_ptr = r9;
|
||||
Register a_i = r10;
|
||||
Register c_idx = r10; //c_idx is not used at the same time as a_i
|
||||
Register limb_mask_scalar = r11;
|
||||
Register b_j = r12;
|
||||
Register mod_j = r12;
|
||||
Register mod_ptr = r13;
|
||||
Register mul_tmp = r14;
|
||||
Register n = r15;
|
||||
|
||||
FloatRegister low_01 = v16;
|
||||
FloatRegister low_23 = v17;
|
||||
FloatRegister low_4x = v18;
|
||||
FloatRegister high_01 = v19;
|
||||
FloatRegister high_23 = v20;
|
||||
FloatRegister high_4x = v21;
|
||||
FloatRegister modmul_low = v22;
|
||||
FloatRegister modmul_high = v23;
|
||||
FloatRegister c_01 = v24;
|
||||
FloatRegister c_23 = v25;
|
||||
FloatRegister limb_mask = v28;
|
||||
FloatRegister tmp = v29;
|
||||
|
||||
int shift1 = 12;
|
||||
int shift2 = 52;
|
||||
|
||||
// Push callee saved registers on to the stack
|
||||
RegSet callee_saved = RegSet::range(r19, r28);
|
||||
__ push(callee_saved, sp);
|
||||
|
||||
// Allocate space on the stack for carry values and zero memory
|
||||
__ sub(sp, sp, 80);
|
||||
__ mov(c_ptr, sp);
|
||||
__ eor(a_i, a_i, a_i);
|
||||
for (int i = 0; i < 10; i++) {
|
||||
__ str(a_i, Address(sp, i * 8));
|
||||
}
|
||||
|
||||
// Calculate limb mask = -1L >>> (64 - BITS_PER_LIMB);
|
||||
__ mov(limb_mask_scalar, 1);
|
||||
__ neg(limb_mask_scalar, limb_mask_scalar);
|
||||
__ lsr(limb_mask_scalar, limb_mask_scalar, 12);
|
||||
__ dup(limb_mask, __ T2D, limb_mask_scalar);
|
||||
|
||||
// Get pointer for modulus
|
||||
__ lea(mod_ptr, ExternalAddress((address)modulus));
|
||||
|
||||
for (int i = 0; i < 5; i++) {
|
||||
// Load a_i into scalar_mult register and increment by 64 bits
|
||||
__ ldr(a_i, Address(a, i * 8));
|
||||
|
||||
// Iterate through b, multiplying each limb by a_i
|
||||
// storing low and high parts in separate vectors.
|
||||
// Compute high[i] = high[i] << shift1 | (low[i] >>> shift2)
|
||||
// and low[i] &= LIMB_MASK
|
||||
__ ldr(b_j, Address(b));
|
||||
__ mul(mul_tmp, a_i, b_j);
|
||||
__ mov(low_01, Assembler::D, 0, mul_tmp);
|
||||
__ umulh(mul_tmp, a_i, b_j);
|
||||
__ mov(high_01, Assembler::D, 0, mul_tmp);
|
||||
//mul64ToVec(a_i, b_j, low_01, high_01, 0);
|
||||
__ ldr(b_j, Address(b, 8));
|
||||
__ mul(mul_tmp, a_i, b_j);
|
||||
__ mov(low_01, Assembler::D, 1, mul_tmp);
|
||||
__ umulh(mul_tmp, a_i, b_j);
|
||||
__ mov(high_01, Assembler::D, 1, mul_tmp);
|
||||
//mul64ToVec(a_i, b_j, low_01, high_01, 1);
|
||||
__ shl(high_01, __ T2D, high_01, shift1);
|
||||
__ ushr(tmp, __ T2D, low_01, shift2);
|
||||
__ orr(high_01, __ T2D, high_01, tmp);
|
||||
__ andr(low_01, __ T2D, low_01, limb_mask);
|
||||
|
||||
__ ldr(b_j, Address(b, 16));
|
||||
__ mul(mul_tmp, a_i, b_j);
|
||||
__ mov(low_23, Assembler::D, 0, mul_tmp);
|
||||
__ umulh(mul_tmp, a_i, b_j);
|
||||
__ mov(high_23, Assembler::D, 0, mul_tmp);
|
||||
//mul64ToVec(a_i, b_j, low_23, high_23, 0);
|
||||
__ ldr(b_j, Address(b, 24));
|
||||
__ mul(mul_tmp, a_i, b_j);
|
||||
__ mov(low_23, Assembler::D, 1, mul_tmp);
|
||||
__ umulh(mul_tmp, a_i, b_j);
|
||||
__ mov(high_23, Assembler::D, 1, mul_tmp);
|
||||
//mul64ToVec(a_i, b_j, low_23, high_23, 1);
|
||||
__ shl(high_23, __ T2D, high_23, shift1);
|
||||
__ ushr(tmp, __ T2D, low_23, shift2);
|
||||
__ orr(high_23, __ T2D, high_23, tmp);
|
||||
__ andr(low_23, __ T2D, low_23, limb_mask);
|
||||
|
||||
__ ldr(b_j, Address(b, 32));
|
||||
__ mul(mul_tmp, a_i, b_j);
|
||||
__ mov(low_4x, Assembler::D, 0, mul_tmp);
|
||||
__ umulh(mul_tmp, a_i, b_j);
|
||||
__ mov(high_4x, Assembler::D, 0, mul_tmp);
|
||||
//mul64ToVec(a_i, b_j, low_4x, high_4x, 0);
|
||||
__ shl(high_4x, __ T2D, high_4x, shift1);
|
||||
__ ushr(tmp, __ T2D, low_4x, shift2);
|
||||
__ orr(high_4x, __ T2D, high_4x, tmp);
|
||||
__ andr(low_4x, __ T2D, low_4x, limb_mask);
|
||||
|
||||
// Load c_i and perform
|
||||
// low_0 += c_i
|
||||
// n = low_0 & limb_mask
|
||||
__ eor(c_01, __ T2D, c_01, c_01);
|
||||
__ ld1(c_01, __ D, 0, c_ptr);
|
||||
__ addv(low_01, __ T2D, low_01, c_01);
|
||||
__ mov(n, low_01, __ D, 0);
|
||||
__ andr(n, n, limb_mask_scalar);
|
||||
|
||||
// Iterate through the modulus, multiplying each limb by n and
|
||||
// storing low and high parts in separate vectors.
|
||||
// Compute high += modmul_high << shift1 | (modmul_low >>> shift2);
|
||||
// and low += modmul_low & LIMB_MASK
|
||||
__ ldr(mod_j, Address(mod_ptr));
|
||||
__ mul(mul_tmp, n, mod_j);
|
||||
__ mov(modmul_low, Assembler::D, 0, mul_tmp);
|
||||
__ umulh(mul_tmp, n, mod_j);
|
||||
__ mov(modmul_high, Assembler::D, 0, mul_tmp);
|
||||
//mul64ToVec(n, mod_j, modmul_low, modmul_high, 0);
|
||||
__ ldr(mod_j, Address(mod_ptr, 8));
|
||||
__ mul(mul_tmp, n, mod_j);
|
||||
__ mov(modmul_low, Assembler::D, 1, mul_tmp);
|
||||
__ umulh(mul_tmp, n, mod_j);
|
||||
__ mov(modmul_high, Assembler::D, 1, mul_tmp);
|
||||
//mul64ToVec(n, mod_j, modmul_low, modmul_high, 1);
|
||||
__ shl(modmul_high, __ T2D, modmul_high, shift1);
|
||||
__ ushr(tmp, __ T2D, modmul_low, shift2);
|
||||
__ orr(modmul_high, __ T2D, modmul_high, tmp);
|
||||
__ addv(high_01, __ T2D, high_01, modmul_high);
|
||||
__ andr(modmul_low, __ T2D, modmul_low, limb_mask);
|
||||
__ addv(low_01, __ T2D, low_01, modmul_low);
|
||||
|
||||
__ ldr(mod_j, Address(mod_ptr, 16));
|
||||
__ mul(mul_tmp, n, mod_j);
|
||||
__ mov(modmul_low, Assembler::D, 0, mul_tmp);
|
||||
__ umulh(mul_tmp, n, mod_j);
|
||||
__ mov(modmul_high, Assembler::D, 0, mul_tmp);
|
||||
//mul64ToVec(n, mod_j, modmul_low, modmul_high, 0);
|
||||
__ ldr(mod_j, Address(mod_ptr, 24));
|
||||
__ mul(mul_tmp, n, mod_j);
|
||||
__ mov(modmul_low, Assembler::D, 1, mul_tmp);
|
||||
__ umulh(mul_tmp, n, mod_j);
|
||||
__ mov(modmul_high, Assembler::D, 1, mul_tmp);
|
||||
//mul64ToVec(n, mod_j, modmul_low, modmul_high, 1);
|
||||
__ shl(modmul_high, __ T2D, modmul_high, shift1);
|
||||
__ ushr(tmp, __ T2D, modmul_low, shift2);
|
||||
__ orr(modmul_high, __ T2D, modmul_high, tmp);
|
||||
__ addv(high_23, __ T2D, high_23, modmul_high);
|
||||
__ andr(modmul_low, __ T2D, modmul_low, limb_mask);
|
||||
__ addv(low_23, __ T2D, low_23, modmul_low);
|
||||
|
||||
__ ldr(mod_j, Address(mod_ptr, 32));
|
||||
__ mul(mul_tmp, n, mod_j);
|
||||
__ mov(modmul_low, Assembler::D, 0, mul_tmp);
|
||||
__ umulh(mul_tmp, n, mod_j);
|
||||
__ mov(modmul_high, Assembler::D, 0, mul_tmp);
|
||||
//mul64ToVec(n, mod_j, modmul_low, modmul_high, 0);
|
||||
__ shl(modmul_high, __ T2D, modmul_high, shift1);
|
||||
__ ushr(tmp, __ T2D, modmul_low, shift2);
|
||||
__ orr(modmul_high, __ T2D, modmul_high, tmp);
|
||||
__ addv(high_4x, __ T2D, high_4x, modmul_high);
|
||||
__ andr(modmul_low, __ T2D, modmul_low, limb_mask);
|
||||
__ addv(low_4x, __ T2D, low_4x, modmul_low);
|
||||
|
||||
// Compute carry values
|
||||
// c_i+1 += low_1 + high_0 + (low_0 >>> shift2)
|
||||
// c_i+2 += low_2 + high_1
|
||||
// c_i+3 += low_3 + high_2
|
||||
// c_i+4 += low_4 + high_3;
|
||||
// c_i+5 = high_4
|
||||
__ add(c_ptr, c_ptr, 8);
|
||||
__ ld1(c_01, c_23, __ T2D, c_ptr);
|
||||
__ add(c_idx, c_ptr, 32);
|
||||
__ st1(high_4x, __ D, 0, c_idx);
|
||||
|
||||
// Add high values to c
|
||||
__ addv(c_01, __ T2D, c_01, high_01);
|
||||
__ addv(c_23, __ T2D, c_23, high_23);
|
||||
|
||||
// Reorder low vectors to enable simd ops
|
||||
// clear tmp_4x and put low_0 in first lane
|
||||
__ ins(tmp, __ D, low_01, 0, 1);
|
||||
__ ins(tmp, __ D, low_23, 1, 0);
|
||||
__ addv(c_01, __ T2D, c_01, tmp);
|
||||
__ ins(tmp, __ D, low_23, 0, 1);
|
||||
__ ins(tmp, __ D, low_4x, 1, 0);
|
||||
__ addv(c_23, __ T2D, c_23, tmp);
|
||||
|
||||
// Shift low_0 and add to c_i+1
|
||||
__ ushr(low_01, __ T2D, low_01, shift2);
|
||||
__ eor(tmp, __ T16B, tmp, tmp); //zero out tmp
|
||||
__ ins(tmp, __ D, low_01, 0, 0);
|
||||
__ addv(c_01, __ T2D, c_01, tmp);
|
||||
|
||||
// Write back carry values to stack
|
||||
__ st1(c_01, c_23, __ T2D, c_ptr);
|
||||
}
|
||||
|
||||
// Final carry propagate and write result
|
||||
|
||||
Register tmp_0 = r10;
|
||||
Register c0 = r19;
|
||||
Register c1 = r20;
|
||||
Register c2 = r21;
|
||||
Register c3 = r22;
|
||||
Register c4 = r23;
|
||||
Register c5 = r24;
|
||||
Register c6 = r25;
|
||||
Register c7 = r26;
|
||||
Register c8 = r27;
|
||||
Register c9 = r28;
|
||||
|
||||
__ pop(callee_saved, sp); //the callee saved registers overlap exactly with the carry values
|
||||
|
||||
// c6 += (c5 >>> BITS_PER_LIMB);
|
||||
// c7 += (c6 >>> BITS_PER_LIMB);
|
||||
// c8 += (c7 >>> BITS_PER_LIMB);
|
||||
// c9 += (c8 >>> BITS_PER_LIMB);
|
||||
|
||||
__ lsr(tmp_0, c5, shift2);
|
||||
__ add(c6, c6, tmp_0);
|
||||
__ lsr(tmp_0, c6, shift2);
|
||||
__ add(c7, c7, tmp_0);
|
||||
__ lsr(tmp_0, c7, shift2);
|
||||
__ add(c8, c8, tmp_0);
|
||||
__ lsr(tmp_0, c8, shift2);
|
||||
__ add(c9, c9, tmp_0);
|
||||
|
||||
__ andr(c5, c5, limb_mask_scalar);
|
||||
__ andr(c6, c6, limb_mask_scalar);
|
||||
__ andr(c7, c7, limb_mask_scalar);
|
||||
__ andr(c8, c8, limb_mask_scalar);
|
||||
|
||||
// c0 = c5 - modulus[0];
|
||||
// c1 = c6 - modulus[1] + (c0 >> BITS_PER_LIMB);
|
||||
// c0 &= LIMB_MASK;
|
||||
// c2 = c7 + (c1 >> BITS_PER_LIMB);
|
||||
// c1 &= LIMB_MASK;
|
||||
// c3 = c8 - modulus[3] + (c2 >> BITS_PER_LIMB);
|
||||
// c2 &= LIMB_MASK;
|
||||
// c4 = c9 - modulus[4] + (c3 >> BITS_PER_LIMB);
|
||||
// c3 &= LIMB_MASK;
|
||||
|
||||
__ ldr(mod_j, __ post(mod_ptr, 8));
|
||||
__ sub(c0, c5, mod_j);
|
||||
|
||||
__ ldr(mod_j, __ post(mod_ptr, 8));
|
||||
__ sub(c1, c6, mod_j);
|
||||
__ asr(tmp_0, c0, shift2);
|
||||
__ add(c1, c1, tmp_0);
|
||||
|
||||
__ ldr(mod_j, __ post(mod_ptr, 8));
|
||||
__ asr(c2, c1, shift2);
|
||||
__ add(c2, c2, c7);
|
||||
|
||||
__ ldr(mod_j, __ post(mod_ptr, 8));
|
||||
__ sub(c3, c8, mod_j);
|
||||
__ asr(tmp_0, c2, shift2);
|
||||
__ add(c3, c3, tmp_0);
|
||||
|
||||
__ ldr(mod_j, __ post(mod_ptr, 8));
|
||||
__ sub(c4, c9, mod_j);
|
||||
__ asr(tmp_0, c3, shift2);
|
||||
__ add(c4, c4, tmp_0);
|
||||
|
||||
// Apply limb mask
|
||||
__ andr(c0, c0, limb_mask_scalar);
|
||||
__ andr(c1, c1, limb_mask_scalar);
|
||||
__ andr(c2, c2, limb_mask_scalar);
|
||||
__ andr(c3, c3, limb_mask_scalar);
|
||||
|
||||
// Final write back
|
||||
// mask = c4 >> 63
|
||||
// r[0] = ((c5 & mask) | (c0 & ~mask));
|
||||
// r[1] = ((c6 & mask) | (c1 & ~mask));
|
||||
// r[2] = ((c7 & mask) | (c2 & ~mask));
|
||||
// r[3] = ((c8 & mask) | (c3 & ~mask));
|
||||
// r[4] = ((c9 & mask) | (c4 & ~mask));
|
||||
|
||||
Register res_0 = r9;
|
||||
Register res_1 = r10;
|
||||
Register res_2 = r11;
|
||||
Register res_3 = r12;
|
||||
Register res_4 = r13;
|
||||
Register mask = r14;
|
||||
Register nmask = r15;
|
||||
Register tmp_1 = r19;
|
||||
|
||||
RegSet res = RegSet::range(r9, r13);
|
||||
|
||||
__ asr(mask, c4, 63);
|
||||
__ mvn(nmask, mask);
|
||||
|
||||
__ andr(res_0, c5, mask);
|
||||
__ andr(tmp_1, c0, nmask);
|
||||
__ orr(res_0, res_0, tmp_1);
|
||||
|
||||
__ andr(res_1, c6, mask);
|
||||
__ andr(tmp_1, c1, nmask);
|
||||
__ orr(res_1, res_1, tmp_1);
|
||||
|
||||
__ andr(res_2, c7, mask);
|
||||
__ andr(tmp_1, c2, nmask);
|
||||
__ orr(res_2, res_2, tmp_1);
|
||||
|
||||
__ andr(res_3, c8, mask);
|
||||
__ andr(tmp_1, c3, nmask);
|
||||
__ orr(res_3, res_3, tmp_1);
|
||||
|
||||
__ andr(res_4, c9, mask);
|
||||
__ andr(tmp_1, c4, nmask);
|
||||
__ orr(res_4, res_4, tmp_1);
|
||||
|
||||
__ str(res_0, result);
|
||||
__ str(res_1, Address(result, 8));
|
||||
__ str(res_2, Address(result, 16));
|
||||
__ str(res_3, Address(result, 24));
|
||||
__ str(res_4, Address(result, 32));
|
||||
|
||||
// End intrinsic call
|
||||
__ pop(callee_saved, sp);
|
||||
__ leave(); // required for proper stackwalking of RuntimeStub frame
|
||||
__ mov(r0, zr); // return 0
|
||||
__ ret(lr);
|
||||
|
||||
return start;
|
||||
}
|
||||
|
||||
// Multiply both 64 bit lanes in b
|
||||
void mul64ToVec(Register a, Register b, FloatRegister low, FloatRegister high, int lane) {
|
||||
|
||||
Register tmp = r14;
|
||||
|
||||
__ mul(tmp, a, b);
|
||||
__ mov(low, Assembler::D, lane, tmp);
|
||||
__ umulh(tmp, a, b);
|
||||
__ mov(high, Assembler::D, lane, tmp);
|
||||
}
|
||||
|
||||
void bcax5(Register a0, Register a1, Register a2, Register a3, Register a4,
|
||||
Register tmp0, Register tmp1, Register tmp2) {
|
||||
__ bic(tmp0, a2, a1); // for a0
|
||||
|
||||
@ -444,6 +444,10 @@ void VM_Version::initialize() {
|
||||
FLAG_SET_DEFAULT(UseDilithiumIntrinsics, false);
|
||||
}
|
||||
|
||||
if (FLAG_IS_DEFAULT(UseIntPolyIntrinsics)) {
|
||||
UseIntPolyIntrinsics = true;
|
||||
}
|
||||
|
||||
if (FLAG_IS_DEFAULT(UseBASE64Intrinsics)) {
|
||||
UseBASE64Intrinsics = true;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user