8349721: Add aarch64 intrinsics for ML-KEM

Reviewed-by: adinn
This commit is contained in:
Ferenc Rakoczi 2025-04-16 12:35:24 +00:00 committed by Andrew Dinn
parent 1ad869f844
commit 465c8e6583
20 changed files with 2523 additions and 149 deletions

View File

@ -58,23 +58,3 @@ const char* PRegister::PRegisterImpl::name() const {
};
return is_valid() ? names[encoding()] : "pnoreg";
}
// convenience methods for splitting 8-way vector register sequences
// in half -- needed because vector operations can normally only be
// benefit from 4-way instruction parallelism
VSeq<4> vs_front(const VSeq<8>& v) {
return VSeq<4>(v.base(), v.delta());
}
VSeq<4> vs_back(const VSeq<8>& v) {
return VSeq<4>(v.base() + 4 * v.delta(), v.delta());
}
VSeq<4> vs_even(const VSeq<8>& v) {
return VSeq<4>(v.base(), v.delta() * 2);
}
VSeq<4> vs_odd(const VSeq<8>& v) {
return VSeq<4>(v.base() + 1, v.delta() * 2);
}

View File

@ -436,19 +436,20 @@ enum RC { rc_bad, rc_int, rc_float, rc_predicate, rc_stack };
// inputs into front and back halves or odd and even halves (see
// convenience methods below).
// helper macro for computing register masks
#define VS_MASK_BIT(base, delta, i) (1 << (base + delta * i))
template<int N> class VSeq {
static_assert(N >= 2, "vector sequence length must be greater than 1");
static_assert(N <= 8, "vector sequence length must not exceed 8");
static_assert((N & (N - 1)) == 0, "vector sequence length must be power of two");
private:
int _base; // index of first register in sequence
int _delta; // increment to derive successive indices
public:
VSeq(FloatRegister base_reg, int delta = 1) : VSeq(base_reg->encoding(), delta) { }
VSeq(int base, int delta = 1) : _base(base), _delta(delta) {
assert (_base >= 0, "invalid base register");
assert (_delta >= 0, "invalid register delta");
assert ((_base + (N - 1) * _delta) < 32, "range exceeded");
assert (_base >= 0 && _base <= 31, "invalid base register");
assert ((_base + (N - 1) * _delta) >= 0, "register range underflow");
assert ((_base + (N - 1) * _delta) < 32, "register range overflow");
}
// indexed access to sequence
FloatRegister operator [](int i) const {
@ -457,27 +458,89 @@ public:
}
int mask() const {
int m = 0;
int bit = 1 << _base;
for (int i = 0; i < N; i++) {
m |= bit << (i * _delta);
m |= VS_MASK_BIT(_base, _delta, i);
}
return m;
}
int base() const { return _base; }
int delta() const { return _delta; }
bool is_constant() const { return _delta == 0; }
};
// declare convenience methods for splitting vector register sequences
VSeq<4> vs_front(const VSeq<8>& v);
VSeq<4> vs_back(const VSeq<8>& v);
VSeq<4> vs_even(const VSeq<8>& v);
VSeq<4> vs_odd(const VSeq<8>& v);
// methods for use in asserts to check VSeq inputs and oupts are
// methods for use in asserts to check VSeq inputs and outputs are
// either disjoint or equal
template<int N, int M> bool vs_disjoint(const VSeq<N>& n, const VSeq<M>& m) { return (n.mask() & m.mask()) == 0; }
template<int N> bool vs_same(const VSeq<N>& n, const VSeq<N>& m) { return n.mask() == m.mask(); }
// method for use in asserts to check whether registers appearing in
// an output sequence will be written before they are read from an
// input sequence.
template<int N> bool vs_write_before_read(const VSeq<N>& vout, const VSeq<N>& vin) {
int b_in = vin.base();
int d_in = vin.delta();
int b_out = vout.base();
int d_out = vout.delta();
int bit_in = 1 << b_in;
int bit_out = 1 << b_out;
int mask_read = vin.mask(); // all pending reads
int mask_write = 0; // no writes as yet
for (int i = 0; i < N - 1; i++) {
// check whether a pending read clashes with a write
if ((mask_write & mask_read) != 0) {
return true;
}
// remove the pending input (so long as this is a constant
// sequence)
if (d_in != 0) {
mask_read ^= VS_MASK_BIT(b_in, d_in, i);
}
// record the next write
mask_write |= VS_MASK_BIT(b_out, d_out, i);
}
// no write before read
return false;
}
// convenience methods for splitting 8-way or 4-way vector register
// sequences in half -- needed because vector operations can normally
// benefit from 4-way instruction parallelism or, occasionally, 2-way
// parallelism
template<int N>
VSeq<N/2> vs_front(const VSeq<N>& v) {
static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even");
return VSeq<N/2>(v.base(), v.delta());
}
template<int N>
VSeq<N/2> vs_back(const VSeq<N>& v) {
static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even");
return VSeq<N/2>(v.base() + N / 2 * v.delta(), v.delta());
}
template<int N>
VSeq<N/2> vs_even(const VSeq<N>& v) {
static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even");
return VSeq<N/2>(v.base(), v.delta() * 2);
}
template<int N>
VSeq<N/2> vs_odd(const VSeq<N>& v) {
static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even");
return VSeq<N/2>(v.base() + v.delta(), v.delta() * 2);
}
// convenience method to construct a vector register sequence that
// indexes its elements in reverse order to the original
template<int N>
VSeq<N> vs_reverse(const VSeq<N>& v) {
return VSeq<N>(v.base() + (N - 1) * v.delta(), -v.delta());
}
#endif // CPU_AARCH64_REGISTER_AARCH64_HPP

View File

@ -44,7 +44,7 @@
do_arch_blob, \
do_arch_entry, \
do_arch_entry_init) \
do_arch_blob(compiler, 55000 ZGC_ONLY(+5000)) \
do_arch_blob(compiler, 75000 ZGC_ONLY(+5000)) \
do_stub(compiler, vector_iota_indices) \
do_arch_entry(aarch64, compiler, vector_iota_indices, \
vector_iota_indices, vector_iota_indices) \

File diff suppressed because it is too large Load Diff

View File

@ -48,6 +48,17 @@ STUBGEN_ARCH_ENTRIES_DO(DEFINE_ARCH_ENTRY, DEFINE_ARCH_ENTRY_INIT)
bool StubRoutines::aarch64::_completed = false;
ATTRIBUTE_ALIGNED(64) uint16_t StubRoutines::aarch64::_kyberConsts[] =
{
// Because we sometimes load these in pairs, montQInvModR, kyber_q
// and kyberBarrettMultiplier should stay together and in this order.
0xF301, 0xF301, 0xF301, 0xF301, 0xF301, 0xF301, 0xF301, 0xF301, // montQInvModR
0x0D01, 0x0D01, 0x0D01, 0x0D01, 0x0D01, 0x0D01, 0x0D01, 0x0D01, // kyber_q
0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, // kyberBarrettMultiplier
0x0200, 0x0200, 0x0200, 0x0200, 0x0200, 0x0200, 0x0200, 0x0200, // toMont((kyber_n / 2)^-1 (mod kyber_q))
0x0549, 0x0549, 0x0549, 0x0549, 0x0549, 0x0549, 0x0549, 0x0549 // montRSquareModQ
};
ATTRIBUTE_ALIGNED(64) uint32_t StubRoutines::aarch64::_dilithiumConsts[] =
{
58728449, 58728449, 58728449, 58728449, // montQInvModR

View File

@ -110,6 +110,7 @@ private:
}
private:
static uint16_t _kyberConsts[];
static uint32_t _dilithiumConsts[];
static juint _crc_table[];
static jubyte _adler_table[];

View File

