jdk/src/hotspot/cpu/x86/stubGenerator_x86_64_kyber.cpp
Shawn M Emery a0e6f028a8 8360934: Add AVX-512 intrinsics for ML-KEM - enhancement on AVX512_VBMI
Co-authored-by: Sandhya Viswanathan <sviswanathan@openjdk.org>
Reviewed-by: jbhateja, vpaprotski
2026-01-17 11:08:30 +00:00

1041 lines
38 KiB
C++

/*
* Copyright (c) 2025, 2026, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*
*/
#include "asm/assembler.hpp"
#include "asm/assembler.inline.hpp"
#include "runtime/stubRoutines.hpp"
#include "macroAssembler_x86.hpp"
#include "stubGenerator_x86_64.hpp"
#define __ _masm->
#define xmm(i) as_XMMRegister(i)
#ifdef PRODUCT
#define BLOCK_COMMENT(str) /* nothing */
#else
#define BLOCK_COMMENT(str) __ block_comment(str)
#endif // PRODUCT
#define BIND(label) bind(label); BLOCK_COMMENT(#label ":")
// Constants
//
ATTRIBUTE_ALIGNED(64) static const uint16_t kyberAvx512Consts[] = {
0xF301, 0xF301, 0xF301, 0xF301, // q^-1 mod montR
0x0D01, 0x0D01, 0x0D01, 0x0D01, // q
0x4EBF, 0x4EBF, 0x4EBF, 0x4EBF, // Barrett multiplier
0x0200, 0x0200, 0x0200, 0x0200, //(dim/2)^-1 mod q
0x0549, 0x0549, 0x0549, 0x0549, // montR^2 mod q
0x0F00, 0x0F00, 0x0F00, 0x0F00 // mask for kyber12to16
};
static int qInvModROffset = 0;
static int qOffset = 8;
static int barretMultiplierOffset = 16;
static int dimHalfInverseOffset = 24;
static int montRSquareModqOffset = 32;
static int f00Offset = 40;
static address kyberAvx512ConstsAddr(int offset) {
return ((address) kyberAvx512Consts) + offset;
}
const Register scratch = r10;
ATTRIBUTE_ALIGNED(64) static const uint8_t kyberAvx512_12To16Dup[] = {
// 0 - 63
0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11, 12, 13, 13, 14, 15, 16,
16, 17, 18, 19, 19, 20, 21, 22, 22, 23, 24, 25, 25, 26, 27, 28, 28, 29, 30,
31, 31, 32, 33, 34, 34, 35, 36, 37, 37, 38, 39, 40, 40, 41, 42, 43, 43, 44,
45, 46, 46, 47
};
static address kyberAvx512_12To16DupAddr() {
return (address) kyberAvx512_12To16Dup;
}
ATTRIBUTE_ALIGNED(64) static const uint16_t kyberAvx512_12To16Shift[] = {
// 0 - 31
0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0,
4, 0, 4, 0, 4, 0, 4
};
static address kyberAvx512_12To16ShiftAddr() {
return (address) kyberAvx512_12To16Shift;
}
ATTRIBUTE_ALIGNED(64) static const uint64_t kyberAvx512_12To16And[] = {
// 0 - 7
0x0FFF0FFF0FFF0FFF, 0x0FFF0FFF0FFF0FFF, 0x0FFF0FFF0FFF0FFF,
0x0FFF0FFF0FFF0FFF, 0x0FFF0FFF0FFF0FFF, 0x0FFF0FFF0FFF0FFF,
0x0FFF0FFF0FFF0FFF, 0x0FFF0FFF0FFF0FFF
};
static address kyberAvx512_12To16AndAddr() {
return (address) kyberAvx512_12To16And;
}
ATTRIBUTE_ALIGNED(64) static const uint16_t kyberAvx512NttPerms[] = {
// 0
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F,
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F,
// 128
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F,
0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F,
0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F,
// 256
0x00, 0x01, 0x02, 0x03, 0x20, 0x21, 0x22, 0x23,
0x08, 0x09, 0x0A, 0x0B, 0x28, 0x29, 0x2A, 0x2B,
0x10, 0x11, 0x12, 0x13, 0x30, 0x31, 0x32, 0x33,
0x18, 0x19, 0x1A, 0x1B, 0x38, 0x39, 0x3A, 0x3B,
0x04, 0x05, 0x06, 0x07, 0x24, 0x25, 0x26, 0x27,
0x0C, 0x0D, 0x0E, 0x0F, 0x2C, 0x2D, 0x2E, 0x2F,
0x14, 0x15, 0x16, 0x17, 0x34, 0x35, 0x36, 0x37,
0x1C, 0x1D, 0x1E, 0x1F, 0x3C, 0x3D, 0x3E, 0x3F,
// 384
0x00, 0x01, 0x20, 0x21, 0x04, 0x05, 0x24, 0x25,
0x08, 0x09, 0x28, 0x29, 0x0C, 0x0D, 0x2C, 0x2D,
0x10, 0x11, 0x30, 0x31, 0x14, 0x15, 0x34, 0x35,
0x18, 0x19, 0x38, 0x39, 0x1C, 0x1D, 0x3C, 0x3D,
0x02, 0x03, 0x22, 0x23, 0x06, 0x07, 0x26, 0x27,
0x0A, 0x0B, 0x2A, 0x2B, 0x0E, 0x0F, 0x2E, 0x2F,
0x12, 0x13, 0x32, 0x33, 0x16, 0x17, 0x36, 0x37,
0x1A, 0x1B, 0x3A, 0x3B, 0x1E, 0x1F, 0x3E, 0x3F,
// 512
0x10, 0x11, 0x30, 0x31, 0x12, 0x13, 0x32, 0x33,
0x14, 0x15, 0x34, 0x35, 0x16, 0x17, 0x36, 0x37,
0x18, 0x19, 0x38, 0x39, 0x1A, 0x1B, 0x3A, 0x3B,
0x1C, 0x1D, 0x3C, 0x3D, 0x1E, 0x1F, 0x3E, 0x3F,
0x00, 0x01, 0x20, 0x21, 0x02, 0x03, 0x22, 0x23,
0x04, 0x05, 0x24, 0x25, 0x06, 0x07, 0x26, 0x27,
0x08, 0x09, 0x28, 0x29, 0x0A, 0x0B, 0x2A, 0x2B,
0x0C, 0x0D, 0x2C, 0x2D, 0x0E, 0x0F, 0x2E, 0x2F
};
static address kyberAvx512NttPermsAddr() {
return (address) kyberAvx512NttPerms;
}
ATTRIBUTE_ALIGNED(64) static const uint16_t kyberAvx512InverseNttPerms[] = {
// 0
0x02, 0x03, 0x06, 0x07, 0x0A, 0x0B, 0x0E, 0x0F,
0x12, 0x13, 0x16, 0x17, 0x1A, 0x1B, 0x1E, 0x1F,
0x22, 0x23, 0x26, 0x27, 0x2A, 0x2B, 0x2E, 0x2F,
0x32, 0x33, 0x36, 0x37, 0x3A, 0x3B, 0x3E, 0x3F,
0x00, 0x01, 0x04, 0x05, 0x08, 0x09, 0x0C, 0x0D,
0x10, 0x11, 0x14, 0x15, 0x18, 0x19, 0x1C, 0x1D,
0x20, 0x21, 0x24, 0x25, 0x28, 0x29, 0x2C, 0x2D,
0x30, 0x31, 0x34, 0x35, 0x38, 0x39, 0x3C, 0x3D,
// 128
0x00, 0x01, 0x20, 0x21, 0x04, 0x05, 0x24, 0x25,
0x08, 0x09, 0x28, 0x29, 0x0C, 0x0D, 0x2C, 0x2D,
0x10, 0x11, 0x30, 0x31, 0x14, 0x15, 0x34, 0x35,
0x18, 0x19, 0x38, 0x39, 0x1C, 0x1D, 0x3C, 0x3D,
0x02, 0x03, 0x22, 0x23, 0x06, 0x07, 0x26, 0x27,
0x0A, 0x0B, 0x2A, 0x2B, 0x0E, 0x0F, 0x2E, 0x2F,
0x12, 0x13, 0x32, 0x33, 0x16, 0x17, 0x36, 0x37,
0x1A, 0x1B, 0x3A, 0x3B, 0x1E, 0x1F, 0x3E, 0x3F,
// 256
0x00, 0x01, 0x02, 0x03, 0x20, 0x21, 0x22, 0x23,
0x08, 0x09, 0x0A, 0x0B, 0x28, 0x29, 0x2A, 0x2B,
0x10, 0x11, 0x12, 0x13, 0x30, 0x31, 0x32, 0x33,
0x18, 0x19, 0x1A, 0x1B, 0x38, 0x39, 0x3A, 0x3B,
0x04, 0x05, 0x06, 0x07, 0x24, 0x25, 0x26, 0x27,
0x0C, 0x0D, 0x0E, 0x0F, 0x2C, 0x2D, 0x2E, 0x2F,
0x14, 0x15, 0x16, 0x17, 0x34, 0x35, 0x36, 0x37,
0x1C, 0x1D, 0x1E, 0x1F, 0x3C, 0x3D, 0x3E, 0x3F,
// 384
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F,
0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F,
0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F,
// 512
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F,
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F
};
static address kyberAvx512InverseNttPermsAddr() {
return (address) kyberAvx512InverseNttPerms;
}
ATTRIBUTE_ALIGNED(64) static const uint16_t kyberAvx512_nttMultPerms[] = {
0x00, 0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E,
0x10, 0x12, 0x14, 0x16, 0x18, 0x1A, 0x1C, 0x1E,
0x20, 0x22, 0x24, 0x26, 0x28, 0x2A, 0x2C, 0x2E,
0x30, 0x32, 0x34, 0x36, 0x38, 0x3A, 0x3C, 0x3E,
0x01, 0x03, 0x05, 0x07, 0x09, 0x0B, 0x0D, 0x0F,
0x11, 0x13, 0x15, 0x17, 0x19, 0x1B, 0x1D, 0x1F,
0x21, 0x23, 0x25, 0x27, 0x29, 0x2B, 0x2D, 0x2F,
0x31, 0x33, 0x35, 0x37, 0x39, 0x3B, 0x3D, 0x3F,
0x00, 0x20, 0x01, 0x21, 0x02, 0x22, 0x03, 0x23,
0x04, 0x24, 0x05, 0x25, 0x06, 0x26, 0x07, 0x27,
0x08, 0x28, 0x09, 0x29, 0x0A, 0x2A, 0x0B, 0x2B,
0x0C, 0x2C, 0x0D, 0x2D, 0x0E, 0x2E, 0x0F, 0x2F,
0x10, 0x30, 0x11, 0x31, 0x12, 0x32, 0x13, 0x33,
0x14, 0x34, 0x15, 0x35, 0x16, 0x36, 0x17, 0x37,
0x18, 0x38, 0x19, 0x39, 0x1A, 0x3A, 0x1B, 0x3B,
0x1C, 0x3C, 0x1D, 0x3D, 0x1E, 0x3E, 0x1F, 0x3F
};
static address kyberAvx512_nttMultPermsAddr() {
return (address) kyberAvx512_nttMultPerms;
}
ATTRIBUTE_ALIGNED(64) static const uint16_t kyberAvx512_12To16Perms[] = {
// 0
0x00, 0x03, 0x06, 0x09, 0x0C, 0x0F, 0x12, 0x15,
0x18, 0x1B, 0x1E, 0x21, 0x24, 0x27, 0x2A, 0x2D,
0x30, 0x33, 0x36, 0x39, 0x3C, 0x3F, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x04, 0x07, 0x0A, 0x0D, 0x10, 0x13, 0x16,
0x19, 0x1C, 0x1F, 0x22, 0x25, 0x28, 0x2B, 0x2E,
0x31, 0x34, 0x37, 0x3A, 0x3D, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
// 128
0x02, 0x05, 0x08, 0x0B, 0x0E, 0x11, 0x14, 0x17,
0x1A, 0x1D, 0x20, 0x23, 0x26, 0x29, 0x2C, 0x2F,
0x32, 0x35, 0x38, 0x3B, 0x3E, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x22, 0x25,
0x28, 0x2B, 0x2E, 0x31, 0x34, 0x37, 0x3A, 0x3D,
// 256
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
0x10, 0x11, 0x12, 0x13, 0x14, 0x20, 0x23, 0x26,
0x29, 0x2C, 0x2F, 0x32, 0x35, 0x38, 0x3B, 0x3E,
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
0x10, 0x11, 0x12, 0x13, 0x14, 0x21, 0x24, 0x27,
0x2A, 0x2D, 0x30, 0x33, 0x36, 0x39, 0x3C, 0x3F,
// 384
0x00, 0x20, 0x01, 0x21, 0x02, 0x22, 0x03, 0x23,
0x04, 0x24, 0x05, 0x25, 0x06, 0x26, 0x07, 0x27,
0x08, 0x28, 0x09, 0x29, 0x0A, 0x2A, 0x0B, 0x2B,
0x0C, 0x2C, 0x0D, 0x2D, 0x0E, 0x2E, 0x0F, 0x2F,
0x10, 0x30, 0x11, 0x31, 0x12, 0x32, 0x13, 0x33,
0x14, 0x34, 0x15, 0x35, 0x16, 0x36, 0x17, 0x37,
0x18, 0x38, 0x19, 0x39, 0x1A, 0x3A, 0x1B, 0x3B,
0x1C, 0x3C, 0x1D, 0x3D, 0x1E, 0x3E, 0x1F, 0x3F
};
static address kyberAvx512_12To16PermsAddr() {
return (address) kyberAvx512_12To16Perms;
}
static void load4regs(int destRegs[], Register address, int offset,
MacroAssembler *_masm) {
for (int i = 0; i < 4; i++) {
__ evmovdquw(xmm(destRegs[i]), Address(address, offset + i * 64),
Assembler::AVX_512bit);
}
}
// For z = montmul(a,b), z will be between -q and q and congruent
// to a * b * R^-1 mod q, where R > 2 * q, R is a power of 2,
// -R/2 * q <= a * b < R/2 * q.
// (See e.g. Algorithm 3 in https://eprint.iacr.org/2018/039.pdf)
// For the Java code, we use R = 2^20 and for the intrinsic, R = 2^16.
// In our computations, b is always c * R mod q, so the montmul() really
// computes a * c mod q. In the Java code, we use 32-bit numbers for the
// computations, and we use R = 2^20 because that way the a * b numbers
// that occur during all computations stay in the required range.
// For the intrinsics, we use R = 2^16, because this way we can do twice
// as much work in parallel, the only drawback is that we should do some Barrett
// reductions in kyberInverseNtt so that the numbers stay in the required range.
static void montmul(int outputRegs[], int inputRegs1[], int inputRegs2[],
int scratchRegs1[], int scratchRegs2[], MacroAssembler *_masm) {
for (int i = 0; i < 4; i++) {
__ evpmullw(xmm(scratchRegs1[i]), k0, xmm(inputRegs1[i]),
xmm(inputRegs2[i]), false, Assembler::AVX_512bit);
}
for (int i = 0; i < 4; i++) {
__ evpmulhw(xmm(scratchRegs2[i]), k0, xmm(inputRegs1[i]),
xmm(inputRegs2[i]), false, Assembler::AVX_512bit);
}
for (int i = 0; i < 4; i++) {
__ evpmullw(xmm(scratchRegs1[i]), k0, xmm(scratchRegs1[i]),
xmm31, false, Assembler::AVX_512bit);
}
for (int i = 0; i < 4; i++) {
__ evpmulhw(xmm(scratchRegs1[i]), k0, xmm(scratchRegs1[i]),
xmm30, false, Assembler::AVX_512bit);
}
for (int i = 0; i < 4; i++) {
__ evpsubw(xmm(outputRegs[i]), k0, xmm(scratchRegs2[i]),
xmm(scratchRegs1[i]), false, Assembler::AVX_512bit);
}
}
static void sub_add(int subResult[], int addResult[], int input1[], int input2[],
MacroAssembler *_masm) {
for (int i = 0; i < 4; i++) {
__ evpsubw(xmm(subResult[i]), k0, xmm(input1[i]), xmm(input2[i]),
false, Assembler::AVX_512bit);
__ evpaddw(xmm(addResult[i]), k0, xmm(input1[i]), xmm(input2[i]),
false, Assembler::AVX_512bit);
}
}
// result2 also acts as input1
// result1 also acts as perm1
static void permute(int result1[], int result2[], int input2[], int perm2,
MacroAssembler *_masm) {
for (int i = 1; i < 4; i++) {
__ evmovdquw(xmm(result1[i]), xmm(result1[0]), Assembler::AVX_512bit);
}
for (int i = 0; i < 4; i++) {
__ evpermi2w(xmm(result1[i]), xmm(result2[i]), xmm(input2[i]),
Assembler::AVX_512bit);
__ evpermt2w(xmm(result2[i]), xmm(perm2), xmm(input2[i]),
Assembler::AVX_512bit);
}
}
static void store4regs(Register address, int offset, int sourceRegs[],
MacroAssembler *_masm) {
for (int i = 0; i < 4; i++) {
__ evmovdquw(Address(address, offset + i * 64), xmm(sourceRegs[i]),
Assembler::AVX_512bit);
}
}
// In all 3 invocations of this function we use the same registers:
// xmm0-xmm7 for the input and the result,
// xmm8-xmm15 as scratch registers and
// xmm16-xmm17 for the constants,
// so we don't pass register arguments.
static void barrettReduce(MacroAssembler *_masm) {
for (int i = 0; i < 8; i++) {
__ evpmulhw(xmm(i + 8), k0, xmm(i), xmm16, false, Assembler::AVX_512bit);
}
for (int i = 0; i < 8; i++) {
__ evpsraw(xmm(i + 8), k0, xmm(i + 8), 10, false, Assembler::AVX_512bit);
}
for (int i = 0; i < 8; i++) {
__ evpmullw(xmm(i + 8), k0, xmm(i + 8), xmm17, false, Assembler::AVX_512bit);
}
for (int i = 0; i < 8; i++) {
__ evpsubw(xmm(i), k0, xmm(i), xmm(i + 8), false, Assembler::AVX_512bit);
}
}
static int xmm0_3[] = {0, 1, 2, 3};
static int xmm0145[] = {0, 1, 4, 5};
static int xmm0246[] = {0, 2, 4, 6};
static int xmm0829[] = {0, 8, 2, 9};
static int xmm1001[] = {1, 0, 0, 1};
static int xmm1357[] = {1, 3, 5, 7};
static int xmm2367[] = {2, 3, 6, 7};
static int xmm2_0_10_8[] = {2, 0, 10, 8};
static int xmm3223[] = {3, 2, 2, 3};
static int xmm4_7[] = {4, 5, 6, 7};
static int xmm5454[] = {5, 4, 5, 4};
static int xmm7676[] = {7, 6, 7, 6};
static int xmm8_11[] = {8, 9, 10, 11};
static int xmm12_15[] = {12, 13, 14, 15};
static int xmm16_19[] = {16, 17, 18, 19};
static int xmm20_23[] = {20, 21, 22, 23};
static int xmm23_23[] = {23, 23, 23, 23};
static int xmm24_27[] = {24, 25, 26, 27};
static int xmm26_29[] = {26, 27, 28, 29};
static int xmm28_31[] = {28, 29, 30, 31};
static int xmm29_29[] = {29, 29, 29, 29};
// Kyber NTT function.
//
// coeffs (short[256]) = c_rarg0
// ntt_zetas (short[256]) = c_rarg1
address generate_kyberNtt_avx512(StubGenerator *stubgen,
MacroAssembler *_masm) {
__ align(CodeEntryAlignment);
StubId stub_id = StubId::stubgen_kyberNtt_id;
StubCodeMark mark(stubgen, stub_id);
address start = __ pc();
__ enter();
const Register coeffs = c_rarg0;
const Register zetas = c_rarg1;
const Register perms = r11;
__ lea(perms, ExternalAddress(kyberAvx512NttPermsAddr()));
load4regs(xmm4_7, coeffs, 256, _masm);
load4regs(xmm20_23, zetas, 0, _masm);
__ vpbroadcastq(xmm30,
ExternalAddress(kyberAvx512ConstsAddr(qOffset)),
Assembler::AVX_512bit, scratch); // q
__ vpbroadcastq(xmm31,
ExternalAddress(kyberAvx512ConstsAddr(qInvModROffset)),
Assembler::AVX_512bit, scratch); // q^-1 mod montR
load4regs(xmm0_3, coeffs, 0, _masm);
// Each level represents one iteration of the outer for loop of the Java version.
// level 0
montmul(xmm8_11, xmm4_7, xmm20_23, xmm8_11, xmm4_7, _masm);
load4regs(xmm20_23, zetas, 256, _masm);
sub_add(xmm4_7, xmm0_3, xmm0_3, xmm8_11, _masm);
//level 1
montmul(xmm12_15, xmm2367, xmm20_23, xmm12_15, xmm8_11, _masm);
load4regs(xmm20_23, zetas, 512, _masm);
sub_add(xmm2367, xmm0145, xmm0145, xmm12_15, _masm);
// level 2
montmul(xmm8_11, xmm1357, xmm20_23, xmm12_15, xmm8_11, _masm);
__ evmovdquw(xmm12, Address(perms, 0), Assembler::AVX_512bit);
__ evmovdquw(xmm16, Address(perms, 64), Assembler::AVX_512bit);
load4regs(xmm20_23, zetas, 768, _masm);
sub_add(xmm1357, xmm0246, xmm0246, xmm8_11, _masm);
//level 3
permute(xmm12_15, xmm0246, xmm1357, 16, _masm);
montmul(xmm8_11, xmm12_15, xmm20_23, xmm16_19, xmm8_11, _masm);
__ evmovdquw(xmm16, Address(perms, 128), Assembler::AVX_512bit);
__ evmovdquw(xmm24, Address(perms, 192), Assembler::AVX_512bit);
load4regs(xmm20_23, zetas, 1024, _masm);
sub_add(xmm1357, xmm0246, xmm0246, xmm8_11, _masm);
// level 4
permute(xmm16_19, xmm0246, xmm1357, 24, _masm);
montmul(xmm8_11, xmm0246, xmm20_23, xmm24_27, xmm8_11, _masm);
__ evmovdquw(xmm1, Address(perms, 256), Assembler::AVX_512bit);
__ evmovdquw(xmm24, Address(perms, 320), Assembler::AVX_512bit);
load4regs(xmm20_23, zetas, 1280, _masm);
sub_add(xmm12_15, xmm0246, xmm16_19, xmm8_11, _masm);
// level 5
permute(xmm1357, xmm0246, xmm12_15, 24, _masm);
montmul(xmm16_19, xmm0246, xmm20_23, xmm16_19, xmm8_11, _masm);
__ evmovdquw(xmm12, Address(perms, 384), Assembler::AVX_512bit);
__ evmovdquw(xmm8, Address(perms, 448), Assembler::AVX_512bit);
load4regs(xmm20_23, zetas, 1536, _masm);
sub_add(xmm24_27, xmm0246, xmm1357, xmm16_19, _masm);
// level 6
permute(xmm12_15, xmm0246, xmm24_27, 8, _masm);
__ evmovdquw(xmm1, Address(perms, 512), Assembler::AVX_512bit);
__ evmovdquw(xmm24, Address(perms, 576), Assembler::AVX_512bit);
montmul(xmm16_19, xmm0246, xmm20_23, xmm16_19, xmm8_11, _masm);
sub_add(xmm20_23, xmm0246, xmm12_15, xmm16_19, _masm);
permute(xmm1357, xmm0246, xmm20_23, 24, _masm);
store4regs(coeffs, 0, xmm0_3, _masm);
store4regs(coeffs, 256, xmm4_7, _masm);
__ leave(); // required for proper stackwalking of RuntimeStub frame
__ mov64(rax, 0); // return 0
__ ret(0);
return start;
}
// Kyber Inverse NTT function
//
// coeffs (short[256]) = c_rarg0
// ntt_zetas (short[256]) = c_rarg1
address generate_kyberInverseNtt_avx512(StubGenerator *stubgen,
MacroAssembler *_masm) {
__ align(CodeEntryAlignment);
StubId stub_id = StubId::stubgen_kyberInverseNtt_id;
StubCodeMark mark(stubgen, stub_id);
address start = __ pc();
__ enter();
const Register coeffs = c_rarg0;
const Register zetas = c_rarg1;
const Register perms = r11;
__ lea(perms, ExternalAddress(kyberAvx512InverseNttPermsAddr()));
__ evmovdquw(xmm12, Address(perms, 0), Assembler::AVX_512bit);
__ evmovdquw(xmm16, Address(perms, 64), Assembler::AVX_512bit);
__ vpbroadcastq(xmm31,
ExternalAddress(kyberAvx512ConstsAddr(qInvModROffset)),
Assembler::AVX_512bit, scratch); // q^-1 mod montR
__ vpbroadcastq(xmm30,
ExternalAddress(kyberAvx512ConstsAddr(qOffset)),
Assembler::AVX_512bit, scratch); // q
__ vpbroadcastq(xmm29,
ExternalAddress(kyberAvx512ConstsAddr(dimHalfInverseOffset)),
Assembler::AVX_512bit, scratch); // (dim/2)^-1 mod q
load4regs(xmm0_3, coeffs, 0, _masm);
load4regs(xmm4_7, coeffs, 256, _masm);
// Each level represents one iteration of the outer for loop of the Java version.
// level 0
load4regs(xmm8_11, zetas, 0, _masm);
permute(xmm12_15, xmm0246, xmm1357, 16, _masm);
__ evmovdquw(xmm1, Address(perms, 128), Assembler::AVX_512bit);
__ evmovdquw(xmm20, Address(perms, 192), Assembler::AVX_512bit);
sub_add(xmm16_19, xmm0246, xmm0246, xmm12_15, _masm);
montmul(xmm12_15, xmm16_19, xmm8_11, xmm12_15, xmm8_11, _masm);
// level 1
load4regs(xmm8_11, zetas, 256, _masm);
permute(xmm1357, xmm0246, xmm12_15, 20, _masm);
sub_add(xmm16_19, xmm0246, xmm1357, xmm0246, _masm);
__ evmovdquw(xmm1, Address(perms, 256), Assembler::AVX_512bit);
__ evmovdquw(xmm20, Address(perms, 320), Assembler::AVX_512bit);
montmul(xmm12_15, xmm16_19, xmm8_11, xmm12_15, xmm8_11, _masm);
// level2
load4regs(xmm8_11, zetas, 512, _masm);
permute(xmm1357, xmm0246, xmm12_15, 20, _masm);
sub_add(xmm16_19, xmm0246, xmm1357, xmm0246,_masm);
__ evmovdquw(xmm1, Address(perms, 384), Assembler::AVX_512bit);
__ evmovdquw(xmm20, Address(perms, 448), Assembler::AVX_512bit);
montmul(xmm12_15, xmm16_19, xmm8_11, xmm12_15, xmm8_11, _masm);
__ vpbroadcastq(xmm16,
ExternalAddress(kyberAvx512ConstsAddr(barretMultiplierOffset)),
Assembler::AVX_512bit, scratch); // Barrett multiplier
__ vpbroadcastq(xmm17,
ExternalAddress(kyberAvx512ConstsAddr(qOffset)),
Assembler::AVX_512bit, scratch); // q
permute(xmm1357, xmm0246, xmm12_15, 20, _masm);
barrettReduce(_masm);
// level 3
load4regs(xmm8_11, zetas, 768, _masm);
sub_add(xmm16_19, xmm0246, xmm1357, xmm0246, _masm);
__ evmovdquw(xmm1, Address(perms, 512), Assembler::AVX_512bit);
__ evmovdquw(xmm20, Address(perms, 576), Assembler::AVX_512bit);
montmul(xmm12_15, xmm16_19, xmm8_11, xmm12_15, xmm8_11, _masm);
permute(xmm1357, xmm0246, xmm12_15, 20, _masm);
// level 4
load4regs(xmm8_11, zetas, 1024, _masm);
__ vpbroadcastq(xmm16,
ExternalAddress(kyberAvx512ConstsAddr(barretMultiplierOffset)),
Assembler::AVX_512bit, scratch); // Barrett multiplier
__ vpbroadcastq(xmm17,
ExternalAddress(kyberAvx512ConstsAddr(qOffset)),
Assembler::AVX_512bit, scratch); // q
sub_add(xmm12_15, xmm0246, xmm0246, xmm1357, _masm);
montmul(xmm1357, xmm12_15, xmm8_11, xmm1357, xmm8_11, _masm);
barrettReduce(_masm);
// level 5
load4regs(xmm8_11, zetas, 1280, _masm);
sub_add(xmm12_15, xmm0145, xmm0145, xmm2367, _masm);
montmul(xmm2367, xmm12_15, xmm8_11, xmm2367, xmm8_11, _masm);
// level 6
load4regs(xmm8_11, zetas, 1536, _masm);
sub_add(xmm12_15, xmm0_3, xmm0_3, xmm4_7, _masm);
montmul(xmm4_7, xmm12_15, xmm8_11, xmm4_7, xmm8_11, _masm);
montmul(xmm8_11, xmm29_29, xmm0_3, xmm8_11, xmm0_3, _masm);
montmul(xmm12_15, xmm29_29, xmm4_7, xmm12_15, xmm4_7, _masm);
store4regs(coeffs, 0, xmm8_11, _masm);
store4regs(coeffs, 256, xmm12_15, _masm);
__ leave(); // required for proper stackwalking of RuntimeStub frame
__ mov64(rax, 0); // return 0
__ ret(0);
return start;
}
// Kyber multiply polynomials in the NTT domain.
//
// result (short[256]) = c_rarg0
// ntta (short[256]) = c_rarg1
// nttb (short[256]) = c_rarg2
// zetas (short[128]) = c_rarg3
address generate_kyberNttMult_avx512(StubGenerator *stubgen,
MacroAssembler *_masm) {
__ align(CodeEntryAlignment);
StubId stub_id = StubId::stubgen_kyberNttMult_id;
StubCodeMark mark(stubgen, stub_id);
address start = __ pc();
__ enter();
const Register result = c_rarg0;
const Register ntta = c_rarg1;
const Register nttb = c_rarg2;
const Register zetas = c_rarg3;
const Register perms = r11;
const Register loopCnt = r12;
__ push_ppx(r12);
__ movl(loopCnt, 2);
Label Loop;
__ lea(perms, ExternalAddress(kyberAvx512_nttMultPermsAddr()));
load4regs(xmm26_29, perms, 0, _masm);
__ vpbroadcastq(xmm31,
ExternalAddress(kyberAvx512ConstsAddr(qInvModROffset)),
Assembler::AVX_512bit, scratch); // q^-1 mod montR
__ vpbroadcastq(xmm30,
ExternalAddress(kyberAvx512ConstsAddr(qOffset)),
Assembler::AVX_512bit, scratch); // q
__ vpbroadcastq(xmm23,
ExternalAddress(kyberAvx512ConstsAddr(montRSquareModqOffset)),
Assembler::AVX_512bit, scratch); // montR^2 mod q
__ BIND(Loop);
__ evmovdquw(xmm1, Address(ntta, 0), Assembler::AVX_512bit);
__ evmovdquw(xmm8, Address(ntta, 64), Assembler::AVX_512bit);
__ evmovdquw(xmm3, Address(ntta, 128), Assembler::AVX_512bit);
__ evmovdquw(xmm9, Address(ntta, 192), Assembler::AVX_512bit);
__ evmovdquw(xmm5, Address(nttb, 0), Assembler::AVX_512bit);
__ evmovdquw(xmm10, Address(nttb, 64), Assembler::AVX_512bit);
__ evmovdquw(xmm7, Address(nttb, 128), Assembler::AVX_512bit);
__ evmovdquw(xmm11, Address(nttb, 192), Assembler::AVX_512bit);
__ evmovdquw(xmm0, xmm26, Assembler::AVX_512bit);
__ evmovdquw(xmm2, xmm26, Assembler::AVX_512bit);
__ evmovdquw(xmm4, xmm26, Assembler::AVX_512bit);
__ evmovdquw(xmm6, xmm26, Assembler::AVX_512bit);
__ evpermi2w(xmm0, xmm1, xmm8, Assembler::AVX_512bit);
__ evpermt2w(xmm1, xmm27, xmm8, Assembler::AVX_512bit);
__ evpermi2w(xmm2, xmm3, xmm9, Assembler::AVX_512bit);
__ evpermt2w(xmm3, xmm27, xmm9, Assembler::AVX_512bit);
__ evpermi2w(xmm4, xmm5, xmm10, Assembler::AVX_512bit);
__ evpermt2w(xmm5, xmm27, xmm10, Assembler::AVX_512bit);
__ evpermi2w(xmm6, xmm7, xmm11, Assembler::AVX_512bit);
__ evpermt2w(xmm7, xmm27, xmm11, Assembler::AVX_512bit);
__ evmovdquw(xmm24, Address(zetas, 0), Assembler::AVX_512bit);
__ evmovdquw(xmm25, Address(zetas, 64), Assembler::AVX_512bit);
montmul(xmm16_19, xmm1001, xmm5454, xmm16_19, xmm12_15, _masm);
montmul(xmm0145, xmm3223, xmm7676, xmm0145, xmm12_15, _masm);
__ evpmullw(xmm2, k0, xmm16, xmm24, false, Assembler::AVX_512bit);
__ evpmullw(xmm3, k0, xmm0, xmm25, false, Assembler::AVX_512bit);
__ evpmulhw(xmm12, k0, xmm16, xmm24, false, Assembler::AVX_512bit);
__ evpmulhw(xmm13, k0, xmm0, xmm25, false, Assembler::AVX_512bit);
__ evpmullw(xmm2, k0, xmm2, xmm31, false, Assembler::AVX_512bit);
__ evpmullw(xmm3, k0, xmm3, xmm31, false, Assembler::AVX_512bit);
__ evpmulhw(xmm2, k0, xmm30, xmm2, false, Assembler::AVX_512bit);
__ evpmulhw(xmm3, k0, xmm30, xmm3, false, Assembler::AVX_512bit);
__ evpsubw(xmm2, k0, xmm12, xmm2, false, Assembler::AVX_512bit);
__ evpsubw(xmm3, k0, xmm13, xmm3, false, Assembler::AVX_512bit);
__ evpaddw(xmm0, k0, xmm2, xmm17, false, Assembler::AVX_512bit);
__ evpaddw(xmm8, k0, xmm3, xmm1, false, Assembler::AVX_512bit);
__ evpaddw(xmm2, k0, xmm18, xmm19, false, Assembler::AVX_512bit);
__ evpaddw(xmm9, k0, xmm4, xmm5, false, Assembler::AVX_512bit);
montmul(xmm1357, xmm0829, xmm23_23, xmm1357, xmm0829, _masm);
__ evmovdquw(xmm0, xmm28, Assembler::AVX_512bit);
__ evmovdquw(xmm2, xmm28, Assembler::AVX_512bit);
__ evpermi2w(xmm0, xmm1, xmm5, Assembler::AVX_512bit);
__ evpermt2w(xmm1, xmm29, xmm5, Assembler::AVX_512bit);
__ evpermi2w(xmm2, xmm3, xmm7, Assembler::AVX_512bit);
__ evpermt2w(xmm3, xmm29, xmm7, Assembler::AVX_512bit);
store4regs(result, 0, xmm0_3, _masm);
__ addptr(ntta, 256);
__ addptr(nttb, 256);
__ addptr(result, 256);
__ addptr(zetas, 128);
__ subl(loopCnt, 1);
__ jcc(Assembler::greater, Loop);
__ pop_ppx(r12);
__ leave(); // required for proper stackwalking of RuntimeStub frame
__ mov64(rax, 0); // return 0
__ ret(0);
return start;
}
// Kyber add 2 polynomials.
//
// result (short[256]) = c_rarg0
// a (short[256]) = c_rarg1
// b (short[256]) = c_rarg2
address generate_kyberAddPoly_2_avx512(StubGenerator *stubgen,
MacroAssembler *_masm) {
__ align(CodeEntryAlignment);
StubId stub_id = StubId::stubgen_kyberAddPoly_2_id;
StubCodeMark mark(stubgen, stub_id);
address start = __ pc();
__ enter();
const Register result = c_rarg0;
const Register a = c_rarg1;
const Register b = c_rarg2;
__ vpbroadcastq(xmm31,
ExternalAddress(kyberAvx512ConstsAddr(qOffset)),
Assembler::AVX_512bit, scratch); // q
for (int i = 0; i < 8; i++) {
__ evmovdquw(xmm(i), Address(a, 64 * i), Assembler::AVX_512bit);
__ evmovdquw(xmm(i + 8), Address(b, 64 * i), Assembler::AVX_512bit);
}
for (int i = 0; i < 8; i++) {
__ evpaddw(xmm(i), k0, xmm(i), xmm(i + 8), false, Assembler::AVX_512bit);
}
for (int i = 0; i < 8; i++) {
__ evpaddw(xmm(i), k0, xmm(i), xmm31, false, Assembler::AVX_512bit);
}
store4regs(result, 0, xmm0_3, _masm);
store4regs(result, 256, xmm4_7, _masm);
__ leave(); // required for proper stackwalking of RuntimeStub frame
__ mov64(rax, 0); // return 0
__ ret(0);
return start;
}
// Kyber add 3 polynomials.
//
// result (short[256]) = c_rarg0
// a (short[256]) = c_rarg1
// b (short[256]) = c_rarg2
// c (short[256]) = c_rarg3
address generate_kyberAddPoly_3_avx512(StubGenerator *stubgen,
MacroAssembler *_masm) {
__ align(CodeEntryAlignment);
StubId stub_id = StubId::stubgen_kyberAddPoly_3_id;
StubCodeMark mark(stubgen, stub_id);
address start = __ pc();
__ enter();
const Register result = c_rarg0;
const Register a = c_rarg1;
const Register b = c_rarg2;
const Register c = c_rarg3;
__ vpbroadcastq(xmm31,
ExternalAddress(kyberAvx512ConstsAddr(qOffset)),
Assembler::AVX_512bit, scratch); // q
for (int i = 0; i < 8; i++) {
__ evmovdquw(xmm(i), Address(a, 64 * i), Assembler::AVX_512bit);
__ evmovdquw(xmm(i + 8), Address(b, 64 * i), Assembler::AVX_512bit);
__ evmovdquw(xmm(i + 16), Address(c, 64 * i), Assembler::AVX_512bit);
}
__ evpaddw(xmm31, k0, xmm31, xmm31, false, Assembler::AVX_512bit);
for (int i = 0; i < 8; i++) {
__ evpaddw(xmm(i), k0, xmm(i), xmm(i + 8), false, Assembler::AVX_512bit);
}
for (int i = 0; i < 8; i++) {
__ evpaddw(xmm(i), k0, xmm(i), xmm(i + 16), false, Assembler::AVX_512bit);
}
for (int i = 0; i < 8; i++) {
__ evpaddw(xmm(i), k0, xmm(i), xmm31, false, Assembler::AVX_512bit);
}
store4regs(result, 0, xmm0_3, _masm);
store4regs(result, 256, xmm4_7, _masm);
__ leave(); // required for proper stackwalking of RuntimeStub frame
__ mov64(rax, 0); // return 0
__ ret(0);
return start;
}
// Kyber parse XOF output to polynomial coefficient candidates.
//
// condensed (byte[168]) = c_rarg0
// condensedOffs (int) = c_rarg1
// parsed (short[112]) = c_rarg2
// parsedLength (int) = c_rarg3
address generate_kyber12To16_avx512(StubGenerator *stubgen,
MacroAssembler *_masm) {
__ align(CodeEntryAlignment);
StubId stub_id = StubId::stubgen_kyber12To16_id;
StubCodeMark mark(stubgen, stub_id);
address start = __ pc();
__ enter();
const Register condensed = c_rarg0;
const Register condensedOffs = c_rarg1;
const Register parsed = c_rarg2;
const Register parsedLength = c_rarg3;
const Register perms = r11;
Label Loop, VBMILoop;
__ addptr(condensed, condensedOffs);
if (VM_Version::supports_avx512_vbmi()) {
// mask load for the first 48 bytes of each vector
__ mov64(rax, 0x0000FFFFFFFFFFFF);
__ kmovql(k1, rax);
__ lea(perms, ExternalAddress(kyberAvx512_12To16DupAddr()));
__ evmovdqub(xmm20, Address(perms), Assembler::AVX_512bit);
__ lea(perms, ExternalAddress(kyberAvx512_12To16ShiftAddr()));
__ evmovdquw(xmm21, Address(perms), Assembler::AVX_512bit);
__ lea(perms, ExternalAddress(kyberAvx512_12To16AndAddr()));
__ evmovdquq(xmm22, Address(perms), Assembler::AVX_512bit);
__ align(OptoLoopAlignment);
__ BIND(VBMILoop);
__ evmovdqub(xmm0, k1, Address(condensed, 0), false,
Assembler::AVX_512bit);
__ evmovdqub(xmm1, k1, Address(condensed, 48), false,
Assembler::AVX_512bit);
__ evmovdqub(xmm2, k1, Address(condensed, 96), false,
Assembler::AVX_512bit);
__ evmovdqub(xmm3, k1, Address(condensed, 144), false,
Assembler::AVX_512bit);
__ evpermb(xmm4, k0, xmm20, xmm0, false, Assembler::AVX_512bit);
__ evpermb(xmm5, k0, xmm20, xmm1, false, Assembler::AVX_512bit);
__ evpermb(xmm6, k0, xmm20, xmm2, false, Assembler::AVX_512bit);
__ evpermb(xmm7, k0, xmm20, xmm3, false, Assembler::AVX_512bit);
__ evpsrlvw(xmm4, xmm4, xmm21, Assembler::AVX_512bit);
__ evpsrlvw(xmm5, xmm5, xmm21, Assembler::AVX_512bit);
__ evpsrlvw(xmm6, xmm6, xmm21, Assembler::AVX_512bit);
__ evpsrlvw(xmm7, xmm7, xmm21, Assembler::AVX_512bit);
__ evpandq(xmm0, xmm22, xmm4, Assembler::AVX_512bit);
__ evpandq(xmm1, xmm22, xmm5, Assembler::AVX_512bit);
__ evpandq(xmm2, xmm22, xmm6, Assembler::AVX_512bit);
__ evpandq(xmm3, xmm22, xmm7, Assembler::AVX_512bit);
store4regs(parsed, 0, xmm0_3, _masm);
__ addptr(condensed, 192);
__ addptr(parsed, 256);
__ subl(parsedLength, 128);
__ jcc(Assembler::greater, VBMILoop);
__ leave(); // required for proper stackwalking of RuntimeStub frame
__ mov64(rax, 0); // return 0
__ ret(0);
return start;
}
__ lea(perms, ExternalAddress(kyberAvx512_12To16PermsAddr()));
load4regs(xmm24_27, perms, 0, _masm);
load4regs(xmm28_31, perms, 256, _masm);
__ vpbroadcastq(xmm23,
ExternalAddress(kyberAvx512ConstsAddr(f00Offset)),
Assembler::AVX_512bit, scratch); // 0xF00
__ BIND(Loop);
__ evmovdqub(xmm0, Address(condensed, 0),Assembler::AVX_256bit);
__ evmovdqub(xmm1, Address(condensed, 32),Assembler::AVX_256bit);
__ evmovdqub(xmm2, Address(condensed, 64),Assembler::AVX_256bit);
__ evmovdqub(xmm8, Address(condensed, 96),Assembler::AVX_256bit);
__ evmovdqub(xmm9, Address(condensed, 128),Assembler::AVX_256bit);
__ evmovdqub(xmm10, Address(condensed, 160),Assembler::AVX_256bit);
__ vpmovzxbw(xmm0, xmm0, Assembler::AVX_512bit);
__ vpmovzxbw(xmm1, xmm1, Assembler::AVX_512bit);
__ vpmovzxbw(xmm2, xmm2, Assembler::AVX_512bit);
__ vpmovzxbw(xmm8, xmm8, Assembler::AVX_512bit);
__ vpmovzxbw(xmm9, xmm9, Assembler::AVX_512bit);
__ vpmovzxbw(xmm10, xmm10, Assembler::AVX_512bit);
__ evmovdquw(xmm3, xmm24, Assembler::AVX_512bit);
__ evmovdquw(xmm4, xmm25, Assembler::AVX_512bit);
__ evmovdquw(xmm5, xmm26, Assembler::AVX_512bit);
__ evmovdquw(xmm11, xmm24, Assembler::AVX_512bit);
__ evmovdquw(xmm12, xmm25, Assembler::AVX_512bit);
__ evmovdquw(xmm13, xmm26, Assembler::AVX_512bit);
__ evpermi2w(xmm3, xmm0, xmm1, Assembler::AVX_512bit);
__ evpermi2w(xmm4, xmm0, xmm1, Assembler::AVX_512bit);
__ evpermi2w(xmm5, xmm0, xmm1, Assembler::AVX_512bit);
__ evpermi2w(xmm11, xmm8, xmm9, Assembler::AVX_512bit);
__ evpermi2w(xmm12, xmm8, xmm9, Assembler::AVX_512bit);
__ evpermi2w(xmm13, xmm8, xmm9, Assembler::AVX_512bit);
__ evpermt2w(xmm3, xmm27, xmm2, Assembler::AVX_512bit);
__ evpermt2w(xmm4, xmm28, xmm2, Assembler::AVX_512bit);
__ evpermt2w(xmm5, xmm29, xmm2, Assembler::AVX_512bit);
__ evpermt2w(xmm11, xmm27, xmm10, Assembler::AVX_512bit);
__ evpermt2w(xmm12, xmm28, xmm10, Assembler::AVX_512bit);
__ evpermt2w(xmm13, xmm29, xmm10, Assembler::AVX_512bit);
__ evpsraw(xmm2, k0, xmm4, 4, false, Assembler::AVX_512bit);
__ evpsllw(xmm0, k0, xmm4, 8, false, Assembler::AVX_512bit);
__ evpsllw(xmm1, k0, xmm5, 4, false, Assembler::AVX_512bit);
__ evpsllw(xmm8, k0, xmm12, 8, false, Assembler::AVX_512bit);
__ evpsraw(xmm10, k0, xmm12, 4, false, Assembler::AVX_512bit);
__ evpsllw(xmm9, k0, xmm13, 4, false, Assembler::AVX_512bit);
__ evpandq(xmm0, k0, xmm0, xmm23, false, Assembler::AVX_512bit);
__ evpandq(xmm8, k0, xmm8, xmm23, false, Assembler::AVX_512bit);
__ evpaddw(xmm1, k0, xmm1, xmm2, false, Assembler::AVX_512bit);
__ evpaddw(xmm0, k0, xmm0, xmm3, false, Assembler::AVX_512bit);
__ evmovdquw(xmm2, xmm30, Assembler::AVX_512bit);
__ evpaddw(xmm9, k0, xmm9, xmm10, false, Assembler::AVX_512bit);
__ evpaddw(xmm8, k0, xmm8, xmm11, false, Assembler::AVX_512bit);
__ evmovdquw(xmm10, xmm30, Assembler::AVX_512bit);
__ evpermi2w(xmm2, xmm0, xmm1, Assembler::AVX_512bit);
__ evpermt2w(xmm0, xmm31, xmm1, Assembler::AVX_512bit);
__ evpermi2w(xmm10, xmm8, xmm9, Assembler::AVX_512bit);
__ evpermt2w(xmm8, xmm31, xmm9, Assembler::AVX_512bit);
store4regs(parsed, 0, xmm2_0_10_8, _masm);
__ addptr(condensed, 192);
__ addptr(parsed, 256);
__ subl(parsedLength, 128);
__ jcc(Assembler::greater, Loop);
__ leave(); // required for proper stackwalking of RuntimeStub frame
__ mov64(rax, 0); // return 0
__ ret(0);
return start;
}
// Kyber barrett reduce function.
//
// coeffs (short[256]) = c_rarg0
address generate_kyberBarrettReduce_avx512(StubGenerator *stubgen,
MacroAssembler *_masm) {
__ align(CodeEntryAlignment);
StubId stub_id = StubId::stubgen_kyberBarrettReduce_id;
StubCodeMark mark(stubgen, stub_id);
address start = __ pc();
__ enter();
const Register coeffs = c_rarg0;
__ vpbroadcastq(xmm16,
ExternalAddress(kyberAvx512ConstsAddr(barretMultiplierOffset)),
Assembler::AVX_512bit, scratch); // Barrett multiplier
__ vpbroadcastq(xmm17,
ExternalAddress(kyberAvx512ConstsAddr(qOffset)),
Assembler::AVX_512bit, scratch); // q
load4regs(xmm0_3, coeffs, 0, _masm);
load4regs(xmm4_7, coeffs, 256, _masm);
barrettReduce(_masm);
store4regs(coeffs, 0, xmm0_3, _masm);
store4regs(coeffs, 256, xmm4_7, _masm);
__ leave(); // required for proper stackwalking of RuntimeStub frame
__ mov64(rax, 0); // return 0
__ ret(0);
return start;
}
void StubGenerator::generate_kyber_stubs() {
// Generate Kyber intrinsics code
if (UseKyberIntrinsics) {
if (VM_Version::supports_evex()) {
StubRoutines::_kyberNtt = generate_kyberNtt_avx512(this, _masm);
StubRoutines::_kyberInverseNtt = generate_kyberInverseNtt_avx512(this, _masm);
StubRoutines::_kyberNttMult = generate_kyberNttMult_avx512(this, _masm);
StubRoutines::_kyberAddPoly_2 = generate_kyberAddPoly_2_avx512(this, _masm);
StubRoutines::_kyberAddPoly_3 = generate_kyberAddPoly_3_avx512(this, _masm);
StubRoutines::_kyber12To16 = generate_kyber12To16_avx512(this, _masm);
StubRoutines::_kyberBarrettReduce = generate_kyberBarrettReduce_avx512(this, _masm);
}
}
}