8350459: MontgomeryIntegerPolynomialP256 multiply intrinsic with AVX2 on x86_64

Reviewed-by: ascarpino, sviswanathan
This commit is contained in:
Volodymyr Paprotski 2025-03-28 15:20:31 +00:00
parent c029220379
commit a269bef04c
9 changed files with 740 additions and 115 deletions

View File

@ -3529,6 +3529,30 @@ void Assembler::vmovdqu(Address dst, XMMRegister src) {
emit_operand(src, dst, 0);
}
// Move Aligned 256bit Vector
void Assembler::vmovdqa(XMMRegister dst, Address src) {
assert(UseAVX > 0, "");
InstructionMark im(this);
InstructionAttr attributes(AVX_256bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true);
attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit);
vex_prefix(src, 0, dst->encoding(), VEX_SIMD_66, VEX_OPCODE_0F, &attributes);
emit_int8(0x6F);
emit_operand(dst, src, 0);
}
void Assembler::vmovdqa(Address dst, XMMRegister src) {
assert(UseAVX > 0, "");
InstructionMark im(this);
InstructionAttr attributes(AVX_256bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true);
attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit);
attributes.reset_is_clear_context();
// swap src<->dst for encoding
assert(src != xnoreg, "sanity");
vex_prefix(dst, 0, src->encoding(), VEX_SIMD_66, VEX_OPCODE_0F, &attributes);
emit_int8(0x7F);
emit_operand(src, dst, 0);
}
void Assembler::vpmaskmovd(XMMRegister dst, XMMRegister mask, Address src, int vector_len) {
assert((VM_Version::supports_avx2() && vector_len == AVX_256bit), "");
InstructionMark im(this);
@ -3791,6 +3815,27 @@ void Assembler::evmovdquq(XMMRegister dst, KRegister mask, Address src, bool mer
emit_operand(dst, src, 0);
}
// Move Aligned 512bit Vector
void Assembler::evmovdqaq(XMMRegister dst, Address src, int vector_len) {
// Unmasked instruction
evmovdqaq(dst, k0, src, /*merge*/ false, vector_len);
}
void Assembler::evmovdqaq(XMMRegister dst, KRegister mask, Address src, bool merge, int vector_len) {
assert(VM_Version::supports_evex(), "");
InstructionMark im(this);
InstructionAttr attributes(vector_len, /* vex_w */ true, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true);
attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit);
attributes.set_embedded_opmask_register_specifier(mask);
attributes.set_is_evex_instruction();
if (merge) {
attributes.reset_is_clear_context();
}
vex_prefix(src, 0, dst->encoding(), VEX_SIMD_66, VEX_OPCODE_0F, &attributes);
emit_int8(0x6F);
emit_operand(dst, src, 0);
}
void Assembler::evmovntdquq(Address dst, XMMRegister src, int vector_len) {
// Unmasked instruction
evmovntdquq(dst, k0, src, /*merge*/ true, vector_len);

View File

@ -1758,6 +1758,10 @@ private:
void vmovdqu(XMMRegister dst, Address src);
void vmovdqu(XMMRegister dst, XMMRegister src);
// Move Aligned 256bit Vector
void vmovdqa(XMMRegister dst, Address src);
void vmovdqa(Address dst, XMMRegister src);
// Move Unaligned 512bit Vector
void evmovdqub(XMMRegister dst, XMMRegister src, int vector_len);
void evmovdqub(XMMRegister dst, Address src, int vector_len);
@ -1791,6 +1795,10 @@ private:
void evmovdquq(XMMRegister dst, KRegister mask, Address src, bool merge, int vector_len);
void evmovdquq(Address dst, KRegister mask, XMMRegister src, bool merge, int vector_len);
// Move Aligned 512bit Vector
void evmovdqaq(XMMRegister dst, Address src, int vector_len);
void evmovdqaq(XMMRegister dst, KRegister mask, Address src, bool merge, int vector_len);
// Move lower 64bit to high 64bit in 128bit register
void movlhps(XMMRegister dst, XMMRegister src);

View File

@ -2720,6 +2720,60 @@ void MacroAssembler::vmovdqu(XMMRegister dst, AddressLiteral src, int vector_len
}
}
void MacroAssembler::vmovdqu(XMMRegister dst, XMMRegister src, int vector_len) {
if (vector_len == AVX_512bit) {
evmovdquq(dst, src, AVX_512bit);
} else if (vector_len == AVX_256bit) {
vmovdqu(dst, src);
} else {
movdqu(dst, src);
}
}
void MacroAssembler::vmovdqu(Address dst, XMMRegister src, int vector_len) {
if (vector_len == AVX_512bit) {
evmovdquq(dst, src, AVX_512bit);
} else if (vector_len == AVX_256bit) {
vmovdqu(dst, src);
} else {
movdqu(dst, src);
}
}
void MacroAssembler::vmovdqu(XMMRegister dst, Address src, int vector_len) {
if (vector_len == AVX_512bit) {
evmovdquq(dst, src, AVX_512bit);
} else if (vector_len == AVX_256bit) {
vmovdqu(dst, src);
} else {
movdqu(dst, src);
}
}
void MacroAssembler::vmovdqa(XMMRegister dst, AddressLiteral src, Register rscratch) {
assert(rscratch != noreg || always_reachable(src), "missing");
if (reachable(src)) {
vmovdqa(dst, as_Address(src));
}
else {
lea(rscratch, src);
vmovdqa(dst, Address(rscratch, 0));
}
}
void MacroAssembler::vmovdqa(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch) {
assert(rscratch != noreg || always_reachable(src), "missing");
if (vector_len == AVX_512bit) {
evmovdqaq(dst, src, AVX_512bit, rscratch);
} else if (vector_len == AVX_256bit) {
vmovdqa(dst, src, rscratch);
} else {
movdqa(dst, src, rscratch);
}
}
void MacroAssembler::kmov(KRegister dst, Address src) {
if (VM_Version::supports_avx512bw()) {
kmovql(dst, src);
@ -2844,6 +2898,29 @@ void MacroAssembler::evmovdquq(XMMRegister dst, AddressLiteral src, int vector_l
}
}
void MacroAssembler::evmovdqaq(XMMRegister dst, KRegister mask, AddressLiteral src, bool merge, int vector_len, Register rscratch) {
assert(rscratch != noreg || always_reachable(src), "missing");
if (reachable(src)) {
Assembler::evmovdqaq(dst, mask, as_Address(src), merge, vector_len);
} else {
lea(rscratch, src);
Assembler::evmovdqaq(dst, mask, Address(rscratch, 0), merge, vector_len);
}
}
void MacroAssembler::evmovdqaq(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch) {
assert(rscratch != noreg || always_reachable(src), "missing");
if (reachable(src)) {
Assembler::evmovdqaq(dst, as_Address(src), vector_len);
} else {
lea(rscratch, src);
Assembler::evmovdqaq(dst, Address(rscratch, 0), vector_len);
}
}
void MacroAssembler::movdqa(XMMRegister dst, AddressLiteral src, Register rscratch) {
assert(rscratch != noreg || always_reachable(src), "missing");

View File

@ -1348,6 +1348,14 @@ public:
void vmovdqu(XMMRegister dst, XMMRegister src);
void vmovdqu(XMMRegister dst, AddressLiteral src, Register rscratch = noreg);
void vmovdqu(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch = noreg);
void vmovdqu(XMMRegister dst, XMMRegister src, int vector_len);
void vmovdqu(XMMRegister dst, Address src, int vector_len);
void vmovdqu(Address dst, XMMRegister src, int vector_len);
// AVX Aligned forms
using Assembler::vmovdqa;
void vmovdqa(XMMRegister dst, AddressLiteral src, Register rscratch = noreg);
void vmovdqa(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch = noreg);
// AVX512 Unaligned
void evmovdqu(BasicType type, KRegister kmask, Address dst, XMMRegister src, bool merge, int vector_len);
@ -1404,6 +1412,7 @@ public:
void evmovdquq(XMMRegister dst, Address src, int vector_len) { Assembler::evmovdquq(dst, src, vector_len); }
void evmovdquq(Address dst, XMMRegister src, int vector_len) { Assembler::evmovdquq(dst, src, vector_len); }
void evmovdquq(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch = noreg);
void evmovdqaq(XMMRegister dst, AddressLiteral src, int vector_len, Register rscratch = noreg);
void evmovdquq(XMMRegister dst, KRegister mask, XMMRegister src, bool merge, int vector_len) {
if (dst->encoding() != src->encoding() || mask != k0) {
@ -1413,6 +1422,7 @@ public:
void evmovdquq(Address dst, KRegister mask, XMMRegister src, bool merge, int vector_len) { Assembler::evmovdquq(dst, mask, src, merge, vector_len); }
void evmovdquq(XMMRegister dst, KRegister mask, Address src, bool merge, int vector_len) { Assembler::evmovdquq(dst, mask, src, merge, vector_len); }
void evmovdquq(XMMRegister dst, KRegister mask, AddressLiteral src, bool merge, int vector_len, Register rscratch = noreg);
void evmovdqaq(XMMRegister dst, KRegister mask, AddressLiteral src, bool merge, int vector_len, Register rscratch = noreg);
// Move Aligned Double Quadword
void movdqa(XMMRegister dst, XMMRegister src) { Assembler::movdqa(dst, src); }

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, Intel Corporation. All rights reserved.
* Copyright (c) 2024, 2025, Intel Corporation. All rights reserved.
*
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
@ -28,17 +28,17 @@
#define __ _masm->
ATTRIBUTE_ALIGNED(64) uint64_t MODULUS_P256[] = {
ATTRIBUTE_ALIGNED(64) constexpr uint64_t MODULUS_P256[] = {
0x000fffffffffffffULL, 0x00000fffffffffffULL,
0x0000000000000000ULL, 0x0000001000000000ULL,
0x0000ffffffff0000ULL, 0x0000000000000000ULL,
0x0000000000000000ULL, 0x0000000000000000ULL
};
static address modulus_p256() {
return (address)MODULUS_P256;
static address modulus_p256(int index = 0) {
return (address)&MODULUS_P256[index];
}
ATTRIBUTE_ALIGNED(64) uint64_t P256_MASK52[] = {
ATTRIBUTE_ALIGNED(64) constexpr uint64_t P256_MASK52[] = {
0x000fffffffffffffULL, 0x000fffffffffffffULL,
0x000fffffffffffffULL, 0x000fffffffffffffULL,
0xffffffffffffffffULL, 0xffffffffffffffffULL,
@ -48,7 +48,7 @@ static address p256_mask52() {
return (address)P256_MASK52;
}
ATTRIBUTE_ALIGNED(64) uint64_t SHIFT1R[] = {
ATTRIBUTE_ALIGNED(64) constexpr uint64_t SHIFT1R[] = {
0x0000000000000001ULL, 0x0000000000000002ULL,
0x0000000000000003ULL, 0x0000000000000004ULL,
0x0000000000000005ULL, 0x0000000000000006ULL,
@ -58,7 +58,7 @@ static address shift_1R() {
return (address)SHIFT1R;
}
ATTRIBUTE_ALIGNED(64) uint64_t SHIFT1L[] = {
ATTRIBUTE_ALIGNED(64) constexpr uint64_t SHIFT1L[] = {
0x0000000000000007ULL, 0x0000000000000000ULL,
0x0000000000000001ULL, 0x0000000000000002ULL,
0x0000000000000003ULL, 0x0000000000000004ULL,
@ -68,6 +68,14 @@ static address shift_1L() {
return (address)SHIFT1L;
}
ATTRIBUTE_ALIGNED(64) constexpr uint64_t MASKL5[] = {
0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL,
0xFFFFFFFFFFFFFFFFULL, 0x0000000000000000ULL,
};
static address mask_limb5() {
return (address)MASKL5;
}
/**
* Unrolled Word-by-Word Montgomery Multiplication
* r = a * b * 2^-260 (mod P)
@ -94,26 +102,25 @@ static address shift_1L() {
* B = replicate(bLimbs[i]) |bi|bi|bi|bi|bi|bi|bi|bi|
* +--+--+--+--+--+--+--+--+
* +--+--+--+--+--+--+--+--+
* | 0| 0| 0|a5|a4|a3|a2|a1|
* Acc1 += A * B *|bi|bi|bi|bi|bi|bi|bi|bi|
* Acc1+=| 0| 0| 0|c5|c4|c3|c2|c1|
* *| 0| 0| 0|a5|a4|a3|a2|a1|
* Acc1 += A * B |bi|bi|bi|bi|bi|bi|bi|bi|
* +--+--+--+--+--+--+--+--+
* Acc2+=| 0| 0| 0| 0| 0| 0| 0| 0|
* *h| 0| 0| 0|a5|a4|a3|a2|a1|
* Acc2 += A *h B |bi|bi|bi|bi|bi|bi|bi|bi|
* | 0| 0| 0|a5|a4|a3|a2|a1|
* Acc2 += A *h B *h|bi|bi|bi|bi|bi|bi|bi|bi|
* Acc2+=| 0| 0| 0| d5|d4|d3|d2|d1|
* +--+--+--+--+--+--+--+--+
* N = replicate(Acc1[0]) |n0|n0|n0|n0|n0|n0|n0|n0|
* +--+--+--+--+--+--+--+--+
* +--+--+--+--+--+--+--+--+
* Acc1+=| 0| 0| 0|c5|c4|c3|c2|c1|
* *| 0| 0| 0|m5|m4|m3|m2|m1|
* Acc1 += M * N |n0|n0|n0|n0|n0|n0|n0|n0| Note: 52 low bits of Acc1[0] == 0 due to Montgomery!
* | 0| 0| 0|m5|m4|m3|m2|m1|
* Acc1 += M * N *|n0|n0|n0|n0|n0|n0|n0|n0|
* Acc1+=| 0| 0| 0|c5|c4|c3|c2|c1| Note: 52 low bits of c1 == 0 due to Montgomery!
* +--+--+--+--+--+--+--+--+
* | 0| 0| 0|m5|m4|m3|m2|m1|
* Acc2 += M *h N *h|n0|n0|n0|n0|n0|n0|n0|n0|
* Acc2+=| 0| 0| 0|d5|d4|d3|d2|d1|
* *h| 0| 0| 0|m5|m4|m3|m2|m1|
* Acc2 += M *h N |n0|n0|n0|n0|n0|n0|n0|n0|
* +--+--+--+--+--+--+--+--+
* if (i == 4) break;
* // Combine high/low partial sums Acc1 + Acc2
* +--+--+--+--+--+--+--+--+
* carry = Acc1[0] >> 52 | 0| 0| 0| 0| 0| 0| 0|c1|
@ -124,13 +131,35 @@ static address shift_1L() {
* +--+--+--+--+--+--+--+--+
* Acc1 = Acc1 + Acc2
* ---- done
* // Last Carry round: Combine high/low partial sums Acc1<high_bits> + Acc1 + Acc2
* carry = Acc1 >> 52
* Acc1 = Acc1 shift one q element >>
* Acc1 = mask52(Acc1)
* Acc2 += carry
* Acc1 = Acc1 + Acc2
* output to rLimbs
*
* At this point the result in Acc1 can overflow by 1 Modulus and needs carry
* propagation. Subtract one modulus, carry-propagate both results and select
* (constant-time) the positive number of the two
*
* Carry = Acc1[0] >> 52
* Acc1L = Acc1[0] & mask52
* Acc1 = Acc1 shift one q element>>
* Acc1 += Carry
*
* Carry = Acc2[0] >> 52
* Acc2L = Acc2[0] & mask52
* Acc2 = Acc2 shift one q element>>
* Acc2 += Carry
*
* for col:=1 to 4
* Carry = Acc2[col]>>52
* Carry = Carry shift one q element<<
* Acc2 += Carry
*
* Carry = Acc1[col]>>52
* Carry = Carry shift one q element<<
* Acc1 += Carry
* done
*
* Acc1 &= mask52
* Acc2 &= mask52
* Mask = sign(Acc2)
* Result = select(Mask ? Acc1 or Acc2)
*/
void montgomeryMultiply(const Register aLimbs, const Register bLimbs, const Register rLimbs, const Register tmp, MacroAssembler* _masm) {
Register t0 = tmp;
@ -145,26 +174,30 @@ void montgomeryMultiply(const Register aLimbs, const Register bLimbs, const Regi
XMMRegister Acc1 = xmm10;
XMMRegister Acc2 = xmm11;
XMMRegister N = xmm12;
XMMRegister carry = xmm13;
XMMRegister Carry = xmm13;
// // Constants
XMMRegister modulus = xmm20;
XMMRegister shift1L = xmm21;
XMMRegister shift1R = xmm22;
XMMRegister mask52 = xmm23;
KRegister limb0 = k1;
KRegister allLimbs = k2;
XMMRegister modulus = xmm5;
XMMRegister shift1L = xmm6;
XMMRegister shift1R = xmm7;
XMMRegister Mask52 = xmm8;
KRegister allLimbs = k1;
KRegister limb0 = k2;
KRegister masks[] = {limb0, k3, k4, k5};
for (int i=0; i<4; i++) {
__ mov64(t0, 1ULL<<i);
__ kmovql(masks[i], t0);
}
__ mov64(t0, 0x1);
__ kmovql(limb0, t0);
__ mov64(t0, 0x1f);
__ kmovql(allLimbs, t0);
__ evmovdquq(shift1L, allLimbs, ExternalAddress(shift_1L()), false, Assembler::AVX_512bit, rscratch);
__ evmovdquq(shift1R, allLimbs, ExternalAddress(shift_1R()), false, Assembler::AVX_512bit, rscratch);
__ evmovdquq(mask52, allLimbs, ExternalAddress(p256_mask52()), false, Assembler::AVX_512bit, rscratch);
__ evmovdqaq(shift1L, allLimbs, ExternalAddress(shift_1L()), false, Assembler::AVX_512bit, rscratch);
__ evmovdqaq(shift1R, allLimbs, ExternalAddress(shift_1R()), false, Assembler::AVX_512bit, rscratch);
__ evmovdqaq(Mask52, allLimbs, ExternalAddress(p256_mask52()), false, Assembler::AVX_512bit, rscratch);
// M = load(*modulus_p256)
__ evmovdquq(modulus, allLimbs, ExternalAddress(modulus_p256()), false, Assembler::AVX_512bit, rscratch);
__ evmovdqaq(modulus, allLimbs, ExternalAddress(modulus_p256()), false, Assembler::AVX_512bit, rscratch);
// A = load(*aLimbs); masked evmovdquq() can be slow. Instead load full 256bit, and compbine with 64bit
__ evmovdquq(A, Address(aLimbs, 8), Assembler::AVX_256bit);
@ -196,15 +229,13 @@ void montgomeryMultiply(const Register aLimbs, const Register bLimbs, const Regi
// Acc2 += M *h N
__ evpmadd52huq(Acc2, modulus, N, Assembler::AVX_512bit);
if (i == 4) break;
// Combine high/low partial sums Acc1 + Acc2
// carry = Acc1[0] >> 52
__ evpsrlq(carry, limb0, Acc1, 52, true, Assembler::AVX_512bit);
__ evpsrlq(Carry, limb0, Acc1, 52, true, Assembler::AVX_512bit);
// Acc2[0] += carry
__ evpaddq(Acc2, limb0, carry, Acc2, true, Assembler::AVX_512bit);
__ evpaddq(Acc2, limb0, Carry, Acc2, true, Assembler::AVX_512bit);
// Acc1 = Acc1 shift one q element >>
__ evpermq(Acc1, allLimbs, shift1R, Acc1, false, Assembler::AVX_512bit);
@ -213,26 +244,317 @@ void montgomeryMultiply(const Register aLimbs, const Register bLimbs, const Regi
__ vpaddq(Acc1, Acc1, Acc2, Assembler::AVX_512bit);
}
// Last Carry round: Combine high/low partial sums Acc1<high_bits> + Acc1 + Acc2
// carry = Acc1 >> 52
__ evpsrlq(carry, allLimbs, Acc1, 52, true, Assembler::AVX_512bit);
// At this point the result is in Acc1, but needs to be normailized to 52bit
// limbs (i.e. needs carry propagation) It can also overflow by 1 modulus.
// Subtract one modulus from Acc1 into Acc2 then carry propagate both
// simultaneously
// Acc1 = Acc1 shift one q element >>
XMMRegister Acc1L = A;
XMMRegister Acc2L = B;
__ vpsubq(Acc2, Acc1, modulus, Assembler::AVX_512bit);
// digit 0 carry out
// Also split Acc1 and Acc2 into two 256-bit vectors each {Acc1, Acc1L} and
// {Acc2, Acc2L} to use 256bit operations
__ evpsraq(Carry, limb0, Acc2, 52, false, Assembler::AVX_256bit);
__ evpandq(Acc2L, limb0, Acc2, Mask52, false, Assembler::AVX_256bit);
__ evpermq(Acc2, allLimbs, shift1R, Acc2, false, Assembler::AVX_512bit);
__ vpaddq(Acc2, Acc2, Carry, Assembler::AVX_256bit);
__ evpsraq(Carry, limb0, Acc1, 52, false, Assembler::AVX_256bit);
__ evpandq(Acc1L, limb0, Acc1, Mask52, false, Assembler::AVX_256bit);
__ evpermq(Acc1, allLimbs, shift1R, Acc1, false, Assembler::AVX_512bit);
__ vpaddq(Acc1, Acc1, Carry, Assembler::AVX_256bit);
// Acc1 = mask52(Acc1)
__ evpandq(Acc1, Acc1, mask52, Assembler::AVX_512bit); // Clear top 12 bits
/* remaining digits carry
* Note1: Carry register contains just the carry for the particular
* column (zero-mask the rest) and gets progressively shifted left
* Note2: 'element shift' with vpermq is more expensive, so using vpalignr when
* possible. vpalignr shifts 'right' not left, so place the carry appropiately
* +--+--+--+--+ +--+--+--+--+ +--+--+
* vpalignr(X, X, X, 8): |x4|x3|x2|x1| >> |x2|x1|x2|x1| |x1|x2|
* +--+--+--+--+ +--+--+--+--+ >> +--+--+
* | +--+--+--+--+ +--+--+
* | |x4|x3|x4|x3| |x3|x4|
* | +--+--+--+--+ +--+--+
* | vv
* | +--+--+--+--+
* (x3 and x1 is effectively shifted +------------------------> |x3|x4|x1|x2|
* left; zero-mask everything but one column of interest) +--+--+--+--+
*/
for (int i = 1; i<4; i++) {
__ evpsraq(Carry, masks[i-1], Acc2, 52, false, Assembler::AVX_256bit);
if (i == 1 || i == 3) {
__ vpalignr(Carry, Carry, Carry, 8, Assembler::AVX_256bit);
} else {
__ vpermq(Carry, Carry, 0b10010011, Assembler::AVX_256bit);
}
__ vpaddq(Acc2, Acc2, Carry, Assembler::AVX_256bit);
// Acc2 += carry
__ evpaddq(Acc2, allLimbs, carry, Acc2, true, Assembler::AVX_512bit);
__ evpsraq(Carry, masks[i-1], Acc1, 52, false, Assembler::AVX_256bit);
if (i == 1 || i == 3) {
__ vpalignr(Carry, Carry, Carry, 8, Assembler::AVX_256bit);
} else {
__ vpermq(Carry, Carry, 0b10010011, Assembler::AVX_256bit); //0b-2-1-0-3
}
__ vpaddq(Acc1, Acc1, Carry, Assembler::AVX_256bit);
}
// Acc1 = Acc1 + Acc2
__ vpaddq(Acc1, Acc1, Acc2, Assembler::AVX_512bit);
// Iff Acc2 is negative, then Acc1 contains the result.
// if Acc2 is negative, upper 12 bits will be set; arithmetic shift by 64 bits
// generates a mask from Acc2 sign bit
__ evpsraq(Carry, Acc2, 64, Assembler::AVX_256bit);
__ vpermq(Carry, Carry, 0b11111111, Assembler::AVX_256bit); //0b-3-3-3-3
__ evpandq(Acc1, Acc1, Mask52, Assembler::AVX_256bit);
__ evpandq(Acc2, Acc2, Mask52, Assembler::AVX_256bit);
// Acc2 = (Acc1 & Mask) | (Acc2 & !Mask)
__ vpandn(Acc2L, Carry, Acc2L, Assembler::AVX_256bit);
__ vpternlogq(Acc2L, 0xF8, Carry, Acc1L, Assembler::AVX_256bit); // A | B&C orAandBC
__ vpandn(Acc2, Carry, Acc2, Assembler::AVX_256bit);
__ vpternlogq(Acc2, 0xF8, Carry, Acc1, Assembler::AVX_256bit);
// output to rLimbs (1 + 4 limbs)
__ movq(Address(rLimbs, 0), Acc1);
__ evpermq(Acc1, k0, shift1R, Acc1, true, Assembler::AVX_512bit);
__ evmovdquq(Address(rLimbs, 8), k0, Acc1, true, Assembler::AVX_256bit);
__ movq(Address(rLimbs, 0), Acc2L);
__ evmovdquq(Address(rLimbs, 8), Acc2, Assembler::AVX_256bit);
// Cleanup
// Zero out zmm0-zmm15, higher registers not used by intrinsic.
__ vzeroall();
}
/**
* Unrolled Word-by-Word Montgomery Multiplication
* r = a * b * 2^-260 (mod P)
*
* Use vpmadd52{l,h}uq multiply for upper four limbs and use
* scalar mulq for the lowest limb.
*
* One has to be careful with mulq vs vpmadd52 'crossovers'; mulq high/low
* is split as 40:64 bits vs 52:52 in the vector version. Shifts are required
* to line up values before addition (see following ascii art)
*
* Pseudocode:
*
* +--+--+--+--+ +--+
* M = load(*modulus_p256) |m5|m4|m3|m2| |m1|
* +--+--+--+--+ +--+
* A = load(*aLimbs) |a5|a4|a3|a2| |a1|
* +--+--+--+--+ +--+
* Acc1 = 0 | 0| 0| 0| 0| | 0|
* +--+--+--+--+ +--+
* ---- for i = 0 to 4
* +--+--+--+--+ +--+
* Acc2 = 0 | 0| 0| 0| 0| | 0|
* +--+--+--+--+ +--+
* B = replicate(bLimbs[i]) |bi|bi|bi|bi| |bi|
* +--+--+--+--+ +--+
* +--+--+--+--+ +--+
* |a5|a4|a3|a2| |a1|
* Acc1 += A * B *|bi|bi|bi|bi| |bi|
* Acc1+=|c5|c4|c3|c2| |c1|
* +--+--+--+--+ +--+
* |a5|a4|a3|a2| |a1|
* Acc2 += A *h B *h|bi|bi|bi|bi| |bi|
* Acc2+=|d5|d4|d3|d2| |d1|
* +--+--+--+--+ +--+
* N = replicate(Acc1[0]) |n0|n0|n0|n0| |n0|
* +--+--+--+--+ +--+
* +--+--+--+--+ +--+
* |m5|m4|m3|m2| |m1|
* Acc1 += M * N *|n0|n0|n0|n0| |n0|
* Acc1+=|c5|c4|c3|c2| |c1| Note: 52 low bits of c1 == 0 due to Montgomery!
* +--+--+--+--+ +--+
* |m5|m4|m3|m2| |m1|
* Acc2 += M *h N *h|n0|n0|n0|n0| |n0|
* Acc2+=|d5|d4|d3|d2| |d1|
* +--+--+--+--+ +--+
* // Combine high/low partial sums Acc1 + Acc2
* +--+
* carry = Acc1[0] >> 52 |c1|
* +--+
* Acc2[0] += carry |d1|
* +--+
* +--+--+--+--+ +--+
* Acc1 = Acc1 shift one q element>> | 0|c5|c4|c3| |c2|
* +|d5|d4|d3|d2| |d1|
* Acc1 = Acc1 + Acc2 Acc1+=|c5|c4|c3|c2| |c1|
* +--+--+--+--+ +--+
* ---- done
* +--+--+--+--+ +--+
* Acc2 = Acc1 - M |d5|d4|d3|d2| |d1|
* +--+--+--+--+ +--+
* Carry propagate Acc2
* Carry propagate Acc1
* Mask = sign(Acc2)
* Result = select(Mask ? Acc1 or Acc2)
*
* Acc1 can overflow by one modulus (hence Acc2); Either Acc1 or Acc2 contain
* the correct result. However, they both need carry propagation (i.e. normalize
* limbs down to 52 bits each).
*
* Carry propagation would require relatively expensive vector lane operations,
* so instead dump to memory and read as scalar registers
*
* Note: the order of reduce-then-propagate vs propagate-then-reduce is different
* in Java
*/
void montgomeryMultiplyAVX2(const Register aLimbs, const Register bLimbs, const Register rLimbs,
const Register tmp_rax, const Register tmp_rdx, const Register tmp1, const Register tmp2,
const Register tmp3, const Register tmp4, const Register tmp5, const Register tmp6,
const Register tmp7, MacroAssembler* _masm) {
Register rscratch = tmp1;
// Inputs
Register a = tmp1;
XMMRegister A = xmm0;
XMMRegister B = xmm1;
// Intermediates
Register acc1 = tmp2;
XMMRegister Acc1 = xmm3;
Register acc2 = tmp3;
XMMRegister Acc2 = xmm4;
XMMRegister N = xmm5;
XMMRegister Carry = xmm6;
// Constants
Register modulus = tmp4;
XMMRegister Modulus = xmm7;
Register mask52 = tmp5;
XMMRegister Mask52 = xmm8;
XMMRegister MaskLimb5 = xmm9;
XMMRegister Zero = xmm10;
__ mov64(mask52, P256_MASK52[0]);
__ movq(Mask52, mask52);
__ vpbroadcastq(Mask52, Mask52, Assembler::AVX_256bit);
__ vmovdqa(MaskLimb5, ExternalAddress(mask_limb5()), Assembler::AVX_256bit, rscratch);
__ vpxor(Zero, Zero, Zero, Assembler::AVX_256bit);
// M = load(*modulus_p256)
__ movq(modulus, mask52);
__ vmovdqu(Modulus, ExternalAddress(modulus_p256(1)), Assembler::AVX_256bit, rscratch);
// A = load(*aLimbs);
__ movq(a, Address(aLimbs, 0));
__ vmovdqu(A, Address(aLimbs, 8)); //Assembler::AVX_256bit
// Acc1 = 0
__ vpxor(Acc1, Acc1, Acc1, Assembler::AVX_256bit);
for (int i = 0; i< 5; i++) {
// Acc2 = 0
__ vpxor(Acc2, Acc2, Acc2, Assembler::AVX_256bit);
// B = replicate(bLimbs[i])
__ movq(tmp_rax, Address(bLimbs, i*8)); //(b==rax)
__ vpbroadcastq(B, Address(bLimbs, i*8), Assembler::AVX_256bit);
// Acc1 += A * B
// Acc2 += A *h B
__ mulq(a); // rdx:rax = a*rax
if (i == 0) {
__ movq(acc1, tmp_rax);
__ movq(acc2, tmp_rdx);
} else {
// Careful with limb size/carries; from mulq, tmp_rax uses full 64 bits
__ xorq(acc2, acc2);
__ addq(acc1, tmp_rax);
__ adcq(acc2, tmp_rdx);
}
__ vpmadd52luq(Acc1, A, B, Assembler::AVX_256bit);
__ vpmadd52huq(Acc2, A, B, Assembler::AVX_256bit);
// N = replicate(Acc1[0])
if (i != 0) {
__ movq(tmp_rax, acc1); // (n==rax)
}
__ andq(tmp_rax, mask52);
__ movq(N, acc1); // masking implicit in vpmadd52
__ vpbroadcastq(N, N, Assembler::AVX_256bit);
// Acc1 += M * N
__ mulq(modulus); // rdx:rax = modulus*rax
__ vpmadd52luq(Acc1, Modulus, N, Assembler::AVX_256bit);
__ addq(acc1, tmp_rax); //carry flag set!
// Acc2 += M *h N
__ adcq(acc2, tmp_rdx);
__ vpmadd52huq(Acc2, Modulus, N, Assembler::AVX_256bit);
// Combine high/low partial sums Acc1 + Acc2
// carry = Acc1[0] >> 52
__ shrq(acc1, 52); // low 52 of acc1 ignored, is zero, because Montgomery
// Acc2[0] += carry
__ shlq(acc2, 12);
__ addq(acc2, acc1);
// Acc1 = Acc1 shift one q element >>
__ movq(acc1, Acc1);
__ vpermq(Acc1, Acc1, 0b11111001, Assembler::AVX_256bit);
__ vpand(Acc1, Acc1, MaskLimb5, Assembler::AVX_256bit);
// Acc1 = Acc1 + Acc2
__ addq(acc1, acc2);
__ vpaddq(Acc1, Acc1, Acc2, Assembler::AVX_256bit);
}
__ movq(acc2, acc1);
__ subq(acc2, modulus);
__ vpsubq(Acc2, Acc1, Modulus, Assembler::AVX_256bit);
__ vmovdqa(Address(rsp, 0), Acc2); //Assembler::AVX_256bit
// Carry propagate the subtraction result Acc2 first (since the last carry is
// used to select result). Careful, following registers overlap:
// acc1 = tmp2; acc2 = tmp3; mask52 = tmp5
// Note that Acc2 limbs are signed (i.e. result of a subtract with modulus)
// i.e. using signed shift is needed for correctness
Register limb[] = {acc2, tmp1, tmp4, tmp_rdx, tmp6};
Register carry = tmp_rax;
for (int i = 0; i<5; i++) {
if (i > 0) {
__ movq(limb[i], Address(rsp, -8+i*8));
__ addq(limb[i], carry);
}
__ movq(carry, limb[i]);
if (i==4) break;
__ sarq(carry, 52);
}
__ sarq(carry, 63);
__ notq(carry); //select
Register select = carry;
carry = tmp7;
// Now carry propagate the multiply result and (constant-time) select correct
// output digit
Register digit = acc1;
__ vmovdqa(Address(rsp, 0), Acc1); //Assembler::AVX_256bit
for (int i = 0; i<5; i++) {
if (i>0) {
__ movq(digit, Address(rsp, -8+i*8));
__ addq(digit, carry);
}
__ movq(carry, digit);
__ sarq(carry, 52);
// long dummyLimbs = maskValue & (a[i] ^ b[i]);
// a[i] = dummyLimbs ^ a[i];
__ xorq(limb[i], digit);
__ andq(limb[i], select);
__ xorq(digit, limb[i]);
__ andq(digit, mask52);
__ movq(Address(rLimbs, i*8), digit);
}
// Cleanup
// Zero out ymm0-ymm15.
__ vzeroall();
__ vpxor(Acc1, Acc1, Acc1, Assembler::AVX_256bit);
__ vmovdqa(Address(rsp, 0), Acc1); //Assembler::AVX_256bit
}
address StubGenerator::generate_intpoly_montgomeryMult_P256() {
@ -242,13 +564,58 @@ address StubGenerator::generate_intpoly_montgomeryMult_P256() {
address start = __ pc();
__ enter();
// Register Map
const Register aLimbs = c_rarg0; // rdi | rcx
const Register bLimbs = c_rarg1; // rsi | rdx
const Register rLimbs = c_rarg2; // rdx | r8
const Register tmp = r9;
if (EnableX86ECoreOpts && UseAVX > 1) {
__ push(r12);
__ push(r13);
__ push(r14);
#ifdef _WIN64
__ push(rsi);
__ push(rdi);
#endif
__ push(rbp);
__ movq(rbp, rsp);
__ andq(rsp, -32);
__ subptr(rsp, 32);
montgomeryMultiply(aLimbs, bLimbs, rLimbs, tmp, _masm);
// Register Map
const Register aLimbs = c_rarg0; // c_rarg0: rdi | rcx
const Register bLimbs = rsi; // c_rarg1: rsi | rdx
const Register rLimbs = r8; // c_rarg2: rdx | r8
const Register tmp1 = r9;
const Register tmp2 = r10;
const Register tmp3 = r11;
const Register tmp4 = r12;
const Register tmp5 = r13;
const Register tmp6 = r14;
#ifdef _WIN64
const Register tmp7 = rdi;
__ movq(bLimbs, c_rarg1); // free-up rdx
#else
const Register tmp7 = rcx;
__ movq(rLimbs, c_rarg2); // free-up rdx
#endif
montgomeryMultiplyAVX2(aLimbs, bLimbs, rLimbs, rax, rdx,
tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, _masm);
__ movq(rsp, rbp);
__ pop(rbp);
#ifdef _WIN64
__ pop(rdi);
__ pop(rsi);
#endif
__ pop(r14);
__ pop(r13);
__ pop(r12);
} else {
// Register Map
const Register aLimbs = c_rarg0; // rdi | rcx
const Register bLimbs = c_rarg1; // rsi | rdx
const Register rLimbs = c_rarg2; // rdx | r8
const Register tmp = r9;
montgomeryMultiply(aLimbs, bLimbs, rLimbs, tmp, _masm);
}
__ leave();
__ ret(0);
@ -259,17 +626,34 @@ address StubGenerator::generate_intpoly_montgomeryMult_P256() {
// Must be:
// - constant time (i.e. no branches)
// - no-side channel (i.e. all memory must always be accessed, and in same order)
void assign_avx(XMMRegister A, Address aAddr, XMMRegister B, Address bAddr, KRegister select, int vector_len, MacroAssembler* _masm) {
__ evmovdquq(A, aAddr, vector_len);
__ evmovdquq(B, bAddr, vector_len);
__ evmovdquq(A, select, B, true, vector_len);
__ evmovdquq(aAddr, A, vector_len);
}
void assign_avx(Register aBase, Register bBase, int offset, XMMRegister select, XMMRegister tmp, XMMRegister aTmp, int vector_len, MacroAssembler* _masm) {
if (vector_len == Assembler::AVX_512bit && UseAVX < 3) {
assign_avx(aBase, bBase, offset, select, tmp, aTmp, Assembler::AVX_256bit, _masm);
assign_avx(aBase, bBase, offset + 32, select, tmp, aTmp, Assembler::AVX_256bit, _masm);
return;
}
Address aAddr = Address(aBase, offset);
Address bAddr = Address(bBase, offset);
void assign_scalar(Address aAddr, Address bAddr, Register select, Register tmp, MacroAssembler* _masm) {
// Original java:
// long dummyLimbs = maskValue & (a[i] ^ b[i]);
// a[i] = dummyLimbs ^ a[i];
__ vmovdqu(tmp, aAddr, vector_len);
__ vmovdqu(aTmp, tmp, vector_len);
__ vpxor(tmp, tmp, bAddr, vector_len);
__ vpand(tmp, tmp, select, vector_len);
__ vpxor(tmp, tmp, aTmp, vector_len);
__ vmovdqu(aAddr, tmp, vector_len);
}
void assign_scalar(Register aBase, Register bBase, int offset, Register select, Register tmp, MacroAssembler* _masm) {
// Original java:
// long dummyLimbs = maskValue & (a[i] ^ b[i]);
// a[i] = dummyLimbs ^ a[i];
Address aAddr = Address(aBase, offset);
Address bAddr = Address(bBase, offset);
__ movq(tmp, aAddr);
__ xorq(tmp, bAddr);
@ -308,13 +692,18 @@ address StubGenerator::generate_intpoly_assign() {
const Register length = c_rarg3;
XMMRegister A = xmm0;
XMMRegister B = xmm1;
XMMRegister select = xmm2;
Register tmp = r9;
KRegister select = k1;
Label L_Length5, L_Length10, L_Length14, L_Length16, L_Length19, L_DefaultLoop, L_Done;
__ negq(set);
__ kmovql(select, set);
if (UseAVX > 2) {
__ evpbroadcastq(select, set, Assembler::AVX_512bit);
} else {
__ movq(select, set);
__ vpbroadcastq(select, select, Assembler::AVX_256bit);
}
// NOTE! Crypto code cannot branch on user input. However; allowed to branch on number of limbs;
// Number of limbs is a constant in each IntegerPolynomial (i.e. this side-channel branch leaks
@ -334,7 +723,7 @@ address StubGenerator::generate_intpoly_assign() {
__ cmpl(length, 0);
__ jcc(Assembler::lessEqual, L_Done);
__ bind(L_DefaultLoop);
assign_scalar(Address(aLimbs, 0), Address(bLimbs, 0), set, tmp, _masm);
assign_scalar(aLimbs, bLimbs, 0, set, tmp, _masm);
__ subl(length, 1);
__ lea(aLimbs, Address(aLimbs,8));
__ lea(bLimbs, Address(bLimbs,8));
@ -343,31 +732,31 @@ address StubGenerator::generate_intpoly_assign() {
__ jmp(L_Done);
__ bind(L_Length5); // 1 + 4
assign_scalar(Address(aLimbs, 0), Address(bLimbs, 0), set, tmp, _masm);
assign_avx(A, Address(aLimbs, 8), B, Address(bLimbs, 8), select, Assembler::AVX_256bit, _masm);
assign_scalar(aLimbs, bLimbs, 0, set, tmp, _masm);
assign_avx (aLimbs, bLimbs, 8, select, A, B, Assembler::AVX_256bit, _masm);
__ jmp(L_Done);
__ bind(L_Length10); // 2 + 8
assign_avx(A, Address(aLimbs, 0), B, Address(bLimbs, 0), select, Assembler::AVX_128bit, _masm);
assign_avx(A, Address(aLimbs, 16), B, Address(bLimbs, 16), select, Assembler::AVX_512bit, _masm);
assign_avx(aLimbs, bLimbs, 0, select, A, B, Assembler::AVX_128bit, _masm);
assign_avx(aLimbs, bLimbs, 16, select, A, B, Assembler::AVX_512bit, _masm);
__ jmp(L_Done);
__ bind(L_Length14); // 2 + 4 + 8
assign_avx(A, Address(aLimbs, 0), B, Address(bLimbs, 0), select, Assembler::AVX_128bit, _masm);
assign_avx(A, Address(aLimbs, 16), B, Address(bLimbs, 16), select, Assembler::AVX_256bit, _masm);
assign_avx(A, Address(aLimbs, 48), B, Address(bLimbs, 48), select, Assembler::AVX_512bit, _masm);
assign_avx(aLimbs, bLimbs, 0, select, A, B, Assembler::AVX_128bit, _masm);
assign_avx(aLimbs, bLimbs, 16, select, A, B, Assembler::AVX_256bit, _masm);
assign_avx(aLimbs, bLimbs, 48, select, A, B, Assembler::AVX_512bit, _masm);
__ jmp(L_Done);
__ bind(L_Length16); // 8 + 8
assign_avx(A, Address(aLimbs, 0), B, Address(bLimbs, 0), select, Assembler::AVX_512bit, _masm);
assign_avx(A, Address(aLimbs, 64), B, Address(bLimbs, 64), select, Assembler::AVX_512bit, _masm);
assign_avx(aLimbs, bLimbs, 0, select, A, B, Assembler::AVX_512bit, _masm);
assign_avx(aLimbs, bLimbs, 64, select, A, B, Assembler::AVX_512bit, _masm);
__ jmp(L_Done);
__ bind(L_Length19); // 1 + 2 + 8 + 8
assign_scalar(Address(aLimbs, 0), Address(bLimbs, 0), set, tmp, _masm);
assign_avx(A, Address(aLimbs, 8), B, Address(bLimbs, 8), select, Assembler::AVX_128bit, _masm);
assign_avx(A, Address(aLimbs, 24), B, Address(bLimbs, 24), select, Assembler::AVX_512bit, _masm);
assign_avx(A, Address(aLimbs, 88), B, Address(bLimbs, 88), select, Assembler::AVX_512bit, _masm);
assign_scalar(aLimbs, bLimbs, 0, set, tmp, _masm);
assign_avx (aLimbs, bLimbs, 8, select, A, B, Assembler::AVX_128bit, _masm);
assign_avx (aLimbs, bLimbs, 24, select, A, B, Assembler::AVX_512bit, _masm);
assign_avx (aLimbs, bLimbs, 88, select, A, B, Assembler::AVX_512bit, _masm);
__ bind(L_Done);
__ leave();

View File

@ -1406,7 +1406,7 @@ void VM_Version::get_processor_features() {
}
#ifdef _LP64
if (supports_avx512ifma() && supports_avx512vlbw()) {
if ((supports_avx512ifma() && supports_avx512vlbw()) || supports_avxifma()) {
if (FLAG_IS_DEFAULT(UseIntPolyIntrinsics)) {
FLAG_SET_DEFAULT(UseIntPolyIntrinsics, true);
}

View File

@ -532,7 +532,7 @@ class methodHandle;
/* support for sun.security.util.math.intpoly.MontgomeryIntegerPolynomialP256 */ \
do_class(sun_security_util_math_intpoly_MontgomeryIntegerPolynomialP256, "sun/security/util/math/intpoly/MontgomeryIntegerPolynomialP256") \
do_intrinsic(_intpoly_montgomeryMult_P256, sun_security_util_math_intpoly_MontgomeryIntegerPolynomialP256, intPolyMult_name, intPolyMult_signature, F_R) \
do_name(intPolyMult_name, "multImpl") \
do_name(intPolyMult_name, "mult") \
do_signature(intPolyMult_signature, "([J[J[J)V") \
\
do_class(sun_security_util_math_intpoly_IntegerPolynomial, "sun/security/util/math/intpoly/IntegerPolynomial") \

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -159,14 +159,8 @@ public final class MontgomeryIntegerPolynomialP256 extends IntegerPolynomial
* numAdds to reuse existing overflow logic.
*/
@Override
protected void mult(long[] a, long[] b, long[] r) {
multImpl(a, b, r);
reducePositive(r);
}
@ForceInline
@IntrinsicCandidate
private void multImpl(long[] a, long[] b, long[] r) {
protected void mult(long[] a, long[] b, long[] r) {
long aa0 = a[0];
long aa1 = a[1];
long aa2 = a[2];
@ -398,17 +392,43 @@ public final class MontgomeryIntegerPolynomialP256 extends IntegerPolynomial
dd4 += Math.unsignedMultiplyHigh(n, modulus[4]) << shift1 | (n4 >>> shift2);
d4 += n4 & LIMB_MASK;
// Final carry propagate
c5 += d1 + dd0 + (d0 >>> BITS_PER_LIMB);
c6 += d2 + dd1;
c7 += d3 + dd2;
c8 += d4 + dd3;
c9 = dd4;
c6 += d2 + dd1 + (c5 >>> BITS_PER_LIMB);
c7 += d3 + dd2 + (c6 >>> BITS_PER_LIMB);
c8 += d4 + dd3 + (c7 >>> BITS_PER_LIMB);
c9 = dd4 + (c8 >>> BITS_PER_LIMB);
r[0] = c5;
r[1] = c6;
r[2] = c7;
r[3] = c8;
r[4] = c9;
c5 &= LIMB_MASK;
c6 &= LIMB_MASK;
c7 &= LIMB_MASK;
c8 &= LIMB_MASK;
// At this point, the result {c5, c6, c7, c8, c9} could overflow by
// one modulus. Subtract one modulus (with carry propagation), into
// {c0, c1, c2, c3, c4}. Note that in this calculation, limbs are
// signed
c0 = c5 - modulus[0];
c1 = c6 - modulus[1] + (c0 >> BITS_PER_LIMB);
c0 &= LIMB_MASK;
c2 = c7 - modulus[2] + (c1 >> BITS_PER_LIMB);
c1 &= LIMB_MASK;
c3 = c8 - modulus[3] + (c2 >> BITS_PER_LIMB);
c2 &= LIMB_MASK;
c4 = c9 - modulus[4] + (c3 >> BITS_PER_LIMB);
c3 &= LIMB_MASK;
// We now must select a result that is in range of [0,modulus). i.e.
// either {c0-4} or {c5-9}. Iff {c0-4} is negative, then {c5-9} contains
// the result. (After carry propagation) IF c4 is negative, {c0-4} is
// negative. Arithmetic shift by 64 bits generates a mask from c4 that
// can be used to select 'constant time' either {c0-4} or {c5-9}.
long mask = c4 >> 63;
r[0] = ((c5 & mask) | (c0 & ~mask));
r[1] = ((c6 & mask) | (c1 & ~mask));
r[2] = ((c7 & mask) | (c2 & ~mask));
r[3] = ((c8 & mask) | (c3 & ~mask));
r[4] = ((c9 & mask) | (c4 & ~mask));
}
@Override

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, Intel Corporation. All rights reserved.
* Copyright (c) 2024, 2025, Intel Corporation. All rights reserved.
*
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
@ -23,9 +23,8 @@
*/
import java.util.Random;
import sun.security.util.math.IntegerMontgomeryFieldModuloP;
import sun.security.util.math.ImmutableIntegerModuloP;
import java.math.BigInteger;
import sun.security.util.math.*;
import sun.security.util.math.intpoly.*;
/*
@ -35,7 +34,7 @@ import sun.security.util.math.intpoly.*;
* java.base/sun.security.util.math.intpoly
* @run main/othervm -XX:+UnlockDiagnosticVMOptions -XX:-UseIntPolyIntrinsics
* MontgomeryPolynomialFuzzTest
* @summary Unit test MontgomeryPolynomialFuzzTest.
* @summary Unit test MontgomeryPolynomialFuzzTest without intrinsic, plain java
*/
/*
@ -45,10 +44,11 @@ import sun.security.util.math.intpoly.*;
* java.base/sun.security.util.math.intpoly
* @run main/othervm -XX:+UnlockDiagnosticVMOptions -XX:+UseIntPolyIntrinsics
* MontgomeryPolynomialFuzzTest
* @summary Unit test MontgomeryPolynomialFuzzTest.
* @summary Unit test MontgomeryPolynomialFuzzTest with intrinsic enabled
*/
// This test case is NOT entirely deterministic, it uses a random seed for pseudo-random number generator
// This test case is NOT entirely deterministic, it uses a random seed for
// pseudo-random number generator
// If a failure occurs, hardcode the seed to make the test case deterministic
public class MontgomeryPolynomialFuzzTest {
public static void main(String[] args) throws Exception {
@ -60,15 +60,38 @@ public class MontgomeryPolynomialFuzzTest {
System.out.println("Fuzz Success");
}
private static void check(BigInteger reference,
private static void checkOverflow(String opMsg,
ImmutableIntegerModuloP testValue, long seed) {
if (!reference.equals(testValue.asBigInteger())) {
throw new RuntimeException("SEED: " + seed);
long limbs[] = testValue.getLimbs();
BigInteger mod = MontgomeryIntegerPolynomialP256.ONE.MODULUS;
BigInteger ref = BigInteger.ZERO;
for (int i = 0; i<limbs.length; i++) {
ref.add(BigInteger.valueOf(limbs[i]).shiftLeft(i*52));
}
if (ref.compareTo(mod)!=-1) {
String msg = "Error while " + opMsg + System.lineSeparator()
+ ref.toString(16) + " != " + mod.toString(16) + System.lineSeparator()
+ "To reproduce, set SEED to [" + seed + "L]: ";
throw new RuntimeException(msg);
}
}
private static void check(String opMsg, BigInteger reference,
ImmutableIntegerModuloP testValue, long seed) {
BigInteger test = testValue.asBigInteger();
if (!reference.equals(test)) {
String msg = "Error while " + opMsg + System.lineSeparator()
+ reference.toString(16) + " != " + test.toString(16)
+ System.lineSeparator()+ "To reproduce, set SEED to ["
+ seed + "L]: ";
throw new RuntimeException(msg);
}
}
public static void run() throws Exception {
Random rnd = new Random();
// To reproduce an error, fix the value of the seed to the value from
// the failure
long seed = rnd.nextLong();
rnd.setSeed(seed);
@ -77,22 +100,75 @@ public class MontgomeryPolynomialFuzzTest {
BigInteger r = BigInteger.ONE.shiftLeft(260).mod(P);
BigInteger rInv = r.modInverse(P);
BigInteger aRef = (new BigInteger(P.bitLength(), rnd)).mod(P);
BigInteger bRef = (new BigInteger(P.bitLength(), rnd)).mod(P);
SmallValue two = montField.getSmallValue(2);
SmallValue three = montField.getSmallValue(3);
SmallValue four = montField.getSmallValue(4);
// Test conversion to montgomery domain
ImmutableIntegerModuloP a = montField.getElement(aRef);
String msg = "converting "+aRef.toString(16) + " to montgomery domain";
aRef = aRef.multiply(r).mod(P);
check(aRef, a, seed);
check(msg, aRef, a, seed);
checkOverflow(msg, a, seed);
ImmutableIntegerModuloP b = montField.getElement(bRef);
msg = "converting "+aRef.toString(16) + " to montgomery domain";
bRef = bRef.multiply(r).mod(P);
check(msg, bRef, b, seed);
checkOverflow(msg, b, seed);
if (rnd.nextBoolean()) {
msg = "squaring "+aRef.toString(16);
aRef = aRef.multiply(aRef).multiply(rInv).mod(P);
a = a.multiply(a);
check(aRef, a, seed);
check(msg, aRef, a, seed);
checkOverflow(msg, a, seed);
}
if (rnd.nextBoolean()) {
msg = "doubling "+aRef.toString(16);
aRef = aRef.add(aRef).mod(P);
a = a.add(a);
check(aRef, a, seed);
check(msg, aRef, a, seed);
}
if (rnd.nextBoolean()) {
msg = "subtracting "+bRef.toString(16)+" from "+aRef.toString(16);
aRef = aRef.subtract(bRef).mod(P);
a = a.mutable().setDifference(b).fixed();
check(msg, aRef, a, seed);
}
if (rnd.nextBoolean()) {
msg = "multiplying "+bRef.toString(16)+" with "+aRef.toString(16);
aRef = aRef.multiply(bRef).multiply(rInv).mod(P);
a = a.multiply(b);
check(msg, aRef, a, seed);
checkOverflow(msg, a, seed);
}
if (rnd.nextBoolean()) {
msg = "multiplying "+aRef.toString(16)+" with constant 2";
aRef = aRef.multiply(BigInteger.valueOf(2)).mod(P);
a = a.mutable().setProduct(two).fixed();
check(msg, aRef, a, seed);
}
if (rnd.nextBoolean()) {
msg = "multiplying "+aRef.toString(16)+" with constant 3";
aRef = aRef.multiply(BigInteger.valueOf(3)).mod(P);
a = a.mutable().setProduct(three).fixed();
check(msg, aRef, a, seed);
checkOverflow(msg, a, seed);
}
if (rnd.nextBoolean()) {
msg = "multiplying "+aRef.toString(16)+" with constant 4";
aRef = aRef.multiply(BigInteger.valueOf(4)).mod(P);
a = a.mutable().setProduct(four).fixed();
check(msg, aRef, a, seed);
checkOverflow(msg, a, seed);
}
}
}