@ -414,13 +414,24 @@ void VM_Version::initialize() {
FLAG_SET_DEFAULT(UseChaCha20Intrinsics, false);
}
if (_features & CPU_ASIMD) {
if (FLAG_IS_DEFAULT(UseKyberIntrinsics)) {
UseKyberIntrinsics = true;
}
} else if (UseKyberIntrinsics) {
if (!FLAG_IS_DEFAULT(UseKyberIntrinsics)) {
warning("Kyber intrinsics require ASIMD instructions");
}
FLAG_SET_DEFAULT(UseKyberIntrinsics, false);
}
if (_features & CPU_ASIMD) {
if (FLAG_IS_DEFAULT(UseDilithiumIntrinsics)) {
UseDilithiumIntrinsics = true;
}
} else if (UseDilithiumIntrinsics) {
if (!FLAG_IS_DEFAULT(UseDilithiumIntrinsics)) {
warning("Dilithium intrinsic requires ASIMD instructions");
warning("Dilithium intrinsics require ASIMD instructions");
}
FLAG_SET_DEFAULT(UseDilithiumIntrinsics, false);
}
@ -703,6 +714,7 @@ void VM_Version::initialize_cpu_information(void) {
get_compatible_board(_cpu_desc + desc_len, CPU_DETAILED_DESC_BUF_SIZE - desc_len);
desc_len = (int)strlen(_cpu_desc);
snprintf(_cpu_desc + desc_len, CPU_DETAILED_DESC_BUF_SIZE - desc_len, " %s", _features_string);
fprintf(stderr, "_features_string = \"%s\"", _features_string);
_initialized = true;
}

View File

