From 10dcdf1b4738efc6b4deaf96f4d123aff4dab832 Mon Sep 17 00:00:00 2001 From: Ben Perez Date: Tue, 13 May 2025 22:31:55 +0000 Subject: [PATCH] 8347606: Optimize Java implementation of ML-DSA Reviewed-by: weijun --- .../classes/sun/security/provider/ML_DSA.java | 84 ++++++++++++------- 1 file changed, 53 insertions(+), 31 deletions(-) diff --git a/src/java.base/share/classes/sun/security/provider/ML_DSA.java b/src/java.base/share/classes/sun/security/provider/ML_DSA.java index ff25eb527ef..af64ef399a8 100644 --- a/src/java.base/share/classes/sun/security/provider/ML_DSA.java +++ b/src/java.base/share/classes/sun/security/provider/ML_DSA.java @@ -541,21 +541,21 @@ public class ML_DSA { int[][][] keygenA = generateA(rho); //A is in NTT domain //Sample S1 and S2 - int[][] s1 = new int[mlDsa_l][ML_DSA_N]; - int[][] s2 = new int[mlDsa_k][ML_DSA_N]; + int[][] s1 = integerMatrixAlloc(mlDsa_l, ML_DSA_N); + int[][] s2 = integerMatrixAlloc(mlDsa_k, ML_DSA_N); //hash is reset before being used in sampleS1S2 sampleS1S2(s1, s2, hash, rhoPrime); //Compute t and tr mlDsaVectorNtt(s1); //s1 now in NTT domain - int[][] As1 = new int[mlDsa_k][ML_DSA_N]; + int[][] As1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N); matrixVectorPointwiseMultiply(As1, keygenA, s1); mlDsaVectorInverseNtt(s1); //take s1 out of NTT domain mlDsaVectorInverseNtt(As1); int[][] t = vectorAddPos(As1, s2); - int[][] t0 = new int[mlDsa_k][ML_DSA_N]; - int[][] t1 = new int[mlDsa_k][ML_DSA_N]; + int[][] t0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N); + int[][] t1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N); power2Round(t, t0, t1); //Encode PK and SK @@ -593,19 +593,19 @@ public class ML_DSA { hash.reset(); //Initialize vectors used in loop - int[][] z = new int[mlDsa_l][ML_DSA_N]; - boolean[][] h = new boolean[mlDsa_k][ML_DSA_N]; + int[][] z = integerMatrixAlloc(mlDsa_l, ML_DSA_N); + boolean[][] h = booleanMatrixAlloc(mlDsa_k, ML_DSA_N); byte[] commitmentHash = new byte[lambda/4]; - int[][] y = new int[mlDsa_l][ML_DSA_N]; - int[][] yy = new int[mlDsa_l][ML_DSA_N]; - int[][] w = new int[mlDsa_k][ML_DSA_N]; - int[][] w0 = new int[mlDsa_k][ML_DSA_N]; - int[][] w1 = new int[mlDsa_k][ML_DSA_N]; - int[][] w_ct0 = new int[mlDsa_k][ML_DSA_N]; + int[][] y = integerMatrixAlloc(mlDsa_l, ML_DSA_N); + int[][] yy = integerMatrixAlloc(mlDsa_l, ML_DSA_N); + int[][] w = integerMatrixAlloc(mlDsa_k, ML_DSA_N); + int[][] w0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N); + int[][] w1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N); + int[][] w_ct0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N); int[] c = new int[ML_DSA_N]; - int[][] cs1 = new int[mlDsa_l][ML_DSA_N]; - int[][] cs2 = new int[mlDsa_k][ML_DSA_N]; - int[][] ct0 = new int[mlDsa_k][ML_DSA_N]; + int[][] cs1 = integerMatrixAlloc(mlDsa_l, ML_DSA_N); + int[][] cs2 = integerMatrixAlloc(mlDsa_k, ML_DSA_N); + int[][] ct0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N); int kappa = 0; while (true) { @@ -621,7 +621,6 @@ public class ML_DSA { matrixVectorPointwiseMultiply(w, aHat, y); mlDsaVectorInverseNtt(w); //w is now in normal domain decompose(w, w0, w1); - //mlDsaVectorInverseNtt(y); //Get commitment hash hash.update(mu); @@ -693,13 +692,13 @@ public class ML_DSA { mlDsaVectorNtt(sig.response()); //Reconstruct signer's commitment - int[][] aHatZ = new int[mlDsa_k][ML_DSA_N]; + int[][] aHatZ = integerMatrixAlloc(mlDsa_k, ML_DSA_N); matrixVectorPointwiseMultiply(aHatZ, aHat, sig.response()); int[][] t1Hat = vectorConstMul(1 << ML_DSA_D, pk.t1()); mlDsaVectorNtt(t1Hat); - int[][] ct1 = new int[mlDsa_k][ML_DSA_N]; + int[][] ct1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N); nttConstMultiply(ct1, cHat, t1Hat); int[][] wApprox = vectorSub(aHatZ, ct1, true); @@ -763,7 +762,7 @@ public class ML_DSA { //This is simpleBitUnpack from FIPS 204. Since it is only called on the //vector t1 we can optimize for that case public int[][] t1Unpack(byte[] v) { - int[][] t1 = new int[mlDsa_k][ML_DSA_N]; + int[][] t1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N); for (int i = 0; i < mlDsa_k; i++) { for (int j = 0; j < ML_DSA_N / 4; j++) { int tOffset = j*4; @@ -872,7 +871,7 @@ public class ML_DSA { } private boolean[][] hintBitUnpack(byte[] y, int offset) { - boolean[][] h = new boolean[mlDsa_k][ML_DSA_N]; + boolean[][] h = booleanMatrixAlloc(mlDsa_k, ML_DSA_N); int idx = 0; for (int i = 0; i < mlDsa_k; i++) { int j = y[offset + omega + i]; @@ -956,18 +955,18 @@ public class ML_DSA { //Parse s1 int start = A_SEED_LEN + K_LEN + TR_LEN; int end = start + (32 * mlDsa_l * s1s2CoeffSize); - int[][] s1 = new int[mlDsa_l][ML_DSA_N]; + int[][] s1 = integerMatrixAlloc(mlDsa_l, ML_DSA_N); bitUnpack(s1, sk, start, mlDsa_l, eta, s1s2CoeffSize); //Parse s2 start = end; end += 32 * s1s2CoeffSize * mlDsa_k; - int[][] s2 = new int[mlDsa_k][ML_DSA_N]; + int[][] s2 = integerMatrixAlloc(mlDsa_k, ML_DSA_N); bitUnpack(s2, sk, start, mlDsa_k, eta, s1s2CoeffSize); //Parse t0 start = end; - int[][] t0 = new int[mlDsa_k][ML_DSA_N]; + int[][] t0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N); bitUnpack(t0, sk, start, mlDsa_k, 1 << 12, T0_COEFF_SIZE); return new ML_DSA_PrivateKey(rho, k, tr, s1, s2, t0); @@ -1002,7 +1001,7 @@ public class ML_DSA { //Decode z int start = cSize; int end = start + zSize; - int[][] z = new int[mlDsa_l][ML_DSA_N]; + int[][] z = integerMatrixAlloc(mlDsa_l, ML_DSA_N); bitUnpack(z, sig, start, mlDsa_l, gamma1, gamma1Bits + 1); //Decode h @@ -1092,7 +1091,12 @@ public class ML_DSA { } private int[][][] generateA(byte[] seed) { - int[][][] a = new int[mlDsa_k][mlDsa_l][]; + + // Manually do multidimensional array initialization for performance + int[][][] a = new int[mlDsa_k][][]; + for (int i = 0; i < mlDsa_k; i++) { + a[i] = new int[mlDsa_l][]; + } int nrPar = 2; int rhoLen = seed.length; @@ -1270,8 +1274,8 @@ public class ML_DSA { } private int[][] highBits(int[][] input) { - int[][] lowPart = new int[mlDsa_k][ML_DSA_N]; - int[][] highPart = new int[mlDsa_k][ML_DSA_N]; + int[][] lowPart = integerMatrixAlloc(mlDsa_k, ML_DSA_N); + int[][] highPart = integerMatrixAlloc(mlDsa_k, ML_DSA_N); decompose(input, lowPart, highPart); return highPart; } @@ -1297,7 +1301,7 @@ public class ML_DSA { private int[][] useHint(boolean[][] h, int[][] r) { int m = (ML_DSA_Q - 1) / (2*gamma2); int[][] lowPart = r; - int[][] highPart = new int[mlDsa_k][ML_DSA_N]; + int[][] highPart = integerMatrixAlloc(mlDsa_k, ML_DSA_N); decompose(r, lowPart, highPart); for (int i = 0; i < mlDsa_k; i++) { @@ -1482,7 +1486,7 @@ public class ML_DSA { } private int[][] vectorConstMul(int c, int[][] vec) { - int[][] res = new int[vec.length][vec[0].length]; + int[][] res = integerMatrixAlloc(vec.length, vec[0].length); for (int i = 0; i < vec.length; i++) { for (int j = 0; j < vec[0].length; j++) { res[i][j] = montMul(c, toMont(vec[i][j])); @@ -1496,7 +1500,7 @@ public class ML_DSA { // The coefficients in the output will be nonnegative and less than MONT_Q int[][] vectorAddPos(int[][] vec1, int[][] vec2) { int dim = vec1.length; - int[][] result = new int[dim][ML_DSA_N]; + int[][] result = integerMatrixAlloc(dim, ML_DSA_N); for (int i = 0; i < dim; i++) { for (int m = 0; m < ML_DSA_N; m++) { int r = vec1[i][m] + vec2[i][m]; // -2 * MONT_Q < r < 2 * MONT_Q @@ -1568,4 +1572,22 @@ public class ML_DSA { static int toMont(int a) { return montMul(a, MONT_R_SQUARE_MOD_Q); } + + // For multidimensional array initialization, manually allocating each entry is + // faster than doing the entire initialization in one go + static boolean[][] booleanMatrixAlloc(int first, int second) { + boolean[][] res = new boolean[first][]; + for (int i = 0; i < first; i++) { + res[i] = new boolean[second]; + } + return res; + } + + static int[][] integerMatrixAlloc(int first, int second) { + int[][] res = new int[first][]; + for (int i = 0; i < first; i++) { + res[i] = new int[second]; + } + return res; + } }