8347606: Optimize Java implementation of ML-DSA

Reviewed-by: weijun
This commit is contained in:
Ben Perez 2025-05-13 22:31:55 +00:00
parent 1bded7188f
commit 10dcdf1b47

View File

@ -541,21 +541,21 @@ public class ML_DSA {
int[][][] keygenA = generateA(rho); //A is in NTT domain
//Sample S1 and S2
int[][] s1 = new int[mlDsa_l][ML_DSA_N];
int[][] s2 = new int[mlDsa_k][ML_DSA_N];
int[][] s1 = integerMatrixAlloc(mlDsa_l, ML_DSA_N);
int[][] s2 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
//hash is reset before being used in sampleS1S2
sampleS1S2(s1, s2, hash, rhoPrime);
//Compute t and tr
mlDsaVectorNtt(s1); //s1 now in NTT domain
int[][] As1 = new int[mlDsa_k][ML_DSA_N];
int[][] As1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
matrixVectorPointwiseMultiply(As1, keygenA, s1);
mlDsaVectorInverseNtt(s1); //take s1 out of NTT domain
mlDsaVectorInverseNtt(As1);
int[][] t = vectorAddPos(As1, s2);
int[][] t0 = new int[mlDsa_k][ML_DSA_N];
int[][] t1 = new int[mlDsa_k][ML_DSA_N];
int[][] t0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
int[][] t1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
power2Round(t, t0, t1);
//Encode PK and SK
@ -593,19 +593,19 @@ public class ML_DSA {
hash.reset();
//Initialize vectors used in loop
int[][] z = new int[mlDsa_l][ML_DSA_N];
boolean[][] h = new boolean[mlDsa_k][ML_DSA_N];
int[][] z = integerMatrixAlloc(mlDsa_l, ML_DSA_N);
boolean[][] h = booleanMatrixAlloc(mlDsa_k, ML_DSA_N);
byte[] commitmentHash = new byte[lambda/4];
int[][] y = new int[mlDsa_l][ML_DSA_N];
int[][] yy = new int[mlDsa_l][ML_DSA_N];
int[][] w = new int[mlDsa_k][ML_DSA_N];
int[][] w0 = new int[mlDsa_k][ML_DSA_N];
int[][] w1 = new int[mlDsa_k][ML_DSA_N];
int[][] w_ct0 = new int[mlDsa_k][ML_DSA_N];
int[][] y = integerMatrixAlloc(mlDsa_l, ML_DSA_N);
int[][] yy = integerMatrixAlloc(mlDsa_l, ML_DSA_N);
int[][] w = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
int[][] w0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
int[][] w1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
int[][] w_ct0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
int[] c = new int[ML_DSA_N];
int[][] cs1 = new int[mlDsa_l][ML_DSA_N];
int[][] cs2 = new int[mlDsa_k][ML_DSA_N];
int[][] ct0 = new int[mlDsa_k][ML_DSA_N];
int[][] cs1 = integerMatrixAlloc(mlDsa_l, ML_DSA_N);
int[][] cs2 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
int[][] ct0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
int kappa = 0;
while (true) {
@ -621,7 +621,6 @@ public class ML_DSA {
matrixVectorPointwiseMultiply(w, aHat, y);
mlDsaVectorInverseNtt(w); //w is now in normal domain
decompose(w, w0, w1);
//mlDsaVectorInverseNtt(y);
//Get commitment hash
hash.update(mu);
@ -693,13 +692,13 @@ public class ML_DSA {
mlDsaVectorNtt(sig.response());
//Reconstruct signer's commitment
int[][] aHatZ = new int[mlDsa_k][ML_DSA_N];
int[][] aHatZ = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
matrixVectorPointwiseMultiply(aHatZ, aHat, sig.response());
int[][] t1Hat = vectorConstMul(1 << ML_DSA_D, pk.t1());
mlDsaVectorNtt(t1Hat);
int[][] ct1 = new int[mlDsa_k][ML_DSA_N];
int[][] ct1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
nttConstMultiply(ct1, cHat, t1Hat);
int[][] wApprox = vectorSub(aHatZ, ct1, true);
@ -763,7 +762,7 @@ public class ML_DSA {
//This is simpleBitUnpack from FIPS 204. Since it is only called on the
//vector t1 we can optimize for that case
public int[][] t1Unpack(byte[] v) {
int[][] t1 = new int[mlDsa_k][ML_DSA_N];
int[][] t1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
for (int i = 0; i < mlDsa_k; i++) {
for (int j = 0; j < ML_DSA_N / 4; j++) {
int tOffset = j*4;
@ -872,7 +871,7 @@ public class ML_DSA {
}
private boolean[][] hintBitUnpack(byte[] y, int offset) {
boolean[][] h = new boolean[mlDsa_k][ML_DSA_N];
boolean[][] h = booleanMatrixAlloc(mlDsa_k, ML_DSA_N);
int idx = 0;
for (int i = 0; i < mlDsa_k; i++) {
int j = y[offset + omega + i];
@ -956,18 +955,18 @@ public class ML_DSA {
//Parse s1
int start = A_SEED_LEN + K_LEN + TR_LEN;
int end = start + (32 * mlDsa_l * s1s2CoeffSize);
int[][] s1 = new int[mlDsa_l][ML_DSA_N];
int[][] s1 = integerMatrixAlloc(mlDsa_l, ML_DSA_N);
bitUnpack(s1, sk, start, mlDsa_l, eta, s1s2CoeffSize);
//Parse s2
start = end;
end += 32 * s1s2CoeffSize * mlDsa_k;
int[][] s2 = new int[mlDsa_k][ML_DSA_N];
int[][] s2 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
bitUnpack(s2, sk, start, mlDsa_k, eta, s1s2CoeffSize);
//Parse t0
start = end;
int[][] t0 = new int[mlDsa_k][ML_DSA_N];
int[][] t0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
bitUnpack(t0, sk, start, mlDsa_k, 1 << 12, T0_COEFF_SIZE);
return new ML_DSA_PrivateKey(rho, k, tr, s1, s2, t0);
@ -1002,7 +1001,7 @@ public class ML_DSA {
//Decode z
int start = cSize;
int end = start + zSize;
int[][] z = new int[mlDsa_l][ML_DSA_N];
int[][] z = integerMatrixAlloc(mlDsa_l, ML_DSA_N);
bitUnpack(z, sig, start, mlDsa_l, gamma1, gamma1Bits + 1);
//Decode h
@ -1092,7 +1091,12 @@ public class ML_DSA {
}
private int[][][] generateA(byte[] seed) {
int[][][] a = new int[mlDsa_k][mlDsa_l][];
// Manually do multidimensional array initialization for performance
int[][][] a = new int[mlDsa_k][][];
for (int i = 0; i < mlDsa_k; i++) {
a[i] = new int[mlDsa_l][];
}
int nrPar = 2;
int rhoLen = seed.length;
@ -1270,8 +1274,8 @@ public class ML_DSA {
}
private int[][] highBits(int[][] input) {
int[][] lowPart = new int[mlDsa_k][ML_DSA_N];
int[][] highPart = new int[mlDsa_k][ML_DSA_N];
int[][] lowPart = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
int[][] highPart = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
decompose(input, lowPart, highPart);
return highPart;
}
@ -1297,7 +1301,7 @@ public class ML_DSA {
private int[][] useHint(boolean[][] h, int[][] r) {
int m = (ML_DSA_Q - 1) / (2*gamma2);
int[][] lowPart = r;
int[][] highPart = new int[mlDsa_k][ML_DSA_N];
int[][] highPart = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
decompose(r, lowPart, highPart);
for (int i = 0; i < mlDsa_k; i++) {
@ -1482,7 +1486,7 @@ public class ML_DSA {
}
private int[][] vectorConstMul(int c, int[][] vec) {
int[][] res = new int[vec.length][vec[0].length];
int[][] res = integerMatrixAlloc(vec.length, vec[0].length);
for (int i = 0; i < vec.length; i++) {
for (int j = 0; j < vec[0].length; j++) {
res[i][j] = montMul(c, toMont(vec[i][j]));
@ -1496,7 +1500,7 @@ public class ML_DSA {
// The coefficients in the output will be nonnegative and less than MONT_Q
int[][] vectorAddPos(int[][] vec1, int[][] vec2) {
int dim = vec1.length;
int[][] result = new int[dim][ML_DSA_N];
int[][] result = integerMatrixAlloc(dim, ML_DSA_N);
for (int i = 0; i < dim; i++) {
for (int m = 0; m < ML_DSA_N; m++) {
int r = vec1[i][m] + vec2[i][m]; // -2 * MONT_Q < r < 2 * MONT_Q
@ -1568,4 +1572,22 @@ public class ML_DSA {
static int toMont(int a) {
return montMul(a, MONT_R_SQUARE_MOD_Q);
}
// For multidimensional array initialization, manually allocating each entry is
// faster than doing the entire initialization in one go
static boolean[][] booleanMatrixAlloc(int first, int second) {
boolean[][] res = new boolean[first][];
for (int i = 0; i < first; i++) {
res[i] = new boolean[second];
}
return res;
}
static int[][] integerMatrixAlloc(int first, int second) {
int[][] res = new int[first][];
for (int i = 0; i < first; i++) {
res[i] = new int[second];
}
return res;
}
}