@ -488,6 +488,14 @@ bool vmIntrinsics::disabled_by_jvm_flags(vmIntrinsics::ID id) {
case vmIntrinsics::_chacha20Block:
if (!UseChaCha20Intrinsics) return true;
break;
case vmIntrinsics::_kyberNtt:
case vmIntrinsics::_kyberInverseNtt:
case vmIntrinsics::_kyberNttMult:
case vmIntrinsics::_kyberAddPoly_2:
case vmIntrinsics::_kyberAddPoly_3:
case vmIntrinsics::_kyber12To16:
case vmIntrinsics::_kyberBarrettReduce:
if (!UseKyberIntrinsics) return true;
case vmIntrinsics::_dilithiumAlmostNtt:
case vmIntrinsics::_dilithiumAlmostInverseNtt:
case vmIntrinsics::_dilithiumNttMult:

View File

@ -569,6 +569,27 @@ class methodHandle;
do_name(chacha20Block_name, "implChaCha20Block") \
do_signature(chacha20Block_signature, "([I[B)I") \
\
/* support for com.sun.crypto.provider.ML_KEM */ \
do_class(com_sun_crypto_provider_ML_KEM, "com/sun/crypto/provider/ML_KEM") \
do_signature(SaSaSaSaI_signature, "([S[S[S[S)I") \
do_signature(BaISaII_signature, "([BI[SI)I") \
do_signature(SaSaSaI_signature, "([S[S[S)I") \
do_signature(SaSaI_signature, "([S[S)I") \
do_signature(SaI_signature, "([S)I") \
do_name(kyberAddPoly_name, "implKyberAddPoly") \
do_intrinsic(_kyberNtt, com_sun_crypto_provider_ML_KEM, kyberNtt_name, SaSaI_signature, F_S) \
do_name(kyberNtt_name, "implKyberNtt") \
do_intrinsic(_kyberInverseNtt, com_sun_crypto_provider_ML_KEM, kyberInverseNtt_name, SaSaI_signature, F_S) \
do_name(kyberInverseNtt_name, "implKyberInverseNtt") \
do_intrinsic(_kyberNttMult, com_sun_crypto_provider_ML_KEM, kyberNttMult_name, SaSaSaSaI_signature, F_S) \
do_name(kyberNttMult_name, "implKyberNttMult") \
do_intrinsic(_kyberAddPoly_2, com_sun_crypto_provider_ML_KEM, kyberAddPoly_name, SaSaSaI_signature, F_S) \
do_intrinsic(_kyberAddPoly_3, com_sun_crypto_provider_ML_KEM, kyberAddPoly_name, SaSaSaSaI_signature, F_S) \
do_intrinsic(_kyber12To16, com_sun_crypto_provider_ML_KEM, kyber12To16_name, BaISaII_signature, F_S) \
do_name(kyber12To16_name, "implKyber12To16") \
do_intrinsic(_kyberBarrettReduce, com_sun_crypto_provider_ML_KEM, kyberBarrettReduce_name, SaI_signature, F_S) \
do_name(kyberBarrettReduce_name, "implKyberBarrettReduce") \
\
/* support for sun.security.provider.ML_DSA */ \
do_class(sun_security_provider_ML_DSA, "sun/security/provider/ML_DSA") \
do_signature(IaII_signature, "([II)I") \

View File

@ -395,6 +395,13 @@
static_field(StubRoutines, _sha3_implCompress, address) \
static_field(StubRoutines, _double_keccak, address) \
static_field(StubRoutines, _sha3_implCompressMB, address) \
static_field(StubRoutines, _kyberNtt, address) \
static_field(StubRoutines, _kyberInverseNtt, address) \
static_field(StubRoutines, _kyberNttMult, address) \
static_field(StubRoutines, _kyberAddPoly_2, address) \
static_field(StubRoutines, _kyberAddPoly_3, address) \
static_field(StubRoutines, _kyber12To16, address) \
static_field(StubRoutines, _kyberBarrettReduce, address) \
static_field(StubRoutines, _dilithiumAlmostNtt, address) \
static_field(StubRoutines, _dilithiumAlmostInverseNtt, address) \
static_field(StubRoutines, _dilithiumNttMult, address) \

View File

@ -792,6 +792,13 @@ bool C2Compiler::is_intrinsic_supported(vmIntrinsics::ID id) {
case vmIntrinsics::_vectorizedMismatch:
case vmIntrinsics::_ghash_processBlocks:
case vmIntrinsics::_chacha20Block:
case vmIntrinsics::_kyberNtt:
case vmIntrinsics::_kyberInverseNtt:
case vmIntrinsics::_kyberNttMult:
case vmIntrinsics::_kyberAddPoly_2:
case vmIntrinsics::_kyberAddPoly_3:
case vmIntrinsics::_kyber12To16:
case vmIntrinsics::_kyberBarrettReduce:
case vmIntrinsics::_dilithiumAlmostNtt:
case vmIntrinsics::_dilithiumAlmostInverseNtt:
case vmIntrinsics::_dilithiumNttMult:

View File

@ -2192,6 +2192,13 @@ void ConnectionGraph::process_call_arguments(CallNode *call) {
strcmp(call->as_CallLeaf()->_name, "intpoly_assign") == 0 ||
strcmp(call->as_CallLeaf()->_name, "ghash_processBlocks") == 0 ||
strcmp(call->as_CallLeaf()->_name, "chacha20Block") == 0 ||
strcmp(call->as_CallLeaf()->_name, "kyberNtt") == 0 ||
strcmp(call->as_CallLeaf()->_name, "kyberInverseNtt") == 0 ||
strcmp(call->as_CallLeaf()->_name, "kyberNttMult") == 0 ||
strcmp(call->as_CallLeaf()->_name, "kyberAddPoly_2") == 0 ||
strcmp(call->as_CallLeaf()->_name, "kyberAddPoly_3") == 0 ||
strcmp(call->as_CallLeaf()->_name, "kyber12To16") == 0 ||
strcmp(call->as_CallLeaf()->_name, "kyberBarrettReduce") == 0 ||
strcmp(call->as_CallLeaf()->_name, "dilithiumAlmostNtt") == 0 ||
strcmp(call->as_CallLeaf()->_name, "dilithiumAlmostInverseNtt") == 0 ||
strcmp(call->as_CallLeaf()->_name, "dilithiumNttMult") == 0 ||

View File

@ -626,6 +626,20 @@ bool LibraryCallKit::try_to_inline(int predicate) {
return inline_ghash_processBlocks();
case vmIntrinsics::_chacha20Block:
return inline_chacha20Block();
case vmIntrinsics::_kyberNtt:
return inline_kyberNtt();
case vmIntrinsics::_kyberInverseNtt:
return inline_kyberInverseNtt();
case vmIntrinsics::_kyberNttMult:
return inline_kyberNttMult();
case vmIntrinsics::_kyberAddPoly_2:
return inline_kyberAddPoly_2();
case vmIntrinsics::_kyberAddPoly_3:
return inline_kyberAddPoly_3();
case vmIntrinsics::_kyber12To16:
return inline_kyber12To16();
case vmIntrinsics::_kyberBarrettReduce:
return inline_kyberBarrettReduce();
case vmIntrinsics::_dilithiumAlmostNtt:
return inline_dilithiumAlmostNtt();
case vmIntrinsics::_dilithiumAlmostInverseNtt:
@ -7640,6 +7654,245 @@ bool LibraryCallKit::inline_chacha20Block() {
return true;
}
//------------------------------inline_kyberNtt
bool LibraryCallKit::inline_kyberNtt() {
address stubAddr;
const char *stubName;
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
assert(callee()->signature()->size() == 2, "kyberNtt has 2 parameters");
stubAddr = StubRoutines::kyberNtt();
stubName = "kyberNtt";
if (!stubAddr) return false;
Node* coeffs = argument(0);
Node* ntt_zetas = argument(1);
coeffs = must_be_not_null(coeffs, true);
ntt_zetas = must_be_not_null(ntt_zetas, true);
Node* coeffs_start = array_element_address(coeffs, intcon(0), T_SHORT);
assert(coeffs_start, "coeffs is null");
Node* ntt_zetas_start = array_element_address(ntt_zetas, intcon(0), T_SHORT);
assert(ntt_zetas_start, "ntt_zetas is null");
Node* kyberNtt = make_runtime_call(RC_LEAF|RC_NO_FP,
OptoRuntime::kyberNtt_Type(),
stubAddr, stubName, TypePtr::BOTTOM,
coeffs_start, ntt_zetas_start);
// return an int
Node* retvalue = _gvn.transform(new ProjNode(kyberNtt, TypeFunc::Parms));
set_result(retvalue);
return true;
}
//------------------------------inline_kyberInverseNtt
bool LibraryCallKit::inline_kyberInverseNtt() {
address stubAddr;
const char *stubName;
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
assert(callee()->signature()->size() == 2, "kyberInverseNtt has 2 parameters");
stubAddr = StubRoutines::kyberInverseNtt();
stubName = "kyberInverseNtt";
if (!stubAddr) return false;
Node* coeffs = argument(0);
Node* zetas = argument(1);
coeffs = must_be_not_null(coeffs, true);
zetas = must_be_not_null(zetas, true);
Node* coeffs_start = array_element_address(coeffs, intcon(0), T_SHORT);
assert(coeffs_start, "coeffs is null");
Node* zetas_start = array_element_address(zetas, intcon(0), T_SHORT);
assert(zetas_start, "inverseNtt_zetas is null");
Node* kyberInverseNtt = make_runtime_call(RC_LEAF|RC_NO_FP,
OptoRuntime::kyberInverseNtt_Type(),
stubAddr, stubName, TypePtr::BOTTOM,
coeffs_start, zetas_start);
// return an int
Node* retvalue = _gvn.transform(new ProjNode(kyberInverseNtt, TypeFunc::Parms));
set_result(retvalue);
return true;
}
//------------------------------inline_kyberNttMult
bool LibraryCallKit::inline_kyberNttMult() {
address stubAddr;
const char *stubName;
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
assert(callee()->signature()->size() == 4, "kyberNttMult has 4 parameters");
stubAddr = StubRoutines::kyberNttMult();
stubName = "kyberNttMult";
if (!stubAddr) return false;
Node* result = argument(0);
Node* ntta = argument(1);
Node* nttb = argument(2);
Node* zetas = argument(3);
result = must_be_not_null(result, true);
ntta = must_be_not_null(ntta, true);
nttb = must_be_not_null(nttb, true);
zetas = must_be_not_null(zetas, true);
Node* result_start = array_element_address(result, intcon(0), T_SHORT);
assert(result_start, "result is null");
Node* ntta_start = array_element_address(ntta, intcon(0), T_SHORT);
assert(ntta_start, "ntta is null");
Node* nttb_start = array_element_address(nttb, intcon(0), T_SHORT);
assert(nttb_start, "nttb is null");
Node* zetas_start = array_element_address(zetas, intcon(0), T_SHORT);
assert(zetas_start, "nttMult_zetas is null");
Node* kyberNttMult = make_runtime_call(RC_LEAF|RC_NO_FP,
OptoRuntime::kyberNttMult_Type(),
stubAddr, stubName, TypePtr::BOTTOM,
result_start, ntta_start, nttb_start,
zetas_start);
// return an int
Node* retvalue = _gvn.transform(new ProjNode(kyberNttMult, TypeFunc::Parms));
set_result(retvalue);
return true;
}
//------------------------------inline_kyberAddPoly_2
bool LibraryCallKit::inline_kyberAddPoly_2() {
address stubAddr;
const char *stubName;
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
assert(callee()->signature()->size() == 3, "kyberAddPoly_2 has 3 parameters");
stubAddr = StubRoutines::kyberAddPoly_2();
stubName = "kyberAddPoly_2";
if (!stubAddr) return false;
Node* result = argument(0);
Node* a = argument(1);
Node* b = argument(2);
result = must_be_not_null(result, true);
a = must_be_not_null(a, true);
b = must_be_not_null(b, true);
Node* result_start = array_element_address(result, intcon(0), T_SHORT);
assert(result_start, "result is null");
Node* a_start = array_element_address(a, intcon(0), T_SHORT);
assert(a_start, "a is null");
Node* b_start = array_element_address(b, intcon(0), T_SHORT);
assert(b_start, "b is null");
Node* kyberAddPoly_2 = make_runtime_call(RC_LEAF|RC_NO_FP,
OptoRuntime::kyberAddPoly_2_Type(),
stubAddr, stubName, TypePtr::BOTTOM,
result_start, a_start, b_start);
// return an int
Node* retvalue = _gvn.transform(new ProjNode(kyberAddPoly_2, TypeFunc::Parms));
set_result(retvalue);
return true;
}
//------------------------------inline_kyberAddPoly_3
bool LibraryCallKit::inline_kyberAddPoly_3() {
address stubAddr;
const char *stubName;
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
assert(callee()->signature()->size() == 4, "kyberAddPoly_3 has 4 parameters");
stubAddr = StubRoutines::kyberAddPoly_3();
stubName = "kyberAddPoly_3";
if (!stubAddr) return false;
Node* result = argument(0);
Node* a = argument(1);
Node* b = argument(2);
Node* c = argument(3);
result = must_be_not_null(result, true);
a = must_be_not_null(a, true);
b = must_be_not_null(b, true);
c = must_be_not_null(c, true);
Node* result_start = array_element_address(result, intcon(0), T_SHORT);
assert(result_start, "result is null");
Node* a_start = array_element_address(a, intcon(0), T_SHORT);
assert(a_start, "a is null");
Node* b_start = array_element_address(b, intcon(0), T_SHORT);
assert(b_start, "b is null");
Node* c_start = array_element_address(c, intcon(0), T_SHORT);
assert(c_start, "c is null");
Node* kyberAddPoly_3 = make_runtime_call(RC_LEAF|RC_NO_FP,
OptoRuntime::kyberAddPoly_3_Type(),
stubAddr, stubName, TypePtr::BOTTOM,
result_start, a_start, b_start, c_start);
// return an int
Node* retvalue = _gvn.transform(new ProjNode(kyberAddPoly_3, TypeFunc::Parms));
set_result(retvalue);
return true;
}
//------------------------------inline_kyber12To16
bool LibraryCallKit::inline_kyber12To16() {
address stubAddr;
const char *stubName;
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
assert(callee()->signature()->size() == 4, "kyber12To16 has 4 parameters");
stubAddr = StubRoutines::kyber12To16();
stubName = "kyber12To16";
if (!stubAddr) return false;
Node* condensed = argument(0);
Node* condensedOffs = argument(1);
Node* parsed = argument(2);
Node* parsedLength = argument(3);
condensed = must_be_not_null(condensed, true);
parsed = must_be_not_null(parsed, true);
Node* condensed_start = array_element_address(condensed, intcon(0), T_BYTE);
assert(condensed_start, "condensed is null");
Node* parsed_start = array_element_address(parsed, intcon(0), T_SHORT);
assert(parsed_start, "parsed is null");
Node* kyber12To16 = make_runtime_call(RC_LEAF|RC_NO_FP,
OptoRuntime::kyber12To16_Type(),
stubAddr, stubName, TypePtr::BOTTOM,
condensed_start, condensedOffs, parsed_start, parsedLength);
// return an int
Node* retvalue = _gvn.transform(new ProjNode(kyber12To16, TypeFunc::Parms));
set_result(retvalue);
return true;
}
//------------------------------inline_kyberBarrettReduce
bool LibraryCallKit::inline_kyberBarrettReduce() {
address stubAddr;
const char *stubName;
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
assert(callee()->signature()->size() == 1, "kyberBarrettReduce has 1 parameters");
stubAddr = StubRoutines::kyberBarrettReduce();
stubName = "kyberBarrettReduce";
if (!stubAddr) return false;
Node* coeffs = argument(0);
coeffs = must_be_not_null(coeffs, true);
Node* coeffs_start = array_element_address(coeffs, intcon(0), T_SHORT);
assert(coeffs_start, "coeffs is null");
Node* kyberBarrettReduce = make_runtime_call(RC_LEAF|RC_NO_FP,
OptoRuntime::kyberBarrettReduce_Type(),
stubAddr, stubName, TypePtr::BOTTOM,
coeffs_start);
// return an int
Node* retvalue = _gvn.transform(new ProjNode(kyberBarrettReduce, TypeFunc::Parms));
set_result(retvalue);
return true;
}
//------------------------------inline_dilithiumAlmostNtt
bool LibraryCallKit::inline_dilithiumAlmostNtt() {
address stubAddr;
@ -7696,7 +7949,6 @@ bool LibraryCallKit::inline_dilithiumAlmostInverseNtt() {
OptoRuntime::dilithiumAlmostInverseNtt_Type(),
stubAddr, stubName, TypePtr::BOTTOM,
coeffs_start, zetas_start);
// return an int
Node* retvalue = _gvn.transform(new ProjNode(dilithiumAlmostInverseNtt, TypeFunc::Parms));
set_result(retvalue);
@ -7717,10 +7969,12 @@ bool LibraryCallKit::inline_dilithiumNttMult() {
Node* result = argument(0);
Node* ntta = argument(1);
Node* nttb = argument(2);
Node* zetas = argument(3);
result = must_be_not_null(result, true);
ntta = must_be_not_null(ntta, true);
nttb = must_be_not_null(nttb, true);
zetas = must_be_not_null(zetas, true);
Node* result_start = array_element_address(result, intcon(0), T_INT);
assert(result_start, "result is null");

View File

@ -314,6 +314,13 @@ class LibraryCallKit : public GraphKit {
Node* get_key_start_from_aescrypt_object(Node* aescrypt_object);
bool inline_ghash_processBlocks();
bool inline_chacha20Block();
bool inline_kyberNtt();
bool inline_kyberInverseNtt();
bool inline_kyberNttMult();
bool inline_kyberAddPoly_2();
bool inline_kyberAddPoly_3();
bool inline_kyber12To16();
bool inline_kyberBarrettReduce();
bool inline_dilithiumAlmostNtt();
bool inline_dilithiumAlmostInverseNtt();
bool inline_dilithiumNttMult();

View File

@ -242,13 +242,18 @@ const TypeFunc* OptoRuntime::_bigIntegerShift_Type = nullptr;
const TypeFunc* OptoRuntime::_vectorizedMismatch_Type = nullptr;
const TypeFunc* OptoRuntime::_ghash_processBlocks_Type = nullptr;
const TypeFunc* OptoRuntime::_chacha20Block_Type = nullptr;
const TypeFunc* OptoRuntime::_kyberNtt_Type = nullptr;
const TypeFunc* OptoRuntime::_kyberInverseNtt_Type = nullptr;
const TypeFunc* OptoRuntime::_kyberNttMult_Type = nullptr;
const TypeFunc* OptoRuntime::_kyberAddPoly_2_Type = nullptr;
const TypeFunc* OptoRuntime::_kyberAddPoly_3_Type = nullptr;
const TypeFunc* OptoRuntime::_kyber12To16_Type = nullptr;
const TypeFunc* OptoRuntime::_kyberBarrettReduce_Type = nullptr;
const TypeFunc* OptoRuntime::_dilithiumAlmostNtt_Type = nullptr;
const TypeFunc* OptoRuntime::_dilithiumAlmostInverseNtt_Type = nullptr;
const TypeFunc* OptoRuntime::_dilithiumNttMult_Type = nullptr;
const TypeFunc* OptoRuntime::_dilithiumMontMulByConstant_Type = nullptr;
const TypeFunc* OptoRuntime::_dilithiumDecomposePoly_Type = nullptr;
const TypeFunc* OptoRuntime::_base64_encodeBlock_Type = nullptr;
const TypeFunc* OptoRuntime::_base64_decodeBlock_Type = nullptr;
const TypeFunc* OptoRuntime::_string_IndexOf_Type = nullptr;
@ -1409,6 +1414,146 @@ static const TypeFunc* make_chacha20Block_Type() {
return TypeFunc::make(domain, range);
}
// Kyber NTT function
static const TypeFunc* make_kyberNtt_Type() {
int argcnt = 2;
const Type** fields = TypeTuple::fields(argcnt);
int argp = TypeFunc::Parms;
fields[argp++] = TypePtr::NOTNULL; // coeffs
fields[argp++] = TypePtr::NOTNULL; // NTT zetas
assert(argp == TypeFunc::Parms + argcnt, "correct decoding");
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields);
// result type needed
fields = TypeTuple::fields(1);
fields[TypeFunc::Parms + 0] = TypeInt::INT;
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields);
return TypeFunc::make(domain, range);
}
// Kyber inverse NTT function
static const TypeFunc* make_kyberInverseNtt_Type() {
int argcnt = 2;
const Type** fields = TypeTuple::fields(argcnt);
int argp = TypeFunc::Parms;
fields[argp++] = TypePtr::NOTNULL; // coeffs
fields[argp++] = TypePtr::NOTNULL; // inverse NTT zetas
assert(argp == TypeFunc::Parms + argcnt, "correct decoding");
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields);
// result type needed
fields = TypeTuple::fields(1);
fields[TypeFunc::Parms + 0] = TypeInt::INT;
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields);
return TypeFunc::make(domain, range);
}
// Kyber NTT multiply function
static const TypeFunc* make_kyberNttMult_Type() {
int argcnt = 4;
const Type** fields = TypeTuple::fields(argcnt);
int argp = TypeFunc::Parms;
fields[argp++] = TypePtr::NOTNULL; // result
fields[argp++] = TypePtr::NOTNULL; // ntta
fields[argp++] = TypePtr::NOTNULL; // nttb
fields[argp++] = TypePtr::NOTNULL; // NTT multiply zetas
assert(argp == TypeFunc::Parms + argcnt, "correct decoding");
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields);
// result type needed
fields = TypeTuple::fields(1);
fields[TypeFunc::Parms + 0] = TypeInt::INT;
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields);
return TypeFunc::make(domain, range);
}
// Kyber add 2 polynomials function
static const TypeFunc* make_kyberAddPoly_2_Type() {
int argcnt = 3;
const Type** fields = TypeTuple::fields(argcnt);
int argp = TypeFunc::Parms;
fields[argp++] = TypePtr::NOTNULL; // result
fields[argp++] = TypePtr::NOTNULL; // a
fields[argp++] = TypePtr::NOTNULL; // b
assert(argp == TypeFunc::Parms + argcnt, "correct decoding");
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields);
// result type needed
fields = TypeTuple::fields(1);
fields[TypeFunc::Parms + 0] = TypeInt::INT;
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields);
return TypeFunc::make(domain, range);
}
// Kyber add 3 polynomials function
static const TypeFunc* make_kyberAddPoly_3_Type() {
int argcnt = 4;
const Type** fields = TypeTuple::fields(argcnt);
int argp = TypeFunc::Parms;
fields[argp++] = TypePtr::NOTNULL; // result
fields[argp++] = TypePtr::NOTNULL; // a
fields[argp++] = TypePtr::NOTNULL; // b
fields[argp++] = TypePtr::NOTNULL; // c
assert(argp == TypeFunc::Parms + argcnt, "correct decoding");
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields);
// result type needed
fields = TypeTuple::fields(1);
fields[TypeFunc::Parms + 0] = TypeInt::INT;
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields);
return TypeFunc::make(domain, range);
}
// Kyber XOF output parsing into polynomial coefficients candidates
// or decompress(12,...) function
static const TypeFunc* make_kyber12To16_Type() {
int argcnt = 4;
const Type** fields = TypeTuple::fields(argcnt);
int argp = TypeFunc::Parms;
fields[argp++] = TypePtr::NOTNULL; // condensed
fields[argp++] = TypeInt::INT; // condensedOffs
fields[argp++] = TypePtr::NOTNULL; // parsed
fields[argp++] = TypeInt::INT; // parsedLength
assert(argp == TypeFunc::Parms + argcnt, "correct decoding");
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields);
// result type needed
fields = TypeTuple::fields(1);
fields[TypeFunc::Parms + 0] = TypeInt::INT;
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields);
return TypeFunc::make(domain, range);
}
// Kyber Barrett reduce function
static const TypeFunc* make_kyberBarrettReduce_Type() {
int argcnt = 1;
const Type** fields = TypeTuple::fields(argcnt);
int argp = TypeFunc::Parms;
fields[argp++] = TypePtr::NOTNULL; // coeffs
assert(argp == TypeFunc::Parms + argcnt, "correct decoding");
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields);
// result type needed
fields = TypeTuple::fields(1);
fields[TypeFunc::Parms + 0] = TypeInt::INT;
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields);
return TypeFunc::make(domain, range);
}
// Dilithium NTT function except for the final "normalization" to |coeff| < Q
static const TypeFunc* make_dilithiumAlmostNtt_Type() {
int argcnt = 2;
@ -2120,13 +2265,18 @@ void OptoRuntime::initialize_types() {
_vectorizedMismatch_Type = make_vectorizedMismatch_Type();
_ghash_processBlocks_Type = make_ghash_processBlocks_Type();
_chacha20Block_Type = make_chacha20Block_Type();
_kyberNtt_Type = make_kyberNtt_Type();
_kyberInverseNtt_Type = make_kyberInverseNtt_Type();
_kyberNttMult_Type = make_kyberNttMult_Type();
_kyberAddPoly_2_Type = make_kyberAddPoly_2_Type();
_kyberAddPoly_3_Type = make_kyberAddPoly_3_Type();
_kyber12To16_Type = make_kyber12To16_Type();
_kyberBarrettReduce_Type = make_kyberBarrettReduce_Type();
_dilithiumAlmostNtt_Type = make_dilithiumAlmostNtt_Type();
_dilithiumAlmostInverseNtt_Type = make_dilithiumAlmostInverseNtt_Type();
_dilithiumNttMult_Type = make_dilithiumNttMult_Type();
_dilithiumMontMulByConstant_Type = make_dilithiumMontMulByConstant_Type();
_dilithiumDecomposePoly_Type = make_dilithiumDecomposePoly_Type();
_base64_encodeBlock_Type = make_base64_encodeBlock_Type();
_base64_decodeBlock_Type = make_base64_decodeBlock_Type();
_string_IndexOf_Type = make_string_IndexOf_Type();

View File

@ -180,6 +180,13 @@ class OptoRuntime : public AllStatic {
static const TypeFunc* _vectorizedMismatch_Type;
static const TypeFunc* _ghash_processBlocks_Type;
static const TypeFunc* _chacha20Block_Type;
static const TypeFunc* _kyberNtt_Type;
static const TypeFunc* _kyberInverseNtt_Type;
static const TypeFunc* _kyberNttMult_Type;
static const TypeFunc* _kyberAddPoly_2_Type;
static const TypeFunc* _kyberAddPoly_3_Type;
static const TypeFunc* _kyber12To16_Type;
static const TypeFunc* _kyberBarrettReduce_Type;
static const TypeFunc* _dilithiumAlmostNtt_Type;
static const TypeFunc* _dilithiumAlmostInverseNtt_Type;
static const TypeFunc* _dilithiumNttMult_Type;
@ -468,6 +475,10 @@ private:
return _unsafe_setmemory_Type;
}
// static const TypeFunc* digestBase_implCompress_Type(bool is_sha3);
// static const TypeFunc* digestBase_implCompressMB_Type(bool is_sha3);
// static const TypeFunc* double_keccak_Type();
static inline const TypeFunc* array_fill_Type() {
assert(_array_fill_Type != nullptr, "should be initialized");
return _array_fill_Type;
@ -584,6 +595,41 @@ private:
return _chacha20Block_Type;
}
static const TypeFunc* kyberNtt_Type() {
assert(_kyberNtt_Type != nullptr, "should be initialized");
return _kyberNtt_Type;
}
static const TypeFunc* kyberInverseNtt_Type() {
assert(_kyberInverseNtt_Type != nullptr, "should be initialized");
return _kyberInverseNtt_Type;
}
static const TypeFunc* kyberNttMult_Type() {
assert(_kyberNttMult_Type != nullptr, "should be initialized");
return _kyberNttMult_Type;
}
static const TypeFunc* kyberAddPoly_2_Type() {
assert(_kyberAddPoly_2_Type != nullptr, "should be initialized");
return _kyberAddPoly_2_Type;
}
static const TypeFunc* kyberAddPoly_3_Type() {
assert(_kyberAddPoly_3_Type != nullptr, "should be initialized");
return _kyberAddPoly_3_Type;
}
static const TypeFunc* kyber12To16_Type() {
assert(_kyber12To16_Type != nullptr, "should be initialized");
return _kyber12To16_Type;
}
static const TypeFunc* kyberBarrettReduce_Type() {
assert(_kyberBarrettReduce_Type != nullptr, "should be initialized");
return _kyberBarrettReduce_Type;
}
static inline const TypeFunc* dilithiumAlmostNtt_Type() {
assert(_dilithiumAlmostNtt_Type != nullptr, "should be initialized");
return _dilithiumAlmostNtt_Type;

View File

@ -325,6 +325,8 @@ const int ObjectAlignmentInBytes = 8;
product(bool, UseChaCha20Intrinsics, false, DIAGNOSTIC, \
"Use intrinsics for the vectorized version of ChaCha20") \
\
product(bool, UseKyberIntrinsics, false, DIAGNOSTIC, \
"Use intrinsics for the vectorized version of Kyber") \
product(bool, UseDilithiumIntrinsics, false, DIAGNOSTIC, \
"Use intrinsics for the vectorized version of Dilithium") \
\

View File

@ -678,6 +678,21 @@
ghash_processBlocks) \
do_stub(compiler, chacha20Block) \
do_entry(compiler, chacha20Block, chacha20Block, chacha20Block) \
do_stub(compiler, kyberNtt) \
do_entry(compiler, kyberNtt, kyberNtt, kyberNtt) \
do_stub(compiler, kyberInverseNtt) \
do_entry(compiler, kyberInverseNtt, kyberInverseNtt, kyberInverseNtt) \
do_stub(compiler, kyberNttMult) \
do_entry(compiler, kyberNttMult, kyberNttMult, kyberNttMult) \
do_stub(compiler, kyberAddPoly_2) \
do_entry(compiler, kyberAddPoly_2, kyberAddPoly_2, kyberAddPoly_2) \
do_stub(compiler, kyberAddPoly_3) \
do_entry(compiler, kyberAddPoly_3, kyberAddPoly_3, kyberAddPoly_3) \
do_stub(compiler, kyber12To16) \
do_entry(compiler, kyber12To16, kyber12To16, kyber12To16) \
do_stub(compiler, kyberBarrettReduce) \
do_entry(compiler, kyberBarrettReduce, kyberBarrettReduce, \
kyberBarrettReduce) \
do_stub(compiler, dilithiumAlmostNtt) \
do_entry(compiler, dilithiumAlmostNtt, \
dilithiumAlmostNtt, dilithiumAlmostNtt) \

View File

@ -28,6 +28,7 @@ package com.sun.crypto.provider;
import java.security.*;
import java.util.Arrays;
import javax.crypto.DecapsulateException;
import jdk.internal.vm.annotation.IntrinsicCandidate;
import sun.security.provider.SHA3.SHAKE256;
import sun.security.provider.SHA3Parallel.Shake128Parallel;
@ -71,6 +72,268 @@ public final class ML_KEM {
-1599, -709, -789, -1317, -57, 1049, -584
};
private static final short[] montZetasForVectorNttArr = new short[]{
// level 0
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
// level 1
-359, -359, -359, -359, -359, -359, -359, -359,
-359, -359, -359, -359, -359, -359, -359, -359,
-359, -359, -359, -359, -359, -359, -359, -359,
-359, -359, -359, -359, -359, -359, -359, -359,
-359, -359, -359, -359, -359, -359, -359, -359,
-359, -359, -359, -359, -359, -359, -359, -359,
-359, -359, -359, -359, -359, -359, -359, -359,
-359, -359, -359, -359, -359, -359, -359, -359,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
// level 2
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
287, 287, 287, 287, 287, 287, 287, 287,
287, 287, 287, 287, 287, 287, 287, 287,
287, 287, 287, 287, 287, 287, 287, 287,
287, 287, 287, 287, 287, 287, 287, 287,
202, 202, 202, 202, 202, 202, 202, 202,
202, 202, 202, 202, 202, 202, 202, 202,
202, 202, 202, 202, 202, 202, 202, 202,
202, 202, 202, 202, 202, 202, 202, 202,
// level 3
-171, -171, -171, -171, -171, -171, -171, -171,
-171, -171, -171, -171, -171, -171, -171, -171,
622, 622, 622, 622, 622, 622, 622, 622,
622, 622, 622, 622, 622, 622, 622, 622,
1577, 1577, 1577, 1577, 1577, 1577, 1577, 1577,
1577, 1577, 1577, 1577, 1577, 1577, 1577, 1577,
182, 182, 182, 182, 182, 182, 182, 182,
182, 182, 182, 182, 182, 182, 182, 182,
962, 962, 962, 962, 962, 962, 962, 962,
962, 962, 962, 962, 962, 962, 962, 962,
-1202, -1202, -1202, -1202, -1202, -1202, -1202, -1202,
-1202, -1202, -1202, -1202, -1202, -1202, -1202, -1202,
-1474, -1474, -1474, -1474, -1474, -1474, -1474, -1474,
-1474, -1474, -1474, -1474, -1474, -1474, -1474, -1474,
1468, 1468, 1468, 1468, 1468, 1468, 1468, 1468,
1468, 1468, 1468, 1468, 1468, 1468, 1468, 1468,
// level 4
573, 573, 573, 573, 573, 573, 573, 573,
-1325, -1325, -1325, -1325, -1325, -1325, -1325, -1325,
264, 264, 264, 264, 264, 264, 264, 264,
383, 383, 383, 383, 383, 383, 383, 383,
-829, -829, -829, -829, -829, -829, -829, -829,
1458, 1458, 1458, 1458, 1458, 1458, 1458, 1458,
-1602, -1602, -1602, -1602, -1602, -1602, -1602, -1602,
-130, -130, -130, -130, -130, -130, -130, -130,
-681, -681, -681, -681, -681, -681, -681, -681,
1017, 1017, 1017, 1017, 1017, 1017, 1017, 1017,
732, 732, 732, 732, 732, 732, 732, 732,
608, 608, 608, 608, 608, 608, 608, 608,
-1542, -1542, -1542, -1542, -1542, -1542, -1542, -1542,
411, 411, 411, 411, 411, 411, 411, 411,
-205, -205, -205, -205, -205, -205, -205, -205,
-1571, -1571, -1571, -1571, -1571, -1571, -1571, -1571,
// level 5
1223, 1223, 1223, 1223, 652, 652, 652, 652,
-552, -552, -552, -552, 1015, 1015, 1015, 1015,
-1293, -1293, -1293, -1293, 1491, 1491, 1491, 1491,
-282, -282, -282, -282, -1544, -1544, -1544, -1544,
516, 516, 516, 516, -8, -8, -8, -8,
-320, -320, -320, -320, -666, -666, -666, -666,
1711, 1711, 1711, 1711, -1162, -1162, -1162, -1162,
126, 126, 126, 126, 1469, 1469, 1469, 1469,
-853, -853, -853, -853, -90, -90, -90, -90,
-271, -271, -271, -271, 830, 830, 830, 830,
107, 107, 107, 107, -1421, -1421, -1421, -1421,
-247, -247, -247, -247, -951, -951, -951, -951,
-398, -398, -398, -398, 961, 961, 961, 961,
-1508, -1508, -1508, -1508, -725, -725, -725, -725,
448, 448, 448, 448, -1065, -1065, -1065, -1065,
677, 677, 677, 677, -1275, -1275, -1275, -1275,
// level 6
-1103, -1103, 430, 430, 555, 555, 843, 843,
-1251, -1251, 871, 871, 1550, 1550, 105, 105,
422, 422, 587, 587, 177, 177, -235, -235,
-291, -291, -460, -460, 1574, 1574, 1653, 1653,
-246, -246, 778, 778, 1159, 1159, -147, -147,
-777, -777, 1483, 1483, -602, -602, 1119, 1119,
-1590, -1590, 644, 644, -872, -872, 349, 349,
418, 418, 329, 329, -156, -156, -75, -75,
817, 817, 1097, 1097, 603, 603, 610, 610,
1322, 1322, -1285, -1285, -1465, -1465, 384, 384,
-1215, -1215, -136, -136, 1218, 1218, -1335, -1335,
-874, -874, 220, 220, -1187, -1187, 1670, 1670,
-1185, -1185, -1530, -1530, -1278, -1278, 794, 794,
-1510, -1510, -854, -854, -870, -870, 478, 478,
-108, -108, -308, -308, 996, 996, 991, 991,
958, 958, -1460, -1460, 1522, 1522, 1628, 1628
};
private static final int[] MONT_ZETAS_FOR_INVERSE_NTT = new int[]{
584, -1049, 57, 1317, 789, 709, 1599, -1601,
-990, 604, 348, 857, 612, 474, 1177, -1014,
-88, -982, -191, 668, 1386, 486, -1153, -534,
514, 137, 586, -1178, 227, 339, -907, 244,
1200, -833, 1394, -30, 1074, 636, -317, -1192,
-1259, -355, -425, -884, -977, 1430, 868, 607,
184, 1448, 702, 1327, 431, 497, 595, -94,
1649, -1497, -620, 42, -172, 1107, -222, 1003,
426, -845, 395, -510, 1613, 825, 1269, -290,
-1429, 623, -567, 1617, 36, 1007, 1440, 332,
-201, 1313, -1382, -744, 669, -1538, 128, -1598,
1401, 1183, -553, 714, 405, -1155, -445, 406,
-1496, -49, 82, 1369, 259, 1604, 373, 909,
-1249, -1000, -25, -52, 530, -895, 1226, 819,
-185, 281, -742, 1253, 417, 1400, 35, -593,
97, -1263, 551, -585, 969, -914, -1188
};
private static final short[] montZetasForVectorInverseNttArr = new short[]{
// level 0
-1628, -1628, -1522, -1522, 1460, 1460, -958, -958,
-991, -991, -996, -996, 308, 308, 108, 108,
-478, -478, 870, 870, 854, 854, 1510, 1510,
-794, -794, 1278, 1278, 1530, 1530, 1185, 1185,
1659, 1659, 1187, 1187, -220, -220, 874, 874,
1335, 1335, -1218, -1218, 136, 136, 1215, 1215,
-384, -384, 1465, 1465, 1285, 1285, -1322, -1322,
-610, -610, -603, -603, -1097, -1097, -817, -817,
75, 75, 156, 156, -329, -329, -418, -418,
-349, -349, 872, 872, -644, -644, 1590, 1590,
-1119, -1119, 602, 602, -1483, -1483, 777, 777,
147, 147, -1159, -1159, -778, -778, 246, 246,
-1653, -1653, -1574, -1574, 460, 460, 291, 291,
235, 235, -177, -177, -587, -587, -422, -422,
-105, -105, -1550, -1550, -871, -871, 1251, 1251,
-843, -843, -555, -555, -430, -430, 1103, 1103,
// level 1
1275, 1275, 1275, 1275, -677, -677, -677, -677,
1065, 1065, 1065, 1065, -448, -448, -448, -448,
725, 725, 725, 725, 1508, 1508, 1508, 1508,
-961, -961, -961, -961, 398, 398, 398, 398,
951, 951, 951, 951, 247, 247, 247, 247,
1421, 1421, 1421, 1421, -107, -107, -107, -107,
-830, -830, -830, -830, 271, 271, 271, 271,
90, 90, 90, 90, 853, 853, 853, 853,
-1469, -1469, -1469, -1469, -126, -126, -126, -126,
1162, 1162, 1162, 1162, 1618, 1618, 1618, 1618,
666, 666, 666, 666, 320, 320, 320, 320,
8, 8, 8, 8, -516, -516, -516, -516,
1544, 1544, 1544, 1544, 282, 282, 282, 282,
-1491, -1491, -1491, -1491, 1293, 1293, 1293, 1293,
-1015, -1015, -1015, -1015, 552, 552, 552, 552,
-652, -652, -652, -652, -1223, -1223, -1223, -1223,
// level 2
1571, 1571, 1571, 1571, 1571, 1571, 1571, 1571,
205, 205, 205, 205, 205, 205, 205, 205,
-411, -411, -411, -411, -411, -411, -411, -411,
1542, 1542, 1542, 1542, 1542, 1542, 1542, 1542,
-608, -608, -608, -608, -608, -608, -608, -608,
-732, -732, -732, -732, -732, -732, -732, -732,
-1017, -1017, -1017, -1017, -1017, -1017, -1017, -1017,
681, 681, 681, 681, 681, 681, 681, 681,
130, 130, 130, 130, 130, 130, 130, 130,
1602, 1602, 1602, 1602, 1602, 1602, 1602, 1602,
-1458, -1458, -1458, -1458, -1458, -1458, -1458, -1458,
829, 829, 829, 829, 829, 829, 829, 829,
-383, -383, -383, -383, -383, -383, -383, -383,
-264, -264, -264, -264, -264, -264, -264, -264,
1325, 1325, 1325, 1325, 1325, 1325, 1325, 1325,
-573, -573, -573, -573, -573, -573, -573, -573,
// level 3
-1468, -1468, -1468, -1468, -1468, -1468, -1468, -1468,
-1468, -1468, -1468, -1468, -1468, -1468, -1468, -1468,
1474, 1474, 1474, 1474, 1474, 1474, 1474, 1474,
1474, 1474, 1474, 1474, 1474, 1474, 1474, 1474,
1202, 1202, 1202, 1202, 1202, 1202, 1202, 1202,
1202, 1202, 1202, 1202, 1202, 1202, 1202, 1202,
-962, -962, -962, -962, -962, -962, -962, -962,
-962, -962, -962, -962, -962, -962, -962, -962,
-182, -182, -182, -182, -182, -182, -182, -182,
-182, -182, -182, -182, -182, -182, -182, -182,
-1577, -1577, -1577, -1577, -1577, -1577, -1577, -1577,
-1577, -1577, -1577, -1577, -1577, -1577, -1577, -1577,
-622, -622, -622, -622, -622, -622, -622, -622,
-622, -622, -622, -622, -622, -622, -622, -622,
171, 171, 171, 171, 171, 171, 171, 171,
171, 171, 171, 171, 171, 171, 171, 171,
// level 4
-202, -202, -202, -202, -202, -202, -202, -202,
-202, -202, -202, -202, -202, -202, -202, -202,
-202, -202, -202, -202, -202, -202, -202, -202,
-202, -202, -202, -202, -202, -202, -202, -202,
-287, -287, -287, -287, -287, -287, -287, -287,
-287, -287, -287, -287, -287, -287, -287, -287,
-287, -287, -287, -287, -287, -287, -287, -287,
-287, -287, -287, -287, -287, -287, -287, -287,
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
// level 5
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
359, 359, 359, 359, 359, 359, 359, 359,
359, 359, 359, 359, 359, 359, 359, 359,
359, 359, 359, 359, 359, 359, 359, 359,
359, 359, 359, 359, 359, 359, 359, 359,
359, 359, 359, 359, 359, 359, 359, 359,
359, 359, 359, 359, 359, 359, 359, 359,
359, 359, 359, 359, 359, 359, 359, 359,
359, 359, 359, 359, 359, 359, 359, 359,
// level 6
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758
};
private static final int[] MONT_ZETAS_FOR_NTT_MULT = new int[]{
-1003, 1003, 222, -222, -1107, 1107, 172, -172,
-42, 42, 620, -620, 1497, -1497, -1649, 1649,
@ -89,6 +352,24 @@ public final class ML_KEM {
1601, -1601, -1599, 1599, -709, 709, -789, 789,
-1317, 1317, -57, 57, 1049, -1049, -584, 584
};
private static final short[] montZetasForVectorNttMultArr = new short[]{
-1103, 1103, 430, -430, 555, -555, 843, -843,
-1251, 1251, 871, -871, 1550, -1550, 105, -105,
422, -422, 587, -587, 177, -177, -235, 235,
-291, 291, -460, 460, 1574, -1574, 1653, -1653,
-246, 246, 778, -778, 1159, -1159, -147, 147,
-777, 777, 1483, -1483, -602, 602, 1119, -1119,
-1590, 1590, 644, -644, -872, 872, 349, -349,
418, -418, 329, -329, -156, 156, -75, 75,
817, -817, 1097, -1097, 603, -603, 610, -610,
1322, -1322, -1285, 1285, -1465, 1465, 384, -384,
-1215, 1215, -136, 136, 1218, -1218, -1335, 1335,
-874, 874, 220, -220, -1187, 1187, 1670, 1659,
-1185, 1185, -1530, 1530, -1278, 1278, 794, -794,
-1510, 1510, -854, 854, -870, 870, 478, -478,
-108, 108, -308, 308, 996, -996, 991, -991,
958, -958, -1460, 1460, 1522, -1522, 1628, -1628
};
private final int mlKem_k;
private final int mlKem_eta1;
@ -261,7 +542,7 @@ public final class ML_KEM {
try {
mlKemH = MessageDigest.getInstance(HASH_H_NAME);
mlKemG = MessageDigest.getInstance(HASH_G_NAME);
} catch (NoSuchAlgorithmException e){
} catch (NoSuchAlgorithmException e) {
// This should never happen.
throw new RuntimeException(e);
}
@ -527,7 +808,7 @@ public final class ML_KEM {
for (int i = 0; i < mlKem_k; i++) {
for (int j = 0; j < mlKem_k; j++) {
xofBufArr[parInd] = seedBuf.clone();
System.arraycopy(seedBuf, 0, xofBufArr[parInd], 0, seedBuf.length);
if (transposed) {
xofBufArr[parInd][rhoLen] = (byte) i;
xofBufArr[parInd][rhoLen + 1] = (byte) j;
@ -707,9 +988,13 @@ public final class ML_KEM {
return vector;
}
// The elements of poly should be in the range [-ML_KEM_Q, ML_KEM_Q]
// The elements of poly at return will be in the range of [0, ML_KEM_Q]
private void mlKemNTT(short[] poly) {
@IntrinsicCandidate
static int implKyberNtt(short[] poly, short[] ntt_zetas) {
implKyberNttJava(poly);
return 1;
}
static void implKyberNttJava(short[] poly) {
int[] coeffs = new int[ML_KEM_N];
for (int m = 0; m < ML_KEM_N; m++) {
coeffs[m] = poly[m];
@ -718,12 +1003,23 @@ public final class ML_KEM {
for (int m = 0; m < ML_KEM_N; m++) {
poly[m] = (short) coeffs[m];
}
}
// The elements of poly should be in the range [-mlKem_q, mlKem_q]
// The elements of poly at return will be in the range of [0, mlKem_q]
private void mlKemNTT(short[] poly) {
assert poly.length == ML_KEM_N;
implKyberNtt(poly, montZetasForVectorNttArr);
mlKemBarrettReduce(poly);
}
// Works in place, but also returns its (modified) input so that it can
// be used in expressions
private short[] mlKemInverseNTT(short[] poly) {
@IntrinsicCandidate
static int implKyberInverseNtt(short[] poly, short[] zetas) {
implKyberInverseNttJava(poly);
return 1;
}
static void implKyberInverseNttJava(short[] poly) {
int[] coeffs = new int[ML_KEM_N];
for (int m = 0; m < ML_KEM_N; m++) {
coeffs[m] = poly[m];
@ -732,6 +1028,13 @@ public final class ML_KEM {
for (int m = 0; m < ML_KEM_N; m++) {
poly[m] = (short) coeffs[m];
}
}
// Works in place, but also returns its (modified) input so that it can
// be used in expressions
private short[] mlKemInverseNTT(short[] poly) {
assert poly.length == ML_KEM_N;
implKyberInverseNtt(poly, montZetasForVectorInverseNttArr);
return poly;
}
@ -822,11 +1125,16 @@ public final class ML_KEM {
return result;
}
// Multiplies two polynomials represented in the NTT domain.
// The result is a representation of the product still in the NTT domain.
// The coefficients in the result are in the range (-ML_KEM_Q, ML_KEM_Q).
private void nttMult(short[] result, short[] ntta, short[] nttb) {
@IntrinsicCandidate
static int implKyberNttMult(short[] result, short[] ntta, short[] nttb,
short[] zetas) {
implKyberNttMultJava(result, ntta, nttb);
return 1;
}
static void implKyberNttMultJava(short[] result, short[] ntta, short[] nttb) {
for (int m = 0; m < ML_KEM_N / 2; m++) {
int a0 = ntta[2 * m];
int a1 = ntta[2 * m + 1];
int b0 = nttb[2 * m];
@ -839,6 +1147,15 @@ public final class ML_KEM {
}
}
// Multiplies two polynomials represented in the NTT domain.
// The result is a representation of the product still in the NTT domain.
// The coefficients in the result are in the range (-mlKem_q, mlKem_q).
private void nttMult(short[] result, short[] ntta, short[] nttb) {
assert (result.length == ML_KEM_N) && (ntta.length == ML_KEM_N) &&
(nttb.length == ML_KEM_N);
implKyberNttMult(result, ntta, nttb, montZetasForVectorNttMultArr);
}
// Adds the vector of polynomials b to a in place, i.e. a will hold
// the result. It also returns (the modified) a so that it can be used
// in an expression.
@ -853,15 +1170,41 @@ public final class ML_KEM {
return a;
}
@IntrinsicCandidate
static int implKyberAddPoly(short[] result, short[] a, short[] b) {
implKyberAddPolyJava(result, a, b);
return 1;
}
static void implKyberAddPolyJava(short[] result, short[] a, short[] b) {
for (int m = 0; m < ML_KEM_N; m++) {
int r = a[m] + b[m] + ML_KEM_Q; // This makes r > - ML_KEM_Q
a[m] = (short) r;
}
mlKemBarrettReduce(a);
}
// Adds the polynomial b to a in place, i.e. (the modified) a will hold
// the result.
// The coefficients are supposed be greater than -ML_KEM_Q in a and
// greater than -ML_KEM_Q and less than ML_KEM_Q in b.
// The coefficients in the result are greater than -ML_KEM_Q.
private void mlKemAddPoly(short[] a, short[] b) {
private short[] mlKemAddPoly(short[] a, short[] b) {
assert (a.length == ML_KEM_N) && (b.length == ML_KEM_N);
implKyberAddPoly(a, a, b);
return a;
}
@IntrinsicCandidate
static int implKyberAddPoly(short[] result, short[] a, short[] b, short[] c) {
implKyberAddPolyJava(result, a, b, c);
return 1;
}
static void implKyberAddPolyJava(short[] result, short[] a, short[] b, short[] c) {
for (int m = 0; m < ML_KEM_N; m++) {
int r = a[m] + b[m] + ML_KEM_Q; // This makes r > -ML_KEM_Q
a[m] = (short) r;
int r = a[m] + b[m] + c[m] + 2 * ML_KEM_Q; // This makes r > - ML_KEM_Q
result[m] = (short) r;
}
}
@ -871,10 +1214,9 @@ public final class ML_KEM {
// greater than -ML_KEM_Q and less than ML_KEM_Q.
// The coefficients in the result are nonnegative and less than ML_KEM_Q.
private short[] mlKemAddPoly(short[] a, short[] b, short[] c) {
for (int m = 0; m < ML_KEM_N; m++) {
int r = a[m] + b[m] + c[m] + 2 * ML_KEM_Q; // This makes r > - ML_KEM_Q
a[m] = (short) r;
}
assert (a.length == ML_KEM_N) && (b.length == ML_KEM_N) &&
(c.length == ML_KEM_N);
implKyberAddPoly(a, a, b, c);
mlKemBarrettReduce(a);
return a;
}
@ -997,15 +1339,13 @@ public final class ML_KEM {
return result;
}
// The intrinsic implementations assume that the input and output buffers
// are such that condensed can be read in 192-byte chunks and
// parsed can be written in 128 shorts chunks. In other words,
// if (i - 1) * 128 < parsedLengths <= i * 128 then
// parsed.size should be at least i * 128 and
// condensed.size should be at least index + i * 192
private void twelve2Sixteen(byte[] condensed, int index,
short[] parsed, int parsedLength) {
@IntrinsicCandidate
private static int implKyber12To16(byte[] condensed, int index, short[] parsed, int parsedLength) {
implKyber12To16Java(condensed, index, parsed, parsedLength);
return 1;
}
private static void implKyber12To16Java(byte[] condensed, int index, short[] parsed, int parsedLength) {
for (int i = 0; i < parsedLength * 3 / 2; i += 3) {
parsed[(i / 3) * 2] = (short) ((condensed[i + index] & 0xff) +
256 * (condensed[i + index + 1] & 0xf));
@ -1014,6 +1354,25 @@ public final class ML_KEM {
}
}
// The intrinsic implementations assume that the input and output buffers
// are such that condensed can be read in 96-byte chunks and
// parsed can be written in 64 shorts chunks except for the last chunk
// that can be either 48 or 64 shorts. In other words,
// if (i - 1) * 64 < parsedLengths <= i * 64 then
// parsed.length should be either i * 64 or (i-1) * 64 + 48 and
// condensed.length should be at least index + i * 96.
private void twelve2Sixteen(byte[] condensed, int index,
short[] parsed, int parsedLength) {
int i = parsedLength / 64;
int remainder = parsedLength - i * 64;
if (remainder != 0) {
i++;
}
assert ((remainder == 0) || (remainder == 48)) &&
(index + i * 96 <= condensed.length);
implKyber12To16(condensed, index, parsed, parsedLength);
}
private static void decodePoly5(byte[] condensed, int index, short[] parsed) {
int j = index;
for (int i = 0; i < ML_KEM_N; i += 8) {
@ -1152,6 +1511,19 @@ public final class ML_KEM {
return result;
}
@IntrinsicCandidate
static int implKyberBarrettReduce(short[] coeffs) {
implKyberBarrettReduceJava(coeffs);
return 1;
}
static void implKyberBarrettReduceJava(short[] poly) {
for (int m = 0; m < ML_KEM_N; m++) {
int tmp = ((int) poly[m] * BARRETT_MULTIPLIER) >> BARRETT_SHIFT;
poly[m] = (short) (poly[m] - tmp * ML_KEM_Q);
}
}
// The input elements can have any short value.
// Modifies poly such that upon return poly[i] will be
// in the range [0, ML_KEM_Q] and will be congruent with the original
@ -1161,11 +1533,9 @@ public final class ML_KEM {
// That means that if the original poly[i] > -ML_KEM_Q then at return it
// will be in the range [0, ML_KEM_Q), i.e. it will be the canonical
// representative of its residue class.
private void mlKemBarrettReduce(short[] poly) {
for (int m = 0; m < ML_KEM_N; m++) {
int tmp = ((int) poly[m] * BARRETT_MULTIPLIER) >> BARRETT_SHIFT;
poly[m] = (short) (poly[m] - tmp * ML_KEM_Q);
}
private static void mlKemBarrettReduce(short[] poly) {
assert poly.length == ML_KEM_N;
implKyberBarrettReduce(poly);
}
// Precondition: -(2^MONT_R_BITS -1) * MONT_Q <= b * c < (2^MONT_R_BITS - 1) * MONT_Q

View File

@ -1554,7 +1554,7 @@ public class ML_DSA {
// precondition: -2^31 * MONT_Q <= a, b < 2^31, -2^31 < a * b < 2^31 * MONT_Q
// computes a * b * 2^-32 mod MONT_Q
// the result is greater than -MONT_Q and less than MONT_Q
// see e.g. Algorithm 3 in https://eprint.iacr.org/2018/039.pdf
// See e.g. Algorithm 3 in https://eprint.iacr.org/2018/039.pdf
private static int montMul(int b, int c) {
long a = (long) b * (long) c;
int aHigh = (int) (a >> MONT_R_BITS);