aarch64 intrinsics for MontgomeryIntegerPolynomialP256.mult()

This commit is contained in:
Ben Perez 2025-10-22 20:48:01 -04:00
parent 3e20a9392f
commit c0c1493026
3 changed files with 361 additions and 1 deletions

View File

@ -51,7 +51,7 @@
do_arch_blob, \
do_arch_entry, \
do_arch_entry_init) \
do_arch_blob(compiler, 70000) \
do_arch_blob(compiler, 75000) \
do_stub(compiler, vector_iota_indices) \
do_arch_entry(aarch64, compiler, vector_iota_indices, \
vector_iota_indices, vector_iota_indices) \

View File

@ -7140,6 +7140,362 @@ class StubGenerator: public StubCodeGenerator {
return start;
}
address generate_intpoly_montgomeryMult_P256() {
__ align(CodeEntryAlignment);
StubId stub_id = StubId::stubgen_intpoly_montgomeryMult_P256_id;
StubCodeMark mark(this, stub_id);
address start = __ pc();
__ enter();
const Register a = c_rarg0;
const Register b = c_rarg1;
const Register result = c_rarg2;
static const int64_t modulus[5] = {
0x000fffffffffffffL, 0x00000fffffffffffL, 0x0000000000000000L,
0x0000001000000000L, 0x0000ffffffff0000L
};
Register c_ptr = r9;
Register a_i = r10;
Register c_idx = r10; //c_idx is not used at the same time as a_i
Register limb_mask_scalar = r11;
Register b_j = r12;
Register mod_j = r12;
Register mod_ptr = r13;
Register mul_tmp = r14;
Register n = r15;
FloatRegister low_01 = v16;
FloatRegister low_23 = v17;
FloatRegister low_4x = v18;
FloatRegister high_01 = v19;
FloatRegister high_23 = v20;
FloatRegister high_4x = v21;
FloatRegister modmul_low = v22;
FloatRegister modmul_high = v23;
FloatRegister c_01 = v24;
FloatRegister c_23 = v25;
FloatRegister limb_mask = v28;
FloatRegister tmp = v29;
int shift1 = 12;
int shift2 = 52;
// Push callee saved registers on to the stack
RegSet callee_saved = RegSet::range(r19, r28);
__ push(callee_saved, sp);
// Allocate space on the stack for carry values and zero memory
__ sub(sp, sp, 80);
__ mov(c_ptr, sp);
__ eor(a_i, a_i, a_i);
for (int i = 0; i < 10; i++) {
__ str(a_i, Address(sp, i * 8));
}
// Calculate limb mask = -1L >>> (64 - BITS_PER_LIMB);
__ mov(limb_mask_scalar, 1);
__ neg(limb_mask_scalar, limb_mask_scalar);
__ lsr(limb_mask_scalar, limb_mask_scalar, 12);
__ dup(limb_mask, __ T2D, limb_mask_scalar);
// Get pointer for modulus
__ lea(mod_ptr, ExternalAddress((address)modulus));
for (int i = 0; i < 5; i++) {
// Load a_i into scalar_mult register and increment by 64 bits
__ ldr(a_i, Address(a, i * 8));
// Iterate through b, multiplying each limb by a_i
// storing low and high parts in separate vectors.
// Compute high[i] = high[i] << shift1 | (low[i] >>> shift2)
// and low[i] &= LIMB_MASK
__ ldr(b_j, Address(b));
__ mul(mul_tmp, a_i, b_j);
__ mov(low_01, Assembler::D, 0, mul_tmp);
__ umulh(mul_tmp, a_i, b_j);
__ mov(high_01, Assembler::D, 0, mul_tmp);
//mul64ToVec(a_i, b_j, low_01, high_01, 0);
__ ldr(b_j, Address(b, 8));
__ mul(mul_tmp, a_i, b_j);
__ mov(low_01, Assembler::D, 1, mul_tmp);
__ umulh(mul_tmp, a_i, b_j);
__ mov(high_01, Assembler::D, 1, mul_tmp);
//mul64ToVec(a_i, b_j, low_01, high_01, 1);
__ shl(high_01, __ T2D, high_01, shift1);
__ ushr(tmp, __ T2D, low_01, shift2);
__ orr(high_01, __ T2D, high_01, tmp);
__ andr(low_01, __ T2D, low_01, limb_mask);
__ ldr(b_j, Address(b, 16));
__ mul(mul_tmp, a_i, b_j);
__ mov(low_23, Assembler::D, 0, mul_tmp);
__ umulh(mul_tmp, a_i, b_j);
__ mov(high_23, Assembler::D, 0, mul_tmp);
//mul64ToVec(a_i, b_j, low_23, high_23, 0);
__ ldr(b_j, Address(b, 24));
__ mul(mul_tmp, a_i, b_j);
__ mov(low_23, Assembler::D, 1, mul_tmp);
__ umulh(mul_tmp, a_i, b_j);
__ mov(high_23, Assembler::D, 1, mul_tmp);
//mul64ToVec(a_i, b_j, low_23, high_23, 1);
__ shl(high_23, __ T2D, high_23, shift1);
__ ushr(tmp, __ T2D, low_23, shift2);
__ orr(high_23, __ T2D, high_23, tmp);
__ andr(low_23, __ T2D, low_23, limb_mask);
__ ldr(b_j, Address(b, 32));
__ mul(mul_tmp, a_i, b_j);
__ mov(low_4x, Assembler::D, 0, mul_tmp);
__ umulh(mul_tmp, a_i, b_j);
__ mov(high_4x, Assembler::D, 0, mul_tmp);
//mul64ToVec(a_i, b_j, low_4x, high_4x, 0);
__ shl(high_4x, __ T2D, high_4x, shift1);
__ ushr(tmp, __ T2D, low_4x, shift2);
__ orr(high_4x, __ T2D, high_4x, tmp);
__ andr(low_4x, __ T2D, low_4x, limb_mask);
// Load c_i and perform
// low_0 += c_i
// n = low_0 & limb_mask
__ eor(c_01, __ T2D, c_01, c_01);
__ ld1(c_01, __ D, 0, c_ptr);
__ addv(low_01, __ T2D, low_01, c_01);
__ mov(n, low_01, __ D, 0);
__ andr(n, n, limb_mask_scalar);
// Iterate through the modulus, multiplying each limb by n and
// storing low and high parts in separate vectors.
// Compute high += modmul_high << shift1 | (modmul_low >>> shift2);
// and low += modmul_low & LIMB_MASK
__ ldr(mod_j, Address(mod_ptr));
__ mul(mul_tmp, n, mod_j);
__ mov(modmul_low, Assembler::D, 0, mul_tmp);
__ umulh(mul_tmp, n, mod_j);
__ mov(modmul_high, Assembler::D, 0, mul_tmp);
//mul64ToVec(n, mod_j, modmul_low, modmul_high, 0);
__ ldr(mod_j, Address(mod_ptr, 8));
__ mul(mul_tmp, n, mod_j);
__ mov(modmul_low, Assembler::D, 1, mul_tmp);
__ umulh(mul_tmp, n, mod_j);
__ mov(modmul_high, Assembler::D, 1, mul_tmp);
//mul64ToVec(n, mod_j, modmul_low, modmul_high, 1);
__ shl(modmul_high, __ T2D, modmul_high, shift1);
__ ushr(tmp, __ T2D, modmul_low, shift2);
__ orr(modmul_high, __ T2D, modmul_high, tmp);
__ addv(high_01, __ T2D, high_01, modmul_high);
__ andr(modmul_low, __ T2D, modmul_low, limb_mask);
__ addv(low_01, __ T2D, low_01, modmul_low);
__ ldr(mod_j, Address(mod_ptr, 16));
__ mul(mul_tmp, n, mod_j);
__ mov(modmul_low, Assembler::D, 0, mul_tmp);
__ umulh(mul_tmp, n, mod_j);
__ mov(modmul_high, Assembler::D, 0, mul_tmp);
//mul64ToVec(n, mod_j, modmul_low, modmul_high, 0);
__ ldr(mod_j, Address(mod_ptr, 24));
__ mul(mul_tmp, n, mod_j);
__ mov(modmul_low, Assembler::D, 1, mul_tmp);
__ umulh(mul_tmp, n, mod_j);
__ mov(modmul_high, Assembler::D, 1, mul_tmp);
//mul64ToVec(n, mod_j, modmul_low, modmul_high, 1);
__ shl(modmul_high, __ T2D, modmul_high, shift1);
__ ushr(tmp, __ T2D, modmul_low, shift2);
__ orr(modmul_high, __ T2D, modmul_high, tmp);
__ addv(high_23, __ T2D, high_23, modmul_high);
__ andr(modmul_low, __ T2D, modmul_low, limb_mask);
__ addv(low_23, __ T2D, low_23, modmul_low);
__ ldr(mod_j, Address(mod_ptr, 32));
__ mul(mul_tmp, n, mod_j);
__ mov(modmul_low, Assembler::D, 0, mul_tmp);
__ umulh(mul_tmp, n, mod_j);
__ mov(modmul_high, Assembler::D, 0, mul_tmp);
//mul64ToVec(n, mod_j, modmul_low, modmul_high, 0);
__ shl(modmul_high, __ T2D, modmul_high, shift1);
__ ushr(tmp, __ T2D, modmul_low, shift2);
__ orr(modmul_high, __ T2D, modmul_high, tmp);
__ addv(high_4x, __ T2D, high_4x, modmul_high);
__ andr(modmul_low, __ T2D, modmul_low, limb_mask);
__ addv(low_4x, __ T2D, low_4x, modmul_low);
// Compute carry values
// c_i+1 += low_1 + high_0 + (low_0 >>> shift2)
// c_i+2 += low_2 + high_1
// c_i+3 += low_3 + high_2
// c_i+4 += low_4 + high_3;
// c_i+5 = high_4
__ add(c_ptr, c_ptr, 8);
__ ld1(c_01, c_23, __ T2D, c_ptr);
__ add(c_idx, c_ptr, 32);
__ st1(high_4x, __ D, 0, c_idx);
// Add high values to c
__ addv(c_01, __ T2D, c_01, high_01);
__ addv(c_23, __ T2D, c_23, high_23);
// Reorder low vectors to enable simd ops
// clear tmp_4x and put low_0 in first lane
__ ins(tmp, __ D, low_01, 0, 1);
__ ins(tmp, __ D, low_23, 1, 0);
__ addv(c_01, __ T2D, c_01, tmp);
__ ins(tmp, __ D, low_23, 0, 1);
__ ins(tmp, __ D, low_4x, 1, 0);
__ addv(c_23, __ T2D, c_23, tmp);
// Shift low_0 and add to c_i+1
__ ushr(low_01, __ T2D, low_01, shift2);
__ eor(tmp, __ T16B, tmp, tmp); //zero out tmp
__ ins(tmp, __ D, low_01, 0, 0);
__ addv(c_01, __ T2D, c_01, tmp);
// Write back carry values to stack
__ st1(c_01, c_23, __ T2D, c_ptr);
}
// Final carry propagate and write result
Register tmp_0 = r10;
Register c0 = r19;
Register c1 = r20;
Register c2 = r21;
Register c3 = r22;
Register c4 = r23;
Register c5 = r24;
Register c6 = r25;
Register c7 = r26;
Register c8 = r27;
Register c9 = r28;
__ pop(callee_saved, sp); //the callee saved registers overlap exactly with the carry values
// c6 += (c5 >>> BITS_PER_LIMB);
// c7 += (c6 >>> BITS_PER_LIMB);
// c8 += (c7 >>> BITS_PER_LIMB);
// c9 += (c8 >>> BITS_PER_LIMB);
__ lsr(tmp_0, c5, shift2);
__ add(c6, c6, tmp_0);
__ lsr(tmp_0, c6, shift2);
__ add(c7, c7, tmp_0);
__ lsr(tmp_0, c7, shift2);
__ add(c8, c8, tmp_0);
__ lsr(tmp_0, c8, shift2);
__ add(c9, c9, tmp_0);
__ andr(c5, c5, limb_mask_scalar);
__ andr(c6, c6, limb_mask_scalar);
__ andr(c7, c7, limb_mask_scalar);
__ andr(c8, c8, limb_mask_scalar);
// c0 = c5 - modulus[0];
// c1 = c6 - modulus[1] + (c0 >> BITS_PER_LIMB);
// c0 &= LIMB_MASK;
// c2 = c7 + (c1 >> BITS_PER_LIMB);
// c1 &= LIMB_MASK;
// c3 = c8 - modulus[3] + (c2 >> BITS_PER_LIMB);
// c2 &= LIMB_MASK;
// c4 = c9 - modulus[4] + (c3 >> BITS_PER_LIMB);
// c3 &= LIMB_MASK;
__ ldr(mod_j, __ post(mod_ptr, 8));
__ sub(c0, c5, mod_j);
__ ldr(mod_j, __ post(mod_ptr, 8));
__ sub(c1, c6, mod_j);
__ asr(tmp_0, c0, shift2);
__ add(c1, c1, tmp_0);
__ ldr(mod_j, __ post(mod_ptr, 8));
__ asr(c2, c1, shift2);
__ add(c2, c2, c7);
__ ldr(mod_j, __ post(mod_ptr, 8));
__ sub(c3, c8, mod_j);
__ asr(tmp_0, c2, shift2);
__ add(c3, c3, tmp_0);
__ ldr(mod_j, __ post(mod_ptr, 8));
__ sub(c4, c9, mod_j);
__ asr(tmp_0, c3, shift2);
__ add(c4, c4, tmp_0);
// Apply limb mask
__ andr(c0, c0, limb_mask_scalar);
__ andr(c1, c1, limb_mask_scalar);
__ andr(c2, c2, limb_mask_scalar);
__ andr(c3, c3, limb_mask_scalar);
// Final write back
// mask = c4 >> 63
// r[0] = ((c5 & mask) | (c0 & ~mask));
// r[1] = ((c6 & mask) | (c1 & ~mask));
// r[2] = ((c7 & mask) | (c2 & ~mask));
// r[3] = ((c8 & mask) | (c3 & ~mask));
// r[4] = ((c9 & mask) | (c4 & ~mask));
Register res_0 = r9;
Register res_1 = r10;
Register res_2 = r11;
Register res_3 = r12;
Register res_4 = r13;
Register mask = r14;
Register nmask = r15;
Register tmp_1 = r19;
RegSet res = RegSet::range(r9, r13);
__ asr(mask, c4, 63);
__ mvn(nmask, mask);
__ andr(res_0, c5, mask);
__ andr(tmp_1, c0, nmask);
__ orr(res_0, res_0, tmp_1);
__ andr(res_1, c6, mask);
__ andr(tmp_1, c1, nmask);
__ orr(res_1, res_1, tmp_1);
__ andr(res_2, c7, mask);
__ andr(tmp_1, c2, nmask);
__ orr(res_2, res_2, tmp_1);
__ andr(res_3, c8, mask);
__ andr(tmp_1, c3, nmask);
__ orr(res_3, res_3, tmp_1);
__ andr(res_4, c9, mask);
__ andr(tmp_1, c4, nmask);
__ orr(res_4, res_4, tmp_1);
__ str(res_0, result);
__ str(res_1, Address(result, 8));
__ str(res_2, Address(result, 16));
__ str(res_3, Address(result, 24));
__ str(res_4, Address(result, 32));
// End intrinsic call
__ pop(callee_saved, sp);
__ leave(); // required for proper stackwalking of RuntimeStub frame
__ mov(r0, zr); // return 0
__ ret(lr);
return start;
}
// Multiply both 64 bit lanes in b
void mul64ToVec(Register a, Register b, FloatRegister low, FloatRegister high, int lane) {
Register tmp = r14;
__ mul(tmp, a, b);
__ mov(low, Assembler::D, lane, tmp);
__ umulh(tmp, a, b);
__ mov(high, Assembler::D, lane, tmp);
}
void bcax5(Register a0, Register a1, Register a2, Register a3, Register a4,
Register tmp0, Register tmp1, Register tmp2) {
__ bic(tmp0, a2, a1); // for a0

View File

@ -444,6 +444,10 @@ void VM_Version::initialize() {
FLAG_SET_DEFAULT(UseDilithiumIntrinsics, false);
}
if (FLAG_IS_DEFAULT(UseIntPolyIntrinsics)) {
UseIntPolyIntrinsics = true;
}
if (FLAG_IS_DEFAULT(UseBASE64Intrinsics)) {
UseBASE64Intrinsics = true;
}