mirror of
https://github.com/openjdk/jdk.git
synced 2026-03-14 18:03:44 +00:00
8349721: Add aarch64 intrinsics for ML-KEM
Reviewed-by: adinn
This commit is contained in:
parent
1ad869f844
commit
465c8e6583
@ -58,23 +58,3 @@ const char* PRegister::PRegisterImpl::name() const {
|
||||
};
|
||||
return is_valid() ? names[encoding()] : "pnoreg";
|
||||
}
|
||||
|
||||
// convenience methods for splitting 8-way vector register sequences
|
||||
// in half -- needed because vector operations can normally only be
|
||||
// benefit from 4-way instruction parallelism
|
||||
|
||||
VSeq<4> vs_front(const VSeq<8>& v) {
|
||||
return VSeq<4>(v.base(), v.delta());
|
||||
}
|
||||
|
||||
VSeq<4> vs_back(const VSeq<8>& v) {
|
||||
return VSeq<4>(v.base() + 4 * v.delta(), v.delta());
|
||||
}
|
||||
|
||||
VSeq<4> vs_even(const VSeq<8>& v) {
|
||||
return VSeq<4>(v.base(), v.delta() * 2);
|
||||
}
|
||||
|
||||
VSeq<4> vs_odd(const VSeq<8>& v) {
|
||||
return VSeq<4>(v.base() + 1, v.delta() * 2);
|
||||
}
|
||||
|
||||
@ -436,19 +436,20 @@ enum RC { rc_bad, rc_int, rc_float, rc_predicate, rc_stack };
|
||||
// inputs into front and back halves or odd and even halves (see
|
||||
// convenience methods below).
|
||||
|
||||
// helper macro for computing register masks
|
||||
#define VS_MASK_BIT(base, delta, i) (1 << (base + delta * i))
|
||||
|
||||
template<int N> class VSeq {
|
||||
static_assert(N >= 2, "vector sequence length must be greater than 1");
|
||||
static_assert(N <= 8, "vector sequence length must not exceed 8");
|
||||
static_assert((N & (N - 1)) == 0, "vector sequence length must be power of two");
|
||||
private:
|
||||
int _base; // index of first register in sequence
|
||||
int _delta; // increment to derive successive indices
|
||||
public:
|
||||
VSeq(FloatRegister base_reg, int delta = 1) : VSeq(base_reg->encoding(), delta) { }
|
||||
VSeq(int base, int delta = 1) : _base(base), _delta(delta) {
|
||||
assert (_base >= 0, "invalid base register");
|
||||
assert (_delta >= 0, "invalid register delta");
|
||||
assert ((_base + (N - 1) * _delta) < 32, "range exceeded");
|
||||
assert (_base >= 0 && _base <= 31, "invalid base register");
|
||||
assert ((_base + (N - 1) * _delta) >= 0, "register range underflow");
|
||||
assert ((_base + (N - 1) * _delta) < 32, "register range overflow");
|
||||
}
|
||||
// indexed access to sequence
|
||||
FloatRegister operator [](int i) const {
|
||||
@ -457,27 +458,89 @@ public:
|
||||
}
|
||||
int mask() const {
|
||||
int m = 0;
|
||||
int bit = 1 << _base;
|
||||
for (int i = 0; i < N; i++) {
|
||||
m |= bit << (i * _delta);
|
||||
m |= VS_MASK_BIT(_base, _delta, i);
|
||||
}
|
||||
return m;
|
||||
}
|
||||
int base() const { return _base; }
|
||||
int delta() const { return _delta; }
|
||||
bool is_constant() const { return _delta == 0; }
|
||||
};
|
||||
|
||||
// declare convenience methods for splitting vector register sequences
|
||||
|
||||
VSeq<4> vs_front(const VSeq<8>& v);
|
||||
VSeq<4> vs_back(const VSeq<8>& v);
|
||||
VSeq<4> vs_even(const VSeq<8>& v);
|
||||
VSeq<4> vs_odd(const VSeq<8>& v);
|
||||
|
||||
// methods for use in asserts to check VSeq inputs and oupts are
|
||||
// methods for use in asserts to check VSeq inputs and outputs are
|
||||
// either disjoint or equal
|
||||
|
||||
template<int N, int M> bool vs_disjoint(const VSeq<N>& n, const VSeq<M>& m) { return (n.mask() & m.mask()) == 0; }
|
||||
template<int N> bool vs_same(const VSeq<N>& n, const VSeq<N>& m) { return n.mask() == m.mask(); }
|
||||
|
||||
// method for use in asserts to check whether registers appearing in
|
||||
// an output sequence will be written before they are read from an
|
||||
// input sequence.
|
||||
|
||||
template<int N> bool vs_write_before_read(const VSeq<N>& vout, const VSeq<N>& vin) {
|
||||
int b_in = vin.base();
|
||||
int d_in = vin.delta();
|
||||
int b_out = vout.base();
|
||||
int d_out = vout.delta();
|
||||
int bit_in = 1 << b_in;
|
||||
int bit_out = 1 << b_out;
|
||||
int mask_read = vin.mask(); // all pending reads
|
||||
int mask_write = 0; // no writes as yet
|
||||
|
||||
|
||||
for (int i = 0; i < N - 1; i++) {
|
||||
// check whether a pending read clashes with a write
|
||||
if ((mask_write & mask_read) != 0) {
|
||||
return true;
|
||||
}
|
||||
// remove the pending input (so long as this is a constant
|
||||
// sequence)
|
||||
if (d_in != 0) {
|
||||
mask_read ^= VS_MASK_BIT(b_in, d_in, i);
|
||||
}
|
||||
// record the next write
|
||||
mask_write |= VS_MASK_BIT(b_out, d_out, i);
|
||||
}
|
||||
// no write before read
|
||||
return false;
|
||||
}
|
||||
|
||||
// convenience methods for splitting 8-way or 4-way vector register
|
||||
// sequences in half -- needed because vector operations can normally
|
||||
// benefit from 4-way instruction parallelism or, occasionally, 2-way
|
||||
// parallelism
|
||||
|
||||
template<int N>
|
||||
VSeq<N/2> vs_front(const VSeq<N>& v) {
|
||||
static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even");
|
||||
return VSeq<N/2>(v.base(), v.delta());
|
||||
}
|
||||
|
||||
template<int N>
|
||||
VSeq<N/2> vs_back(const VSeq<N>& v) {
|
||||
static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even");
|
||||
return VSeq<N/2>(v.base() + N / 2 * v.delta(), v.delta());
|
||||
}
|
||||
|
||||
template<int N>
|
||||
VSeq<N/2> vs_even(const VSeq<N>& v) {
|
||||
static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even");
|
||||
return VSeq<N/2>(v.base(), v.delta() * 2);
|
||||
}
|
||||
|
||||
template<int N>
|
||||
VSeq<N/2> vs_odd(const VSeq<N>& v) {
|
||||
static_assert(N > 0 && ((N & 1) == 0), "sequence length must be even");
|
||||
return VSeq<N/2>(v.base() + v.delta(), v.delta() * 2);
|
||||
}
|
||||
|
||||
// convenience method to construct a vector register sequence that
|
||||
// indexes its elements in reverse order to the original
|
||||
|
||||
template<int N>
|
||||
VSeq<N> vs_reverse(const VSeq<N>& v) {
|
||||
return VSeq<N>(v.base() + (N - 1) * v.delta(), -v.delta());
|
||||
}
|
||||
|
||||
#endif // CPU_AARCH64_REGISTER_AARCH64_HPP
|
||||
|
||||
@ -44,7 +44,7 @@
|
||||
do_arch_blob, \
|
||||
do_arch_entry, \
|
||||
do_arch_entry_init) \
|
||||
do_arch_blob(compiler, 55000 ZGC_ONLY(+5000)) \
|
||||
do_arch_blob(compiler, 75000 ZGC_ONLY(+5000)) \
|
||||
do_stub(compiler, vector_iota_indices) \
|
||||
do_arch_entry(aarch64, compiler, vector_iota_indices, \
|
||||
vector_iota_indices, vector_iota_indices) \
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -48,6 +48,17 @@ STUBGEN_ARCH_ENTRIES_DO(DEFINE_ARCH_ENTRY, DEFINE_ARCH_ENTRY_INIT)
|
||||
|
||||
bool StubRoutines::aarch64::_completed = false;
|
||||
|
||||
ATTRIBUTE_ALIGNED(64) uint16_t StubRoutines::aarch64::_kyberConsts[] =
|
||||
{
|
||||
// Because we sometimes load these in pairs, montQInvModR, kyber_q
|
||||
// and kyberBarrettMultiplier should stay together and in this order.
|
||||
0xF301, 0xF301, 0xF301, 0xF301, 0xF301, 0xF301, 0xF301, 0xF301, // montQInvModR
|
||||
0x0D01, 0x0D01, 0x0D01, 0x0D01, 0x0D01, 0x0D01, 0x0D01, 0x0D01, // kyber_q
|
||||
0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, // kyberBarrettMultiplier
|
||||
0x0200, 0x0200, 0x0200, 0x0200, 0x0200, 0x0200, 0x0200, 0x0200, // toMont((kyber_n / 2)^-1 (mod kyber_q))
|
||||
0x0549, 0x0549, 0x0549, 0x0549, 0x0549, 0x0549, 0x0549, 0x0549 // montRSquareModQ
|
||||
};
|
||||
|
||||
ATTRIBUTE_ALIGNED(64) uint32_t StubRoutines::aarch64::_dilithiumConsts[] =
|
||||
{
|
||||
58728449, 58728449, 58728449, 58728449, // montQInvModR
|
||||
|
||||
@ -110,6 +110,7 @@ private:
|
||||
}
|
||||
|
||||
private:
|
||||
static uint16_t _kyberConsts[];
|
||||
static uint32_t _dilithiumConsts[];
|
||||
static juint _crc_table[];
|
||||
static jubyte _adler_table[];
|
||||
|
||||
@ -414,13 +414,24 @@ void VM_Version::initialize() {
|
||||
FLAG_SET_DEFAULT(UseChaCha20Intrinsics, false);
|
||||
}
|
||||
|
||||
if (_features & CPU_ASIMD) {
|
||||
if (FLAG_IS_DEFAULT(UseKyberIntrinsics)) {
|
||||
UseKyberIntrinsics = true;
|
||||
}
|
||||
} else if (UseKyberIntrinsics) {
|
||||
if (!FLAG_IS_DEFAULT(UseKyberIntrinsics)) {
|
||||
warning("Kyber intrinsics require ASIMD instructions");
|
||||
}
|
||||
FLAG_SET_DEFAULT(UseKyberIntrinsics, false);
|
||||
}
|
||||
|
||||
if (_features & CPU_ASIMD) {
|
||||
if (FLAG_IS_DEFAULT(UseDilithiumIntrinsics)) {
|
||||
UseDilithiumIntrinsics = true;
|
||||
}
|
||||
} else if (UseDilithiumIntrinsics) {
|
||||
if (!FLAG_IS_DEFAULT(UseDilithiumIntrinsics)) {
|
||||
warning("Dilithium intrinsic requires ASIMD instructions");
|
||||
warning("Dilithium intrinsics require ASIMD instructions");
|
||||
}
|
||||
FLAG_SET_DEFAULT(UseDilithiumIntrinsics, false);
|
||||
}
|
||||
@ -703,6 +714,7 @@ void VM_Version::initialize_cpu_information(void) {
|
||||
get_compatible_board(_cpu_desc + desc_len, CPU_DETAILED_DESC_BUF_SIZE - desc_len);
|
||||
desc_len = (int)strlen(_cpu_desc);
|
||||
snprintf(_cpu_desc + desc_len, CPU_DETAILED_DESC_BUF_SIZE - desc_len, " %s", _features_string);
|
||||
fprintf(stderr, "_features_string = \"%s\"", _features_string);
|
||||
|
||||
_initialized = true;
|
||||
}
|
||||
|
||||
@ -488,6 +488,14 @@ bool vmIntrinsics::disabled_by_jvm_flags(vmIntrinsics::ID id) {
|
||||
case vmIntrinsics::_chacha20Block:
|
||||
if (!UseChaCha20Intrinsics) return true;
|
||||
break;
|
||||
case vmIntrinsics::_kyberNtt:
|
||||
case vmIntrinsics::_kyberInverseNtt:
|
||||
case vmIntrinsics::_kyberNttMult:
|
||||
case vmIntrinsics::_kyberAddPoly_2:
|
||||
case vmIntrinsics::_kyberAddPoly_3:
|
||||
case vmIntrinsics::_kyber12To16:
|
||||
case vmIntrinsics::_kyberBarrettReduce:
|
||||
if (!UseKyberIntrinsics) return true;
|
||||
case vmIntrinsics::_dilithiumAlmostNtt:
|
||||
case vmIntrinsics::_dilithiumAlmostInverseNtt:
|
||||
case vmIntrinsics::_dilithiumNttMult:
|
||||
|
||||
@ -569,6 +569,27 @@ class methodHandle;
|
||||
do_name(chacha20Block_name, "implChaCha20Block") \
|
||||
do_signature(chacha20Block_signature, "([I[B)I") \
|
||||
\
|
||||
/* support for com.sun.crypto.provider.ML_KEM */ \
|
||||
do_class(com_sun_crypto_provider_ML_KEM, "com/sun/crypto/provider/ML_KEM") \
|
||||
do_signature(SaSaSaSaI_signature, "([S[S[S[S)I") \
|
||||
do_signature(BaISaII_signature, "([BI[SI)I") \
|
||||
do_signature(SaSaSaI_signature, "([S[S[S)I") \
|
||||
do_signature(SaSaI_signature, "([S[S)I") \
|
||||
do_signature(SaI_signature, "([S)I") \
|
||||
do_name(kyberAddPoly_name, "implKyberAddPoly") \
|
||||
do_intrinsic(_kyberNtt, com_sun_crypto_provider_ML_KEM, kyberNtt_name, SaSaI_signature, F_S) \
|
||||
do_name(kyberNtt_name, "implKyberNtt") \
|
||||
do_intrinsic(_kyberInverseNtt, com_sun_crypto_provider_ML_KEM, kyberInverseNtt_name, SaSaI_signature, F_S) \
|
||||
do_name(kyberInverseNtt_name, "implKyberInverseNtt") \
|
||||
do_intrinsic(_kyberNttMult, com_sun_crypto_provider_ML_KEM, kyberNttMult_name, SaSaSaSaI_signature, F_S) \
|
||||
do_name(kyberNttMult_name, "implKyberNttMult") \
|
||||
do_intrinsic(_kyberAddPoly_2, com_sun_crypto_provider_ML_KEM, kyberAddPoly_name, SaSaSaI_signature, F_S) \
|
||||
do_intrinsic(_kyberAddPoly_3, com_sun_crypto_provider_ML_KEM, kyberAddPoly_name, SaSaSaSaI_signature, F_S) \
|
||||
do_intrinsic(_kyber12To16, com_sun_crypto_provider_ML_KEM, kyber12To16_name, BaISaII_signature, F_S) \
|
||||
do_name(kyber12To16_name, "implKyber12To16") \
|
||||
do_intrinsic(_kyberBarrettReduce, com_sun_crypto_provider_ML_KEM, kyberBarrettReduce_name, SaI_signature, F_S) \
|
||||
do_name(kyberBarrettReduce_name, "implKyberBarrettReduce") \
|
||||
\
|
||||
/* support for sun.security.provider.ML_DSA */ \
|
||||
do_class(sun_security_provider_ML_DSA, "sun/security/provider/ML_DSA") \
|
||||
do_signature(IaII_signature, "([II)I") \
|
||||
|
||||
@ -395,6 +395,13 @@
|
||||
static_field(StubRoutines, _sha3_implCompress, address) \
|
||||
static_field(StubRoutines, _double_keccak, address) \
|
||||
static_field(StubRoutines, _sha3_implCompressMB, address) \
|
||||
static_field(StubRoutines, _kyberNtt, address) \
|
||||
static_field(StubRoutines, _kyberInverseNtt, address) \
|
||||
static_field(StubRoutines, _kyberNttMult, address) \
|
||||
static_field(StubRoutines, _kyberAddPoly_2, address) \
|
||||
static_field(StubRoutines, _kyberAddPoly_3, address) \
|
||||
static_field(StubRoutines, _kyber12To16, address) \
|
||||
static_field(StubRoutines, _kyberBarrettReduce, address) \
|
||||
static_field(StubRoutines, _dilithiumAlmostNtt, address) \
|
||||
static_field(StubRoutines, _dilithiumAlmostInverseNtt, address) \
|
||||
static_field(StubRoutines, _dilithiumNttMult, address) \
|
||||
|
||||
@ -792,6 +792,13 @@ bool C2Compiler::is_intrinsic_supported(vmIntrinsics::ID id) {
|
||||
case vmIntrinsics::_vectorizedMismatch:
|
||||
case vmIntrinsics::_ghash_processBlocks:
|
||||
case vmIntrinsics::_chacha20Block:
|
||||
case vmIntrinsics::_kyberNtt:
|
||||
case vmIntrinsics::_kyberInverseNtt:
|
||||
case vmIntrinsics::_kyberNttMult:
|
||||
case vmIntrinsics::_kyberAddPoly_2:
|
||||
case vmIntrinsics::_kyberAddPoly_3:
|
||||
case vmIntrinsics::_kyber12To16:
|
||||
case vmIntrinsics::_kyberBarrettReduce:
|
||||
case vmIntrinsics::_dilithiumAlmostNtt:
|
||||
case vmIntrinsics::_dilithiumAlmostInverseNtt:
|
||||
case vmIntrinsics::_dilithiumNttMult:
|
||||
|
||||
@ -2192,6 +2192,13 @@ void ConnectionGraph::process_call_arguments(CallNode *call) {
|
||||
strcmp(call->as_CallLeaf()->_name, "intpoly_assign") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "ghash_processBlocks") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "chacha20Block") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "kyberNtt") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "kyberInverseNtt") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "kyberNttMult") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "kyberAddPoly_2") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "kyberAddPoly_3") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "kyber12To16") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "kyberBarrettReduce") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "dilithiumAlmostNtt") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "dilithiumAlmostInverseNtt") == 0 ||
|
||||
strcmp(call->as_CallLeaf()->_name, "dilithiumNttMult") == 0 ||
|
||||
|
||||
@ -626,6 +626,20 @@ bool LibraryCallKit::try_to_inline(int predicate) {
|
||||
return inline_ghash_processBlocks();
|
||||
case vmIntrinsics::_chacha20Block:
|
||||
return inline_chacha20Block();
|
||||
case vmIntrinsics::_kyberNtt:
|
||||
return inline_kyberNtt();
|
||||
case vmIntrinsics::_kyberInverseNtt:
|
||||
return inline_kyberInverseNtt();
|
||||
case vmIntrinsics::_kyberNttMult:
|
||||
return inline_kyberNttMult();
|
||||
case vmIntrinsics::_kyberAddPoly_2:
|
||||
return inline_kyberAddPoly_2();
|
||||
case vmIntrinsics::_kyberAddPoly_3:
|
||||
return inline_kyberAddPoly_3();
|
||||
case vmIntrinsics::_kyber12To16:
|
||||
return inline_kyber12To16();
|
||||
case vmIntrinsics::_kyberBarrettReduce:
|
||||
return inline_kyberBarrettReduce();
|
||||
case vmIntrinsics::_dilithiumAlmostNtt:
|
||||
return inline_dilithiumAlmostNtt();
|
||||
case vmIntrinsics::_dilithiumAlmostInverseNtt:
|
||||
@ -7640,6 +7654,245 @@ bool LibraryCallKit::inline_chacha20Block() {
|
||||
return true;
|
||||
}
|
||||
|
||||
//------------------------------inline_kyberNtt
|
||||
bool LibraryCallKit::inline_kyberNtt() {
|
||||
address stubAddr;
|
||||
const char *stubName;
|
||||
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
|
||||
assert(callee()->signature()->size() == 2, "kyberNtt has 2 parameters");
|
||||
|
||||
stubAddr = StubRoutines::kyberNtt();
|
||||
stubName = "kyberNtt";
|
||||
if (!stubAddr) return false;
|
||||
|
||||
Node* coeffs = argument(0);
|
||||
Node* ntt_zetas = argument(1);
|
||||
|
||||
coeffs = must_be_not_null(coeffs, true);
|
||||
ntt_zetas = must_be_not_null(ntt_zetas, true);
|
||||
|
||||
Node* coeffs_start = array_element_address(coeffs, intcon(0), T_SHORT);
|
||||
assert(coeffs_start, "coeffs is null");
|
||||
Node* ntt_zetas_start = array_element_address(ntt_zetas, intcon(0), T_SHORT);
|
||||
assert(ntt_zetas_start, "ntt_zetas is null");
|
||||
Node* kyberNtt = make_runtime_call(RC_LEAF|RC_NO_FP,
|
||||
OptoRuntime::kyberNtt_Type(),
|
||||
stubAddr, stubName, TypePtr::BOTTOM,
|
||||
coeffs_start, ntt_zetas_start);
|
||||
// return an int
|
||||
Node* retvalue = _gvn.transform(new ProjNode(kyberNtt, TypeFunc::Parms));
|
||||
set_result(retvalue);
|
||||
return true;
|
||||
}
|
||||
|
||||
//------------------------------inline_kyberInverseNtt
|
||||
bool LibraryCallKit::inline_kyberInverseNtt() {
|
||||
address stubAddr;
|
||||
const char *stubName;
|
||||
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
|
||||
assert(callee()->signature()->size() == 2, "kyberInverseNtt has 2 parameters");
|
||||
|
||||
stubAddr = StubRoutines::kyberInverseNtt();
|
||||
stubName = "kyberInverseNtt";
|
||||
if (!stubAddr) return false;
|
||||
|
||||
Node* coeffs = argument(0);
|
||||
Node* zetas = argument(1);
|
||||
|
||||
coeffs = must_be_not_null(coeffs, true);
|
||||
zetas = must_be_not_null(zetas, true);
|
||||
|
||||
Node* coeffs_start = array_element_address(coeffs, intcon(0), T_SHORT);
|
||||
assert(coeffs_start, "coeffs is null");
|
||||
Node* zetas_start = array_element_address(zetas, intcon(0), T_SHORT);
|
||||
assert(zetas_start, "inverseNtt_zetas is null");
|
||||
Node* kyberInverseNtt = make_runtime_call(RC_LEAF|RC_NO_FP,
|
||||
OptoRuntime::kyberInverseNtt_Type(),
|
||||
stubAddr, stubName, TypePtr::BOTTOM,
|
||||
coeffs_start, zetas_start);
|
||||
|
||||
// return an int
|
||||
Node* retvalue = _gvn.transform(new ProjNode(kyberInverseNtt, TypeFunc::Parms));
|
||||
set_result(retvalue);
|
||||
return true;
|
||||
}
|
||||
|
||||
//------------------------------inline_kyberNttMult
|
||||
bool LibraryCallKit::inline_kyberNttMult() {
|
||||
address stubAddr;
|
||||
const char *stubName;
|
||||
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
|
||||
assert(callee()->signature()->size() == 4, "kyberNttMult has 4 parameters");
|
||||
|
||||
stubAddr = StubRoutines::kyberNttMult();
|
||||
stubName = "kyberNttMult";
|
||||
if (!stubAddr) return false;
|
||||
|
||||
Node* result = argument(0);
|
||||
Node* ntta = argument(1);
|
||||
Node* nttb = argument(2);
|
||||
Node* zetas = argument(3);
|
||||
|
||||
result = must_be_not_null(result, true);
|
||||
ntta = must_be_not_null(ntta, true);
|
||||
nttb = must_be_not_null(nttb, true);
|
||||
zetas = must_be_not_null(zetas, true);
|
||||
Node* result_start = array_element_address(result, intcon(0), T_SHORT);
|
||||
assert(result_start, "result is null");
|
||||
Node* ntta_start = array_element_address(ntta, intcon(0), T_SHORT);
|
||||
assert(ntta_start, "ntta is null");
|
||||
Node* nttb_start = array_element_address(nttb, intcon(0), T_SHORT);
|
||||
assert(nttb_start, "nttb is null");
|
||||
Node* zetas_start = array_element_address(zetas, intcon(0), T_SHORT);
|
||||
assert(zetas_start, "nttMult_zetas is null");
|
||||
Node* kyberNttMult = make_runtime_call(RC_LEAF|RC_NO_FP,
|
||||
OptoRuntime::kyberNttMult_Type(),
|
||||
stubAddr, stubName, TypePtr::BOTTOM,
|
||||
result_start, ntta_start, nttb_start,
|
||||
zetas_start);
|
||||
|
||||
// return an int
|
||||
Node* retvalue = _gvn.transform(new ProjNode(kyberNttMult, TypeFunc::Parms));
|
||||
set_result(retvalue);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//------------------------------inline_kyberAddPoly_2
|
||||
bool LibraryCallKit::inline_kyberAddPoly_2() {
|
||||
address stubAddr;
|
||||
const char *stubName;
|
||||
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
|
||||
assert(callee()->signature()->size() == 3, "kyberAddPoly_2 has 3 parameters");
|
||||
|
||||
stubAddr = StubRoutines::kyberAddPoly_2();
|
||||
stubName = "kyberAddPoly_2";
|
||||
if (!stubAddr) return false;
|
||||
|
||||
Node* result = argument(0);
|
||||
Node* a = argument(1);
|
||||
Node* b = argument(2);
|
||||
|
||||
result = must_be_not_null(result, true);
|
||||
a = must_be_not_null(a, true);
|
||||
b = must_be_not_null(b, true);
|
||||
|
||||
Node* result_start = array_element_address(result, intcon(0), T_SHORT);
|
||||
assert(result_start, "result is null");
|
||||
Node* a_start = array_element_address(a, intcon(0), T_SHORT);
|
||||
assert(a_start, "a is null");
|
||||
Node* b_start = array_element_address(b, intcon(0), T_SHORT);
|
||||
assert(b_start, "b is null");
|
||||
Node* kyberAddPoly_2 = make_runtime_call(RC_LEAF|RC_NO_FP,
|
||||
OptoRuntime::kyberAddPoly_2_Type(),
|
||||
stubAddr, stubName, TypePtr::BOTTOM,
|
||||
result_start, a_start, b_start);
|
||||
// return an int
|
||||
Node* retvalue = _gvn.transform(new ProjNode(kyberAddPoly_2, TypeFunc::Parms));
|
||||
set_result(retvalue);
|
||||
return true;
|
||||
}
|
||||
|
||||
//------------------------------inline_kyberAddPoly_3
|
||||
bool LibraryCallKit::inline_kyberAddPoly_3() {
|
||||
address stubAddr;
|
||||
const char *stubName;
|
||||
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
|
||||
assert(callee()->signature()->size() == 4, "kyberAddPoly_3 has 4 parameters");
|
||||
|
||||
stubAddr = StubRoutines::kyberAddPoly_3();
|
||||
stubName = "kyberAddPoly_3";
|
||||
if (!stubAddr) return false;
|
||||
|
||||
Node* result = argument(0);
|
||||
Node* a = argument(1);
|
||||
Node* b = argument(2);
|
||||
Node* c = argument(3);
|
||||
|
||||
result = must_be_not_null(result, true);
|
||||
a = must_be_not_null(a, true);
|
||||
b = must_be_not_null(b, true);
|
||||
c = must_be_not_null(c, true);
|
||||
|
||||
Node* result_start = array_element_address(result, intcon(0), T_SHORT);
|
||||
assert(result_start, "result is null");
|
||||
Node* a_start = array_element_address(a, intcon(0), T_SHORT);
|
||||
assert(a_start, "a is null");
|
||||
Node* b_start = array_element_address(b, intcon(0), T_SHORT);
|
||||
assert(b_start, "b is null");
|
||||
Node* c_start = array_element_address(c, intcon(0), T_SHORT);
|
||||
assert(c_start, "c is null");
|
||||
Node* kyberAddPoly_3 = make_runtime_call(RC_LEAF|RC_NO_FP,
|
||||
OptoRuntime::kyberAddPoly_3_Type(),
|
||||
stubAddr, stubName, TypePtr::BOTTOM,
|
||||
result_start, a_start, b_start, c_start);
|
||||
// return an int
|
||||
Node* retvalue = _gvn.transform(new ProjNode(kyberAddPoly_3, TypeFunc::Parms));
|
||||
set_result(retvalue);
|
||||
return true;
|
||||
}
|
||||
|
||||
//------------------------------inline_kyber12To16
|
||||
bool LibraryCallKit::inline_kyber12To16() {
|
||||
address stubAddr;
|
||||
const char *stubName;
|
||||
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
|
||||
assert(callee()->signature()->size() == 4, "kyber12To16 has 4 parameters");
|
||||
|
||||
stubAddr = StubRoutines::kyber12To16();
|
||||
stubName = "kyber12To16";
|
||||
if (!stubAddr) return false;
|
||||
|
||||
Node* condensed = argument(0);
|
||||
Node* condensedOffs = argument(1);
|
||||
Node* parsed = argument(2);
|
||||
Node* parsedLength = argument(3);
|
||||
|
||||
condensed = must_be_not_null(condensed, true);
|
||||
parsed = must_be_not_null(parsed, true);
|
||||
|
||||
Node* condensed_start = array_element_address(condensed, intcon(0), T_BYTE);
|
||||
assert(condensed_start, "condensed is null");
|
||||
Node* parsed_start = array_element_address(parsed, intcon(0), T_SHORT);
|
||||
assert(parsed_start, "parsed is null");
|
||||
Node* kyber12To16 = make_runtime_call(RC_LEAF|RC_NO_FP,
|
||||
OptoRuntime::kyber12To16_Type(),
|
||||
stubAddr, stubName, TypePtr::BOTTOM,
|
||||
condensed_start, condensedOffs, parsed_start, parsedLength);
|
||||
// return an int
|
||||
Node* retvalue = _gvn.transform(new ProjNode(kyber12To16, TypeFunc::Parms));
|
||||
set_result(retvalue);
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
//------------------------------inline_kyberBarrettReduce
|
||||
bool LibraryCallKit::inline_kyberBarrettReduce() {
|
||||
address stubAddr;
|
||||
const char *stubName;
|
||||
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
|
||||
assert(callee()->signature()->size() == 1, "kyberBarrettReduce has 1 parameters");
|
||||
|
||||
stubAddr = StubRoutines::kyberBarrettReduce();
|
||||
stubName = "kyberBarrettReduce";
|
||||
if (!stubAddr) return false;
|
||||
|
||||
Node* coeffs = argument(0);
|
||||
|
||||
coeffs = must_be_not_null(coeffs, true);
|
||||
|
||||
Node* coeffs_start = array_element_address(coeffs, intcon(0), T_SHORT);
|
||||
assert(coeffs_start, "coeffs is null");
|
||||
Node* kyberBarrettReduce = make_runtime_call(RC_LEAF|RC_NO_FP,
|
||||
OptoRuntime::kyberBarrettReduce_Type(),
|
||||
stubAddr, stubName, TypePtr::BOTTOM,
|
||||
coeffs_start);
|
||||
// return an int
|
||||
Node* retvalue = _gvn.transform(new ProjNode(kyberBarrettReduce, TypeFunc::Parms));
|
||||
set_result(retvalue);
|
||||
return true;
|
||||
}
|
||||
|
||||
//------------------------------inline_dilithiumAlmostNtt
|
||||
bool LibraryCallKit::inline_dilithiumAlmostNtt() {
|
||||
address stubAddr;
|
||||
@ -7696,7 +7949,6 @@ bool LibraryCallKit::inline_dilithiumAlmostInverseNtt() {
|
||||
OptoRuntime::dilithiumAlmostInverseNtt_Type(),
|
||||
stubAddr, stubName, TypePtr::BOTTOM,
|
||||
coeffs_start, zetas_start);
|
||||
|
||||
// return an int
|
||||
Node* retvalue = _gvn.transform(new ProjNode(dilithiumAlmostInverseNtt, TypeFunc::Parms));
|
||||
set_result(retvalue);
|
||||
@ -7717,10 +7969,12 @@ bool LibraryCallKit::inline_dilithiumNttMult() {
|
||||
Node* result = argument(0);
|
||||
Node* ntta = argument(1);
|
||||
Node* nttb = argument(2);
|
||||
Node* zetas = argument(3);
|
||||
|
||||
result = must_be_not_null(result, true);
|
||||
ntta = must_be_not_null(ntta, true);
|
||||
nttb = must_be_not_null(nttb, true);
|
||||
zetas = must_be_not_null(zetas, true);
|
||||
|
||||
Node* result_start = array_element_address(result, intcon(0), T_INT);
|
||||
assert(result_start, "result is null");
|
||||
|
||||
@ -314,6 +314,13 @@ class LibraryCallKit : public GraphKit {
|
||||
Node* get_key_start_from_aescrypt_object(Node* aescrypt_object);
|
||||
bool inline_ghash_processBlocks();
|
||||
bool inline_chacha20Block();
|
||||
bool inline_kyberNtt();
|
||||
bool inline_kyberInverseNtt();
|
||||
bool inline_kyberNttMult();
|
||||
bool inline_kyberAddPoly_2();
|
||||
bool inline_kyberAddPoly_3();
|
||||
bool inline_kyber12To16();
|
||||
bool inline_kyberBarrettReduce();
|
||||
bool inline_dilithiumAlmostNtt();
|
||||
bool inline_dilithiumAlmostInverseNtt();
|
||||
bool inline_dilithiumNttMult();
|
||||
|
||||
@ -242,13 +242,18 @@ const TypeFunc* OptoRuntime::_bigIntegerShift_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_vectorizedMismatch_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_ghash_processBlocks_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_chacha20Block_Type = nullptr;
|
||||
|
||||
const TypeFunc* OptoRuntime::_kyberNtt_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_kyberInverseNtt_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_kyberNttMult_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_kyberAddPoly_2_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_kyberAddPoly_3_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_kyber12To16_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_kyberBarrettReduce_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_dilithiumAlmostNtt_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_dilithiumAlmostInverseNtt_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_dilithiumNttMult_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_dilithiumMontMulByConstant_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_dilithiumDecomposePoly_Type = nullptr;
|
||||
|
||||
const TypeFunc* OptoRuntime::_base64_encodeBlock_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_base64_decodeBlock_Type = nullptr;
|
||||
const TypeFunc* OptoRuntime::_string_IndexOf_Type = nullptr;
|
||||
@ -1409,6 +1414,146 @@ static const TypeFunc* make_chacha20Block_Type() {
|
||||
return TypeFunc::make(domain, range);
|
||||
}
|
||||
|
||||
// Kyber NTT function
|
||||
static const TypeFunc* make_kyberNtt_Type() {
|
||||
int argcnt = 2;
|
||||
|
||||
const Type** fields = TypeTuple::fields(argcnt);
|
||||
int argp = TypeFunc::Parms;
|
||||
fields[argp++] = TypePtr::NOTNULL; // coeffs
|
||||
fields[argp++] = TypePtr::NOTNULL; // NTT zetas
|
||||
|
||||
assert(argp == TypeFunc::Parms + argcnt, "correct decoding");
|
||||
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields);
|
||||
|
||||
// result type needed
|
||||
fields = TypeTuple::fields(1);
|
||||
fields[TypeFunc::Parms + 0] = TypeInt::INT;
|
||||
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields);
|
||||
return TypeFunc::make(domain, range);
|
||||
}
|
||||
|
||||
// Kyber inverse NTT function
|
||||
static const TypeFunc* make_kyberInverseNtt_Type() {
|
||||
int argcnt = 2;
|
||||
|
||||
const Type** fields = TypeTuple::fields(argcnt);
|
||||
int argp = TypeFunc::Parms;
|
||||
fields[argp++] = TypePtr::NOTNULL; // coeffs
|
||||
fields[argp++] = TypePtr::NOTNULL; // inverse NTT zetas
|
||||
|
||||
assert(argp == TypeFunc::Parms + argcnt, "correct decoding");
|
||||
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields);
|
||||
|
||||
// result type needed
|
||||
fields = TypeTuple::fields(1);
|
||||
fields[TypeFunc::Parms + 0] = TypeInt::INT;
|
||||
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields);
|
||||
return TypeFunc::make(domain, range);
|
||||
}
|
||||
|
||||
// Kyber NTT multiply function
|
||||
static const TypeFunc* make_kyberNttMult_Type() {
|
||||
int argcnt = 4;
|
||||
|
||||
const Type** fields = TypeTuple::fields(argcnt);
|
||||
int argp = TypeFunc::Parms;
|
||||
fields[argp++] = TypePtr::NOTNULL; // result
|
||||
fields[argp++] = TypePtr::NOTNULL; // ntta
|
||||
fields[argp++] = TypePtr::NOTNULL; // nttb
|
||||
fields[argp++] = TypePtr::NOTNULL; // NTT multiply zetas
|
||||
|
||||
assert(argp == TypeFunc::Parms + argcnt, "correct decoding");
|
||||
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields);
|
||||
|
||||
// result type needed
|
||||
fields = TypeTuple::fields(1);
|
||||
fields[TypeFunc::Parms + 0] = TypeInt::INT;
|
||||
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields);
|
||||
return TypeFunc::make(domain, range);
|
||||
}
|
||||
// Kyber add 2 polynomials function
|
||||
static const TypeFunc* make_kyberAddPoly_2_Type() {
|
||||
int argcnt = 3;
|
||||
|
||||
const Type** fields = TypeTuple::fields(argcnt);
|
||||
int argp = TypeFunc::Parms;
|
||||
fields[argp++] = TypePtr::NOTNULL; // result
|
||||
fields[argp++] = TypePtr::NOTNULL; // a
|
||||
fields[argp++] = TypePtr::NOTNULL; // b
|
||||
|
||||
assert(argp == TypeFunc::Parms + argcnt, "correct decoding");
|
||||
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields);
|
||||
|
||||
// result type needed
|
||||
fields = TypeTuple::fields(1);
|
||||
fields[TypeFunc::Parms + 0] = TypeInt::INT;
|
||||
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields);
|
||||
return TypeFunc::make(domain, range);
|
||||
}
|
||||
|
||||
|
||||
// Kyber add 3 polynomials function
|
||||
static const TypeFunc* make_kyberAddPoly_3_Type() {
|
||||
int argcnt = 4;
|
||||
|
||||
const Type** fields = TypeTuple::fields(argcnt);
|
||||
int argp = TypeFunc::Parms;
|
||||
fields[argp++] = TypePtr::NOTNULL; // result
|
||||
fields[argp++] = TypePtr::NOTNULL; // a
|
||||
fields[argp++] = TypePtr::NOTNULL; // b
|
||||
fields[argp++] = TypePtr::NOTNULL; // c
|
||||
|
||||
assert(argp == TypeFunc::Parms + argcnt, "correct decoding");
|
||||
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields);
|
||||
|
||||
// result type needed
|
||||
fields = TypeTuple::fields(1);
|
||||
fields[TypeFunc::Parms + 0] = TypeInt::INT;
|
||||
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields);
|
||||
return TypeFunc::make(domain, range);
|
||||
}
|
||||
|
||||
|
||||
// Kyber XOF output parsing into polynomial coefficients candidates
|
||||
// or decompress(12,...) function
|
||||
static const TypeFunc* make_kyber12To16_Type() {
|
||||
int argcnt = 4;
|
||||
|
||||
const Type** fields = TypeTuple::fields(argcnt);
|
||||
int argp = TypeFunc::Parms;
|
||||
fields[argp++] = TypePtr::NOTNULL; // condensed
|
||||
fields[argp++] = TypeInt::INT; // condensedOffs
|
||||
fields[argp++] = TypePtr::NOTNULL; // parsed
|
||||
fields[argp++] = TypeInt::INT; // parsedLength
|
||||
|
||||
assert(argp == TypeFunc::Parms + argcnt, "correct decoding");
|
||||
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields);
|
||||
|
||||
// result type needed
|
||||
fields = TypeTuple::fields(1);
|
||||
fields[TypeFunc::Parms + 0] = TypeInt::INT;
|
||||
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields);
|
||||
return TypeFunc::make(domain, range);
|
||||
}
|
||||
|
||||
// Kyber Barrett reduce function
|
||||
static const TypeFunc* make_kyberBarrettReduce_Type() {
|
||||
int argcnt = 1;
|
||||
|
||||
const Type** fields = TypeTuple::fields(argcnt);
|
||||
int argp = TypeFunc::Parms;
|
||||
fields[argp++] = TypePtr::NOTNULL; // coeffs
|
||||
assert(argp == TypeFunc::Parms + argcnt, "correct decoding");
|
||||
const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms + argcnt, fields);
|
||||
|
||||
// result type needed
|
||||
fields = TypeTuple::fields(1);
|
||||
fields[TypeFunc::Parms + 0] = TypeInt::INT;
|
||||
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms + 1, fields);
|
||||
return TypeFunc::make(domain, range);
|
||||
}
|
||||
|
||||
// Dilithium NTT function except for the final "normalization" to |coeff| < Q
|
||||
static const TypeFunc* make_dilithiumAlmostNtt_Type() {
|
||||
int argcnt = 2;
|
||||
@ -2120,13 +2265,18 @@ void OptoRuntime::initialize_types() {
|
||||
_vectorizedMismatch_Type = make_vectorizedMismatch_Type();
|
||||
_ghash_processBlocks_Type = make_ghash_processBlocks_Type();
|
||||
_chacha20Block_Type = make_chacha20Block_Type();
|
||||
|
||||
_kyberNtt_Type = make_kyberNtt_Type();
|
||||
_kyberInverseNtt_Type = make_kyberInverseNtt_Type();
|
||||
_kyberNttMult_Type = make_kyberNttMult_Type();
|
||||
_kyberAddPoly_2_Type = make_kyberAddPoly_2_Type();
|
||||
_kyberAddPoly_3_Type = make_kyberAddPoly_3_Type();
|
||||
_kyber12To16_Type = make_kyber12To16_Type();
|
||||
_kyberBarrettReduce_Type = make_kyberBarrettReduce_Type();
|
||||
_dilithiumAlmostNtt_Type = make_dilithiumAlmostNtt_Type();
|
||||
_dilithiumAlmostInverseNtt_Type = make_dilithiumAlmostInverseNtt_Type();
|
||||
_dilithiumNttMult_Type = make_dilithiumNttMult_Type();
|
||||
_dilithiumMontMulByConstant_Type = make_dilithiumMontMulByConstant_Type();
|
||||
_dilithiumDecomposePoly_Type = make_dilithiumDecomposePoly_Type();
|
||||
|
||||
_base64_encodeBlock_Type = make_base64_encodeBlock_Type();
|
||||
_base64_decodeBlock_Type = make_base64_decodeBlock_Type();
|
||||
_string_IndexOf_Type = make_string_IndexOf_Type();
|
||||
|
||||
@ -180,6 +180,13 @@ class OptoRuntime : public AllStatic {
|
||||
static const TypeFunc* _vectorizedMismatch_Type;
|
||||
static const TypeFunc* _ghash_processBlocks_Type;
|
||||
static const TypeFunc* _chacha20Block_Type;
|
||||
static const TypeFunc* _kyberNtt_Type;
|
||||
static const TypeFunc* _kyberInverseNtt_Type;
|
||||
static const TypeFunc* _kyberNttMult_Type;
|
||||
static const TypeFunc* _kyberAddPoly_2_Type;
|
||||
static const TypeFunc* _kyberAddPoly_3_Type;
|
||||
static const TypeFunc* _kyber12To16_Type;
|
||||
static const TypeFunc* _kyberBarrettReduce_Type;
|
||||
static const TypeFunc* _dilithiumAlmostNtt_Type;
|
||||
static const TypeFunc* _dilithiumAlmostInverseNtt_Type;
|
||||
static const TypeFunc* _dilithiumNttMult_Type;
|
||||
@ -468,6 +475,10 @@ private:
|
||||
return _unsafe_setmemory_Type;
|
||||
}
|
||||
|
||||
// static const TypeFunc* digestBase_implCompress_Type(bool is_sha3);
|
||||
// static const TypeFunc* digestBase_implCompressMB_Type(bool is_sha3);
|
||||
// static const TypeFunc* double_keccak_Type();
|
||||
|
||||
static inline const TypeFunc* array_fill_Type() {
|
||||
assert(_array_fill_Type != nullptr, "should be initialized");
|
||||
return _array_fill_Type;
|
||||
@ -584,6 +595,41 @@ private:
|
||||
return _chacha20Block_Type;
|
||||
}
|
||||
|
||||
static const TypeFunc* kyberNtt_Type() {
|
||||
assert(_kyberNtt_Type != nullptr, "should be initialized");
|
||||
return _kyberNtt_Type;
|
||||
}
|
||||
|
||||
static const TypeFunc* kyberInverseNtt_Type() {
|
||||
assert(_kyberInverseNtt_Type != nullptr, "should be initialized");
|
||||
return _kyberInverseNtt_Type;
|
||||
}
|
||||
|
||||
static const TypeFunc* kyberNttMult_Type() {
|
||||
assert(_kyberNttMult_Type != nullptr, "should be initialized");
|
||||
return _kyberNttMult_Type;
|
||||
}
|
||||
|
||||
static const TypeFunc* kyberAddPoly_2_Type() {
|
||||
assert(_kyberAddPoly_2_Type != nullptr, "should be initialized");
|
||||
return _kyberAddPoly_2_Type;
|
||||
}
|
||||
|
||||
static const TypeFunc* kyberAddPoly_3_Type() {
|
||||
assert(_kyberAddPoly_3_Type != nullptr, "should be initialized");
|
||||
return _kyberAddPoly_3_Type;
|
||||
}
|
||||
|
||||
static const TypeFunc* kyber12To16_Type() {
|
||||
assert(_kyber12To16_Type != nullptr, "should be initialized");
|
||||
return _kyber12To16_Type;
|
||||
}
|
||||
|
||||
static const TypeFunc* kyberBarrettReduce_Type() {
|
||||
assert(_kyberBarrettReduce_Type != nullptr, "should be initialized");
|
||||
return _kyberBarrettReduce_Type;
|
||||
}
|
||||
|
||||
static inline const TypeFunc* dilithiumAlmostNtt_Type() {
|
||||
assert(_dilithiumAlmostNtt_Type != nullptr, "should be initialized");
|
||||
return _dilithiumAlmostNtt_Type;
|
||||
|
||||
@ -325,6 +325,8 @@ const int ObjectAlignmentInBytes = 8;
|
||||
product(bool, UseChaCha20Intrinsics, false, DIAGNOSTIC, \
|
||||
"Use intrinsics for the vectorized version of ChaCha20") \
|
||||
\
|
||||
product(bool, UseKyberIntrinsics, false, DIAGNOSTIC, \
|
||||
"Use intrinsics for the vectorized version of Kyber") \
|
||||
product(bool, UseDilithiumIntrinsics, false, DIAGNOSTIC, \
|
||||
"Use intrinsics for the vectorized version of Dilithium") \
|
||||
\
|
||||
|
||||
@ -678,6 +678,21 @@
|
||||
ghash_processBlocks) \
|
||||
do_stub(compiler, chacha20Block) \
|
||||
do_entry(compiler, chacha20Block, chacha20Block, chacha20Block) \
|
||||
do_stub(compiler, kyberNtt) \
|
||||
do_entry(compiler, kyberNtt, kyberNtt, kyberNtt) \
|
||||
do_stub(compiler, kyberInverseNtt) \
|
||||
do_entry(compiler, kyberInverseNtt, kyberInverseNtt, kyberInverseNtt) \
|
||||
do_stub(compiler, kyberNttMult) \
|
||||
do_entry(compiler, kyberNttMult, kyberNttMult, kyberNttMult) \
|
||||
do_stub(compiler, kyberAddPoly_2) \
|
||||
do_entry(compiler, kyberAddPoly_2, kyberAddPoly_2, kyberAddPoly_2) \
|
||||
do_stub(compiler, kyberAddPoly_3) \
|
||||
do_entry(compiler, kyberAddPoly_3, kyberAddPoly_3, kyberAddPoly_3) \
|
||||
do_stub(compiler, kyber12To16) \
|
||||
do_entry(compiler, kyber12To16, kyber12To16, kyber12To16) \
|
||||
do_stub(compiler, kyberBarrettReduce) \
|
||||
do_entry(compiler, kyberBarrettReduce, kyberBarrettReduce, \
|
||||
kyberBarrettReduce) \
|
||||
do_stub(compiler, dilithiumAlmostNtt) \
|
||||
do_entry(compiler, dilithiumAlmostNtt, \
|
||||
dilithiumAlmostNtt, dilithiumAlmostNtt) \
|
||||
|
||||
@ -28,6 +28,7 @@ package com.sun.crypto.provider;
|
||||
import java.security.*;
|
||||
import java.util.Arrays;
|
||||
import javax.crypto.DecapsulateException;
|
||||
import jdk.internal.vm.annotation.IntrinsicCandidate;
|
||||
|
||||
import sun.security.provider.SHA3.SHAKE256;
|
||||
import sun.security.provider.SHA3Parallel.Shake128Parallel;
|
||||
@ -71,6 +72,268 @@ public final class ML_KEM {
|
||||
-1599, -709, -789, -1317, -57, 1049, -584
|
||||
};
|
||||
|
||||
private static final short[] montZetasForVectorNttArr = new short[]{
|
||||
// level 0
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
-758, -758, -758, -758, -758, -758, -758, -758,
|
||||
// level 1
|
||||
-359, -359, -359, -359, -359, -359, -359, -359,
|
||||
-359, -359, -359, -359, -359, -359, -359, -359,
|
||||
-359, -359, -359, -359, -359, -359, -359, -359,
|
||||
-359, -359, -359, -359, -359, -359, -359, -359,
|
||||
-359, -359, -359, -359, -359, -359, -359, -359,
|
||||
-359, -359, -359, -359, -359, -359, -359, -359,
|
||||
-359, -359, -359, -359, -359, -359, -359, -359,
|
||||
-359, -359, -359, -359, -359, -359, -359, -359,
|
||||
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
|
||||
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
|
||||
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
|
||||
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
|
||||
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
|
||||
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
|
||||
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
|
||||
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
|
||||
// level 2
|
||||
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
|
||||
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
|
||||
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
|
||||
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
|
||||
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
|
||||
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
|
||||
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
|
||||
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
|
||||
287, 287, 287, 287, 287, 287, 287, 287,
|
||||
287, 287, 287, 287, 287, 287, 287, 287,
|
||||
287, 287, 287, 287, 287, 287, 287, 287,
|
||||
287, 287, 287, 287, 287, 287, 287, 287,
|
||||
202, 202, 202, 202, 202, 202, 202, 202,
|
||||
202, 202, 202, 202, 202, 202, 202, 202,
|
||||
202, 202, 202, 202, 202, 202, 202, 202,
|
||||
202, 202, 202, 202, 202, 202, 202, 202,
|
||||
// level 3
|
||||
-171, -171, -171, -171, -171, -171, -171, -171,
|
||||
-171, -171, -171, -171, -171, -171, -171, -171,
|
||||
622, 622, 622, 622, 622, 622, 622, 622,
|
||||
622, 622, 622, 622, 622, 622, 622, 622,
|
||||
1577, 1577, 1577, 1577, 1577, 1577, 1577, 1577,
|
||||
1577, 1577, 1577, 1577, 1577, 1577, 1577, 1577,
|
||||
182, 182, 182, 182, 182, 182, 182, 182,
|
||||
182, 182, 182, 182, 182, 182, 182, 182,
|
||||
962, 962, 962, 962, 962, 962, 962, 962,
|
||||
962, 962, 962, 962, 962, 962, 962, 962,
|
||||
-1202, -1202, -1202, -1202, -1202, -1202, -1202, -1202,
|
||||
-1202, -1202, -1202, -1202, -1202, -1202, -1202, -1202,
|
||||
-1474, -1474, -1474, -1474, -1474, -1474, -1474, -1474,
|
||||
-1474, -1474, -1474, -1474, -1474, -1474, -1474, -1474,
|
||||
1468, 1468, 1468, 1468, 1468, 1468, 1468, 1468,
|
||||
1468, 1468, 1468, 1468, 1468, 1468, 1468, 1468,
|
||||
// level 4
|
||||
573, 573, 573, 573, 573, 573, 573, 573,
|
||||
-1325, -1325, -1325, -1325, -1325, -1325, -1325, -1325,
|
||||
264, 264, 264, 264, 264, 264, 264, 264,
|
||||
383, 383, 383, 383, 383, 383, 383, 383,
|
||||
-829, -829, -829, -829, -829, -829, -829, -829,
|
||||
1458, 1458, 1458, 1458, 1458, 1458, 1458, 1458,
|
||||
-1602, -1602, -1602, -1602, -1602, -1602, -1602, -1602,
|
||||
-130, -130, -130, -130, -130, -130, -130, -130,
|
||||
-681, -681, -681, -681, -681, -681, -681, -681,
|
||||
1017, 1017, 1017, 1017, 1017, 1017, 1017, 1017,
|
||||
732, 732, 732, 732, 732, 732, 732, 732,
|
||||
608, 608, 608, 608, 608, 608, 608, 608,
|
||||
-1542, -1542, -1542, -1542, -1542, -1542, -1542, -1542,
|
||||
411, 411, 411, 411, 411, 411, 411, 411,
|
||||
-205, -205, -205, -205, -205, -205, -205, -205,
|
||||
-1571, -1571, -1571, -1571, -1571, -1571, -1571, -1571,
|
||||
// level 5
|
||||
1223, 1223, 1223, 1223, 652, 652, 652, 652,
|
||||
-552, -552, -552, -552, 1015, 1015, 1015, 1015,
|
||||
-1293, -1293, -1293, -1293, 1491, 1491, 1491, 1491,
|
||||
-282, -282, -282, -282, -1544, -1544, -1544, -1544,
|
||||
516, 516, 516, 516, -8, -8, -8, -8,
|
||||
-320, -320, -320, -320, -666, -666, -666, -666,
|
||||
1711, 1711, 1711, 1711, -1162, -1162, -1162, -1162,
|
||||
126, 126, 126, 126, 1469, 1469, 1469, 1469,
|
||||
-853, -853, -853, -853, -90, -90, -90, -90,
|
||||
-271, -271, -271, -271, 830, 830, 830, 830,
|
||||
107, 107, 107, 107, -1421, -1421, -1421, -1421,
|
||||
-247, -247, -247, -247, -951, -951, -951, -951,
|
||||
-398, -398, -398, -398, 961, 961, 961, 961,
|
||||
-1508, -1508, -1508, -1508, -725, -725, -725, -725,
|
||||
448, 448, 448, 448, -1065, -1065, -1065, -1065,
|
||||
677, 677, 677, 677, -1275, -1275, -1275, -1275,
|
||||
// level 6
|
||||
-1103, -1103, 430, 430, 555, 555, 843, 843,
|
||||
-1251, -1251, 871, 871, 1550, 1550, 105, 105,
|
||||
422, 422, 587, 587, 177, 177, -235, -235,
|
||||
-291, -291, -460, -460, 1574, 1574, 1653, 1653,
|
||||
-246, -246, 778, 778, 1159, 1159, -147, -147,
|
||||
-777, -777, 1483, 1483, -602, -602, 1119, 1119,
|
||||
-1590, -1590, 644, 644, -872, -872, 349, 349,
|
||||
418, 418, 329, 329, -156, -156, -75, -75,
|
||||
817, 817, 1097, 1097, 603, 603, 610, 610,
|
||||
1322, 1322, -1285, -1285, -1465, -1465, 384, 384,
|
||||
-1215, -1215, -136, -136, 1218, 1218, -1335, -1335,
|
||||
-874, -874, 220, 220, -1187, -1187, 1670, 1670,
|
||||
-1185, -1185, -1530, -1530, -1278, -1278, 794, 794,
|
||||
-1510, -1510, -854, -854, -870, -870, 478, 478,
|
||||
-108, -108, -308, -308, 996, 996, 991, 991,
|
||||
958, 958, -1460, -1460, 1522, 1522, 1628, 1628
|
||||
};
|
||||
private static final int[] MONT_ZETAS_FOR_INVERSE_NTT = new int[]{
|
||||
584, -1049, 57, 1317, 789, 709, 1599, -1601,
|
||||
-990, 604, 348, 857, 612, 474, 1177, -1014,
|
||||
-88, -982, -191, 668, 1386, 486, -1153, -534,
|
||||
514, 137, 586, -1178, 227, 339, -907, 244,
|
||||
1200, -833, 1394, -30, 1074, 636, -317, -1192,
|
||||
-1259, -355, -425, -884, -977, 1430, 868, 607,
|
||||
184, 1448, 702, 1327, 431, 497, 595, -94,
|
||||
1649, -1497, -620, 42, -172, 1107, -222, 1003,
|
||||
426, -845, 395, -510, 1613, 825, 1269, -290,
|
||||
-1429, 623, -567, 1617, 36, 1007, 1440, 332,
|
||||
-201, 1313, -1382, -744, 669, -1538, 128, -1598,
|
||||
1401, 1183, -553, 714, 405, -1155, -445, 406,
|
||||
-1496, -49, 82, 1369, 259, 1604, 373, 909,
|
||||
-1249, -1000, -25, -52, 530, -895, 1226, 819,
|
||||
-185, 281, -742, 1253, 417, 1400, 35, -593,
|
||||
97, -1263, 551, -585, 969, -914, -1188
|
||||
};
|
||||
|
||||
private static final short[] montZetasForVectorInverseNttArr = new short[]{
|
||||
// level 0
|
||||
-1628, -1628, -1522, -1522, 1460, 1460, -958, -958,
|
||||
-991, -991, -996, -996, 308, 308, 108, 108,
|
||||
-478, -478, 870, 870, 854, 854, 1510, 1510,
|
||||
-794, -794, 1278, 1278, 1530, 1530, 1185, 1185,
|
||||
1659, 1659, 1187, 1187, -220, -220, 874, 874,
|
||||
1335, 1335, -1218, -1218, 136, 136, 1215, 1215,
|
||||
-384, -384, 1465, 1465, 1285, 1285, -1322, -1322,
|
||||
-610, -610, -603, -603, -1097, -1097, -817, -817,
|
||||
75, 75, 156, 156, -329, -329, -418, -418,
|
||||
-349, -349, 872, 872, -644, -644, 1590, 1590,
|
||||
-1119, -1119, 602, 602, -1483, -1483, 777, 777,
|
||||
147, 147, -1159, -1159, -778, -778, 246, 246,
|
||||
-1653, -1653, -1574, -1574, 460, 460, 291, 291,
|
||||
235, 235, -177, -177, -587, -587, -422, -422,
|
||||
-105, -105, -1550, -1550, -871, -871, 1251, 1251,
|
||||
-843, -843, -555, -555, -430, -430, 1103, 1103,
|
||||
// level 1
|
||||
1275, 1275, 1275, 1275, -677, -677, -677, -677,
|
||||
1065, 1065, 1065, 1065, -448, -448, -448, -448,
|
||||
725, 725, 725, 725, 1508, 1508, 1508, 1508,
|
||||
-961, -961, -961, -961, 398, 398, 398, 398,
|
||||
951, 951, 951, 951, 247, 247, 247, 247,
|
||||
1421, 1421, 1421, 1421, -107, -107, -107, -107,
|
||||
-830, -830, -830, -830, 271, 271, 271, 271,
|
||||
90, 90, 90, 90, 853, 853, 853, 853,
|
||||
-1469, -1469, -1469, -1469, -126, -126, -126, -126,
|
||||
1162, 1162, 1162, 1162, 1618, 1618, 1618, 1618,
|
||||
666, 666, 666, 666, 320, 320, 320, 320,
|
||||
8, 8, 8, 8, -516, -516, -516, -516,
|
||||
1544, 1544, 1544, 1544, 282, 282, 282, 282,
|
||||
-1491, -1491, -1491, -1491, 1293, 1293, 1293, 1293,
|
||||
-1015, -1015, -1015, -1015, 552, 552, 552, 552,
|
||||
-652, -652, -652, -652, -1223, -1223, -1223, -1223,
|
||||
// level 2
|
||||
1571, 1571, 1571, 1571, 1571, 1571, 1571, 1571,
|
||||
205, 205, 205, 205, 205, 205, 205, 205,
|
||||
-411, -411, -411, -411, -411, -411, -411, -411,
|
||||
1542, 1542, 1542, 1542, 1542, 1542, 1542, 1542,
|
||||
-608, -608, -608, -608, -608, -608, -608, -608,
|
||||
-732, -732, -732, -732, -732, -732, -732, -732,
|
||||
-1017, -1017, -1017, -1017, -1017, -1017, -1017, -1017,
|
||||
681, 681, 681, 681, 681, 681, 681, 681,
|
||||
130, 130, 130, 130, 130, 130, 130, 130,
|
||||
1602, 1602, 1602, 1602, 1602, 1602, 1602, 1602,
|
||||
-1458, -1458, -1458, -1458, -1458, -1458, -1458, -1458,
|
||||
829, 829, 829, 829, 829, 829, 829, 829,
|
||||
-383, -383, -383, -383, -383, -383, -383, -383,
|
||||
-264, -264, -264, -264, -264, -264, -264, -264,
|
||||
1325, 1325, 1325, 1325, 1325, 1325, 1325, 1325,
|
||||
-573, -573, -573, -573, -573, -573, -573, -573,
|
||||
// level 3
|
||||
-1468, -1468, -1468, -1468, -1468, -1468, -1468, -1468,
|
||||
-1468, -1468, -1468, -1468, -1468, -1468, -1468, -1468,
|
||||
1474, 1474, 1474, 1474, 1474, 1474, 1474, 1474,
|
||||
1474, 1474, 1474, 1474, 1474, 1474, 1474, 1474,
|
||||
1202, 1202, 1202, 1202, 1202, 1202, 1202, 1202,
|
||||
1202, 1202, 1202, 1202, 1202, 1202, 1202, 1202,
|
||||
-962, -962, -962, -962, -962, -962, -962, -962,
|
||||
-962, -962, -962, -962, -962, -962, -962, -962,
|
||||
-182, -182, -182, -182, -182, -182, -182, -182,
|
||||
-182, -182, -182, -182, -182, -182, -182, -182,
|
||||
-1577, -1577, -1577, -1577, -1577, -1577, -1577, -1577,
|
||||
-1577, -1577, -1577, -1577, -1577, -1577, -1577, -1577,
|
||||
-622, -622, -622, -622, -622, -622, -622, -622,
|
||||
-622, -622, -622, -622, -622, -622, -622, -622,
|
||||
171, 171, 171, 171, 171, 171, 171, 171,
|
||||
171, 171, 171, 171, 171, 171, 171, 171,
|
||||
// level 4
|
||||
-202, -202, -202, -202, -202, -202, -202, -202,
|
||||
-202, -202, -202, -202, -202, -202, -202, -202,
|
||||
-202, -202, -202, -202, -202, -202, -202, -202,
|
||||
-202, -202, -202, -202, -202, -202, -202, -202,
|
||||
-287, -287, -287, -287, -287, -287, -287, -287,
|
||||
-287, -287, -287, -287, -287, -287, -287, -287,
|
||||
-287, -287, -287, -287, -287, -287, -287, -287,
|
||||
-287, -287, -287, -287, -287, -287, -287, -287,
|
||||
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
|
||||
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
|
||||
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
|
||||
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
|
||||
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
|
||||
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
|
||||
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
|
||||
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
|
||||
// level 5
|
||||
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
|
||||
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
|
||||
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
|
||||
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
|
||||
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
|
||||
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
|
||||
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
|
||||
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
|
||||
359, 359, 359, 359, 359, 359, 359, 359,
|
||||
359, 359, 359, 359, 359, 359, 359, 359,
|
||||
359, 359, 359, 359, 359, 359, 359, 359,
|
||||
359, 359, 359, 359, 359, 359, 359, 359,
|
||||
359, 359, 359, 359, 359, 359, 359, 359,
|
||||
359, 359, 359, 359, 359, 359, 359, 359,
|
||||
359, 359, 359, 359, 359, 359, 359, 359,
|
||||
359, 359, 359, 359, 359, 359, 359, 359,
|
||||
// level 6
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758,
|
||||
758, 758, 758, 758, 758, 758, 758, 758
|
||||
};
|
||||
|
||||
private static final int[] MONT_ZETAS_FOR_NTT_MULT = new int[]{
|
||||
-1003, 1003, 222, -222, -1107, 1107, 172, -172,
|
||||
-42, 42, 620, -620, 1497, -1497, -1649, 1649,
|
||||
@ -89,6 +352,24 @@ public final class ML_KEM {
|
||||
1601, -1601, -1599, 1599, -709, 709, -789, 789,
|
||||
-1317, 1317, -57, 57, 1049, -1049, -584, 584
|
||||
};
|
||||
private static final short[] montZetasForVectorNttMultArr = new short[]{
|
||||
-1103, 1103, 430, -430, 555, -555, 843, -843,
|
||||
-1251, 1251, 871, -871, 1550, -1550, 105, -105,
|
||||
422, -422, 587, -587, 177, -177, -235, 235,
|
||||
-291, 291, -460, 460, 1574, -1574, 1653, -1653,
|
||||
-246, 246, 778, -778, 1159, -1159, -147, 147,
|
||||
-777, 777, 1483, -1483, -602, 602, 1119, -1119,
|
||||
-1590, 1590, 644, -644, -872, 872, 349, -349,
|
||||
418, -418, 329, -329, -156, 156, -75, 75,
|
||||
817, -817, 1097, -1097, 603, -603, 610, -610,
|
||||
1322, -1322, -1285, 1285, -1465, 1465, 384, -384,
|
||||
-1215, 1215, -136, 136, 1218, -1218, -1335, 1335,
|
||||
-874, 874, 220, -220, -1187, 1187, 1670, 1659,
|
||||
-1185, 1185, -1530, 1530, -1278, 1278, 794, -794,
|
||||
-1510, 1510, -854, 854, -870, 870, 478, -478,
|
||||
-108, 108, -308, 308, 996, -996, 991, -991,
|
||||
958, -958, -1460, 1460, 1522, -1522, 1628, -1628
|
||||
};
|
||||
|
||||
private final int mlKem_k;
|
||||
private final int mlKem_eta1;
|
||||
@ -261,7 +542,7 @@ public final class ML_KEM {
|
||||
try {
|
||||
mlKemH = MessageDigest.getInstance(HASH_H_NAME);
|
||||
mlKemG = MessageDigest.getInstance(HASH_G_NAME);
|
||||
} catch (NoSuchAlgorithmException e){
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
// This should never happen.
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
@ -527,7 +808,7 @@ public final class ML_KEM {
|
||||
|
||||
for (int i = 0; i < mlKem_k; i++) {
|
||||
for (int j = 0; j < mlKem_k; j++) {
|
||||
xofBufArr[parInd] = seedBuf.clone();
|
||||
System.arraycopy(seedBuf, 0, xofBufArr[parInd], 0, seedBuf.length);
|
||||
if (transposed) {
|
||||
xofBufArr[parInd][rhoLen] = (byte) i;
|
||||
xofBufArr[parInd][rhoLen + 1] = (byte) j;
|
||||
@ -707,9 +988,13 @@ public final class ML_KEM {
|
||||
return vector;
|
||||
}
|
||||
|
||||
// The elements of poly should be in the range [-ML_KEM_Q, ML_KEM_Q]
|
||||
// The elements of poly at return will be in the range of [0, ML_KEM_Q]
|
||||
private void mlKemNTT(short[] poly) {
|
||||
@IntrinsicCandidate
|
||||
static int implKyberNtt(short[] poly, short[] ntt_zetas) {
|
||||
implKyberNttJava(poly);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static void implKyberNttJava(short[] poly) {
|
||||
int[] coeffs = new int[ML_KEM_N];
|
||||
for (int m = 0; m < ML_KEM_N; m++) {
|
||||
coeffs[m] = poly[m];
|
||||
@ -718,12 +1003,23 @@ public final class ML_KEM {
|
||||
for (int m = 0; m < ML_KEM_N; m++) {
|
||||
poly[m] = (short) coeffs[m];
|
||||
}
|
||||
}
|
||||
|
||||
// The elements of poly should be in the range [-mlKem_q, mlKem_q]
|
||||
// The elements of poly at return will be in the range of [0, mlKem_q]
|
||||
private void mlKemNTT(short[] poly) {
|
||||
assert poly.length == ML_KEM_N;
|
||||
implKyberNtt(poly, montZetasForVectorNttArr);
|
||||
mlKemBarrettReduce(poly);
|
||||
}
|
||||
|
||||
// Works in place, but also returns its (modified) input so that it can
|
||||
// be used in expressions
|
||||
private short[] mlKemInverseNTT(short[] poly) {
|
||||
@IntrinsicCandidate
|
||||
static int implKyberInverseNtt(short[] poly, short[] zetas) {
|
||||
implKyberInverseNttJava(poly);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static void implKyberInverseNttJava(short[] poly) {
|
||||
int[] coeffs = new int[ML_KEM_N];
|
||||
for (int m = 0; m < ML_KEM_N; m++) {
|
||||
coeffs[m] = poly[m];
|
||||
@ -732,6 +1028,13 @@ public final class ML_KEM {
|
||||
for (int m = 0; m < ML_KEM_N; m++) {
|
||||
poly[m] = (short) coeffs[m];
|
||||
}
|
||||
}
|
||||
|
||||
// Works in place, but also returns its (modified) input so that it can
|
||||
// be used in expressions
|
||||
private short[] mlKemInverseNTT(short[] poly) {
|
||||
assert poly.length == ML_KEM_N;
|
||||
implKyberInverseNtt(poly, montZetasForVectorInverseNttArr);
|
||||
return poly;
|
||||
}
|
||||
|
||||
@ -822,11 +1125,16 @@ public final class ML_KEM {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Multiplies two polynomials represented in the NTT domain.
|
||||
// The result is a representation of the product still in the NTT domain.
|
||||
// The coefficients in the result are in the range (-ML_KEM_Q, ML_KEM_Q).
|
||||
private void nttMult(short[] result, short[] ntta, short[] nttb) {
|
||||
@IntrinsicCandidate
|
||||
static int implKyberNttMult(short[] result, short[] ntta, short[] nttb,
|
||||
short[] zetas) {
|
||||
implKyberNttMultJava(result, ntta, nttb);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static void implKyberNttMultJava(short[] result, short[] ntta, short[] nttb) {
|
||||
for (int m = 0; m < ML_KEM_N / 2; m++) {
|
||||
|
||||
int a0 = ntta[2 * m];
|
||||
int a1 = ntta[2 * m + 1];
|
||||
int b0 = nttb[2 * m];
|
||||
@ -839,6 +1147,15 @@ public final class ML_KEM {
|
||||
}
|
||||
}
|
||||
|
||||
// Multiplies two polynomials represented in the NTT domain.
|
||||
// The result is a representation of the product still in the NTT domain.
|
||||
// The coefficients in the result are in the range (-mlKem_q, mlKem_q).
|
||||
private void nttMult(short[] result, short[] ntta, short[] nttb) {
|
||||
assert (result.length == ML_KEM_N) && (ntta.length == ML_KEM_N) &&
|
||||
(nttb.length == ML_KEM_N);
|
||||
implKyberNttMult(result, ntta, nttb, montZetasForVectorNttMultArr);
|
||||
}
|
||||
|
||||
// Adds the vector of polynomials b to a in place, i.e. a will hold
|
||||
// the result. It also returns (the modified) a so that it can be used
|
||||
// in an expression.
|
||||
@ -853,15 +1170,41 @@ public final class ML_KEM {
|
||||
return a;
|
||||
}
|
||||
|
||||
@IntrinsicCandidate
|
||||
static int implKyberAddPoly(short[] result, short[] a, short[] b) {
|
||||
implKyberAddPolyJava(result, a, b);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static void implKyberAddPolyJava(short[] result, short[] a, short[] b) {
|
||||
for (int m = 0; m < ML_KEM_N; m++) {
|
||||
int r = a[m] + b[m] + ML_KEM_Q; // This makes r > - ML_KEM_Q
|
||||
a[m] = (short) r;
|
||||
}
|
||||
mlKemBarrettReduce(a);
|
||||
}
|
||||
|
||||
// Adds the polynomial b to a in place, i.e. (the modified) a will hold
|
||||
// the result.
|
||||
// The coefficients are supposed be greater than -ML_KEM_Q in a and
|
||||
// greater than -ML_KEM_Q and less than ML_KEM_Q in b.
|
||||
// The coefficients in the result are greater than -ML_KEM_Q.
|
||||
private void mlKemAddPoly(short[] a, short[] b) {
|
||||
private short[] mlKemAddPoly(short[] a, short[] b) {
|
||||
assert (a.length == ML_KEM_N) && (b.length == ML_KEM_N);
|
||||
implKyberAddPoly(a, a, b);
|
||||
return a;
|
||||
}
|
||||
|
||||
@IntrinsicCandidate
|
||||
static int implKyberAddPoly(short[] result, short[] a, short[] b, short[] c) {
|
||||
implKyberAddPolyJava(result, a, b, c);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static void implKyberAddPolyJava(short[] result, short[] a, short[] b, short[] c) {
|
||||
for (int m = 0; m < ML_KEM_N; m++) {
|
||||
int r = a[m] + b[m] + ML_KEM_Q; // This makes r > -ML_KEM_Q
|
||||
a[m] = (short) r;
|
||||
int r = a[m] + b[m] + c[m] + 2 * ML_KEM_Q; // This makes r > - ML_KEM_Q
|
||||
result[m] = (short) r;
|
||||
}
|
||||
}
|
||||
|
||||
@ -871,10 +1214,9 @@ public final class ML_KEM {
|
||||
// greater than -ML_KEM_Q and less than ML_KEM_Q.
|
||||
// The coefficients in the result are nonnegative and less than ML_KEM_Q.
|
||||
private short[] mlKemAddPoly(short[] a, short[] b, short[] c) {
|
||||
for (int m = 0; m < ML_KEM_N; m++) {
|
||||
int r = a[m] + b[m] + c[m] + 2 * ML_KEM_Q; // This makes r > - ML_KEM_Q
|
||||
a[m] = (short) r;
|
||||
}
|
||||
assert (a.length == ML_KEM_N) && (b.length == ML_KEM_N) &&
|
||||
(c.length == ML_KEM_N);
|
||||
implKyberAddPoly(a, a, b, c);
|
||||
mlKemBarrettReduce(a);
|
||||
return a;
|
||||
}
|
||||
@ -997,15 +1339,13 @@ public final class ML_KEM {
|
||||
return result;
|
||||
}
|
||||
|
||||
// The intrinsic implementations assume that the input and output buffers
|
||||
// are such that condensed can be read in 192-byte chunks and
|
||||
// parsed can be written in 128 shorts chunks. In other words,
|
||||
// if (i - 1) * 128 < parsedLengths <= i * 128 then
|
||||
// parsed.size should be at least i * 128 and
|
||||
// condensed.size should be at least index + i * 192
|
||||
private void twelve2Sixteen(byte[] condensed, int index,
|
||||
short[] parsed, int parsedLength) {
|
||||
@IntrinsicCandidate
|
||||
private static int implKyber12To16(byte[] condensed, int index, short[] parsed, int parsedLength) {
|
||||
implKyber12To16Java(condensed, index, parsed, parsedLength);
|
||||
return 1;
|
||||
}
|
||||
|
||||
private static void implKyber12To16Java(byte[] condensed, int index, short[] parsed, int parsedLength) {
|
||||
for (int i = 0; i < parsedLength * 3 / 2; i += 3) {
|
||||
parsed[(i / 3) * 2] = (short) ((condensed[i + index] & 0xff) +
|
||||
256 * (condensed[i + index + 1] & 0xf));
|
||||
@ -1014,6 +1354,25 @@ public final class ML_KEM {
|
||||
}
|
||||
}
|
||||
|
||||
// The intrinsic implementations assume that the input and output buffers
|
||||
// are such that condensed can be read in 96-byte chunks and
|
||||
// parsed can be written in 64 shorts chunks except for the last chunk
|
||||
// that can be either 48 or 64 shorts. In other words,
|
||||
// if (i - 1) * 64 < parsedLengths <= i * 64 then
|
||||
// parsed.length should be either i * 64 or (i-1) * 64 + 48 and
|
||||
// condensed.length should be at least index + i * 96.
|
||||
private void twelve2Sixteen(byte[] condensed, int index,
|
||||
short[] parsed, int parsedLength) {
|
||||
int i = parsedLength / 64;
|
||||
int remainder = parsedLength - i * 64;
|
||||
if (remainder != 0) {
|
||||
i++;
|
||||
}
|
||||
assert ((remainder == 0) || (remainder == 48)) &&
|
||||
(index + i * 96 <= condensed.length);
|
||||
implKyber12To16(condensed, index, parsed, parsedLength);
|
||||
}
|
||||
|
||||
private static void decodePoly5(byte[] condensed, int index, short[] parsed) {
|
||||
int j = index;
|
||||
for (int i = 0; i < ML_KEM_N; i += 8) {
|
||||
@ -1152,6 +1511,19 @@ public final class ML_KEM {
|
||||
return result;
|
||||
}
|
||||
|
||||
@IntrinsicCandidate
|
||||
static int implKyberBarrettReduce(short[] coeffs) {
|
||||
implKyberBarrettReduceJava(coeffs);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static void implKyberBarrettReduceJava(short[] poly) {
|
||||
for (int m = 0; m < ML_KEM_N; m++) {
|
||||
int tmp = ((int) poly[m] * BARRETT_MULTIPLIER) >> BARRETT_SHIFT;
|
||||
poly[m] = (short) (poly[m] - tmp * ML_KEM_Q);
|
||||
}
|
||||
}
|
||||
|
||||
// The input elements can have any short value.
|
||||
// Modifies poly such that upon return poly[i] will be
|
||||
// in the range [0, ML_KEM_Q] and will be congruent with the original
|
||||
@ -1161,11 +1533,9 @@ public final class ML_KEM {
|
||||
// That means that if the original poly[i] > -ML_KEM_Q then at return it
|
||||
// will be in the range [0, ML_KEM_Q), i.e. it will be the canonical
|
||||
// representative of its residue class.
|
||||
private void mlKemBarrettReduce(short[] poly) {
|
||||
for (int m = 0; m < ML_KEM_N; m++) {
|
||||
int tmp = ((int) poly[m] * BARRETT_MULTIPLIER) >> BARRETT_SHIFT;
|
||||
poly[m] = (short) (poly[m] - tmp * ML_KEM_Q);
|
||||
}
|
||||
private static void mlKemBarrettReduce(short[] poly) {
|
||||
assert poly.length == ML_KEM_N;
|
||||
implKyberBarrettReduce(poly);
|
||||
}
|
||||
|
||||
// Precondition: -(2^MONT_R_BITS -1) * MONT_Q <= b * c < (2^MONT_R_BITS - 1) * MONT_Q
|
||||
|
||||
@ -1554,7 +1554,7 @@ public class ML_DSA {
|
||||
// precondition: -2^31 * MONT_Q <= a, b < 2^31, -2^31 < a * b < 2^31 * MONT_Q
|
||||
// computes a * b * 2^-32 mod MONT_Q
|
||||
// the result is greater than -MONT_Q and less than MONT_Q
|
||||
// see e.g. Algorithm 3 in https://eprint.iacr.org/2018/039.pdf
|
||||
// See e.g. Algorithm 3 in https://eprint.iacr.org/2018/039.pdf
|
||||
private static int montMul(int b, int c) {
|
||||
long a = (long) b * (long) c;
|
||||
int aHigh = (int) (a >> MONT_R_BITS);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user