8351034: Add AVX-512 intrinsics for ML-DSA

Reviewed-by: sviswanathan, lmesnik, vpaprotski, jbhateja
This commit is contained in:
Ferenc Rakoczi 2025-04-09 21:15:45 +00:00 committed by Sandhya Viswanathan
parent e3f26b056e
commit e87ff328d5
10 changed files with 1280 additions and 42 deletions

View File

@ -85,7 +85,7 @@
do_arch_blob, \
do_arch_entry, \
do_arch_entry_init) \
do_arch_blob(compiler, 20000 LP64_ONLY(+64000) WINDOWS_ONLY(+2000)) \
do_arch_blob(compiler, 20000 LP64_ONLY(+89000) WINDOWS_ONLY(+2000)) \
do_stub(compiler, vector_float_sign_mask) \
do_arch_entry(x86, compiler, vector_float_sign_mask, \
vector_float_sign_mask, vector_float_sign_mask) \

View File

@ -4204,6 +4204,8 @@ void StubGenerator::generate_compiler_stubs() {
generate_chacha_stubs();
generate_dilithium_stubs();
generate_sha3_stubs();
// data cache line writeback

View File

@ -489,8 +489,9 @@ class StubGenerator: public StubCodeGenerator {
// SHA3 stubs
void generate_sha3_stubs();
address generate_sha3_implCompress(StubGenStubId stub_id);
// Dilithium stubs and helper functions
void generate_dilithium_stubs();
// BASE64 stubs
address base64_shuffle_addr();

File diff suppressed because it is too large Load Diff

View File

@ -38,6 +38,8 @@
#define BIND(label) bind(label); BLOCK_COMMENT(#label ":")
#define xmm(i) as_XMMRegister(i)
// Constants
ATTRIBUTE_ALIGNED(64) static const uint64_t round_consts_arr[24] = {
0x0000000000000001L, 0x0000000000008082L, 0x800000000000808AL,
@ -79,13 +81,6 @@ static address permsAndRotsAddr() {
return (address) permsAndRots;
}
void StubGenerator::generate_sha3_stubs() {
if (UseSHA3Intrinsics) {
StubRoutines::_sha3_implCompress = generate_sha3_implCompress(StubGenStubId::sha3_implCompress_id);
StubRoutines::_sha3_implCompressMB = generate_sha3_implCompress(StubGenStubId::sha3_implCompressMB_id);
}
}
// Arguments:
//
// Inputs:
@ -95,7 +90,9 @@ void StubGenerator::generate_sha3_stubs() {
// c_rarg3 - int offset
// c_rarg4 - int limit
//
address StubGenerator::generate_sha3_implCompress(StubGenStubId stub_id) {
static address generate_sha3_implCompress(StubGenStubId stub_id,
StubGenerator *stubgen,
MacroAssembler *_masm) {
bool multiBlock;
switch(stub_id) {
case sha3_implCompress_id:
@ -109,7 +106,7 @@ address StubGenerator::generate_sha3_implCompress(StubGenStubId stub_id) {
}
__ align(CodeEntryAlignment);
StubCodeMark mark(this, stub_id);
StubCodeMark mark(stubgen, stub_id);
address start = __ pc();
const Register buf = c_rarg0;
@ -154,29 +151,16 @@ address StubGenerator::generate_sha3_implCompress(StubGenStubId stub_id) {
__ kshiftrwl(k1, k5, 4);
// load the state
__ evmovdquq(xmm0, k5, Address(state, 0), false, Assembler::AVX_512bit);
__ evmovdquq(xmm1, k5, Address(state, 40), false, Assembler::AVX_512bit);
__ evmovdquq(xmm2, k5, Address(state, 80), false, Assembler::AVX_512bit);
__ evmovdquq(xmm3, k5, Address(state, 120), false, Assembler::AVX_512bit);
__ evmovdquq(xmm4, k5, Address(state, 160), false, Assembler::AVX_512bit);
for (int i = 0; i < 5; i++) {
__ evmovdquq(xmm(i), k5, Address(state, i * 40), false, Assembler::AVX_512bit);
}
// load the permutation and rotation constants
__ evmovdquq(xmm17, Address(permsAndRots, 0), Assembler::AVX_512bit);
__ evmovdquq(xmm18, Address(permsAndRots, 64), Assembler::AVX_512bit);
__ evmovdquq(xmm19, Address(permsAndRots, 128), Assembler::AVX_512bit);
__ evmovdquq(xmm20, Address(permsAndRots, 192), Assembler::AVX_512bit);
__ evmovdquq(xmm21, Address(permsAndRots, 256), Assembler::AVX_512bit);
__ evmovdquq(xmm22, Address(permsAndRots, 320), Assembler::AVX_512bit);
__ evmovdquq(xmm23, Address(permsAndRots, 384), Assembler::AVX_512bit);
__ evmovdquq(xmm24, Address(permsAndRots, 448), Assembler::AVX_512bit);
__ evmovdquq(xmm25, Address(permsAndRots, 512), Assembler::AVX_512bit);
__ evmovdquq(xmm26, Address(permsAndRots, 576), Assembler::AVX_512bit);
__ evmovdquq(xmm27, Address(permsAndRots, 640), Assembler::AVX_512bit);
__ evmovdquq(xmm28, Address(permsAndRots, 704), Assembler::AVX_512bit);
__ evmovdquq(xmm29, Address(permsAndRots, 768), Assembler::AVX_512bit);
__ evmovdquq(xmm30, Address(permsAndRots, 832), Assembler::AVX_512bit);
__ evmovdquq(xmm31, Address(permsAndRots, 896), Assembler::AVX_512bit);
for (int i = 0; i < 15; i++) {
__ evmovdquq(xmm(i + 17), Address(permsAndRots, i * 64), Assembler::AVX_512bit);
}
__ align(OptoLoopAlignment);
__ BIND(sha3_loop);
// there will be 24 keccak rounds
@ -231,6 +215,7 @@ address StubGenerator::generate_sha3_implCompress(StubGenStubId stub_id) {
// The implementation closely follows the Java version, with the state
// array "rows" in the lowest 5 64-bit slots of zmm0 - zmm4, i.e.
// each row of the SHA3 specification is located in one zmm register.
__ align(OptoLoopAlignment);
__ BIND(rounds24_loop);
__ subl(roundsLeft, 1);
@ -257,7 +242,7 @@ address StubGenerator::generate_sha3_implCompress(StubGenStubId stub_id) {
// Do the cyclical permutation of the 24 moving state elements
// and the required rotations within each element (the combined
// rho and sigma steps).
// rho and pi steps).
__ evpermt2q(xmm4, xmm17, xmm3, Assembler::AVX_512bit);
__ evpermt2q(xmm3, xmm18, xmm2, Assembler::AVX_512bit);
__ evpermt2q(xmm2, xmm17, xmm1, Assembler::AVX_512bit);
@ -279,7 +264,7 @@ address StubGenerator::generate_sha3_implCompress(StubGenStubId stub_id) {
__ evpermt2q(xmm2, xmm24, xmm4, Assembler::AVX_512bit);
__ evpermt2q(xmm3, xmm25, xmm4, Assembler::AVX_512bit);
__ evpermt2q(xmm4, xmm26, xmm5, Assembler::AVX_512bit);
// The combined rho and sigma steps are done.
// The combined rho and pi steps are done.
// Do the chi step (the same operation on all 5 rows).
// vpternlogq(x, 180, y, z) does x = x ^ (y & ~z).
@ -320,11 +305,9 @@ address StubGenerator::generate_sha3_implCompress(StubGenStubId stub_id) {
}
// store the state
__ evmovdquq(Address(state, 0), k5, xmm0, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state, 40), k5, xmm1, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state, 80), k5, xmm2, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state, 120), k5, xmm3, true, Assembler::AVX_512bit);
__ evmovdquq(Address(state, 160), k5, xmm4, true, Assembler::AVX_512bit);
for (int i = 0; i < 5; i++) {
__ evmovdquq(Address(state, i * 40), k5, xmm(i), true, Assembler::AVX_512bit);
}
__ pop(r14);
__ pop(r13);
@ -335,3 +318,193 @@ address StubGenerator::generate_sha3_implCompress(StubGenStubId stub_id) {
return start;
}
// Inputs:
// c_rarg0 - long[] state0
// c_rarg1 - long[] state1
//
// Performs two keccak() computations in parallel. The steps of the
// two computations are executed interleaved.
static address generate_double_keccak(StubGenerator *stubgen, MacroAssembler *_masm) {
__ align(CodeEntryAlignment);
StubGenStubId stub_id = double_keccak_id;
StubCodeMark mark(stubgen, stub_id);
address start = __ pc();
const Register state0 = c_rarg0;
const Register state1 = c_rarg1;
const Register permsAndRots = c_rarg2;
const Register round_consts = c_rarg3;
const Register constant2use = r10;
const Register roundsLeft = r11;
Label rounds24_loop;
__ enter();
__ lea(permsAndRots, ExternalAddress(permsAndRotsAddr()));
__ lea(round_consts, ExternalAddress(round_constsAddr()));
// set up the masks
__ movl(rax, 0x1F);
__ kmovwl(k5, rax);
__ kshiftrwl(k4, k5, 1);
__ kshiftrwl(k3, k5, 2);
__ kshiftrwl(k2, k5, 3);
__ kshiftrwl(k1, k5, 4);
// load the states
for (int i = 0; i < 5; i++) {
__ evmovdquq(xmm(i), k5, Address(state0, i * 40), false, Assembler::AVX_512bit);
}
for (int i = 0; i < 5; i++) {
__ evmovdquq(xmm(10 + i), k5, Address(state1, i * 40), false, Assembler::AVX_512bit);
}
// load the permutation and rotation constants
for (int i = 0; i < 15; i++) {
__ evmovdquq(xmm(17 + i), Address(permsAndRots, i * 64), Assembler::AVX_512bit);
}
// there will be 24 keccak rounds
// The same operations as the ones in generate_sha3_implCompress are
// performed, but in parallel for two states: one in regs z0-z5, using z6
// as the scratch register and the other in z10-z15, using z16 as the
// scratch register.
// The permutation and rotation constants, that are loaded into z17-z31,
// are shared between the two computations.
__ movl(roundsLeft, 24);
// load round_constants base
__ movptr(constant2use, round_consts);
__ align(OptoLoopAlignment);
__ BIND(rounds24_loop);
__ subl( roundsLeft, 1);
__ evmovdquw(xmm5, xmm0, Assembler::AVX_512bit);
__ evmovdquw(xmm15, xmm10, Assembler::AVX_512bit);
__ vpternlogq(xmm5, 150, xmm1, xmm2, Assembler::AVX_512bit);
__ vpternlogq(xmm15, 150, xmm11, xmm12, Assembler::AVX_512bit);
__ vpternlogq(xmm5, 150, xmm3, xmm4, Assembler::AVX_512bit);
__ vpternlogq(xmm15, 150, xmm13, xmm14, Assembler::AVX_512bit);
__ evprolq(xmm6, xmm5, 1, Assembler::AVX_512bit);
__ evprolq(xmm16, xmm15, 1, Assembler::AVX_512bit);
__ evpermt2q(xmm5, xmm30, xmm5, Assembler::AVX_512bit);
__ evpermt2q(xmm15, xmm30, xmm15, Assembler::AVX_512bit);
__ evpermt2q(xmm6, xmm31, xmm6, Assembler::AVX_512bit);
__ evpermt2q(xmm16, xmm31, xmm16, Assembler::AVX_512bit);
__ vpternlogq(xmm0, 150, xmm5, xmm6, Assembler::AVX_512bit);
__ vpternlogq(xmm10, 150, xmm15, xmm16, Assembler::AVX_512bit);
__ vpternlogq(xmm1, 150, xmm5, xmm6, Assembler::AVX_512bit);
__ vpternlogq(xmm11, 150, xmm15, xmm16, Assembler::AVX_512bit);
__ vpternlogq(xmm2, 150, xmm5, xmm6, Assembler::AVX_512bit);
__ vpternlogq(xmm12, 150, xmm15, xmm16, Assembler::AVX_512bit);
__ vpternlogq(xmm3, 150, xmm5, xmm6, Assembler::AVX_512bit);
__ vpternlogq(xmm13, 150, xmm15, xmm16, Assembler::AVX_512bit);
__ vpternlogq(xmm4, 150, xmm5, xmm6, Assembler::AVX_512bit);
__ vpternlogq(xmm14, 150, xmm15, xmm16, Assembler::AVX_512bit);
__ evpermt2q(xmm4, xmm17, xmm3, Assembler::AVX_512bit);
__ evpermt2q(xmm14, xmm17, xmm13, Assembler::AVX_512bit);
__ evpermt2q(xmm3, xmm18, xmm2, Assembler::AVX_512bit);
__ evpermt2q(xmm13, xmm18, xmm12, Assembler::AVX_512bit);
__ evpermt2q(xmm2, xmm17, xmm1, Assembler::AVX_512bit);
__ evpermt2q(xmm12, xmm17, xmm11, Assembler::AVX_512bit);
__ evpermt2q(xmm1, xmm19, xmm0, Assembler::AVX_512bit);
__ evpermt2q(xmm11, xmm19, xmm10, Assembler::AVX_512bit);
__ evpermt2q(xmm4, xmm20, xmm2, Assembler::AVX_512bit);
__ evpermt2q(xmm14, xmm20, xmm12, Assembler::AVX_512bit);
__ evprolvq(xmm1, xmm1, xmm27, Assembler::AVX_512bit);
__ evprolvq(xmm11, xmm11, xmm27, Assembler::AVX_512bit);
__ evprolvq(xmm3, xmm3, xmm28, Assembler::AVX_512bit);
__ evprolvq(xmm13, xmm13, xmm28, Assembler::AVX_512bit);
__ evprolvq(xmm4, xmm4, xmm29, Assembler::AVX_512bit);
__ evprolvq(xmm14, xmm14, xmm29, Assembler::AVX_512bit);
__ evmovdquw(xmm2, xmm1, Assembler::AVX_512bit);
__ evmovdquw(xmm12, xmm11, Assembler::AVX_512bit);
__ evmovdquw(xmm5, xmm3, Assembler::AVX_512bit);
__ evmovdquw(xmm15, xmm13, Assembler::AVX_512bit);
__ evpermt2q(xmm0, xmm21, xmm4, Assembler::AVX_512bit);
__ evpermt2q(xmm10, xmm21, xmm14, Assembler::AVX_512bit);
__ evpermt2q(xmm1, xmm22, xmm3, Assembler::AVX_512bit);
__ evpermt2q(xmm11, xmm22, xmm13, Assembler::AVX_512bit);
__ evpermt2q(xmm5, xmm22, xmm2, Assembler::AVX_512bit);
__ evpermt2q(xmm15, xmm22, xmm12, Assembler::AVX_512bit);
__ evmovdquw(xmm3, xmm1, Assembler::AVX_512bit);
__ evmovdquw(xmm13, xmm11, Assembler::AVX_512bit);
__ evmovdquw(xmm2, xmm5, Assembler::AVX_512bit);
__ evmovdquw(xmm12, xmm15, Assembler::AVX_512bit);
__ evpermt2q(xmm1, xmm23, xmm4, Assembler::AVX_512bit);
__ evpermt2q(xmm11, xmm23, xmm14, Assembler::AVX_512bit);
__ evpermt2q(xmm2, xmm24, xmm4, Assembler::AVX_512bit);
__ evpermt2q(xmm12, xmm24, xmm14, Assembler::AVX_512bit);
__ evpermt2q(xmm3, xmm25, xmm4, Assembler::AVX_512bit);
__ evpermt2q(xmm13, xmm25, xmm14, Assembler::AVX_512bit);
__ evpermt2q(xmm4, xmm26, xmm5, Assembler::AVX_512bit);
__ evpermt2q(xmm14, xmm26, xmm15, Assembler::AVX_512bit);
__ evpermt2q(xmm5, xmm31, xmm0, Assembler::AVX_512bit);
__ evpermt2q(xmm15, xmm31, xmm10, Assembler::AVX_512bit);
__ evpermt2q(xmm6, xmm31, xmm5, Assembler::AVX_512bit);
__ evpermt2q(xmm16, xmm31, xmm15, Assembler::AVX_512bit);
__ vpternlogq(xmm0, 180, xmm6, xmm5, Assembler::AVX_512bit);
__ vpternlogq(xmm10, 180, xmm16, xmm15, Assembler::AVX_512bit);
__ evpermt2q(xmm5, xmm31, xmm1, Assembler::AVX_512bit);
__ evpermt2q(xmm15, xmm31, xmm11, Assembler::AVX_512bit);
__ evpermt2q(xmm6, xmm31, xmm5, Assembler::AVX_512bit);
__ evpermt2q(xmm16, xmm31, xmm15, Assembler::AVX_512bit);
__ vpternlogq(xmm1, 180, xmm6, xmm5, Assembler::AVX_512bit);
__ vpternlogq(xmm11, 180, xmm16, xmm15, Assembler::AVX_512bit);
__ evpxorq(xmm0, k1, xmm0, Address(constant2use, 0), true, Assembler::AVX_512bit);
__ evpxorq(xmm10, k1, xmm10, Address(constant2use, 0), true, Assembler::AVX_512bit);
__ addptr(constant2use, 8);
__ evpermt2q(xmm5, xmm31, xmm2, Assembler::AVX_512bit);
__ evpermt2q(xmm15, xmm31, xmm12, Assembler::AVX_512bit);
__ evpermt2q(xmm6, xmm31, xmm5, Assembler::AVX_512bit);
__ evpermt2q(xmm16, xmm31, xmm15, Assembler::AVX_512bit);
__ vpternlogq(xmm2, 180, xmm6, xmm5, Assembler::AVX_512bit);
__ vpternlogq(xmm12, 180, xmm16, xmm15, Assembler::AVX_512bit);
__ evpermt2q(xmm5, xmm31, xmm3, Assembler::AVX_512bit);
__ evpermt2q(xmm15, xmm31, xmm13, Assembler::AVX_512bit);
__ evpermt2q(xmm6, xmm31, xmm5, Assembler::AVX_512bit);
__ evpermt2q(xmm16, xmm31, xmm15, Assembler::AVX_512bit);
__ vpternlogq(xmm3, 180, xmm6, xmm5, Assembler::AVX_512bit);
__ vpternlogq(xmm13, 180, xmm16, xmm15, Assembler::AVX_512bit);
__ evpermt2q(xmm5, xmm31, xmm4, Assembler::AVX_512bit);
__ evpermt2q(xmm15, xmm31, xmm14, Assembler::AVX_512bit);
__ evpermt2q(xmm6, xmm31, xmm5, Assembler::AVX_512bit);
__ evpermt2q(xmm16, xmm31, xmm15, Assembler::AVX_512bit);
__ vpternlogq(xmm4, 180, xmm6, xmm5, Assembler::AVX_512bit);
__ vpternlogq(xmm14, 180, xmm16, xmm15, Assembler::AVX_512bit);
__ cmpl(roundsLeft, 0);
__ jcc(Assembler::notEqual, rounds24_loop);
// store the states
for (int i = 0; i < 5; i++) {
__ evmovdquq(Address(state0, i * 40), k5, xmm(i), true, Assembler::AVX_512bit);
}
for (int i = 0; i < 5; i++) {
__ evmovdquq(Address(state1, i * 40), k5, xmm(10 + i), true, Assembler::AVX_512bit);
}
__ leave(); // required for proper stackwalking of RuntimeStub frame
__ ret(0);
return start;
}
void StubGenerator::generate_sha3_stubs() {
if (UseSHA3Intrinsics) {
StubRoutines::_sha3_implCompress =
generate_sha3_implCompress(StubGenStubId::sha3_implCompress_id, this, _masm);
StubRoutines::_double_keccak =
generate_double_keccak(this, _masm);
StubRoutines::_sha3_implCompressMB =
generate_sha3_implCompress(StubGenStubId::sha3_implCompressMB_id, this, _masm);
}
}

View File

@ -1246,6 +1246,20 @@ void VM_Version::get_processor_features() {
}
#endif // _LP64
// Dilithium Intrinsics
// Currently we only have them for AVX512
#ifdef _LP64
if (supports_evex() && supports_avx512bw()) {
if (FLAG_IS_DEFAULT(UseDilithiumIntrinsics)) {
UseDilithiumIntrinsics = true;
}
} else
#endif
if (UseDilithiumIntrinsics) {
warning("Intrinsics for ML-DSA are not available on this CPU.");
FLAG_SET_DEFAULT(UseDilithiumIntrinsics, false);
}
// Base64 Intrinsics (Check the condition for which the intrinsic will be active)
if (UseAVX >= 2) {
if (FLAG_IS_DEFAULT(UseBASE64Intrinsics)) {

View File

@ -570,7 +570,7 @@ class methodHandle;
do_signature(chacha20Block_signature, "([I[B)I") \
\
/* support for sun.security.provider.ML_DSA */ \
do_class(sun_security_provider_ML_DSA, "sun/security/provider/ML_DSA") \
do_class(sun_security_provider_ML_DSA, "sun/security/provider/ML_DSA") \
do_signature(IaII_signature, "([II)I") \
do_signature(IaIaI_signature, "([I[I)I") \
do_signature(IaIaIaI_signature, "([I[I[I)I") \

View File

@ -740,11 +740,11 @@
do_stub(compiler, sha3_implCompress) \
do_entry(compiler, sha3_implCompress, sha3_implCompress, \
sha3_implCompress) \
do_stub(compiler, double_keccak) \
do_entry(compiler, double_keccak, double_keccak, double_keccak) \
do_stub(compiler, sha3_implCompressMB) \
do_entry(compiler, sha3_implCompressMB, sha3_implCompressMB, \
sha3_implCompressMB) \
do_stub(compiler, double_keccak) \
do_entry(compiler, double_keccak, double_keccak, double_keccak) \
do_stub(compiler, updateBytesAdler32) \
do_entry(compiler, updateBytesAdler32, updateBytesAdler32, \
updateBytesAdler32) \

View File

@ -26,7 +26,6 @@
package sun.security.provider;
import jdk.internal.vm.annotation.IntrinsicCandidate;
import sun.security.provider.SHA3.SHAKE128;
import sun.security.provider.SHA3.SHAKE256;
import sun.security.provider.SHA3Parallel.Shake128Parallel;
@ -1317,6 +1316,7 @@ public class ML_DSA {
*/
public static void mlDsaNtt(int[] coeffs) {
assert coeffs.length == ML_DSA_N;
implDilithiumAlmostNtt(coeffs, MONT_ZETAS_FOR_VECTOR_NTT);
implDilithiumMontMulByConstant(coeffs, MONT_R_MOD_Q);
}
@ -1343,6 +1343,7 @@ public class ML_DSA {
}
public static void mlDsaInverseNtt(int[] coeffs) {
assert coeffs.length == ML_DSA_N;
implDilithiumAlmostInverseNtt(coeffs, MONT_ZETAS_FOR_VECTOR_INVERSE_NTT);
implDilithiumMontMulByConstant(coeffs, MONT_DIM_INVERSE);
}
@ -1382,6 +1383,7 @@ public class ML_DSA {
}
public static void mlDsaNttMultiply(int[] product, int[] coeffs1, int[] coeffs2) {
assert (coeffs1.length == ML_DSA_N) && (coeffs2.length == ML_DSA_N);
implDilithiumNttMult(product, coeffs1, coeffs2);
}
@ -1412,6 +1414,8 @@ public class ML_DSA {
public static void mlDsaDecomposePoly(int[] input, int[] lowPart, int[] highPart,
int twoGamma2, int multiplier) {
assert (input.length == ML_DSA_N) && (lowPart.length == ML_DSA_N)
&& (highPart.length == ML_DSA_N);
implDilithiumDecomposePoly(input, lowPart, highPart,twoGamma2, multiplier);
}

View File

@ -37,6 +37,16 @@ import java.util.zip.ZipFile;
* @bug 8342442 8345057
* @library /test/lib
* @modules java.base/sun.security.provider
* @run main Launcher
*/
/*
* @test
* @summary Test verifying the intrinsic implementation.
* @bug 8342442 8345057
* @library /test/lib
* @modules java.base/sun.security.provider
* @run main/othervm -Xcomp Launcher
*/
/// This test runs on `internalProjection.json`-style files generated by NIST's