mirror of
https://github.com/openjdk/jdk.git
synced 2026-03-15 18:33:41 +00:00
8350459: MontgomeryIntegerPolynomialP256 multiply intrinsic with AVX2 on x86_64
Reviewed-by: ascarpino, sviswanathan
This commit is contained in:
parent
c029220379
commit
a269bef04c
@ -3529,6 +3529,30 @@ void Assembler::vmovdqu(Address dst, XMMRegister src) {
|
||||
emit_operand(src, dst, 0);
|
||||
}
|
||||
|
||||
// Move Aligned 256bit Vector
|
||||
void Assembler::vmovdqa(XMMRegister dst, Address src) {
|
||||
assert(UseAVX > 0, "");
|
||||
InstructionMark im(this);
|
||||
InstructionAttr attributes(AVX_256bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true);
|
||||
attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit);
|
||||
vex_prefix(src, 0, dst->encoding(), VEX_SIMD_66, VEX_OPCODE_0F, &attributes);
|
||||
emit_int8(0x6F);
|
||||
emit_operand(dst, src, 0);
|
||||
}
|
||||
|
||||
void Assembler::vmovdqa(Address dst, XMMRegister src) {
|
||||
assert(UseAVX > 0, "");
|
||||
InstructionMark im(this);
|
||||
InstructionAttr attributes(AVX_256bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true);
|
||||
attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit);
|
||||
attributes.reset_is_clear_context();
|
||||
// swap src<->dst for encoding
|
||||
assert(src != xnoreg, "sanity");
|
||||
vex_prefix(dst, 0, src->encoding(), VEX_SIMD_66, VEX_OPCODE_0F, &attributes);
|
||||
emit_int8(0x7F);
|
||||
emit_operand(src, dst, 0);
|
||||
}
|
||||
|
||||
void Assembler::vpmaskmovd(XMMRegister dst, XMMRegister mask, Address src, int vector_len) {
|
||||
assert((VM_Version::supports_avx2() && vector_len == AVX_256bit), "");
|
||||
InstructionMark im(this);
|
||||
@ -3791,6 +3815,27 @@ void Assembler::evmovdquq(XMMRegister dst, KRegister mask, Address src, bool mer
|
||||
emit_operand(dst, src, 0);
|
||||
}
|
||||
|
||||
// Move Aligned 512bit Vector
|
||||
void Assembler::evmovdqaq(XMMRegister dst, Address src, int vector_len) {
|
||||
// Unmasked instruction
|
||||
evmovdqaq(dst, k0, src, /*merge*/ false, vector_len);
|
||||
}
|
||||
|
||||
void Assembler::evmovdqaq(XMMRegister dst, KRegister mask, Address src, bool merge, int vector_len) {
|
||||
assert(VM_Version::supports_evex(), "");
|
||||
InstructionMark im(this);
|
||||
InstructionAttr attributes(vector_len, /* vex_w */ true, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true);
|
||||
attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit);
|
||||
attributes.set_embedded_opmask_register_specifier(mask);
|
||||
attributes.set_is_evex_instruction();
|
||||
if (merge) {
|
||||
attributes.reset_is_clear_context();
|
||||
}
|
||||
vex_prefix(src, 0, dst->encoding(), VEX_SIMD_66, VEX_OPCODE_0F, &attributes);
|
||||
emit_int8(0x6F);
|
||||
emit_operand(dst, src, 0);
|
||||
}
|
||||
|
||||
void Assembler::evmovntdquq(Address dst, XMMRegister src, int vector_len) {
|
||||
// Unmasked instruction
|
||||
evmovntdquq(dst, k0, src, /*merge*/ true, vector_len);
|
||||
|
||||
@ -1758,6 +1758,10 @@ private:
|
||||
void vmovdqu(XMMRegister dst, Address src);
|
||||
void vmovdqu(XMMRegister dst, XMMRegister src);
|
||||
|
||||
// Move Aligned 256bit Vector
|
||||
void vmovdqa(XMMRegister dst, Address src);
|
||||
void vmovdqa(Address dst, XMMRegister src);
|
||||
|
||||
// Move Unaligned 512bit Vector
|
||||
void evmovdqub(XMMRegister dst, XMMRegister src, int vector_len);
|
||||
void evmovdqub(XMMRegister dst, Address src, int vector_len);
|
||||
@ -1791,6 +1795,10 @@ private:
|
||||
void evmovdquq(XMMRegister dst, KRegister mask, Address src, bool merge, int vector_len);
|
||||
void evmovdquq(Address dst, KRegister mask, XMMRegister src, bool merge, int vector_len);
|
||||
|
||||
// Move Aligned 512bit Vector
|
||||
void evmovdqaq(XMMRegister dst, Address src, int vector_len);
|
||||
void evmovdqaq(XMMRegister dst, KRegister mask, Address src, bool merge, int vector_len);
|
||||
|
||||
// Move lower 64bit to high 64bit in 128bit register
|
||||
void movlhps(XMMRegister dst, XMMRegister src);
|
||||
|
||||
|
||||
@ -2720,6 +2720,60 @@ void MacroAssembler::vmovdqu(XMMRegister dst, AddressLiteral src, int vector_len
|
||||
}
|
||||
}
|
||||
|
||||
void MacroAssembler::vmovdqu(XMMRegister dst, XMMRegister src, int vector_len) {
|
||||
if (vector_len == AVX_512bit) {
|
||||
evmovdquq(dst, src, AVX_512bit);
|
||||
} else if (vector_len == AVX_256bit) {
|
||||
vmovdqu(dst, src);
|
||||
} else {
|
||||
movdqu(dst, src);
|
||||
}
|
||||
}
|
||||
|
||||
void MacroAssembler::vmovdqu(Address dst, XMMRegister src, int vector_len) {
|
||||
if (vector_len == AVX_512bit) {
|
||||
evmovdquq(dst, src, AVX_512bit);
|
||||
} else if (vector_len == AVX_256bit) {
|
||||
vmovdqu(dst, src);
|
||||
} else {
|
||||
movdqu(dst, src);
|
||||
}
|
||||
}
|
||||
|
||||
void MacroAssembler::vmovdqu(XMMRegister dst, Address src, int vector_len) {
|
||||
if (vector_len == AVX_512bit) {
|
||||
evmovdquq(dst, src, AVX_512bit);
|
||||
} else if (vector_len == AVX_256bit) {
|
||||
vmovdqu(dst, src);
|
||||
} else {
|
||||
movdqu(dst, src);
|
||||
}
|
||||
}
|
||||
|
||||
void MacroAssembler::vmovdqa(XMMRegister dst, AddressLiteral src, Register rscratch) {
|
||||
assert(rscratch != noreg || always_reachable(src), "missing");
|
||||
|
||||
if (reachable(src)) {
|
||||
vmovdqa(dst, as_Address(src));
|
||||
}
|
||||
else {
|
||||
lea(rscratch, src);
|
||||
vmovdqa(dst, Address(rscratch, 0));
|
||||
}
|
||||
}
|
||||
|
||||
void MacroAssembler::vmovdqa(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch) {
|
||||
assert(rscratch != noreg || always_reachable(src), "missing");
|
||||
|
||||
if (vector_len == AVX_512bit) {
|
||||
evmovdqaq(dst, src, AVX_512bit, rscratch);
|
||||
} else if (vector_len == AVX_256bit) {
|
||||
vmovdqa(dst, src, rscratch);
|
||||
} else {
|
||||
movdqa(dst, src, rscratch);
|
||||
}
|
||||
}
|
||||
|
||||
void MacroAssembler::kmov(KRegister dst, Address src) {
|
||||
if (VM_Version::supports_avx512bw()) {
|
||||
kmovql(dst, src);
|
||||
@ -2844,6 +2898,29 @@ void MacroAssembler::evmovdquq(XMMRegister dst, AddressLiteral src, int vector_l
|
||||
}
|
||||
}
|
||||
|
||||
void MacroAssembler::evmovdqaq(XMMRegister dst, KRegister mask, AddressLiteral src, bool merge, int vector_len, Register rscratch) {
|
||||
assert(rscratch != noreg || always_reachable(src), "missing");
|
||||
|
||||
if (reachable(src)) {
|
||||
Assembler::evmovdqaq(dst, mask, as_Address(src), merge, vector_len);
|
||||
} else {
|
||||
lea(rscratch, src);
|
||||
Assembler::evmovdqaq(dst, mask, Address(rscratch, 0), merge, vector_len);
|
||||
}
|
||||
}
|
||||
|
||||
void MacroAssembler::evmovdqaq(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch) {
|
||||
assert(rscratch != noreg || always_reachable(src), "missing");
|
||||
|
||||
if (reachable(src)) {
|
||||
Assembler::evmovdqaq(dst, as_Address(src), vector_len);
|
||||
} else {
|
||||
lea(rscratch, src);
|
||||
Assembler::evmovdqaq(dst, Address(rscratch, 0), vector_len);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void MacroAssembler::movdqa(XMMRegister dst, AddressLiteral src, Register rscratch) {
|
||||
assert(rscratch != noreg || always_reachable(src), "missing");
|
||||
|
||||
|
||||
@ -1348,6 +1348,14 @@ public:
|
||||
void vmovdqu(XMMRegister dst, XMMRegister src);
|
||||
void vmovdqu(XMMRegister dst, AddressLiteral src, Register rscratch = noreg);
|
||||
void vmovdqu(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch = noreg);
|
||||
void vmovdqu(XMMRegister dst, XMMRegister src, int vector_len);
|
||||
void vmovdqu(XMMRegister dst, Address src, int vector_len);
|
||||
void vmovdqu(Address dst, XMMRegister src, int vector_len);
|
||||
|
||||
// AVX Aligned forms
|
||||
using Assembler::vmovdqa;
|
||||
void vmovdqa(XMMRegister dst, AddressLiteral src, Register rscratch = noreg);
|
||||
void vmovdqa(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch = noreg);
|
||||
|
||||
// AVX512 Unaligned
|
||||
void evmovdqu(BasicType type, KRegister kmask, Address dst, XMMRegister src, bool merge, int vector_len);
|
||||
@ -1404,6 +1412,7 @@ public:
|
||||
void evmovdquq(XMMRegister dst, Address src, int vector_len) { Assembler::evmovdquq(dst, src, vector_len); }
|
||||
void evmovdquq(Address dst, XMMRegister src, int vector_len) { Assembler::evmovdquq(dst, src, vector_len); }
|
||||
void evmovdquq(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch = noreg);
|
||||
void evmovdqaq(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch = noreg);
|
||||
|
||||
void evmovdquq(XMMRegister dst, KRegister mask, XMMRegister src, bool merge, int vector_len) {
|
||||
if (dst->encoding() != src->encoding() || mask != k0) {
|
||||
@ -1413,6 +1422,7 @@ public:
|
||||
void evmovdquq(Address dst, KRegister mask, XMMRegister src, bool merge, int vector_len) { Assembler::evmovdquq(dst, mask, src, merge, vector_len); }
|
||||
void evmovdquq(XMMRegister dst, KRegister mask, Address src, bool merge, int vector_len) { Assembler::evmovdquq(dst, mask, src, merge, vector_len); }
|
||||
void evmovdquq(XMMRegister dst, KRegister mask, AddressLiteral src, bool merge, int vector_len, Register rscratch = noreg);
|
||||
void evmovdqaq(XMMRegister dst, KRegister mask, AddressLiteral src, bool merge, int vector_len, Register rscratch = noreg);
|
||||
|
||||
// Move Aligned Double Quadword
|
||||
void movdqa(XMMRegister dst, XMMRegister src) { Assembler::movdqa(dst, src); }
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2024, Intel Corporation. All rights reserved.
|
||||
* Copyright (c) 2024, 2025, Intel Corporation. All rights reserved.
|
||||
*
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
@ -28,17 +28,17 @@
|
||||
|
||||
#define __ _masm->
|
||||
|
||||
ATTRIBUTE_ALIGNED(64) uint64_t MODULUS_P256[] = {
|
||||
ATTRIBUTE_ALIGNED(64) constexpr uint64_t MODULUS_P256[] = {
|
||||
0x000fffffffffffffULL, 0x00000fffffffffffULL,
|
||||
0x0000000000000000ULL, 0x0000001000000000ULL,
|
||||
0x0000ffffffff0000ULL, 0x0000000000000000ULL,
|
||||
0x0000000000000000ULL, 0x0000000000000000ULL
|
||||
};
|
||||
static address modulus_p256() {
|
||||
return (address)MODULUS_P256;
|
||||
static address modulus_p256(int index = 0) {
|
||||
return (address)&MODULUS_P256[index];
|
||||
}
|
||||
|
||||
ATTRIBUTE_ALIGNED(64) uint64_t P256_MASK52[] = {
|
||||
ATTRIBUTE_ALIGNED(64) constexpr uint64_t P256_MASK52[] = {
|
||||
0x000fffffffffffffULL, 0x000fffffffffffffULL,
|
||||
0x000fffffffffffffULL, 0x000fffffffffffffULL,
|
||||
0xffffffffffffffffULL, 0xffffffffffffffffULL,
|
||||
@ -48,7 +48,7 @@ static address p256_mask52() {
|
||||
return (address)P256_MASK52;
|
||||
}
|
||||
|
||||
ATTRIBUTE_ALIGNED(64) uint64_t SHIFT1R[] = {
|
||||
ATTRIBUTE_ALIGNED(64) constexpr uint64_t SHIFT1R[] = {
|
||||
0x0000000000000001ULL, 0x0000000000000002ULL,
|
||||
0x0000000000000003ULL, 0x0000000000000004ULL,
|
||||
0x0000000000000005ULL, 0x0000000000000006ULL,
|
||||
@ -58,7 +58,7 @@ static address shift_1R() {
|
||||
return (address)SHIFT1R;
|
||||
}
|
||||
|
||||
ATTRIBUTE_ALIGNED(64) uint64_t SHIFT1L[] = {
|
||||
ATTRIBUTE_ALIGNED(64) constexpr uint64_t SHIFT1L[] = {
|
||||
0x0000000000000007ULL, 0x0000000000000000ULL,
|
||||
0x0000000000000001ULL, 0x0000000000000002ULL,
|
||||
0x0000000000000003ULL, 0x0000000000000004ULL,
|
||||
@ -68,6 +68,14 @@ static address shift_1L() {
|
||||
return (address)SHIFT1L;
|
||||
}
|
||||
|
||||
ATTRIBUTE_ALIGNED(64) constexpr uint64_t MASKL5[] = {
|
||||
0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL,
|
||||
0xFFFFFFFFFFFFFFFFULL, 0x0000000000000000ULL,
|
||||
};
|
||||
static address mask_limb5() {
|
||||
return (address)MASKL5;
|
||||
}
|
||||
|
||||
/**
|
||||
* Unrolled Word-by-Word Montgomery Multiplication
|
||||
* r = a * b * 2^-260 (mod P)
|
||||
@ -94,26 +102,25 @@ static address shift_1L() {
|
||||
* B = replicate(bLimbs[i]) |bi|bi|bi|bi|bi|bi|bi|bi|
|
||||
* +--+--+--+--+--+--+--+--+
|
||||
* +--+--+--+--+--+--+--+--+
|
||||
* | 0| 0| 0|a5|a4|a3|a2|a1|
|
||||
* Acc1 += A * B *|bi|bi|bi|bi|bi|bi|bi|bi|
|
||||
* Acc1+=| 0| 0| 0|c5|c4|c3|c2|c1|
|
||||
* *| 0| 0| 0|a5|a4|a3|a2|a1|
|
||||
* Acc1 += A * B |bi|bi|bi|bi|bi|bi|bi|bi|
|
||||
* +--+--+--+--+--+--+--+--+
|
||||
* Acc2+=| 0| 0| 0| 0| 0| 0| 0| 0|
|
||||
* *h| 0| 0| 0|a5|a4|a3|a2|a1|
|
||||
* Acc2 += A *h B |bi|bi|bi|bi|bi|bi|bi|bi|
|
||||
* | 0| 0| 0|a5|a4|a3|a2|a1|
|
||||
* Acc2 += A *h B *h|bi|bi|bi|bi|bi|bi|bi|bi|
|
||||
* Acc2+=| 0| 0| 0| d5|d4|d3|d2|d1|
|
||||
* +--+--+--+--+--+--+--+--+
|
||||
* N = replicate(Acc1[0]) |n0|n0|n0|n0|n0|n0|n0|n0|
|
||||
* +--+--+--+--+--+--+--+--+
|
||||
* +--+--+--+--+--+--+--+--+
|
||||
* Acc1+=| 0| 0| 0|c5|c4|c3|c2|c1|
|
||||
* *| 0| 0| 0|m5|m4|m3|m2|m1|
|
||||
* Acc1 += M * N |n0|n0|n0|n0|n0|n0|n0|n0| Note: 52 low bits of Acc1[0] == 0 due to Montgomery!
|
||||
* | 0| 0| 0|m5|m4|m3|m2|m1|
|
||||
* Acc1 += M * N *|n0|n0|n0|n0|n0|n0|n0|n0|
|
||||
* Acc1+=| 0| 0| 0|c5|c4|c3|c2|c1| Note: 52 low bits of c1 == 0 due to Montgomery!
|
||||
* +--+--+--+--+--+--+--+--+
|
||||
* | 0| 0| 0|m5|m4|m3|m2|m1|
|
||||
* Acc2 += M *h N *h|n0|n0|n0|n0|n0|n0|n0|n0|
|
||||
* Acc2+=| 0| 0| 0|d5|d4|d3|d2|d1|
|
||||
* *h| 0| 0| 0|m5|m4|m3|m2|m1|
|
||||
* Acc2 += M *h N |n0|n0|n0|n0|n0|n0|n0|n0|
|
||||
* +--+--+--+--+--+--+--+--+
|
||||
* if (i == 4) break;
|
||||
* // Combine high/low partial sums Acc1 + Acc2
|
||||
* +--+--+--+--+--+--+--+--+
|
||||
* carry = Acc1[0] >> 52 | 0| 0| 0| 0| 0| 0| 0|c1|
|
||||
@ -124,13 +131,35 @@ static address shift_1L() {
|
||||
* +--+--+--+--+--+--+--+--+
|
||||
* Acc1 = Acc1 + Acc2
|
||||
* ---- done
|
||||
* // Last Carry round: Combine high/low partial sums Acc1<high_bits> + Acc1 + Acc2
|
||||
* carry = Acc1 >> 52
|
||||
* Acc1 = Acc1 shift one q element >>
|
||||
* Acc1 = mask52(Acc1)
|
||||
* Acc2 += carry
|
||||
* Acc1 = Acc1 + Acc2
|
||||
* output to rLimbs
|
||||
*
|
||||
* At this point the result in Acc1 can overflow by 1 Modulus and needs carry
|
||||
* propagation. Subtract one modulus, carry-propagate both results and select
|
||||
* (constant-time) the positive number of the two
|
||||
*
|
||||
* Carry = Acc1[0] >> 52
|
||||
* Acc1L = Acc1[0] & mask52
|
||||
* Acc1 = Acc1 shift one q element>>
|
||||
* Acc1 += Carry
|
||||
*
|
||||
* Carry = Acc2[0] >> 52
|
||||
* Acc2L = Acc2[0] & mask52
|
||||
* Acc2 = Acc2 shift one q element>>
|
||||
* Acc2 += Carry
|
||||
*
|
||||
* for col:=1 to 4
|
||||
* Carry = Acc2[col]>>52
|
||||
* Carry = Carry shift one q element<<
|
||||
* Acc2 += Carry
|
||||
*
|
||||
* Carry = Acc1[col]>>52
|
||||
* Carry = Carry shift one q element<<
|
||||
* Acc1 += Carry
|
||||
* done
|
||||
*
|
||||
* Acc1 &= mask52
|
||||
* Acc2 &= mask52
|
||||
* Mask = sign(Acc2)
|
||||
* Result = select(Mask ? Acc1 or Acc2)
|
||||
*/
|
||||
void montgomeryMultiply(const Register aLimbs, const Register bLimbs, const Register rLimbs, const Register tmp, MacroAssembler* _masm) {
|
||||
Register t0 = tmp;
|
||||
@ -145,26 +174,30 @@ void montgomeryMultiply(const Register aLimbs, const Register bLimbs, const Regi
|
||||
XMMRegister Acc1 = xmm10;
|
||||
XMMRegister Acc2 = xmm11;
|
||||
XMMRegister N = xmm12;
|
||||
XMMRegister carry = xmm13;
|
||||
XMMRegister Carry = xmm13;
|
||||
|
||||
// // Constants
|
||||
XMMRegister modulus = xmm20;
|
||||
XMMRegister shift1L = xmm21;
|
||||
XMMRegister shift1R = xmm22;
|
||||
XMMRegister mask52 = xmm23;
|
||||
KRegister limb0 = k1;
|
||||
KRegister allLimbs = k2;
|
||||
XMMRegister modulus = xmm5;
|
||||
XMMRegister shift1L = xmm6;
|
||||
XMMRegister shift1R = xmm7;
|
||||
XMMRegister Mask52 = xmm8;
|
||||
KRegister allLimbs = k1;
|
||||
KRegister limb0 = k2;
|
||||
KRegister masks[] = {limb0, k3, k4, k5};
|
||||
|
||||
for (int i=0; i<4; i++) {
|
||||
__ mov64(t0, 1ULL<<i);
|
||||
__ kmovql(masks[i], t0);
|
||||
}
|
||||
|
||||
__ mov64(t0, 0x1);
|
||||
__ kmovql(limb0, t0);
|
||||
__ mov64(t0, 0x1f);
|
||||
__ kmovql(allLimbs, t0);
|
||||
__ evmovdquq(shift1L, allLimbs, ExternalAddress(shift_1L()), false, Assembler::AVX_512bit, rscratch);
|
||||
__ evmovdquq(shift1R, allLimbs, ExternalAddress(shift_1R()), false, Assembler::AVX_512bit, rscratch);
|
||||
__ evmovdquq(mask52, allLimbs, ExternalAddress(p256_mask52()), false, Assembler::AVX_512bit, rscratch);
|
||||
__ evmovdqaq(shift1L, allLimbs, ExternalAddress(shift_1L()), false, Assembler::AVX_512bit, rscratch);
|
||||
__ evmovdqaq(shift1R, allLimbs, ExternalAddress(shift_1R()), false, Assembler::AVX_512bit, rscratch);
|
||||
__ evmovdqaq(Mask52, allLimbs, ExternalAddress(p256_mask52()), false, Assembler::AVX_512bit, rscratch);
|
||||
|
||||
// M = load(*modulus_p256)
|
||||
__ evmovdquq(modulus, allLimbs, ExternalAddress(modulus_p256()), false, Assembler::AVX_512bit, rscratch);
|
||||
__ evmovdqaq(modulus, allLimbs, ExternalAddress(modulus_p256()), false, Assembler::AVX_512bit, rscratch);
|
||||
|
||||
// A = load(*aLimbs); masked evmovdquq() can be slow. Instead load full 256bit, and compbine with 64bit
|
||||
__ evmovdquq(A, Address(aLimbs, 8), Assembler::AVX_256bit);
|
||||
@ -196,15 +229,13 @@ void montgomeryMultiply(const Register aLimbs, const Register bLimbs, const Regi
|
||||
// Acc2 += M *h N
|
||||
__ evpmadd52huq(Acc2, modulus, N, Assembler::AVX_512bit);
|
||||
|
||||
if (i == 4) break;
|
||||
|
||||
// Combine high/low partial sums Acc1 + Acc2
|
||||
|
||||
// carry = Acc1[0] >> 52
|
||||
__ evpsrlq(carry, limb0, Acc1, 52, true, Assembler::AVX_512bit);
|
||||
__ evpsrlq(Carry, limb0, Acc1, 52, true, Assembler::AVX_512bit);
|
||||
|
||||
// Acc2[0] += carry
|
||||
__ evpaddq(Acc2, limb0, carry, Acc2, true, Assembler::AVX_512bit);
|
||||
__ evpaddq(Acc2, limb0, Carry, Acc2, true, Assembler::AVX_512bit);
|
||||
|
||||
// Acc1 = Acc1 shift one q element >>
|
||||
__ evpermq(Acc1, allLimbs, shift1R, Acc1, false, Assembler::AVX_512bit);
|
||||
@ -213,26 +244,317 @@ void montgomeryMultiply(const Register aLimbs, const Register bLimbs, const Regi
|
||||
__ vpaddq(Acc1, Acc1, Acc2, Assembler::AVX_512bit);
|
||||
}
|
||||
|
||||
// Last Carry round: Combine high/low partial sums Acc1<high_bits> + Acc1 + Acc2
|
||||
// carry = Acc1 >> 52
|
||||
__ evpsrlq(carry, allLimbs, Acc1, 52, true, Assembler::AVX_512bit);
|
||||
// At this point the result is in Acc1, but needs to be normailized to 52bit
|
||||
// limbs (i.e. needs carry propagation) It can also overflow by 1 modulus.
|
||||
// Subtract one modulus from Acc1 into Acc2 then carry propagate both
|
||||
// simultaneously
|
||||
|
||||
// Acc1 = Acc1 shift one q element >>
|
||||
XMMRegister Acc1L = A;
|
||||
XMMRegister Acc2L = B;
|
||||
__ vpsubq(Acc2, Acc1, modulus, Assembler::AVX_512bit);
|
||||
|
||||
// digit 0 carry out
|
||||
// Also split Acc1 and Acc2 into two 256-bit vectors each {Acc1, Acc1L} and
|
||||
// {Acc2, Acc2L} to use 256bit operations
|
||||
__ evpsraq(Carry, limb0, Acc2, 52, false, Assembler::AVX_256bit);
|
||||
__ evpandq(Acc2L, limb0, Acc2, Mask52, false, Assembler::AVX_256bit);
|
||||
__ evpermq(Acc2, allLimbs, shift1R, Acc2, false, Assembler::AVX_512bit);
|
||||
__ vpaddq(Acc2, Acc2, Carry, Assembler::AVX_256bit);
|
||||
|
||||
__ evpsraq(Carry, limb0, Acc1, 52, false, Assembler::AVX_256bit);
|
||||
__ evpandq(Acc1L, limb0, Acc1, Mask52, false, Assembler::AVX_256bit);
|
||||
__ evpermq(Acc1, allLimbs, shift1R, Acc1, false, Assembler::AVX_512bit);
|
||||
__ vpaddq(Acc1, Acc1, Carry, Assembler::AVX_256bit);
|
||||
|
||||
// Acc1 = mask52(Acc1)
|
||||
__ evpandq(Acc1, Acc1, mask52, Assembler::AVX_512bit); // Clear top 12 bits
|
||||
/* remaining digits carry
|
||||
* Note1: Carry register contains just the carry for the particular
|
||||
* column (zero-mask the rest) and gets progressively shifted left
|
||||
* Note2: 'element shift' with vpermq is more expensive, so using vpalignr when
|
||||
* possible. vpalignr shifts 'right' not left, so place the carry appropiately
|
||||
* +--+--+--+--+ +--+--+--+--+ +--+--+
|
||||
* vpalignr(X, X, X, 8): |x4|x3|x2|x1| >> |x2|x1|x2|x1| |x1|x2|
|
||||
* +--+--+--+--+ +--+--+--+--+ >> +--+--+
|
||||
* | +--+--+--+--+ +--+--+
|
||||
* | |x4|x3|x4|x3| |x3|x4|
|
||||
* | +--+--+--+--+ +--+--+
|
||||
* | vv
|
||||
* | +--+--+--+--+
|
||||
* (x3 and x1 is effectively shifted +------------------------> |x3|x4|x1|x2|
|
||||
* left; zero-mask everything but one column of interest) +--+--+--+--+
|
||||
*/
|
||||
for (int i = 1; i<4; i++) {
|
||||
__ evpsraq(Carry, masks[i-1], Acc2, 52, false, Assembler::AVX_256bit);
|
||||
if (i == 1 || i == 3) {
|
||||
__ vpalignr(Carry, Carry, Carry, 8, Assembler::AVX_256bit);
|
||||
} else {
|
||||
__ vpermq(Carry, Carry, 0b10010011, Assembler::AVX_256bit);
|
||||
}
|
||||
__ vpaddq(Acc2, Acc2, Carry, Assembler::AVX_256bit);
|
||||
|
||||
// Acc2 += carry
|
||||
__ evpaddq(Acc2, allLimbs, carry, Acc2, true, Assembler::AVX_512bit);
|
||||
__ evpsraq(Carry, masks[i-1], Acc1, 52, false, Assembler::AVX_256bit);
|
||||
if (i == 1 || i == 3) {
|
||||
__ vpalignr(Carry, Carry, Carry, 8, Assembler::AVX_256bit);
|
||||
} else {
|
||||
__ vpermq(Carry, Carry, 0b10010011, Assembler::AVX_256bit); //0b-2-1-0-3
|
||||
}
|
||||
__ vpaddq(Acc1, Acc1, Carry, Assembler::AVX_256bit);
|
||||
}
|
||||
|
||||
// Acc1 = Acc1 + Acc2
|
||||
__ vpaddq(Acc1, Acc1, Acc2, Assembler::AVX_512bit);
|
||||
// Iff Acc2 is negative, then Acc1 contains the result.
|
||||
// if Acc2 is negative, upper 12 bits will be set; arithmetic shift by 64 bits
|
||||
// generates a mask from Acc2 sign bit
|
||||
__ evpsraq(Carry, Acc2, 64, Assembler::AVX_256bit);
|
||||
__ vpermq(Carry, Carry, 0b11111111, Assembler::AVX_256bit); //0b-3-3-3-3
|
||||
__ evpandq(Acc1, Acc1, Mask52, Assembler::AVX_256bit);
|
||||
__ evpandq(Acc2, Acc2, Mask52, Assembler::AVX_256bit);
|
||||
|
||||
// Acc2 = (Acc1 & Mask) | (Acc2 & !Mask)
|
||||
__ vpandn(Acc2L, Carry, Acc2L, Assembler::AVX_256bit);
|
||||
__ vpternlogq(Acc2L, 0xF8, Carry, Acc1L, Assembler::AVX_256bit); // A | B&C orAandBC
|
||||
__ vpandn(Acc2, Carry, Acc2, Assembler::AVX_256bit);
|
||||
__ vpternlogq(Acc2, 0xF8, Carry, Acc1, Assembler::AVX_256bit);
|
||||
|
||||
// output to rLimbs (1 + 4 limbs)
|
||||
__ movq(Address(rLimbs, 0), Acc1);
|
||||
__ evpermq(Acc1, k0, shift1R, Acc1, true, Assembler::AVX_512bit);
|
||||
__ evmovdquq(Address(rLimbs, 8), k0, Acc1, true, Assembler::AVX_256bit);
|
||||
__ movq(Address(rLimbs, 0), Acc2L);
|
||||
__ evmovdquq(Address(rLimbs, 8), Acc2, Assembler::AVX_256bit);
|
||||
|
||||
// Cleanup
|
||||
// Zero out zmm0-zmm15, higher registers not used by intrinsic.
|
||||
__ vzeroall();
|
||||
}
|
||||
|
||||
/**
|
||||
* Unrolled Word-by-Word Montgomery Multiplication
|
||||
* r = a * b * 2^-260 (mod P)
|
||||
*
|
||||
* Use vpmadd52{l,h}uq multiply for upper four limbs and use
|
||||
* scalar mulq for the lowest limb.
|
||||
*
|
||||
* One has to be careful with mulq vs vpmadd52 'crossovers'; mulq high/low
|
||||
* is split as 40:64 bits vs 52:52 in the vector version. Shifts are required
|
||||
* to line up values before addition (see following ascii art)
|
||||
*
|
||||
* Pseudocode:
|
||||
*
|
||||
* +--+--+--+--+ +--+
|
||||
* M = load(*modulus_p256) |m5|m4|m3|m2| |m1|
|
||||
* +--+--+--+--+ +--+
|
||||
* A = load(*aLimbs) |a5|a4|a3|a2| |a1|
|
||||
* +--+--+--+--+ +--+
|
||||
* Acc1 = 0 | 0| 0| 0| 0| | 0|
|
||||
* +--+--+--+--+ +--+
|
||||
* ---- for i = 0 to 4
|
||||
* +--+--+--+--+ +--+
|
||||
* Acc2 = 0 | 0| 0| 0| 0| | 0|
|
||||
* +--+--+--+--+ +--+
|
||||
* B = replicate(bLimbs[i]) |bi|bi|bi|bi| |bi|
|
||||
* +--+--+--+--+ +--+
|
||||
* +--+--+--+--+ +--+
|
||||
* |a5|a4|a3|a2| |a1|
|
||||
* Acc1 += A * B *|bi|bi|bi|bi| |bi|
|
||||
* Acc1+=|c5|c4|c3|c2| |c1|
|
||||
* +--+--+--+--+ +--+
|
||||
* |a5|a4|a3|a2| |a1|
|
||||
* Acc2 += A *h B *h|bi|bi|bi|bi| |bi|
|
||||
* Acc2+=|d5|d4|d3|d2| |d1|
|
||||
* +--+--+--+--+ +--+
|
||||
* N = replicate(Acc1[0]) |n0|n0|n0|n0| |n0|
|
||||
* +--+--+--+--+ +--+
|
||||
* +--+--+--+--+ +--+
|
||||
* |m5|m4|m3|m2| |m1|
|
||||
* Acc1 += M * N *|n0|n0|n0|n0| |n0|
|
||||
* Acc1+=|c5|c4|c3|c2| |c1| Note: 52 low bits of c1 == 0 due to Montgomery!
|
||||
* +--+--+--+--+ +--+
|
||||
* |m5|m4|m3|m2| |m1|
|
||||
* Acc2 += M *h N *h|n0|n0|n0|n0| |n0|
|
||||
* Acc2+=|d5|d4|d3|d2| |d1|
|
||||
* +--+--+--+--+ +--+
|
||||
* // Combine high/low partial sums Acc1 + Acc2
|
||||
* +--+
|
||||
* carry = Acc1[0] >> 52 |c1|
|
||||
* +--+
|
||||
* Acc2[0] += carry |d1|
|
||||
* +--+
|
||||
* +--+--+--+--+ +--+
|
||||
* Acc1 = Acc1 shift one q element>> | 0|c5|c4|c3| |c2|
|
||||
* +|d5|d4|d3|d2| |d1|
|
||||
* Acc1 = Acc1 + Acc2 Acc1+=|c5|c4|c3|c2| |c1|
|
||||
* +--+--+--+--+ +--+
|
||||
* ---- done
|
||||
* +--+--+--+--+ +--+
|
||||
* Acc2 = Acc1 - M |d5|d4|d3|d2| |d1|
|
||||
* +--+--+--+--+ +--+
|
||||
* Carry propagate Acc2
|
||||
* Carry propagate Acc1
|
||||
* Mask = sign(Acc2)
|
||||
* Result = select(Mask ? Acc1 or Acc2)
|
||||
*
|
||||
* Acc1 can overflow by one modulus (hence Acc2); Either Acc1 or Acc2 contain
|
||||
* the correct result. However, they both need carry propagation (i.e. normalize
|
||||
* limbs down to 52 bits each).
|
||||
*
|
||||
* Carry propagation would require relatively expensive vector lane operations,
|
||||
* so instead dump to memory and read as scalar registers
|
||||
*
|
||||
* Note: the order of reduce-then-propagate vs propagate-then-reduce is different
|
||||
* in Java
|
||||
*/
|
||||
void montgomeryMultiplyAVX2(const Register aLimbs, const Register bLimbs, const Register rLimbs,
|
||||
const Register tmp_rax, const Register tmp_rdx, const Register tmp1, const Register tmp2,
|
||||
const Register tmp3, const Register tmp4, const Register tmp5, const Register tmp6,
|
||||
const Register tmp7, MacroAssembler* _masm) {
|
||||
Register rscratch = tmp1;
|
||||
|
||||
// Inputs
|
||||
Register a = tmp1;
|
||||
XMMRegister A = xmm0;
|
||||
XMMRegister B = xmm1;
|
||||
|
||||
// Intermediates
|
||||
Register acc1 = tmp2;
|
||||
XMMRegister Acc1 = xmm3;
|
||||
Register acc2 = tmp3;
|
||||
XMMRegister Acc2 = xmm4;
|
||||
XMMRegister N = xmm5;
|
||||
XMMRegister Carry = xmm6;
|
||||
|
||||
// Constants
|
||||
Register modulus = tmp4;
|
||||
XMMRegister Modulus = xmm7;
|
||||
Register mask52 = tmp5;
|
||||
XMMRegister Mask52 = xmm8;
|
||||
XMMRegister MaskLimb5 = xmm9;
|
||||
XMMRegister Zero = xmm10;
|
||||
|
||||
__ mov64(mask52, P256_MASK52[0]);
|
||||
__ movq(Mask52, mask52);
|
||||
__ vpbroadcastq(Mask52, Mask52, Assembler::AVX_256bit);
|
||||
__ vmovdqa(MaskLimb5, ExternalAddress(mask_limb5()), Assembler::AVX_256bit, rscratch);
|
||||
__ vpxor(Zero, Zero, Zero, Assembler::AVX_256bit);
|
||||
|
||||
// M = load(*modulus_p256)
|
||||
__ movq(modulus, mask52);
|
||||
__ vmovdqu(Modulus, ExternalAddress(modulus_p256(1)), Assembler::AVX_256bit, rscratch);
|
||||
|
||||
// A = load(*aLimbs);
|
||||
__ movq(a, Address(aLimbs, 0));
|
||||
__ vmovdqu(A, Address(aLimbs, 8)); //Assembler::AVX_256bit
|
||||
|
||||
// Acc1 = 0
|
||||
__ vpxor(Acc1, Acc1, Acc1, Assembler::AVX_256bit);
|
||||
for (int i = 0; i< 5; i++) {
|
||||
// Acc2 = 0
|
||||
__ vpxor(Acc2, Acc2, Acc2, Assembler::AVX_256bit);
|
||||
|
||||
// B = replicate(bLimbs[i])
|
||||
__ movq(tmp_rax, Address(bLimbs, i*8)); //(b==rax)
|
||||
__ vpbroadcastq(B, Address(bLimbs, i*8), Assembler::AVX_256bit);
|
||||
|
||||
// Acc1 += A * B
|
||||
// Acc2 += A *h B
|
||||
__ mulq(a); // rdx:rax = a*rax
|
||||
if (i == 0) {
|
||||
__ movq(acc1, tmp_rax);
|
||||
__ movq(acc2, tmp_rdx);
|
||||
} else {
|
||||
// Careful with limb size/carries; from mulq, tmp_rax uses full 64 bits
|
||||
__ xorq(acc2, acc2);
|
||||
__ addq(acc1, tmp_rax);
|
||||
__ adcq(acc2, tmp_rdx);
|
||||
}
|
||||
__ vpmadd52luq(Acc1, A, B, Assembler::AVX_256bit);
|
||||
__ vpmadd52huq(Acc2, A, B, Assembler::AVX_256bit);
|
||||
|
||||
// N = replicate(Acc1[0])
|
||||
if (i != 0) {
|
||||
__ movq(tmp_rax, acc1); // (n==rax)
|
||||
}
|
||||
__ andq(tmp_rax, mask52);
|
||||
__ movq(N, acc1); // masking implicit in vpmadd52
|
||||
__ vpbroadcastq(N, N, Assembler::AVX_256bit);
|
||||
|
||||
// Acc1 += M * N
|
||||
__ mulq(modulus); // rdx:rax = modulus*rax
|
||||
__ vpmadd52luq(Acc1, Modulus, N, Assembler::AVX_256bit);
|
||||
__ addq(acc1, tmp_rax); //carry flag set!
|
||||
|
||||
// Acc2 += M *h N
|
||||
__ adcq(acc2, tmp_rdx);
|
||||
__ vpmadd52huq(Acc2, Modulus, N, Assembler::AVX_256bit);
|
||||
|
||||
// Combine high/low partial sums Acc1 + Acc2
|
||||
|
||||
// carry = Acc1[0] >> 52
|
||||
__ shrq(acc1, 52); // low 52 of acc1 ignored, is zero, because Montgomery
|
||||
|
||||
// Acc2[0] += carry
|
||||
__ shlq(acc2, 12);
|
||||
__ addq(acc2, acc1);
|
||||
|
||||
// Acc1 = Acc1 shift one q element >>
|
||||
__ movq(acc1, Acc1);
|
||||
__ vpermq(Acc1, Acc1, 0b11111001, Assembler::AVX_256bit);
|
||||
__ vpand(Acc1, Acc1, MaskLimb5, Assembler::AVX_256bit);
|
||||
|
||||
// Acc1 = Acc1 + Acc2
|
||||
__ addq(acc1, acc2);
|
||||
__ vpaddq(Acc1, Acc1, Acc2, Assembler::AVX_256bit);
|
||||
}
|
||||
|
||||
__ movq(acc2, acc1);
|
||||
__ subq(acc2, modulus);
|
||||
__ vpsubq(Acc2, Acc1, Modulus, Assembler::AVX_256bit);
|
||||
__ vmovdqa(Address(rsp, 0), Acc2); //Assembler::AVX_256bit
|
||||
|
||||
// Carry propagate the subtraction result Acc2 first (since the last carry is
|
||||
// used to select result). Careful, following registers overlap:
|
||||
// acc1 = tmp2; acc2 = tmp3; mask52 = tmp5
|
||||
// Note that Acc2 limbs are signed (i.e. result of a subtract with modulus)
|
||||
// i.e. using signed shift is needed for correctness
|
||||
Register limb[] = {acc2, tmp1, tmp4, tmp_rdx, tmp6};
|
||||
Register carry = tmp_rax;
|
||||
for (int i = 0; i<5; i++) {
|
||||
if (i > 0) {
|
||||
__ movq(limb[i], Address(rsp, -8+i*8));
|
||||
__ addq(limb[i], carry);
|
||||
}
|
||||
__ movq(carry, limb[i]);
|
||||
if (i==4) break;
|
||||
__ sarq(carry, 52);
|
||||
}
|
||||
__ sarq(carry, 63);
|
||||
__ notq(carry); //select
|
||||
Register select = carry;
|
||||
carry = tmp7;
|
||||
|
||||
// Now carry propagate the multiply result and (constant-time) select correct
|
||||
// output digit
|
||||
Register digit = acc1;
|
||||
__ vmovdqa(Address(rsp, 0), Acc1); //Assembler::AVX_256bit
|
||||
|
||||
for (int i = 0; i<5; i++) {
|
||||
if (i>0) {
|
||||
__ movq(digit, Address(rsp, -8+i*8));
|
||||
__ addq(digit, carry);
|
||||
}
|
||||
__ movq(carry, digit);
|
||||
__ sarq(carry, 52);
|
||||
|
||||
// long dummyLimbs = maskValue & (a[i] ^ b[i]);
|
||||
// a[i] = dummyLimbs ^ a[i];
|
||||
__ xorq(limb[i], digit);
|
||||
__ andq(limb[i], select);
|
||||
__ xorq(digit, limb[i]);
|
||||
|
||||
__ andq(digit, mask52);
|
||||
__ movq(Address(rLimbs, i*8), digit);
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
// Zero out ymm0-ymm15.
|
||||
__ vzeroall();
|
||||
__ vpxor(Acc1, Acc1, Acc1, Assembler::AVX_256bit);
|
||||
__ vmovdqa(Address(rsp, 0), Acc1); //Assembler::AVX_256bit
|
||||
}
|
||||
|
||||
address StubGenerator::generate_intpoly_montgomeryMult_P256() {
|
||||
@ -242,13 +564,58 @@ address StubGenerator::generate_intpoly_montgomeryMult_P256() {
|
||||
address start = __ pc();
|
||||
__ enter();
|
||||
|
||||
// Register Map
|
||||
const Register aLimbs = c_rarg0; // rdi | rcx
|
||||
const Register bLimbs = c_rarg1; // rsi | rdx
|
||||
const Register rLimbs = c_rarg2; // rdx | r8
|
||||
const Register tmp = r9;
|
||||
if (EnableX86ECoreOpts && UseAVX > 1) {
|
||||
__ push(r12);
|
||||
__ push(r13);
|
||||
__ push(r14);
|
||||
#ifdef _WIN64
|
||||
__ push(rsi);
|
||||
__ push(rdi);
|
||||
#endif
|
||||
__ push(rbp);
|
||||
__ movq(rbp, rsp);
|
||||
__ andq(rsp, -32);
|
||||
__ subptr(rsp, 32);
|
||||
|
||||
montgomeryMultiply(aLimbs, bLimbs, rLimbs, tmp, _masm);
|
||||
// Register Map
|
||||
const Register aLimbs = c_rarg0; // c_rarg0: rdi | rcx
|
||||
const Register bLimbs = rsi; // c_rarg1: rsi | rdx
|
||||
const Register rLimbs = r8; // c_rarg2: rdx | r8
|
||||
const Register tmp1 = r9;
|
||||
const Register tmp2 = r10;
|
||||
const Register tmp3 = r11;
|
||||
const Register tmp4 = r12;
|
||||
const Register tmp5 = r13;
|
||||
const Register tmp6 = r14;
|
||||
#ifdef _WIN64
|
||||
const Register tmp7 = rdi;
|
||||
__ movq(bLimbs, c_rarg1); // free-up rdx
|
||||
#else
|
||||
const Register tmp7 = rcx;
|
||||
__ movq(rLimbs, c_rarg2); // free-up rdx
|
||||
#endif
|
||||
|
||||
montgomeryMultiplyAVX2(aLimbs, bLimbs, rLimbs, rax, rdx,
|
||||
tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, _masm);
|
||||
|
||||
__ movq(rsp, rbp);
|
||||
__ pop(rbp);
|
||||
#ifdef _WIN64
|
||||
__ pop(rdi);
|
||||
__ pop(rsi);
|
||||
#endif
|
||||
__ pop(r14);
|
||||
__ pop(r13);
|
||||
__ pop(r12);
|
||||
} else {
|
||||
// Register Map
|
||||
const Register aLimbs = c_rarg0; // rdi | rcx
|
||||
const Register bLimbs = c_rarg1; // rsi | rdx
|
||||
const Register rLimbs = c_rarg2; // rdx | r8
|
||||
const Register tmp = r9;
|
||||
|
||||
montgomeryMultiply(aLimbs, bLimbs, rLimbs, tmp, _masm);
|
||||
}
|
||||
|
||||
__ leave();
|
||||
__ ret(0);
|
||||
@ -259,17 +626,34 @@ address StubGenerator::generate_intpoly_montgomeryMult_P256() {
|
||||
// Must be:
|
||||
// - constant time (i.e. no branches)
|
||||
// - no-side channel (i.e. all memory must always be accessed, and in same order)
|
||||
void assign_avx(XMMRegister A, Address aAddr, XMMRegister B, Address bAddr, KRegister select, int vector_len, MacroAssembler* _masm) {
|
||||
__ evmovdquq(A, aAddr, vector_len);
|
||||
__ evmovdquq(B, bAddr, vector_len);
|
||||
__ evmovdquq(A, select, B, true, vector_len);
|
||||
__ evmovdquq(aAddr, A, vector_len);
|
||||
}
|
||||
void assign_avx(Register aBase, Register bBase, int offset, XMMRegister select, XMMRegister tmp, XMMRegister aTmp, int vector_len, MacroAssembler* _masm) {
|
||||
if (vector_len == Assembler::AVX_512bit && UseAVX < 3) {
|
||||
assign_avx(aBase, bBase, offset, select, tmp, aTmp, Assembler::AVX_256bit, _masm);
|
||||
assign_avx(aBase, bBase, offset + 32, select, tmp, aTmp, Assembler::AVX_256bit, _masm);
|
||||
return;
|
||||
}
|
||||
|
||||
Address aAddr = Address(aBase, offset);
|
||||
Address bAddr = Address(bBase, offset);
|
||||
|
||||
void assign_scalar(Address aAddr, Address bAddr, Register select, Register tmp, MacroAssembler* _masm) {
|
||||
// Original java:
|
||||
// long dummyLimbs = maskValue & (a[i] ^ b[i]);
|
||||
// a[i] = dummyLimbs ^ a[i];
|
||||
__ vmovdqu(tmp, aAddr, vector_len);
|
||||
__ vmovdqu(aTmp, tmp, vector_len);
|
||||
__ vpxor(tmp, tmp, bAddr, vector_len);
|
||||
__ vpand(tmp, tmp, select, vector_len);
|
||||
__ vpxor(tmp, tmp, aTmp, vector_len);
|
||||
__ vmovdqu(aAddr, tmp, vector_len);
|
||||
}
|
||||
|
||||
void assign_scalar(Register aBase, Register bBase, int offset, Register select, Register tmp, MacroAssembler* _masm) {
|
||||
// Original java:
|
||||
// long dummyLimbs = maskValue & (a[i] ^ b[i]);
|
||||
// a[i] = dummyLimbs ^ a[i];
|
||||
|
||||
Address aAddr = Address(aBase, offset);
|
||||
Address bAddr = Address(bBase, offset);
|
||||
|
||||
__ movq(tmp, aAddr);
|
||||
__ xorq(tmp, bAddr);
|
||||
@ -308,13 +692,18 @@ address StubGenerator::generate_intpoly_assign() {
|
||||
const Register length = c_rarg3;
|
||||
XMMRegister A = xmm0;
|
||||
XMMRegister B = xmm1;
|
||||
XMMRegister select = xmm2;
|
||||
|
||||
Register tmp = r9;
|
||||
KRegister select = k1;
|
||||
Label L_Length5, L_Length10, L_Length14, L_Length16, L_Length19, L_DefaultLoop, L_Done;
|
||||
|
||||
__ negq(set);
|
||||
__ kmovql(select, set);
|
||||
if (UseAVX > 2) {
|
||||
__ evpbroadcastq(select, set, Assembler::AVX_512bit);
|
||||
} else {
|
||||
__ movq(select, set);
|
||||
__ vpbroadcastq(select, select, Assembler::AVX_256bit);
|
||||
}
|
||||
|
||||
// NOTE! Crypto code cannot branch on user input. However; allowed to branch on number of limbs;
|
||||
// Number of limbs is a constant in each IntegerPolynomial (i.e. this side-channel branch leaks
|
||||
@ -334,7 +723,7 @@ address StubGenerator::generate_intpoly_assign() {
|
||||
__ cmpl(length, 0);
|
||||
__ jcc(Assembler::lessEqual, L_Done);
|
||||
__ bind(L_DefaultLoop);
|
||||
assign_scalar(Address(aLimbs, 0), Address(bLimbs, 0), set, tmp, _masm);
|
||||
assign_scalar(aLimbs, bLimbs, 0, set, tmp, _masm);
|
||||
__ subl(length, 1);
|
||||
__ lea(aLimbs, Address(aLimbs,8));
|
||||
__ lea(bLimbs, Address(bLimbs,8));
|
||||
@ -343,31 +732,31 @@ address StubGenerator::generate_intpoly_assign() {
|
||||
__ jmp(L_Done);
|
||||
|
||||
__ bind(L_Length5); // 1 + 4
|
||||
assign_scalar(Address(aLimbs, 0), Address(bLimbs, 0), set, tmp, _masm);
|
||||
assign_avx(A, Address(aLimbs, 8), B, Address(bLimbs, 8), select, Assembler::AVX_256bit, _masm);
|
||||
assign_scalar(aLimbs, bLimbs, 0, set, tmp, _masm);
|
||||
assign_avx (aLimbs, bLimbs, 8, select, A, B, Assembler::AVX_256bit, _masm);
|
||||
__ jmp(L_Done);
|
||||
|
||||
__ bind(L_Length10); // 2 + 8
|
||||
assign_avx(A, Address(aLimbs, 0), B, Address(bLimbs, 0), select, Assembler::AVX_128bit, _masm);
|
||||
assign_avx(A, Address(aLimbs, 16), B, Address(bLimbs, 16), select, Assembler::AVX_512bit, _masm);
|
||||
assign_avx(aLimbs, bLimbs, 0, select, A, B, Assembler::AVX_128bit, _masm);
|
||||
assign_avx(aLimbs, bLimbs, 16, select, A, B, Assembler::AVX_512bit, _masm);
|
||||
__ jmp(L_Done);
|
||||
|
||||
__ bind(L_Length14); // 2 + 4 + 8
|
||||
assign_avx(A, Address(aLimbs, 0), B, Address(bLimbs, 0), select, Assembler::AVX_128bit, _masm);
|
||||
assign_avx(A, Address(aLimbs, 16), B, Address(bLimbs, 16), select, Assembler::AVX_256bit, _masm);
|
||||
assign_avx(A, Address(aLimbs, 48), B, Address(bLimbs, 48), select, Assembler::AVX_512bit, _masm);
|
||||
assign_avx(aLimbs, bLimbs, 0, select, A, B, Assembler::AVX_128bit, _masm);
|
||||
assign_avx(aLimbs, bLimbs, 16, select, A, B, Assembler::AVX_256bit, _masm);
|
||||
assign_avx(aLimbs, bLimbs, 48, select, A, B, Assembler::AVX_512bit, _masm);
|
||||
__ jmp(L_Done);
|
||||
|
||||
__ bind(L_Length16); // 8 + 8
|
||||
assign_avx(A, Address(aLimbs, 0), B, Address(bLimbs, 0), select, Assembler::AVX_512bit, _masm);
|
||||
assign_avx(A, Address(aLimbs, 64), B, Address(bLimbs, 64), select, Assembler::AVX_512bit, _masm);
|
||||
assign_avx(aLimbs, bLimbs, 0, select, A, B, Assembler::AVX_512bit, _masm);
|
||||
assign_avx(aLimbs, bLimbs, 64, select, A, B, Assembler::AVX_512bit, _masm);
|
||||
__ jmp(L_Done);
|
||||
|
||||
__ bind(L_Length19); // 1 + 2 + 8 + 8
|
||||
assign_scalar(Address(aLimbs, 0), Address(bLimbs, 0), set, tmp, _masm);
|
||||
assign_avx(A, Address(aLimbs, 8), B, Address(bLimbs, 8), select, Assembler::AVX_128bit, _masm);
|
||||
assign_avx(A, Address(aLimbs, 24), B, Address(bLimbs, 24), select, Assembler::AVX_512bit, _masm);
|
||||
assign_avx(A, Address(aLimbs, 88), B, Address(bLimbs, 88), select, Assembler::AVX_512bit, _masm);
|
||||
assign_scalar(aLimbs, bLimbs, 0, set, tmp, _masm);
|
||||
assign_avx (aLimbs, bLimbs, 8, select, A, B, Assembler::AVX_128bit, _masm);
|
||||
assign_avx (aLimbs, bLimbs, 24, select, A, B, Assembler::AVX_512bit, _masm);
|
||||
assign_avx (aLimbs, bLimbs, 88, select, A, B, Assembler::AVX_512bit, _masm);
|
||||
|
||||
__ bind(L_Done);
|
||||
__ leave();
|
||||
|
||||
@ -1406,7 +1406,7 @@ void VM_Version::get_processor_features() {
|
||||
}
|
||||
|
||||
#ifdef _LP64
|
||||
if (supports_avx512ifma() && supports_avx512vlbw()) {
|
||||
if ((supports_avx512ifma() && supports_avx512vlbw()) || supports_avxifma()) {
|
||||
if (FLAG_IS_DEFAULT(UseIntPolyIntrinsics)) {
|
||||
FLAG_SET_DEFAULT(UseIntPolyIntrinsics, true);
|
||||
}
|
||||
|
||||
@ -532,7 +532,7 @@ class methodHandle;
|
||||
/* support for sun.security.util.math.intpoly.MontgomeryIntegerPolynomialP256 */ \
|
||||
do_class(sun_security_util_math_intpoly_MontgomeryIntegerPolynomialP256, "sun/security/util/math/intpoly/MontgomeryIntegerPolynomialP256") \
|
||||
do_intrinsic(_intpoly_montgomeryMult_P256, sun_security_util_math_intpoly_MontgomeryIntegerPolynomialP256, intPolyMult_name, intPolyMult_signature, F_R) \
|
||||
do_name(intPolyMult_name, "multImpl") \
|
||||
do_name(intPolyMult_name, "mult") \
|
||||
do_signature(intPolyMult_signature, "([J[J[J)V") \
|
||||
\
|
||||
do_class(sun_security_util_math_intpoly_IntegerPolynomial, "sun/security/util/math/intpoly/IntegerPolynomial") \
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
@ -159,14 +159,8 @@ public final class MontgomeryIntegerPolynomialP256 extends IntegerPolynomial
|
||||
* numAdds to reuse existing overflow logic.
|
||||
*/
|
||||
@Override
|
||||
protected void mult(long[] a, long[] b, long[] r) {
|
||||
multImpl(a, b, r);
|
||||
reducePositive(r);
|
||||
}
|
||||
|
||||
@ForceInline
|
||||
@IntrinsicCandidate
|
||||
private void multImpl(long[] a, long[] b, long[] r) {
|
||||
protected void mult(long[] a, long[] b, long[] r) {
|
||||
long aa0 = a[0];
|
||||
long aa1 = a[1];
|
||||
long aa2 = a[2];
|
||||
@ -398,17 +392,43 @@ public final class MontgomeryIntegerPolynomialP256 extends IntegerPolynomial
|
||||
dd4 += Math.unsignedMultiplyHigh(n, modulus[4]) << shift1 | (n4 >>> shift2);
|
||||
d4 += n4 & LIMB_MASK;
|
||||
|
||||
// Final carry propagate
|
||||
c5 += d1 + dd0 + (d0 >>> BITS_PER_LIMB);
|
||||
c6 += d2 + dd1;
|
||||
c7 += d3 + dd2;
|
||||
c8 += d4 + dd3;
|
||||
c9 = dd4;
|
||||
c6 += d2 + dd1 + (c5 >>> BITS_PER_LIMB);
|
||||
c7 += d3 + dd2 + (c6 >>> BITS_PER_LIMB);
|
||||
c8 += d4 + dd3 + (c7 >>> BITS_PER_LIMB);
|
||||
c9 = dd4 + (c8 >>> BITS_PER_LIMB);
|
||||
|
||||
r[0] = c5;
|
||||
r[1] = c6;
|
||||
r[2] = c7;
|
||||
r[3] = c8;
|
||||
r[4] = c9;
|
||||
c5 &= LIMB_MASK;
|
||||
c6 &= LIMB_MASK;
|
||||
c7 &= LIMB_MASK;
|
||||
c8 &= LIMB_MASK;
|
||||
|
||||
// At this point, the result {c5, c6, c7, c8, c9} could overflow by
|
||||
// one modulus. Subtract one modulus (with carry propagation), into
|
||||
// {c0, c1, c2, c3, c4}. Note that in this calculation, limbs are
|
||||
// signed
|
||||
c0 = c5 - modulus[0];
|
||||
c1 = c6 - modulus[1] + (c0 >> BITS_PER_LIMB);
|
||||
c0 &= LIMB_MASK;
|
||||
c2 = c7 - modulus[2] + (c1 >> BITS_PER_LIMB);
|
||||
c1 &= LIMB_MASK;
|
||||
c3 = c8 - modulus[3] + (c2 >> BITS_PER_LIMB);
|
||||
c2 &= LIMB_MASK;
|
||||
c4 = c9 - modulus[4] + (c3 >> BITS_PER_LIMB);
|
||||
c3 &= LIMB_MASK;
|
||||
|
||||
// We now must select a result that is in range of [0,modulus). i.e.
|
||||
// either {c0-4} or {c5-9}. Iff {c0-4} is negative, then {c5-9} contains
|
||||
// the result. (After carry propagation) IF c4 is negative, {c0-4} is
|
||||
// negative. Arithmetic shift by 64 bits generates a mask from c4 that
|
||||
// can be used to select 'constant time' either {c0-4} or {c5-9}.
|
||||
long mask = c4 >> 63;
|
||||
r[0] = ((c5 & mask) | (c0 & ~mask));
|
||||
r[1] = ((c6 & mask) | (c1 & ~mask));
|
||||
r[2] = ((c7 & mask) | (c2 & ~mask));
|
||||
r[3] = ((c8 & mask) | (c3 & ~mask));
|
||||
r[4] = ((c9 & mask) | (c4 & ~mask));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2024, Intel Corporation. All rights reserved.
|
||||
* Copyright (c) 2024, 2025, Intel Corporation. All rights reserved.
|
||||
*
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
@ -23,9 +23,8 @@
|
||||
*/
|
||||
|
||||
import java.util.Random;
|
||||
import sun.security.util.math.IntegerMontgomeryFieldModuloP;
|
||||
import sun.security.util.math.ImmutableIntegerModuloP;
|
||||
import java.math.BigInteger;
|
||||
import sun.security.util.math.*;
|
||||
import sun.security.util.math.intpoly.*;
|
||||
|
||||
/*
|
||||
@ -35,7 +34,7 @@ import sun.security.util.math.intpoly.*;
|
||||
* java.base/sun.security.util.math.intpoly
|
||||
* @run main/othervm -XX:+UnlockDiagnosticVMOptions -XX:-UseIntPolyIntrinsics
|
||||
* MontgomeryPolynomialFuzzTest
|
||||
* @summary Unit test MontgomeryPolynomialFuzzTest.
|
||||
* @summary Unit test MontgomeryPolynomialFuzzTest without intrinsic, plain java
|
||||
*/
|
||||
|
||||
/*
|
||||
@ -45,10 +44,11 @@ import sun.security.util.math.intpoly.*;
|
||||
* java.base/sun.security.util.math.intpoly
|
||||
* @run main/othervm -XX:+UnlockDiagnosticVMOptions -XX:+UseIntPolyIntrinsics
|
||||
* MontgomeryPolynomialFuzzTest
|
||||
* @summary Unit test MontgomeryPolynomialFuzzTest.
|
||||
* @summary Unit test MontgomeryPolynomialFuzzTest with intrinsic enabled
|
||||
*/
|
||||
|
||||
// This test case is NOT entirely deterministic, it uses a random seed for pseudo-random number generator
|
||||
// This test case is NOT entirely deterministic, it uses a random seed for
|
||||
// pseudo-random number generator
|
||||
// If a failure occurs, hardcode the seed to make the test case deterministic
|
||||
public class MontgomeryPolynomialFuzzTest {
|
||||
public static void main(String[] args) throws Exception {
|
||||
@ -60,15 +60,38 @@ public class MontgomeryPolynomialFuzzTest {
|
||||
System.out.println("Fuzz Success");
|
||||
}
|
||||
|
||||
private static void check(BigInteger reference,
|
||||
private static void checkOverflow(String opMsg,
|
||||
ImmutableIntegerModuloP testValue, long seed) {
|
||||
if (!reference.equals(testValue.asBigInteger())) {
|
||||
throw new RuntimeException("SEED: " + seed);
|
||||
long limbs[] = testValue.getLimbs();
|
||||
BigInteger mod = MontgomeryIntegerPolynomialP256.ONE.MODULUS;
|
||||
BigInteger ref = BigInteger.ZERO;
|
||||
for (int i = 0; i<limbs.length; i++) {
|
||||
ref.add(BigInteger.valueOf(limbs[i]).shiftLeft(i*52));
|
||||
}
|
||||
if (ref.compareTo(mod)!=-1) {
|
||||
String msg = "Error while " + opMsg + System.lineSeparator()
|
||||
+ ref.toString(16) + " != " + mod.toString(16) + System.lineSeparator()
|
||||
+ "To reproduce, set SEED to [" + seed + "L]: ";
|
||||
throw new RuntimeException(msg);
|
||||
}
|
||||
}
|
||||
|
||||
private static void check(String opMsg, BigInteger reference,
|
||||
ImmutableIntegerModuloP testValue, long seed) {
|
||||
BigInteger test = testValue.asBigInteger();
|
||||
if (!reference.equals(test)) {
|
||||
String msg = "Error while " + opMsg + System.lineSeparator()
|
||||
+ reference.toString(16) + " != " + test.toString(16)
|
||||
+ System.lineSeparator()+ "To reproduce, set SEED to ["
|
||||
+ seed + "L]: ";
|
||||
throw new RuntimeException(msg);
|
||||
}
|
||||
}
|
||||
|
||||
public static void run() throws Exception {
|
||||
Random rnd = new Random();
|
||||
// To reproduce an error, fix the value of the seed to the value from
|
||||
// the failure
|
||||
long seed = rnd.nextLong();
|
||||
rnd.setSeed(seed);
|
||||
|
||||
@ -77,22 +100,75 @@ public class MontgomeryPolynomialFuzzTest {
|
||||
BigInteger r = BigInteger.ONE.shiftLeft(260).mod(P);
|
||||
BigInteger rInv = r.modInverse(P);
|
||||
BigInteger aRef = (new BigInteger(P.bitLength(), rnd)).mod(P);
|
||||
BigInteger bRef = (new BigInteger(P.bitLength(), rnd)).mod(P);
|
||||
SmallValue two = montField.getSmallValue(2);
|
||||
SmallValue three = montField.getSmallValue(3);
|
||||
SmallValue four = montField.getSmallValue(4);
|
||||
|
||||
// Test conversion to montgomery domain
|
||||
ImmutableIntegerModuloP a = montField.getElement(aRef);
|
||||
String msg = "converting "+aRef.toString(16) + " to montgomery domain";
|
||||
aRef = aRef.multiply(r).mod(P);
|
||||
check(aRef, a, seed);
|
||||
check(msg, aRef, a, seed);
|
||||
checkOverflow(msg, a, seed);
|
||||
|
||||
ImmutableIntegerModuloP b = montField.getElement(bRef);
|
||||
msg = "converting "+aRef.toString(16) + " to montgomery domain";
|
||||
bRef = bRef.multiply(r).mod(P);
|
||||
check(msg, bRef, b, seed);
|
||||
checkOverflow(msg, b, seed);
|
||||
|
||||
if (rnd.nextBoolean()) {
|
||||
msg = "squaring "+aRef.toString(16);
|
||||
aRef = aRef.multiply(aRef).multiply(rInv).mod(P);
|
||||
a = a.multiply(a);
|
||||
check(aRef, a, seed);
|
||||
check(msg, aRef, a, seed);
|
||||
checkOverflow(msg, a, seed);
|
||||
}
|
||||
|
||||
if (rnd.nextBoolean()) {
|
||||
msg = "doubling "+aRef.toString(16);
|
||||
aRef = aRef.add(aRef).mod(P);
|
||||
a = a.add(a);
|
||||
check(aRef, a, seed);
|
||||
check(msg, aRef, a, seed);
|
||||
}
|
||||
|
||||
if (rnd.nextBoolean()) {
|
||||
msg = "subtracting "+bRef.toString(16)+" from "+aRef.toString(16);
|
||||
aRef = aRef.subtract(bRef).mod(P);
|
||||
a = a.mutable().setDifference(b).fixed();
|
||||
check(msg, aRef, a, seed);
|
||||
}
|
||||
|
||||
if (rnd.nextBoolean()) {
|
||||
msg = "multiplying "+bRef.toString(16)+" with "+aRef.toString(16);
|
||||
aRef = aRef.multiply(bRef).multiply(rInv).mod(P);
|
||||
a = a.multiply(b);
|
||||
check(msg, aRef, a, seed);
|
||||
checkOverflow(msg, a, seed);
|
||||
}
|
||||
|
||||
if (rnd.nextBoolean()) {
|
||||
msg = "multiplying "+aRef.toString(16)+" with constant 2";
|
||||
aRef = aRef.multiply(BigInteger.valueOf(2)).mod(P);
|
||||
a = a.mutable().setProduct(two).fixed();
|
||||
check(msg, aRef, a, seed);
|
||||
}
|
||||
|
||||
if (rnd.nextBoolean()) {
|
||||
msg = "multiplying "+aRef.toString(16)+" with constant 3";
|
||||
aRef = aRef.multiply(BigInteger.valueOf(3)).mod(P);
|
||||
a = a.mutable().setProduct(three).fixed();
|
||||
check(msg, aRef, a, seed);
|
||||
checkOverflow(msg, a, seed);
|
||||
}
|
||||
|
||||
if (rnd.nextBoolean()) {
|
||||
msg = "multiplying "+aRef.toString(16)+" with constant 4";
|
||||
aRef = aRef.multiply(BigInteger.valueOf(4)).mod(P);
|
||||
a = a.mutable().setProduct(four).fixed();
|
||||
check(msg, aRef, a, seed);
|
||||
checkOverflow(msg, a, seed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user