From ac3ad03a3f946fbff147732c5f403c8dc445eed8 Mon Sep 17 00:00:00 2001 From: Andrew Dinn Date: Wed, 19 Mar 2025 17:23:23 +0000 Subject: [PATCH] 8350589: Investigate cleaner implementation of AArch64 ML-DSA intrinsic introduced in JDK-8348561 Reviewed-by: dlong --- src/hotspot/cpu/aarch64/register_aarch64.cpp | 20 + src/hotspot/cpu/aarch64/register_aarch64.hpp | 68 + .../cpu/aarch64/stubGenerator_aarch64.cpp | 1149 +++++++++-------- 3 files changed, 666 insertions(+), 571 deletions(-) diff --git a/src/hotspot/cpu/aarch64/register_aarch64.cpp b/src/hotspot/cpu/aarch64/register_aarch64.cpp index 82683daae4f..349845154e2 100644 --- a/src/hotspot/cpu/aarch64/register_aarch64.cpp +++ b/src/hotspot/cpu/aarch64/register_aarch64.cpp @@ -58,3 +58,23 @@ const char* PRegister::PRegisterImpl::name() const { }; return is_valid() ? names[encoding()] : "pnoreg"; } + +// convenience methods for splitting 8-way vector register sequences +// in half -- needed because vector operations can normally only be +// benefit from 4-way instruction parallelism + +VSeq<4> vs_front(const VSeq<8>& v) { + return VSeq<4>(v.base(), v.delta()); +} + +VSeq<4> vs_back(const VSeq<8>& v) { + return VSeq<4>(v.base() + 4 * v.delta(), v.delta()); +} + +VSeq<4> vs_even(const VSeq<8>& v) { + return VSeq<4>(v.base(), v.delta() * 2); +} + +VSeq<4> vs_odd(const VSeq<8>& v) { + return VSeq<4>(v.base() + 1, v.delta() * 2); +} diff --git a/src/hotspot/cpu/aarch64/register_aarch64.hpp b/src/hotspot/cpu/aarch64/register_aarch64.hpp index ab635e1be90..45578336cfe 100644 --- a/src/hotspot/cpu/aarch64/register_aarch64.hpp +++ b/src/hotspot/cpu/aarch64/register_aarch64.hpp @@ -412,4 +412,72 @@ inline Register as_Register(FloatRegister reg) { // High-level register class of an OptoReg or a VMReg register. enum RC { rc_bad, rc_int, rc_float, rc_predicate, rc_stack }; +// AArch64 Vector Register Sequence management support +// +// VSeq implements an indexable (by operator[]) vector register +// sequence starting from a fixed base register and with a fixed delta +// (defaulted to 1, but sometimes 0 or 2) e.g. VSeq<4>(16) will return +// registers v16, ... v19 for indices 0, ... 3. +// +// Generator methods may iterate across sets of VSeq<4> to schedule an +// operation 4 times using distinct input and output registers, +// profiting from 4-way instruction parallelism. +// +// A VSeq<2> can be used to specify registers loaded with special +// constants e.g. --> . +// +// A VSeq with base n and delta 0 can be used to generate code that +// combines values in another VSeq with the constant in register vn. +// +// A VSeq with base n and delta 2 can be used to select an odd or even +// indexed set of registers. +// +// Methods which accept arguments of type VSeq<8>, may split their +// inputs into front and back halves or odd and even halves (see +// convenience methods below). + +template class VSeq { + static_assert(N >= 2, "vector sequence length must be greater than 1"); + static_assert(N <= 8, "vector sequence length must not exceed 8"); + static_assert((N & (N - 1)) == 0, "vector sequence length must be power of two"); +private: + int _base; // index of first register in sequence + int _delta; // increment to derive successive indices +public: + VSeq(FloatRegister base_reg, int delta = 1) : VSeq(base_reg->encoding(), delta) { } + VSeq(int base, int delta = 1) : _base(base), _delta(delta) { + assert (_base >= 0, "invalid base register"); + assert (_delta >= 0, "invalid register delta"); + assert ((_base + (N - 1) * _delta) < 32, "range exceeded"); + } + // indexed access to sequence + FloatRegister operator [](int i) const { + assert (0 <= i && i < N, "index out of bounds"); + return as_FloatRegister(_base + i * _delta); + } + int mask() const { + int m = 0; + int bit = 1 << _base; + for (int i = 0; i < N; i++) { + m |= bit << (i * _delta); + } + return m; + } + int base() const { return _base; } + int delta() const { return _delta; } +}; + +// declare convenience methods for splitting vector register sequences + +VSeq<4> vs_front(const VSeq<8>& v); +VSeq<4> vs_back(const VSeq<8>& v); +VSeq<4> vs_even(const VSeq<8>& v); +VSeq<4> vs_odd(const VSeq<8>& v); + +// methods for use in asserts to check VSeq inputs and oupts are +// either disjoint or equal + +template bool vs_disjoint(const VSeq& n, const VSeq& m) { return (n.mask() & m.mask()) == 0; } +template bool vs_same(const VSeq& n, const VSeq& m) { return n.mask() == m.mask(); } + #endif // CPU_AARCH64_REGISTER_AARCH64_HPP diff --git a/src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp b/src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp index 7b8f5eca374..f0f145e3d76 100644 --- a/src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp +++ b/src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp @@ -4643,144 +4643,290 @@ class StubGenerator: public StubCodeGenerator { return start; } - void dilithium_load16zetas(int o0, Register zetas) { - __ ldpq(as_FloatRegister(o0), as_FloatRegister(o0 + 1), __ post (zetas, 32)); - __ ldpq(as_FloatRegister(o0 + 2), as_FloatRegister(o0 + 3), __ post (zetas, 32)); + // Helpers to schedule parallel operation bundles across vector + // register sequences of size 2, 4 or 8. + // Implement various primitive computations across vector sequences + + template + void vs_addv(const VSeq& v, Assembler::SIMD_Arrangement T, + const VSeq& v1, const VSeq& v2) { + for (int i = 0; i < N; i++) { + __ addv(v[i], T, v1[i], v2[i]); + } } - void dilithium_load32zetas(Register zetas) { - dilithium_load16zetas(16, zetas); - dilithium_load16zetas(20, zetas); + template + void vs_subv(const VSeq& v, Assembler::SIMD_Arrangement T, + const VSeq& v1, const VSeq& v2) { + for (int i = 0; i < N; i++) { + __ subv(v[i], T, v1[i], v2[i]); + } } - // 2x16 32-bit Montgomery multiplications in parallel + template + void vs_mulv(const VSeq& v, Assembler::SIMD_Arrangement T, + const VSeq& v1, const VSeq& v2) { + for (int i = 0; i < N; i++) { + __ mulv(v[i], T, v1[i], v2[i]); + } + } + + template + void vs_negr(const VSeq& v, Assembler::SIMD_Arrangement T, const VSeq& v1) { + for (int i = 0; i < N; i++) { + __ negr(v[i], T, v1[i]); + } + } + + template + void vs_sshr(const VSeq& v, Assembler::SIMD_Arrangement T, + const VSeq& v1, int shift) { + for (int i = 0; i < N; i++) { + __ sshr(v[i], T, v1[i], shift); + } + } + + template + void vs_andr(const VSeq& v, const VSeq& v1, const VSeq& v2) { + for (int i = 0; i < N; i++) { + __ andr(v[i], __ T16B, v1[i], v2[i]); + } + } + + template + void vs_orr(const VSeq& v, const VSeq& v1, const VSeq& v2) { + for (int i = 0; i < N; i++) { + __ orr(v[i], __ T16B, v1[i], v2[i]); + } + } + + template + void vs_notr(const VSeq& v, const VSeq& v1) { + for (int i = 0; i < N; i++) { + __ notr(v[i], __ T16B, v1[i]); + } + } + + // load N/2 successive pairs of quadword values from memory in order + // into N successive vector registers of the sequence via the + // address supplied in base. + template + void vs_ldpq(const VSeq& v, Register base) { + for (int i = 0; i < N; i += 2) { + __ ldpq(v[i], v[i+1], Address(base, 32 * i)); + } + } + + // load N/2 successive pairs of quadword values from memory in order + // into N vector registers of the sequence via the address supplied + // in base using post-increment addressing + template + void vs_ldpq_post(const VSeq& v, Register base) { + for (int i = 0; i < N; i += 2) { + __ ldpq(v[i], v[i+1], __ post(base, 32)); + } + } + + // store N successive vector registers of the sequence into N/2 + // successive pairs of quadword memory locations via the address + // supplied in base using post-increment addressing + template + void vs_stpq_post(const VSeq& v, Register base) { + for (int i = 0; i < N; i += 2) { + __ stpq(v[i], v[i+1], __ post(base, 32)); + } + } + + // load N/2 pairs of quadword values from memory into N vector + // registers via the address supplied in base with each pair indexed + // using the the start offset plus the corresponding entry in the + // offsets array + template + void vs_ldpq_indexed(const VSeq& v, Register base, int start, int (&offsets)[N/2]) { + for (int i = 0; i < N/2; i++) { + __ ldpq(v[2*i], v[2*i+1], Address(base, start + offsets[i])); + } + } + + // store N vector registers into N/2 pairs of quadword memory + // locations via the address supplied in base with each pair indexed + // using the the start offset plus the corresponding entry in the + // offsets array + template + void vs_stpq_indexed(const VSeq& v, Register base, int start, int offsets[N/2]) { + for (int i = 0; i < N/2; i++) { + __ stpq(v[2*i], v[2*i+1], Address(base, start + offsets[i])); + } + } + + // load N single quadword values from memory into N vector registers + // via the address supplied in base with each value indexed using + // the the start offset plus the corresponding entry in the offsets + // array + template + void vs_ldr_indexed(const VSeq& v, Assembler::SIMD_RegVariant T, Register base, + int start, int (&offsets)[N]) { + for (int i = 0; i < N; i++) { + __ ldr(v[i], T, Address(base, start + offsets[i])); + } + } + + // store N vector registers into N single quadword memory locations + // via the address supplied in base with each value indexed using + // the the start offset plus the corresponding entry in the offsets + // array + template + void vs_str_indexed(const VSeq& v, Assembler::SIMD_RegVariant T, Register base, + int start, int (&offsets)[N]) { + for (int i = 0; i < N; i++) { + __ str(v[i], T, Address(base, start + offsets[i])); + } + } + + // load N/2 pairs of quadword values from memory de-interleaved into + // N vector registers 2 at a time via the address supplied in base + // with each pair indexed using the the start offset plus the + // corresponding entry in the offsets array + template + void vs_ld2_indexed(const VSeq& v, Assembler::SIMD_Arrangement T, Register base, + Register tmp, int start, int (&offsets)[N/2]) { + for (int i = 0; i < N/2; i++) { + __ add(tmp, base, start + offsets[i]); + __ ld2(v[2*i], v[2*i+1], T, tmp); + } + } + + // store N vector registers 2 at a time interleaved into N/2 pairs + // of quadword memory locations via the address supplied in base + // with each pair indexed using the the start offset plus the + // corresponding entry in the offsets array + template + void vs_st2_indexed(const VSeq& v, Assembler::SIMD_Arrangement T, Register base, + Register tmp, int start, int (&offsets)[N/2]) { + for (int i = 0; i < N/2; i++) { + __ add(tmp, base, start + offsets[i]); + __ st2(v[2*i], v[2*i+1], T, tmp); + } + } + + // Helper routines for various flavours of dilithium montgomery + // multiply + + // Perform 16 32-bit Montgomery multiplications in parallel // See the montMul() method of the sun.security.provider.ML_DSA class. - // Here MONT_R_BITS is 32, so the right shift by it is implicit. - // The constants qInv = MONT_Q_INV_MOD_R and q = MONT_Q are loaded in - // (all 32-bit chunks of) vector registers v30 and v31, resp. - // The inputs are b[i]s in v0-v7 and c[i]s v16-v23 and - // the results are a[i]s in v16-v23, four 32-bit values in each register - // and we do a_i = b_i * c_i * 2^-32 mod MONT_Q for all - void dilithium_montmul32(bool by_constant) { - FloatRegister vr0 = by_constant ? v29 : v0; - FloatRegister vr1 = by_constant ? v29 : v1; - FloatRegister vr2 = by_constant ? v29 : v2; - FloatRegister vr3 = by_constant ? v29 : v3; - FloatRegister vr4 = by_constant ? v29 : v4; - FloatRegister vr5 = by_constant ? v29 : v5; - FloatRegister vr6 = by_constant ? v29 : v6; - FloatRegister vr7 = by_constant ? v29 : v7; + // + // Computes 4x4S results + // a = b * c * 2^-32 mod MONT_Q + // Inputs: vb, vc - 4x4S vector register sequences + // vq - 2x4S constants + // Temps: vtmp - 4x4S vector sequence trashed after call + // Outputs: va - 4x4S vector register sequences + // vb, vc, vtmp and vq must all be disjoint + // va must be disjoint from all other inputs/temps or must equal vc + // n.b. MONT_R_BITS is 32, so the right shift by it is implicit. + void dilithium_montmul16(const VSeq<4>& va, const VSeq<4>& vb, const VSeq<4>& vc, + const VSeq<4>& vtmp, const VSeq<2>& vq) { + assert(vs_disjoint(vb, vc), "vb and vc overlap"); + assert(vs_disjoint(vb, vq), "vb and vq overlap"); + assert(vs_disjoint(vb, vtmp), "vb and vtmp overlap"); - __ sqdmulh(v24, __ T4S, vr0, v16); // aHigh = hi32(2 * b * c) - __ mulv(v16, __ T4S, vr0, v16); // aLow = lo32(b * c) - __ sqdmulh(v25, __ T4S, vr1, v17); - __ mulv(v17, __ T4S, vr1, v17); - __ sqdmulh(v26, __ T4S, vr2, v18); - __ mulv(v18, __ T4S, vr2, v18); - __ sqdmulh(v27, __ T4S, vr3, v19); - __ mulv(v19, __ T4S, vr3, v19); + assert(vs_disjoint(vc, vq), "vc and vq overlap"); + assert(vs_disjoint(vc, vtmp), "vc and vtmp overlap"); - __ mulv(v16, __ T4S, v16, v30); // m = aLow * qinv - __ mulv(v17, __ T4S, v17, v30); - __ mulv(v18, __ T4S, v18, v30); - __ mulv(v19, __ T4S, v19, v30); + assert(vs_disjoint(vq, vtmp), "vq and vtmp overlap"); - __ sqdmulh(v16, __ T4S, v16, v31); // n = hi32(2 * m * q) - __ sqdmulh(v17, __ T4S, v17, v31); - __ sqdmulh(v18, __ T4S, v18, v31); - __ sqdmulh(v19, __ T4S, v19, v31); + assert(vs_disjoint(va, vc) || vs_same(va, vc), "va and vc neither disjoint nor equal"); + assert(vs_disjoint(va, vb), "va and vb overlap"); + assert(vs_disjoint(va, vq), "va and vq overlap"); + assert(vs_disjoint(va, vtmp), "va and vtmp overlap"); - __ shsubv(v16, __ T4S, v24, v16); // a = (aHigh - n) / 2 - __ shsubv(v17, __ T4S, v25, v17); - __ shsubv(v18, __ T4S, v26, v18); - __ shsubv(v19, __ T4S, v27, v19); + // schedule 4 streams of instructions across the vector sequences + for (int i = 0; i < 4; i++) { + __ sqdmulh(vtmp[i], __ T4S, vb[i], vc[i]); // aHigh = hi32(2 * b * c) + __ mulv(va[i], __ T4S, vb[i], vc[i]); // aLow = lo32(b * c) + } - __ sqdmulh(v24, __ T4S, vr4, v20); - __ mulv(v20, __ T4S, vr4, v20); - __ sqdmulh(v25, __ T4S, vr5, v21); - __ mulv(v21, __ T4S, vr5, v21); - __ sqdmulh(v26, __ T4S, vr6, v22); - __ mulv(v22, __ T4S, vr6, v22); - __ sqdmulh(v27, __ T4S, vr7, v23); - __ mulv(v23, __ T4S, vr7, v23); + for (int i = 0; i < 4; i++) { + __ mulv(va[i], __ T4S, va[i], vq[0]); // m = aLow * qinv + } - __ mulv(v20, __ T4S, v20, v30); - __ mulv(v21, __ T4S, v21, v30); - __ mulv(v22, __ T4S, v22, v30); - __ mulv(v23, __ T4S, v23, v30); + for (int i = 0; i < 4; i++) { + __ sqdmulh(va[i], __ T4S, va[i], vq[1]); // n = hi32(2 * m * q) + } - __ sqdmulh(v20, __ T4S, v20, v31); - __ sqdmulh(v21, __ T4S, v21, v31); - __ sqdmulh(v22, __ T4S, v22, v31); - __ sqdmulh(v23, __ T4S, v23, v31); - - __ shsubv(v20, __ T4S, v24, v20); - __ shsubv(v21, __ T4S, v25, v21); - __ shsubv(v22, __ T4S, v26, v22); - __ shsubv(v23, __ T4S, v27, v23); + for (int i = 0; i < 4; i++) { + __ shsubv(va[i], __ T4S, vtmp[i], va[i]); // a = (aHigh - n) / 2 + } } - // Do the addition and subtraction done in the ntt algorithm. - // See sun.security.provider.ML_DSA.implDilithiumAlmostNttJava() - void dilithium_add_sub32() { - __ addv(v24, __ T4S, v0, v16); // coeffs[j] = coeffs[j] + tmp; - __ addv(v25, __ T4S, v1, v17); - __ addv(v26, __ T4S, v2, v18); - __ addv(v27, __ T4S, v3, v19); - __ addv(v28, __ T4S, v4, v20); - __ addv(v29, __ T4S, v5, v21); - __ addv(v30, __ T4S, v6, v22); - __ addv(v31, __ T4S, v7, v23); + // Perform 2x16 32-bit Montgomery multiplications in parallel + // See the montMul() method of the sun.security.provider.ML_DSA class. + // + // Computes 8x4S results + // a = b * c * 2^-32 mod MONT_Q + // Inputs: vb, vc - 8x4S vector register sequences + // vq - 2x4S constants + // Temps: vtmp - 4x4S vector sequence trashed after call + // Outputs: va - 8x4S vector register sequences + // vb, vc, vtmp and vq must all be disjoint + // va must be disjoint from all other inputs/temps or must equal vc + // n.b. MONT_R_BITS is 32, so the right shift by it is implicit. + void vs_montmul32(const VSeq<8>& va, const VSeq<8>& vb, const VSeq<8>& vc, + const VSeq<4>& vtmp, const VSeq<2>& vq) { + // vb, vc, vtmp and vq must be disjoint. va must either be + // disjoint from all other registers or equal vc - __ subv(v0, __ T4S, v0, v16); // coeffs[j + l] = coeffs[j] - tmp; - __ subv(v1, __ T4S, v1, v17); - __ subv(v2, __ T4S, v2, v18); - __ subv(v3, __ T4S, v3, v19); - __ subv(v4, __ T4S, v4, v20); - __ subv(v5, __ T4S, v5, v21); - __ subv(v6, __ T4S, v6, v22); - __ subv(v7, __ T4S, v7, v23); + assert(vs_disjoint(vb, vc), "vb and vc overlap"); + assert(vs_disjoint(vb, vq), "vb and vq overlap"); + assert(vs_disjoint(vb, vtmp), "vb and vtmp overlap"); + + assert(vs_disjoint(vc, vq), "vc and vq overlap"); + assert(vs_disjoint(vc, vtmp), "vc and vtmp overlap"); + + assert(vs_disjoint(vq, vtmp), "vq and vtmp overlap"); + + assert(vs_disjoint(va, vc) || vs_same(va, vc), "va and vc neither disjoint nor equal"); + assert(vs_disjoint(va, vb), "va and vb overlap"); + assert(vs_disjoint(va, vq), "va and vq overlap"); + assert(vs_disjoint(va, vtmp), "va and vtmp overlap"); + + // we need to multiply the front and back halves of each sequence + // 4x4S at a time because + // + // 1) we are currently only able to get 4-way instruction + // parallelism at best + // + // 2) we need registers for the constants in vq and temporary + // scratch registers to hold intermediate results so vtmp can only + // be a VSeq<4> which means we only have 4 scratch slots + + dilithium_montmul16(vs_front(va), vs_front(vb), vs_front(vc), vtmp, vq); + dilithium_montmul16(vs_back(va), vs_back(vb), vs_back(vc), vtmp, vq); } - // Do the same computation that - // dilithium_montmul32() and dilithium_add_sub32() does, - // except for only 4x4 32-bit vector elements and with - // different register usage. - void dilithium_montmul_sub_add16() { - __ sqdmulh(v24, __ T4S, v1, v16); - __ mulv(v16, __ T4S, v1, v16); - __ sqdmulh(v25, __ T4S, v3, v17); - __ mulv(v17, __ T4S, v3, v17); - __ sqdmulh(v26, __ T4S, v5, v18); - __ mulv(v18, __ T4S, v5, v18); - __ sqdmulh(v27, __ T4S, v7, v19); - __ mulv(v19, __ T4S, v7, v19); + // perform combined montmul then add/sub on 4x4S vectors - __ mulv(v16, __ T4S, v16, v30); - __ mulv(v17, __ T4S, v17, v30); - __ mulv(v18, __ T4S, v18, v30); - __ mulv(v19, __ T4S, v19, v30); + void dilithium_montmul16_sub_add(const VSeq<4>& va0, const VSeq<4>& va1, const VSeq<4>& vc, + const VSeq<4>& vtmp, const VSeq<2>& vq) { + // compute a = montmul(a1, c) + dilithium_montmul16(vc, va1, vc, vtmp, vq); + // ouptut a1 = a0 - a + vs_subv(va1, __ T4S, va0, vc); + // and a0 = a0 + a + vs_addv(va0, __ T4S, va0, vc); + } - __ sqdmulh(v16, __ T4S, v16, v31); - __ sqdmulh(v17, __ T4S, v17, v31); - __ sqdmulh(v18, __ T4S, v18, v31); - __ sqdmulh(v19, __ T4S, v19, v31); + // perform combined add/sub then montul on 4x4S vectors - __ shsubv(v16, __ T4S, v24, v16); - __ shsubv(v17, __ T4S, v25, v17); - __ shsubv(v18, __ T4S, v26, v18); - __ shsubv(v19, __ T4S, v27, v19); - - __ subv(v1, __ T4S, v0, v16); - __ subv(v3, __ T4S, v2, v17); - __ subv(v5, __ T4S, v4, v18); - __ subv(v7, __ T4S, v6, v19); - - __ addv(v0, __ T4S, v0, v16); - __ addv(v2, __ T4S, v2, v17); - __ addv(v4, __ T4S, v4, v18); - __ addv(v6, __ T4S, v6, v19); + void dilithium_sub_add_montmul16(const VSeq<4>& va0, const VSeq<4>& va1, const VSeq<4>& vb, + const VSeq<4>& vtmp1, const VSeq<4>& vtmp2, const VSeq<2>& vq) { + // compute c = a0 - a1 + vs_subv(vtmp1, __ T4S, va0, va1); + // output a0 = a0 + a1 + vs_addv(va0, __ T4S, va0, va1); + // output a1 = b montmul c + dilithium_montmul16(va1, vtmp1, vb, vtmp2, vq); } // At these levels, the indices that correspond to the 'j's (and 'j+l's) @@ -4798,44 +4944,47 @@ class StubGenerator: public StubCodeGenerator { int c1 = 0; int c2 = 512; int startIncr; - int incr1 = 32; - int incr2 = 64; - int incr3 = 96; + // don't use callee save registers v8 - v15 + VSeq<8> vs1(0), vs2(16), vs3(24); // 3 sets of 8x4s inputs/outputs + VSeq<4> vtmp = vs_front(vs3); // n.b. tmp registers overlap vs3 + VSeq<2> vq(30); // n.b. constants overlap vs3 + int offsets[4] = { 0, 32, 64, 96 }; for (int level = 0; level < 5; level++) { int c1Start = c1; int c2Start = c2; if (level == 3) { - incr1 = 32; - incr2 = 128; - incr3 = 160; + offsets[1] = 32; + offsets[2] = 128; + offsets[3] = 160; } else if (level == 4) { - incr1 = 64; - incr2 = 128; - incr3 = 192; + offsets[1] = 64; + offsets[2] = 128; + offsets[3] = 192; } + // for levels 1 - 4 we simply load 2 x 4 adjacent values at a + // time at 4 different offsets and multiply them in order by the + // next set of input values. So we employ indexed load and store + // pair instructions with arrangement 4S for (int i = 0; i < 4; i++) { - __ ldpq(v30, v31, Address(dilithiumConsts, 0)); // qInv, q - __ ldpq(v0, v1, Address(coeffs, c2Start)); - __ ldpq(v2, v3, Address(coeffs, c2Start + incr1)); - __ ldpq(v4, v5, Address(coeffs, c2Start + incr2)); - __ ldpq(v6, v7, Address(coeffs, c2Start + incr3)); - dilithium_load32zetas(zetas); - dilithium_montmul32(false); - __ ldpq(v0, v1, Address(coeffs, c1Start)); - __ ldpq(v2, v3, Address(coeffs, c1Start + incr1)); - __ ldpq(v4, v5, Address(coeffs, c1Start + incr2)); - __ ldpq(v6, v7, Address(coeffs, c1Start + incr3)); - dilithium_add_sub32(); - __ stpq(v24, v25, Address(coeffs, c1Start)); - __ stpq(v26, v27, Address(coeffs, c1Start + incr1)); - __ stpq(v28, v29, Address(coeffs, c1Start + incr2)); - __ stpq(v30, v31, Address(coeffs, c1Start + incr3)); - __ stpq(v0, v1, Address(coeffs, c2Start)); - __ stpq(v2, v3, Address(coeffs, c2Start + incr1)); - __ stpq(v4, v5, Address(coeffs, c2Start + incr2)); - __ stpq(v6, v7, Address(coeffs, c2Start + incr3)); + // reload q and qinv + vs_ldpq(vq, dilithiumConsts); // qInv, q + // load 8x4S coefficients via second start pos == c2 + vs_ldpq_indexed(vs1, coeffs, c2Start, offsets); + // load next 8x4S inputs == b + vs_ldpq_post(vs2, zetas); + // compute a == c2 * b mod MONT_Q + vs_montmul32(vs2, vs1, vs2, vtmp, vq); + // load 8x4s coefficients via first start pos == c1 + vs_ldpq_indexed(vs1, coeffs, c1Start, offsets); + // compute a1 = c1 + a + vs_addv(vs3, __ T4S, vs1, vs2); + // compute a2 = c1 - a + vs_subv(vs1, __ T4S, vs1, vs2); + // output a1 and a2 + vs_stpq_indexed(vs3, coeffs, c1Start, offsets); + vs_stpq_indexed(vs1, coeffs, c2Start, offsets); int k = 4 * level + i; @@ -4876,7 +5025,13 @@ class StubGenerator: public StubCodeGenerator { const Register tmpAddr = r9; const Register dilithiumConsts = r10; const Register result = r11; - + // don't use callee save registers v8 - v15 + VSeq<8> vs1(0), vs2(16), vs3(24); // 3 sets of 8x4s inputs/outputs + VSeq<4> vtmp = vs_front(vs3); // n.b. tmp registers overlap vs3 + VSeq<2> vq(30); // n.b. constants overlap vs3 + int offsets[4] = {0, 32, 64, 96}; + int offsets1[8] = {16, 48, 80, 112, 144, 176, 208, 240 }; + int offsets2[8] = { 0, 32, 64, 96, 128, 160, 192, 224 }; __ add(result, coeffs, 0); __ lea(dilithiumConsts, ExternalAddress((address) StubRoutines::aarch64::_dilithiumConsts)); @@ -4886,200 +5041,156 @@ class StubGenerator: public StubCodeGenerator { dilithiumNttLevel0_4(dilithiumConsts, coeffs, zetas); // level 5 + + // at level 5 the coefficients we need to combine with the zetas + // are grouped in memory in blocks of size 4. So, for both sets of + // coefficients we load 4 adjacent values at 8 different offsets + // using an indexed ldr with register variant Q and multiply them + // in sequence order by the next set of inputs. Likewise we store + // the resuls using an indexed str with register variant Q. for (int i = 0; i < 1024; i += 256) { - __ ldpq(v30, v31, Address(dilithiumConsts, 0)); // qInv, q - __ ldr(v0, __ Q, Address(coeffs, i + 16)); - __ ldr(v1, __ Q, Address(coeffs, i + 48)); - __ ldr(v2, __ Q, Address(coeffs, i + 80)); - __ ldr(v3, __ Q, Address(coeffs, i + 112)); - __ ldr(v4, __ Q, Address(coeffs, i + 144)); - __ ldr(v5, __ Q, Address(coeffs, i + 176)); - __ ldr(v6, __ Q, Address(coeffs, i + 208)); - __ ldr(v7, __ Q, Address(coeffs, i + 240)); - dilithium_load32zetas(zetas); - dilithium_montmul32(false); - __ ldr(v0, __ Q, Address(coeffs, i)); - __ ldr(v1, __ Q, Address(coeffs, i + 32)); - __ ldr(v2, __ Q, Address(coeffs, i + 64)); - __ ldr(v3, __ Q, Address(coeffs, i + 96)); - __ ldr(v4, __ Q, Address(coeffs, i + 128)); - __ ldr(v5, __ Q, Address(coeffs, i + 160)); - __ ldr(v6, __ Q, Address(coeffs, i + 192)); - __ ldr(v7, __ Q, Address(coeffs, i + 224)); - dilithium_add_sub32(); - __ str(v24, __ Q, Address(coeffs, i)); - __ str(v25, __ Q, Address(coeffs, i + 32)); - __ str(v26, __ Q, Address(coeffs, i + 64)); - __ str(v27, __ Q, Address(coeffs, i + 96)); - __ str(v28, __ Q, Address(coeffs, i + 128)); - __ str(v29, __ Q, Address(coeffs, i + 160)); - __ str(v30, __ Q, Address(coeffs, i + 192)); - __ str(v31, __ Q, Address(coeffs, i + 224)); - __ str(v0, __ Q, Address(coeffs, i + 16)); - __ str(v1, __ Q, Address(coeffs, i + 48)); - __ str(v2, __ Q, Address(coeffs, i + 80)); - __ str(v3, __ Q, Address(coeffs, i + 112)); - __ str(v4, __ Q, Address(coeffs, i + 144)); - __ str(v5, __ Q, Address(coeffs, i + 176)); - __ str(v6, __ Q, Address(coeffs, i + 208)); - __ str(v7, __ Q, Address(coeffs, i + 240)); + // reload constants q, qinv each iteration as they get clobbered later + vs_ldpq(vq, dilithiumConsts); // qInv, q + // load 32 (8x4S) coefficients via first offsets = c1 + vs_ldr_indexed(vs1, __ Q, coeffs, i, offsets1); + // load next 32 (8x4S) inputs = b + vs_ldpq_post(vs2, zetas); + // a = b montul c1 + vs_montmul32(vs2, vs1, vs2, vtmp, vq); + // load 32 (8x4S) coefficients via second offsets = c2 + vs_ldr_indexed(vs1, __ Q, coeffs, i, offsets2); + // add/sub with result of multiply + vs_addv(vs3, __ T4S, vs1, vs2); // a1 = a - c2 + vs_subv(vs1, __ T4S, vs1, vs2); // a0 = a + c1 + // write back new coefficients using same offsets + vs_str_indexed(vs3, __ Q, coeffs, i, offsets2); + vs_str_indexed(vs1, __ Q, coeffs, i, offsets1); } // level 6 + // at level 6 the coefficients we need to combine with the zetas + // are grouped in memory in pairs, the first two being montmul + // inputs and the second add/sub inputs. We can still implement + // the montmul+sub+add using 4-way parallelism but only if we + // combine the coefficients with the zetas 16 at a time. We load 8 + // adjacent values at 4 different offsets using an ld2 load with + // arrangement 2D. That interleaves the lower and upper halves of + // each pair of quadwords into successive vector registers. We + // then need to montmul the 4 even elements of the coefficients + // register sequence by the zetas in order and then add/sub the 4 + // odd elements of the coefficients register sequence. We use an + // equivalent st2 operation to store the results back into memory + // de-interleaved. for (int i = 0; i < 1024; i += 128) { - __ ldpq(v30, v31, Address(dilithiumConsts, 0)); // qInv, q - __ add(tmpAddr, coeffs, i); - __ ld2(v0, v1, __ T2D, tmpAddr); - __ add(tmpAddr, coeffs, i + 32); - __ ld2(v2, v3, __ T2D, tmpAddr); - __ add(tmpAddr, coeffs, i + 64); - __ ld2(v4, v5, __ T2D, tmpAddr); - __ add(tmpAddr, coeffs, i + 96); - __ ld2(v6, v7, __ T2D, tmpAddr); - dilithium_load16zetas(16, zetas); - dilithium_montmul_sub_add16(); - __ add(tmpAddr, coeffs, i); - __ st2(v0, v1, __ T2D, tmpAddr); - __ add(tmpAddr, coeffs, i + 32); - __ st2(v2, v3, __ T2D, tmpAddr); - __ add(tmpAddr, coeffs, i + 64); - __ st2(v4, v5, __ T2D, tmpAddr); - __ add(tmpAddr, coeffs, i + 96); - __ st2(v6, v7, __ T2D, tmpAddr); + // reload constants q, qinv each iteration as they get clobbered later + vs_ldpq(vq, dilithiumConsts); // qInv, q + // load interleaved 16 (4x2D) coefficients via offsets + vs_ld2_indexed(vs1, __ T2D, coeffs, tmpAddr, i, offsets); + // load next 16 (4x4S) inputs + vs_ldpq_post(vs_front(vs2), zetas); + // mont multiply odd elements of vs1 by vs2 and add/sub into odds/evens + dilithium_montmul16_sub_add(vs_even(vs1), vs_odd(vs1), + vs_front(vs2), vtmp, vq); + // store interleaved 16 (4x2D) coefficients via offsets + vs_st2_indexed(vs1, __ T2D, coeffs, tmpAddr, i, offsets); } // level 7 + // at level 7 the coefficients we need to combine with the zetas + // occur singly with montmul inputs alterating with add/sub + // inputs. Once again we can use 4-way parallelism to combine 16 + // zetas at a time. However, we have to load 8 adjacent values at + // 4 different offsets using an ld2 load with arrangement 4S. That + // interleaves the the odd words of each pair into one + // coefficients vector register and the even words of the pair + // into the next register. We then need to montmul the 4 even + // elements of the coefficients register sequence by the zetas in + // order and then add/sub the 4 odd elements of the coefficients + // register sequence. We use an equivalent st2 operation to store + // the results back into memory de-interleaved. + for (int i = 0; i < 1024; i += 128) { - __ ldpq(v30, v31, Address(dilithiumConsts, 0)); // qInv, q - __ add(tmpAddr, coeffs, i); - __ ld2(v0, v1, __ T4S, tmpAddr); - __ add(tmpAddr, coeffs, i + 32); - __ ld2(v2, v3, __ T4S, tmpAddr); - __ add(tmpAddr, coeffs, i + 64); - __ ld2(v4, v5, __ T4S, tmpAddr); - __ add(tmpAddr, coeffs, i + 96); - __ ld2(v6, v7, __ T4S, tmpAddr); - dilithium_load16zetas(16, zetas); - dilithium_montmul_sub_add16(); - __ add(tmpAddr, coeffs, i); - __ st2(v0, v1, __ T4S, tmpAddr); - __ add(tmpAddr, coeffs, i + 32); - __ st2(v2, v3, __ T4S, tmpAddr); - __ add(tmpAddr, coeffs, i + 64); - __ st2(v4, v5, __ T4S, tmpAddr); - __ add(tmpAddr, coeffs, i + 96); - __ st2(v6, v7, __ T4S, tmpAddr); + // reload constants q, qinv each iteration as they get clobbered later + vs_ldpq(vq, dilithiumConsts); // qInv, q + // load interleaved 16 (4x4S) coefficients via offsets + vs_ld2_indexed(vs1, __ T4S, coeffs, tmpAddr, i, offsets); + // load next 16 (4x4S) inputs + vs_ldpq_post(vs_front(vs2), zetas); + // mont multiply odd elements of vs1 by vs2 and add/sub into odds/evens + dilithium_montmul16_sub_add(vs_even(vs1), vs_odd(vs1), + vs_front(vs2), vtmp, vq); + // store interleaved 16 (4x4S) coefficients via offsets + vs_st2_indexed(vs1, __ T4S, coeffs, tmpAddr, i, offsets); } __ leave(); // required for proper stackwalking of RuntimeStub frame __ mov(r0, zr); // return 0 __ ret(lr); return start; - - } - - // Do the computations that can be found in the body of the loop in - // sun.security.provider.ML_DSA.implDilithiumAlmostInverseNttJava() - // for 16 coefficients in parallel: - // tmp = coeffs[j]; - // coeffs[j] = (tmp + coeffs[j + l]); - // coeffs[j + l] = montMul(tmp - coeffs[j + l], -MONT_ZETAS_FOR_NTT[m]); - // coefss[j]s are loaded in v0, v2, v4 and v6, - // coeffs[j + l]s in v1, v3, v5 and v7, - // the corresponding zetas in v16, v17, v18 and v19. - void dilithium_sub_add_montmul16() { - __ subv(v20, __ T4S, v0, v1); - __ subv(v21, __ T4S, v2, v3); - __ subv(v22, __ T4S, v4, v5); - __ subv(v23, __ T4S, v6, v7); - - __ addv(v0, __ T4S, v0, v1); - __ addv(v2, __ T4S, v2, v3); - __ addv(v4, __ T4S, v4, v5); - __ addv(v6, __ T4S, v6, v7); - - __ sqdmulh(v24, __ T4S, v20, v16); // aHigh = hi32(2 * b * c) - __ mulv(v1, __ T4S, v20, v16); // aLow = lo32(b * c) - __ sqdmulh(v25, __ T4S, v21, v17); - __ mulv(v3, __ T4S, v21, v17); - __ sqdmulh(v26, __ T4S, v22, v18); - __ mulv(v5, __ T4S, v22, v18); - __ sqdmulh(v27, __ T4S, v23, v19); - __ mulv(v7, __ T4S, v23, v19); - - __ mulv(v1, __ T4S, v1, v30); // m = (aLow * q) - __ mulv(v3, __ T4S, v3, v30); - __ mulv(v5, __ T4S, v5, v30); - __ mulv(v7, __ T4S, v7, v30); - - __ sqdmulh(v1, __ T4S, v1, v31); // n = hi32(2 * m * q) - __ sqdmulh(v3, __ T4S, v3, v31); - __ sqdmulh(v5, __ T4S, v5, v31); - __ sqdmulh(v7, __ T4S, v7, v31); - - __ shsubv(v1, __ T4S, v24, v1); // a = (aHigh - n) / 2 - __ shsubv(v3, __ T4S, v25, v3); - __ shsubv(v5, __ T4S, v26, v5); - __ shsubv(v7, __ T4S, v27, v7); } // At these levels, the indices that correspond to the 'j's (and 'j+l's) // in the Java implementation come in sequences of at least 8, so we // can use ldpq to collect the corresponding data into pairs of vector // registers - // We collect the coefficients that correspond to the 'j's into v0-v7 - // the coefficiets that correspond to the 'j+l's into v16-v23 then - // do the additions into v24-v31 and the subtractions into v0-v7 then - // save the result of the additions, load the zetas into v16-v23 - // do the (Montgomery) multiplications by zeta in parallel into v16-v23 + // We collect the coefficients that correspond to the 'j's into vs1 + // the coefficiets that correspond to the 'j+l's into vs2 then + // do the additions into vs3 and the subtractions into vs1 then + // save the result of the additions, load the zetas into vs2 + // do the (Montgomery) multiplications by zeta in parallel into vs2 // finally save the results back to the coeffs array void dilithiumInverseNttLevel3_7(const Register dilithiumConsts, const Register coeffs, const Register zetas) { int c1 = 0; int c2 = 32; int startIncr; - int incr1; - int incr2; - int incr3; + int offsets[4]; + VSeq<8> vs1(0), vs2(16), vs3(24); // 3 sets of 8x4s inputs/outputs + VSeq<4> vtmp = vs_front(vs3); // n.b. tmp registers overlap vs3 + VSeq<2> vq(30); // n.b. constants overlap vs3 + + offsets[0] = 0; for (int level = 3; level < 8; level++) { int c1Start = c1; int c2Start = c2; if (level == 3) { - incr1 = 64; - incr2 = 128; - incr3 = 192; + offsets[1] = 64; + offsets[2] = 128; + offsets[3] = 192; } else if (level == 4) { - incr1 = 32; - incr2 = 128; - incr3 = 160; + offsets[1] = 32; + offsets[2] = 128; + offsets[3] = 160; } else { - incr1 = 32; - incr2 = 64; - incr3 = 96; + offsets[1] = 32; + offsets[2] = 64; + offsets[3] = 96; } + // for levels 3 - 7 we simply load 2 x 4 adjacent values at a + // time at 4 different offsets and multiply them in order by the + // next set of input values. So we employ indexed load and store + // pair instructions with arrangement 4S for (int i = 0; i < 4; i++) { - __ ldpq(v0, v1, Address(coeffs, c1Start)); - __ ldpq(v2, v3, Address(coeffs, c1Start + incr1)); - __ ldpq(v4, v5, Address(coeffs, c1Start + incr2)); - __ ldpq(v6, v7, Address(coeffs, c1Start + incr3)); - __ ldpq(v16, v17, Address(coeffs, c2Start)); - __ ldpq(v18, v19, Address(coeffs, c2Start + incr1)); - __ ldpq(v20, v21, Address(coeffs, c2Start + incr2)); - __ ldpq(v22, v23, Address(coeffs, c2Start + incr3)); - dilithium_add_sub32(); - __ stpq(v24, v25, Address(coeffs, c1Start)); - __ stpq(v26, v27, Address(coeffs, c1Start + incr1)); - __ stpq(v28, v29, Address(coeffs, c1Start + incr2)); - __ stpq(v30, v31, Address(coeffs, c1Start + incr3)); - __ ldpq(v30, v31, Address(dilithiumConsts, 0)); // qInv, q - dilithium_load32zetas(zetas); - dilithium_montmul32(false); - __ stpq(v16, v17, Address(coeffs, c2Start)); - __ stpq(v18, v19, Address(coeffs, c2Start + incr1)); - __ stpq(v20, v21, Address(coeffs, c2Start + incr2)); - __ stpq(v22, v23, Address(coeffs, c2Start + incr3)); + // load v1 32 (8x4S) coefficients relative to first start index + vs_ldpq_indexed(vs1, coeffs, c1Start, offsets); + // load v2 32 (8x4S) coefficients relative to second start index + vs_ldpq_indexed(vs2, coeffs, c2Start, offsets); + // a0 = v1 + v2 -- n.b. clobbers vqs + vs_addv(vs3, __ T4S, vs1, vs2); + // a1 = v1 - v2 + vs_subv(vs1, __ T4S, vs1, vs2); + // save a1 relative to first start index + vs_stpq_indexed(vs3, coeffs, c1Start, offsets); + // load constants q, qinv each iteration as they get clobbered above + vs_ldpq(vq, dilithiumConsts); // qInv, q + // load b next 32 (8x4S) inputs + vs_ldpq_post(vs2, zetas); + // a = a1 montmul b + vs_montmul32(vs2, vs1, vs2, vtmp, vq); + // save a relative to second start index + vs_stpq_indexed(vs2, coeffs, c2Start, offsets); int k = 4 * level + i; @@ -5120,94 +5231,84 @@ class StubGenerator: public StubCodeGenerator { const Register tmpAddr = r9; const Register dilithiumConsts = r10; const Register result = r11; + VSeq<8> vs1(0), vs2(16), vs3(24); // 3 sets of 8x4s inputs/outputs + VSeq<4> vtmp = vs_front(vs3); // n.b. tmp registers overlap vs3 + VSeq<2> vq(30); // n.b. constants overlap vs3 + int offsets[4] = { 0, 32, 64, 96 }; + int offsets1[8] = { 0, 32, 64, 96, 128, 160, 192, 224 }; + int offsets2[8] = { 16, 48, 80, 112, 144, 176, 208, 240 }; __ add(result, coeffs, 0); __ lea(dilithiumConsts, ExternalAddress((address) StubRoutines::aarch64::_dilithiumConsts)); // Each level represents one iteration of the outer for loop of the Java version // level0 + + // level 0 + // At level 0 we need to interleave adjacent quartets of + // coefficients before we multiply and add/sub by the next 16 + // zetas just as we did for level 7 in the multiply code. So we + // load and store the values using an ld2/st2 with arrangement 4S for (int i = 0; i < 1024; i += 128) { - __ ldpq(v30, v31, Address(dilithiumConsts, 0)); // qInv, q - __ add(tmpAddr, coeffs, i); - __ ld2(v0, v1, __ T4S, tmpAddr); - __ add(tmpAddr, coeffs, i + 32); - __ ld2(v2, v3, __ T4S, tmpAddr); - __ add(tmpAddr, coeffs, i + 64); - __ ld2(v4, v5, __ T4S, tmpAddr); - __ add(tmpAddr, coeffs, i + 96); - __ ld2(v6, v7, __ T4S, tmpAddr); - dilithium_load16zetas(16, zetas); - dilithium_sub_add_montmul16(); - __ add(tmpAddr, coeffs, i); - __ st2(v0, v1, __ T4S, tmpAddr); - __ add(tmpAddr, coeffs, i + 32); - __ st2(v2, v3, __ T4S, tmpAddr); - __ add(tmpAddr, coeffs, i + 64); - __ st2(v4, v5, __ T4S, tmpAddr); - __ add(tmpAddr, coeffs, i + 96); - __ st2(v6, v7, __ T4S, tmpAddr); + // load constants q, qinv + // n.b. this can be moved out of the loop as they do not get + // clobbered by first two loops + vs_ldpq(vq, dilithiumConsts); // qInv, q + // a0/a1 load interleaved 32 (8x4S) coefficients + vs_ld2_indexed(vs1, __ T4S, coeffs, tmpAddr, i, offsets); + // b load next 32 (8x4S) inputs + vs_ldpq_post(vs_front(vs2), zetas); + // compute in parallel (a0, a1) = (a0 + a1, (a0 - a1) montmul b) + // n.b. second half of vs2 provides temporary register storage + dilithium_sub_add_montmul16(vs_even(vs1), vs_odd(vs1), + vs_front(vs2), vs_back(vs2), vtmp, vq); + // a0/a1 store interleaved 32 (8x4S) coefficients + vs_st2_indexed(vs1, __ T4S, coeffs, tmpAddr, i, offsets); } // level 1 + // At level 1 we need to interleave pairs of adjacent pairs of + // coefficients before we multiply by the next 16 zetas just as we + // did for level 6 in the multiply code. So we load and store the + // values an ld2/st2 with arrangement 2D for (int i = 0; i < 1024; i += 128) { - __ add(tmpAddr, coeffs, i); - __ ld2(v0, v1, __ T2D, tmpAddr); - __ add(tmpAddr, coeffs, i + 32); - __ ld2(v2, v3, __ T2D, tmpAddr); - __ add(tmpAddr, coeffs, i + 64); - __ ld2(v4, v5, __ T2D, tmpAddr); - __ add(tmpAddr, coeffs, i + 96); - __ ld2(v6, v7, __ T2D, tmpAddr); - dilithium_load16zetas(16, zetas); - dilithium_sub_add_montmul16(); - __ add(tmpAddr, coeffs, i); - __ st2(v0, v1, __ T2D, tmpAddr); - __ add(tmpAddr, coeffs, i + 32); - __ st2(v2, v3, __ T2D, tmpAddr); - __ add(tmpAddr, coeffs, i + 64); - __ st2(v4, v5, __ T2D, tmpAddr); - __ add(tmpAddr, coeffs, i + 96); - __ st2(v6, v7, __ T2D, tmpAddr); + // a0/a1 load interleaved 32 (8x2D) coefficients + vs_ld2_indexed(vs1, __ T2D, coeffs, tmpAddr, i, offsets); + // b load next 16 (4x4S) inputs + vs_ldpq_post(vs_front(vs2), zetas); + // compute in parallel (a0, a1) = (a0 + a1, (a0 - a1) montmul b) + // n.b. second half of vs2 provides temporary register storage + dilithium_sub_add_montmul16(vs_even(vs1), vs_odd(vs1), + vs_front(vs2), vs_back(vs2), vtmp, vq); + // a0/a1 store interleaved 32 (8x2D) coefficients + vs_st2_indexed(vs1, __ T2D, coeffs, tmpAddr, i, offsets); } - //level 2 + // level 2 + // At level 2 coefficients come in blocks of 4. So, we load 4 + // adjacent coefficients at 8 distinct offsets for both the first + // and second coefficient sequences, using an ldr with register + // variant Q then combine them with next set of 32 zetas. Likewise + // we store the results using an str with register variant Q. for (int i = 0; i < 1024; i += 256) { - __ ldr(v0, __ Q, Address(coeffs, i)); - __ ldr(v1, __ Q, Address(coeffs, i + 32)); - __ ldr(v2, __ Q, Address(coeffs, i + 64)); - __ ldr(v3, __ Q, Address(coeffs, i + 96)); - __ ldr(v4, __ Q, Address(coeffs, i + 128)); - __ ldr(v5, __ Q, Address(coeffs, i + 160)); - __ ldr(v6, __ Q, Address(coeffs, i + 192)); - __ ldr(v7, __ Q, Address(coeffs, i + 224)); - __ ldr(v16, __ Q, Address(coeffs, i + 16)); - __ ldr(v17, __ Q, Address(coeffs, i + 48)); - __ ldr(v18, __ Q, Address(coeffs, i + 80)); - __ ldr(v19, __ Q, Address(coeffs, i + 112)); - __ ldr(v20, __ Q, Address(coeffs, i + 144)); - __ ldr(v21, __ Q, Address(coeffs, i + 176)); - __ ldr(v22, __ Q, Address(coeffs, i + 208)); - __ ldr(v23, __ Q, Address(coeffs, i + 240)); - dilithium_add_sub32(); - __ str(v24, __ Q, Address(coeffs, i)); - __ str(v25, __ Q, Address(coeffs, i + 32)); - __ str(v26, __ Q, Address(coeffs, i + 64)); - __ str(v27, __ Q, Address(coeffs, i + 96)); - __ str(v28, __ Q, Address(coeffs, i + 128)); - __ str(v29, __ Q, Address(coeffs, i + 160)); - __ str(v30, __ Q, Address(coeffs, i + 192)); - __ str(v31, __ Q, Address(coeffs, i + 224)); - dilithium_load32zetas(zetas); - __ ldpq(v30, v31, Address(dilithiumConsts, 0)); // qInv, q - dilithium_montmul32(false); - __ str(v16, __ Q, Address(coeffs, i + 16)); - __ str(v17, __ Q, Address(coeffs, i + 48)); - __ str(v18, __ Q, Address(coeffs, i + 80)); - __ str(v19, __ Q, Address(coeffs, i + 112)); - __ str(v20, __ Q, Address(coeffs, i + 144)); - __ str(v21, __ Q, Address(coeffs, i + 176)); - __ str(v22, __ Q, Address(coeffs, i + 208)); - __ str(v23, __ Q, Address(coeffs, i + 240)); + // c0 load 32 (8x4S) coefficients via first offsets + vs_ldr_indexed(vs1, __ Q, coeffs, i, offsets1); + // c1 load 32 (8x4S) coefficients via second offsets + vs_ldr_indexed(vs2, __ Q,coeffs, i, offsets2); + // a0 = c0 + c1 n.b. clobbers vq which overlaps vs3 + vs_addv(vs3, __ T4S, vs1, vs2); + // c = c0 - c1 + vs_subv(vs1, __ T4S, vs1, vs2); + // store a0 32 (8x4S) coefficients via first offsets + vs_str_indexed(vs3, __ Q, coeffs, i, offsets1); + // b load 32 (8x4S) next inputs + vs_ldpq_post(vs2, zetas); + // reload constants q, qinv -- they were clobbered earlier + vs_ldpq(vq, dilithiumConsts); // qInv, q + // compute a1 = b montmul c + vs_montmul32(vs2, vs1, vs2, vtmp, vq); + // store a1 32 (8x4S) coefficients via second offsets + vs_str_indexed(vs2, __ Q, coeffs, i, offsets2); } // level 3-7 @@ -5232,7 +5333,7 @@ class StubGenerator: public StubCodeGenerator { // poly2 (int[256]) = c_rarg2 address generate_dilithiumNttMult() { - __ align(CodeEntryAlignment); + __ align(CodeEntryAlignment); StubGenStubId stub_id = StubGenStubId::dilithiumNttMult_id; StubCodeMark mark(this, stub_id); address start = __ pc(); @@ -5247,9 +5348,16 @@ class StubGenerator: public StubCodeGenerator { const Register dilithiumConsts = r10; const Register len = r11; + VSeq<8> vs1(0), vs2(16), vs3(24); // 3 sets of 8x4s inputs/outputs + VSeq<4> vtmp = vs_front(vs3); // n.b. tmp registers overlap vs3 + VSeq<2> vq(30); // n.b. constants overlap vs3 + VSeq<8> vrsquare(29, 0); // for montmul by constant RSQUARE + __ lea(dilithiumConsts, ExternalAddress((address) StubRoutines::aarch64::_dilithiumConsts)); - __ ldpq(v30, v31, Address(dilithiumConsts, 0)); // qInv, q + // load constants q, qinv + vs_ldpq(vq, dilithiumConsts); // qInv, q + // load constant rSquare into v29 __ ldr(v29, __ Q, Address(dilithiumConsts, 48)); // rSquare __ mov(len, zr); @@ -5257,20 +5365,16 @@ class StubGenerator: public StubCodeGenerator { __ BIND(L_loop); - __ ldpq(v0, v1, __ post(poly1, 32)); - __ ldpq(v2, v3, __ post(poly1, 32)); - __ ldpq(v4, v5, __ post(poly1, 32)); - __ ldpq(v6, v7, __ post(poly1, 32)); - __ ldpq(v16, v17, __ post(poly2, 32)); - __ ldpq(v18, v19, __ post(poly2, 32)); - __ ldpq(v20, v21, __ post(poly2, 32)); - __ ldpq(v22, v23, __ post(poly2, 32)); - dilithium_montmul32(false); - dilithium_montmul32(true); - __ stpq(v16, v17, __ post(result, 32)); - __ stpq(v18, v19, __ post(result, 32)); - __ stpq(v20, v21, __ post(result, 32)); - __ stpq(v22, v23, __ post(result, 32)); + // b load 32 (8x4S) next inputs from poly1 + vs_ldpq_post(vs1, poly1); + // c load 32 (8x4S) next inputs from poly2 + vs_ldpq_post(vs2, poly2); + // compute a = b montmul c + vs_montmul32(vs2, vs1, vs2, vtmp, vq); + // compute a = rsquare montmul a + vs_montmul32(vs2, vrsquare, vs2, vtmp, vq); + // save a 32 (8x4S) results + vs_stpq_post(vs2, result); __ sub(len, len, 128); __ cmp(len, (u1)128); @@ -5308,25 +5412,30 @@ class StubGenerator: public StubCodeGenerator { const Register result = r11; const Register len = r12; + VSeq<8> vs1(0), vs2(16), vs3(24); // 3 sets of 8x4s inputs/outputs + VSeq<4> vtmp = vs_front(vs3); // n.b. tmp registers overlap vs3 + VSeq<2> vq(30); // n.b. constants overlap vs3 + VSeq<8> vconst(29, 0); // for montmul by constant + + // results track inputs __ add(result, coeffs, 0); __ lea(dilithiumConsts, ExternalAddress((address) StubRoutines::aarch64::_dilithiumConsts)); - __ ldpq(v30, v31, Address(dilithiumConsts, 0)); // qInv, q - __ dup(v29, __ T4S, constant); + // load constants q, qinv -- they do not get clobbered by first two loops + vs_ldpq(vq, dilithiumConsts); // qInv, q + // copy caller supplied constant across vconst + __ dup(vconst[0], __ T4S, constant); __ mov(len, zr); __ add(len, len, 1024); __ BIND(L_loop); - __ ldpq(v16, v17, __ post(coeffs, 32)); - __ ldpq(v18, v19, __ post(coeffs, 32)); - __ ldpq(v20, v21, __ post(coeffs, 32)); - __ ldpq(v22, v23, __ post(coeffs, 32)); - dilithium_montmul32(true); - __ stpq(v16, v17, __ post(result, 32)); - __ stpq(v18, v19, __ post(result, 32)); - __ stpq(v20, v21, __ post(result, 32)); - __ stpq(v22, v23, __ post(result, 32)); + // load next 32 inputs + vs_ldpq_post(vs2, coeffs); + // mont mul by constant + vs_montmul32(vs2, vconst, vs2, vtmp, vq); + // write next 32 results + vs_stpq_post(vs2, result); __ sub(len, len, 128); __ cmp(len, (u1)128); @@ -5337,6 +5446,7 @@ class StubGenerator: public StubCodeGenerator { __ ret(lr); return start; + } // Dilithium decompose poly. @@ -5355,8 +5465,6 @@ class StubGenerator: public StubCodeGenerator { StubGenStubId stub_id = StubGenStubId::dilithiumDecomposePoly_id; StubCodeMark mark(this, stub_id); address start = __ pc(); - __ enter(); - Label L_loop; const Register input = c_rarg0; @@ -5369,6 +5477,18 @@ class StubGenerator: public StubCodeGenerator { const Register dilithiumConsts = r10; const Register tmp = r11; + VSeq<4> vs1(0), vs2(4), vs3(8); // 6 independent sets of 4x4s values + VSeq<4> vs4(12), vs5(16), vtmp(20); + VSeq<4> one(25, 0); // 7 constants for cross-multiplying + VSeq<4> qminus1(26, 0); + VSeq<4> g2(27, 0); + VSeq<4> twog2(28, 0); + VSeq<4> mult(29, 0); + VSeq<4> q(30, 0); + VSeq<4> qadd(31, 0); + + __ enter(); + __ lea(dilithiumConsts, ExternalAddress((address) StubRoutines::aarch64::_dilithiumConsts)); // save callee-saved registers @@ -5377,208 +5497,94 @@ class StubGenerator: public StubCodeGenerator { __ stpd(v12, v13, Address(sp, 32)); __ stpd(v14, v15, Address(sp, 48)); - + // populate constant registers __ mov(tmp, zr); __ add(tmp, tmp, 1); - __ dup(v25, __ T4S, tmp); // 1 - __ ldr(v30, __ Q, Address(dilithiumConsts, 16)); // q - __ ldr(v31, __ Q, Address(dilithiumConsts, 64)); // addend for mod q reduce - __ dup(v28, __ T4S, twoGamma2); // 2 * gamma2 - __ dup(v29, __ T4S, multiplier); // multiplier for mod 2 * gamma reduce - __ subv(v26, __ T4S, v30, v25); // q - 1 - __ sshr(v27, __ T4S, v28, 1); // gamma2 + __ dup(one[0], __ T4S, tmp); // 1 + __ ldr(q[0], __ Q, Address(dilithiumConsts, 16)); // q + __ ldr(qadd[0], __ Q, Address(dilithiumConsts, 64)); // addend for mod q reduce + __ dup(twog2[0], __ T4S, twoGamma2); // 2 * gamma2 + __ dup(mult[0], __ T4S, multiplier); // multiplier for mod 2 * gamma reduce + __ subv(qminus1[0], __ T4S, v30, v25); // q - 1 + __ sshr(g2[0], __ T4S, v28, 1); // gamma2 __ mov(len, zr); __ add(len, len, 1024); __ BIND(L_loop); - __ ld4(v0, v1, v2, v3, __ T4S, __ post(input, 64)); + // load next 4x4S inputs interleaved: rplus --> vs1 + __ ld4(vs1[0], vs1[1], vs1[2], vs1[3], __ T4S, __ post(input, 64)); - // rplus in v0 - // rplus = rplus - ((rplus + 5373807) >> 23) * dilithium_q; - __ addv(v4, __ T4S, v0, v31); - __ addv(v5, __ T4S, v1, v31); - __ addv(v6, __ T4S, v2, v31); - __ addv(v7, __ T4S, v3, v31); + // rplus = rplus - ((rplus + qadd) >> 23) * q + vs_addv(vtmp, __ T4S, vs1, qadd); + vs_sshr(vtmp, __ T4S, vtmp, 23); + vs_mulv(vtmp, __ T4S, vtmp, q); + vs_subv(vs1, __ T4S, vs1, vtmp); - __ sshr(v4, __ T4S, v4, 23); - __ sshr(v5, __ T4S, v5, 23); - __ sshr(v6, __ T4S, v6, 23); - __ sshr(v7, __ T4S, v7, 23); - - __ mulv(v4, __ T4S, v4, v30); - __ mulv(v5, __ T4S, v5, v30); - __ mulv(v6, __ T4S, v6, v30); - __ mulv(v7, __ T4S, v7, v30); - - __ subv(v0, __ T4S, v0, v4); - __ subv(v1, __ T4S, v1, v5); - __ subv(v2, __ T4S, v2, v6); - __ subv(v3, __ T4S, v3, v7); - - // rplus in v0 // rplus = rplus + ((rplus >> 31) & dilithium_q); - __ sshr(v4, __ T4S, v0, 31); - __ sshr(v5, __ T4S, v1, 31); - __ sshr(v6, __ T4S, v2, 31); - __ sshr(v7, __ T4S, v3, 31); + vs_sshr(vtmp, __ T4S, vs1, 31); + vs_andr(vtmp, vtmp, q); + vs_addv(vs1, __ T4S, vs1, vtmp); - __ andr(v4, __ T16B, v4, v30); - __ andr(v5, __ T16B, v5, v30); - __ andr(v6, __ T16B, v6, v30); - __ andr(v7, __ T16B, v7, v30); - - __ addv(v0, __ T4S, v0, v4); - __ addv(v1, __ T4S, v1, v5); - __ addv(v2, __ T4S, v2, v6); - __ addv(v3, __ T4S, v3, v7); - - // rplus in v0 + // quotient --> vs2 // int quotient = (rplus * multiplier) >> 22; - __ mulv(v4, __ T4S, v0, v29); - __ mulv(v5, __ T4S, v1, v29); - __ mulv(v6, __ T4S, v2, v29); - __ mulv(v7, __ T4S, v3, v29); + vs_mulv(vtmp, __ T4S, vs1, mult); + vs_sshr(vs2, __ T4S, vtmp, 22); - __ sshr(v4, __ T4S, v4, 22); - __ sshr(v5, __ T4S, v5, 22); - __ sshr(v6, __ T4S, v6, 22); - __ sshr(v7, __ T4S, v7, 22); - - // quotient in v4 + // r0 --> vs3 // int r0 = rplus - quotient * twoGamma2; - __ mulv(v8, __ T4S, v4, v28); - __ mulv(v9, __ T4S, v5, v28); - __ mulv(v10, __ T4S, v6, v28); - __ mulv(v11, __ T4S, v7, v28); + vs_mulv(vtmp, __ T4S, vs2, twog2); + vs_subv(vs3, __ T4S, vs1, vtmp); - __ subv(v8, __ T4S, v0, v8); - __ subv(v9, __ T4S, v1, v9); - __ subv(v10, __ T4S, v2, v10); - __ subv(v11, __ T4S, v3, v11); - - // r0 in v8 + // mask --> vs4 // int mask = (twoGamma2 - r0) >> 22; - __ subv(v12, __ T4S, v28, v8); - __ subv(v13, __ T4S, v28, v9); - __ subv(v14, __ T4S, v28, v10); - __ subv(v15, __ T4S, v28, v11); + vs_subv(vtmp, __ T4S, twog2, vs3); + vs_sshr(vs4, __ T4S, vtmp, 22); - __ sshr(v12, __ T4S, v12, 22); - __ sshr(v13, __ T4S, v13, 22); - __ sshr(v14, __ T4S, v14, 22); - __ sshr(v15, __ T4S, v15, 22); - - // mask in v12 // r0 -= (mask & twoGamma2); - __ andr(v16, __ T16B, v12, v28); - __ andr(v17, __ T16B, v13, v28); - __ andr(v18, __ T16B, v14, v28); - __ andr(v19, __ T16B, v15, v28); + vs_andr(vtmp, vs4, twog2); + vs_subv(vs3, __ T4S, vs3, vtmp); - __ subv(v8, __ T4S, v8, v16); - __ subv(v9, __ T4S, v9, v17); - __ subv(v10, __ T4S, v10, v18); - __ subv(v11, __ T4S, v11, v19); - - // r0 in v8 // quotient += (mask & 1); - __ andr(v16, __ T16B, v12, v25); - __ andr(v17, __ T16B, v13, v25); - __ andr(v18, __ T16B, v14, v25); - __ andr(v19, __ T16B, v15, v25); - - __ addv(v4, __ T4S, v4, v16); - __ addv(v5, __ T4S, v5, v17); - __ addv(v6, __ T4S, v6, v18); - __ addv(v7, __ T4S, v7, v19); + vs_andr(vtmp, vs4, one); + vs_addv(vs2, __ T4S, vs2, vtmp); // mask = (twoGamma2 / 2 - r0) >> 31; - __ subv(v12, __ T4S, v27, v8); - __ subv(v13, __ T4S, v27, v9); - __ subv(v14, __ T4S, v27, v10); - __ subv(v15, __ T4S, v27, v11); - - __ sshr(v12, __ T4S, v12, 31); - __ sshr(v13, __ T4S, v13, 31); - __ sshr(v14, __ T4S, v14, 31); - __ sshr(v15, __ T4S, v15, 31); + vs_subv(vtmp, __ T4S, g2, vs3); + vs_sshr(vs4, __ T4S, vtmp, 31); // r0 -= (mask & twoGamma2); - __ andr(v16, __ T16B, v12, v28); - __ andr(v17, __ T16B, v13, v28); - __ andr(v18, __ T16B, v14, v28); - __ andr(v19, __ T16B, v15, v28); - - __ subv(v8, __ T4S, v8, v16); - __ subv(v9, __ T4S, v9, v17); - __ subv(v10, __ T4S, v10, v18); - __ subv(v11, __ T4S, v11, v19); + vs_andr(vtmp, vs4, twog2); + vs_subv(vs3, __ T4S, vs3, vtmp); // quotient += (mask & 1); - __ andr(v16, __ T16B, v12, v25); - __ andr(v17, __ T16B, v13, v25); - __ andr(v18, __ T16B, v14, v25); - __ andr(v19, __ T16B, v15, v25); - - __ addv(v4, __ T4S, v4, v16); - __ addv(v5, __ T4S, v5, v17); - __ addv(v6, __ T4S, v6, v18); - __ addv(v7, __ T4S, v7, v19); + vs_andr(vtmp, vs4, one); + vs_addv(vs2, __ T4S, vs2, vtmp); + // r1 --> vs5 // int r1 = rplus - r0 - (dilithium_q - 1); - __ subv(v16, __ T4S, v0, v8); - __ subv(v17, __ T4S, v1, v9); - __ subv(v18, __ T4S, v2, v10); - __ subv(v19, __ T4S, v3, v11); + vs_subv(vtmp, __ T4S, vs1, vs3); + vs_subv(vs5, __ T4S, vtmp, qminus1); - __ subv(v16, __ T4S, v16, v26); - __ subv(v17, __ T4S, v17, v26); - __ subv(v18, __ T4S, v18, v26); - __ subv(v19, __ T4S, v19, v26); - - // r1 in v16 + // r1 --> vs1 (overwriting rplus) // r1 = (r1 | (-r1)) >> 31; // 0 if rplus - r0 == (dilithium_q - 1), -1 otherwise - __ negr(v20, __ T4S, v16); - __ negr(v21, __ T4S, v17); - __ negr(v22, __ T4S, v18); - __ negr(v23, __ T4S, v19); + vs_negr(vtmp, __ T4S, vs5); + vs_orr(vtmp, vs5, vtmp); + vs_sshr(vs1, __ T4S, vtmp, 31); - __ orr(v16, __ T16B, v16, v20); - __ orr(v17, __ T16B, v17, v21); - __ orr(v18, __ T16B, v18, v22); - __ orr(v19, __ T16B, v19, v23); - - __ sshr(v0, __ T4S, v16, 31); - __ sshr(v1, __ T4S, v17, 31); - __ sshr(v2, __ T4S, v18, 31); - __ sshr(v3, __ T4S, v19, 31); - - // r1 in v0 // r0 += ~r1; - __ notr(v20, __ T16B, v0); - __ notr(v21, __ T16B, v1); - __ notr(v22, __ T16B, v2); - __ notr(v23, __ T16B, v3); + vs_notr(vtmp, vs1); + vs_addv(vs3, __ T4S, vs3, vtmp); - __ addv(v8, __ T4S, v8, v20); - __ addv(v9, __ T4S, v9, v21); - __ addv(v10, __ T4S, v10, v22); - __ addv(v11, __ T4S, v11, v23); - - // r0 in v8 // r1 = r1 & quotient; - __ andr(v0, __ T16B, v4, v0); - __ andr(v1, __ T16B, v5, v1); - __ andr(v2, __ T16B, v6, v2); - __ andr(v3, __ T16B, v7, v3); + vs_andr(vs1, vs2, vs1); - // r1 in v0 + // store results inteleaved // lowPart[m] = r0; // highPart[m] = r1; - __ st4(v8, v9, v10, v11, __ T4S, __ post(lowPart, 64)); - __ st4(v0, v1, v2, v3, __ T4S, __ post(highPart, 64)); + __ st4(vs3[0], vs3[1], vs3[2], vs3[3], __ T4S, __ post(lowPart, 64)); + __ st4(vs1[0], vs1[1], vs1[2], vs1[3], __ T4S, __ post(highPart, 64)); __ sub(len, len, 64); @@ -5596,6 +5602,7 @@ class StubGenerator: public StubCodeGenerator { __ ret(lr); return start; + } /**