mirror of
https://github.com/openjdk/jdk.git
synced 2026-01-28 12:09:14 +00:00
8347606: Optimize Java implementation of ML-DSA
Reviewed-by: weijun
This commit is contained in:
parent
1bded7188f
commit
10dcdf1b47
@ -541,21 +541,21 @@ public class ML_DSA {
|
||||
int[][][] keygenA = generateA(rho); //A is in NTT domain
|
||||
|
||||
//Sample S1 and S2
|
||||
int[][] s1 = new int[mlDsa_l][ML_DSA_N];
|
||||
int[][] s2 = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] s1 = integerMatrixAlloc(mlDsa_l, ML_DSA_N);
|
||||
int[][] s2 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
//hash is reset before being used in sampleS1S2
|
||||
sampleS1S2(s1, s2, hash, rhoPrime);
|
||||
|
||||
//Compute t and tr
|
||||
mlDsaVectorNtt(s1); //s1 now in NTT domain
|
||||
int[][] As1 = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] As1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
matrixVectorPointwiseMultiply(As1, keygenA, s1);
|
||||
mlDsaVectorInverseNtt(s1); //take s1 out of NTT domain
|
||||
|
||||
mlDsaVectorInverseNtt(As1);
|
||||
int[][] t = vectorAddPos(As1, s2);
|
||||
int[][] t0 = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] t1 = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] t0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
int[][] t1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
power2Round(t, t0, t1);
|
||||
|
||||
//Encode PK and SK
|
||||
@ -593,19 +593,19 @@ public class ML_DSA {
|
||||
hash.reset();
|
||||
|
||||
//Initialize vectors used in loop
|
||||
int[][] z = new int[mlDsa_l][ML_DSA_N];
|
||||
boolean[][] h = new boolean[mlDsa_k][ML_DSA_N];
|
||||
int[][] z = integerMatrixAlloc(mlDsa_l, ML_DSA_N);
|
||||
boolean[][] h = booleanMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
byte[] commitmentHash = new byte[lambda/4];
|
||||
int[][] y = new int[mlDsa_l][ML_DSA_N];
|
||||
int[][] yy = new int[mlDsa_l][ML_DSA_N];
|
||||
int[][] w = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] w0 = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] w1 = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] w_ct0 = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] y = integerMatrixAlloc(mlDsa_l, ML_DSA_N);
|
||||
int[][] yy = integerMatrixAlloc(mlDsa_l, ML_DSA_N);
|
||||
int[][] w = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
int[][] w0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
int[][] w1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
int[][] w_ct0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
int[] c = new int[ML_DSA_N];
|
||||
int[][] cs1 = new int[mlDsa_l][ML_DSA_N];
|
||||
int[][] cs2 = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] ct0 = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] cs1 = integerMatrixAlloc(mlDsa_l, ML_DSA_N);
|
||||
int[][] cs2 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
int[][] ct0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
|
||||
int kappa = 0;
|
||||
while (true) {
|
||||
@ -621,7 +621,6 @@ public class ML_DSA {
|
||||
matrixVectorPointwiseMultiply(w, aHat, y);
|
||||
mlDsaVectorInverseNtt(w); //w is now in normal domain
|
||||
decompose(w, w0, w1);
|
||||
//mlDsaVectorInverseNtt(y);
|
||||
|
||||
//Get commitment hash
|
||||
hash.update(mu);
|
||||
@ -693,13 +692,13 @@ public class ML_DSA {
|
||||
mlDsaVectorNtt(sig.response());
|
||||
|
||||
//Reconstruct signer's commitment
|
||||
int[][] aHatZ = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] aHatZ = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
matrixVectorPointwiseMultiply(aHatZ, aHat, sig.response());
|
||||
|
||||
int[][] t1Hat = vectorConstMul(1 << ML_DSA_D, pk.t1());
|
||||
mlDsaVectorNtt(t1Hat);
|
||||
|
||||
int[][] ct1 = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] ct1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
nttConstMultiply(ct1, cHat, t1Hat);
|
||||
|
||||
int[][] wApprox = vectorSub(aHatZ, ct1, true);
|
||||
@ -763,7 +762,7 @@ public class ML_DSA {
|
||||
//This is simpleBitUnpack from FIPS 204. Since it is only called on the
|
||||
//vector t1 we can optimize for that case
|
||||
public int[][] t1Unpack(byte[] v) {
|
||||
int[][] t1 = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] t1 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
for (int i = 0; i < mlDsa_k; i++) {
|
||||
for (int j = 0; j < ML_DSA_N / 4; j++) {
|
||||
int tOffset = j*4;
|
||||
@ -872,7 +871,7 @@ public class ML_DSA {
|
||||
}
|
||||
|
||||
private boolean[][] hintBitUnpack(byte[] y, int offset) {
|
||||
boolean[][] h = new boolean[mlDsa_k][ML_DSA_N];
|
||||
boolean[][] h = booleanMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
int idx = 0;
|
||||
for (int i = 0; i < mlDsa_k; i++) {
|
||||
int j = y[offset + omega + i];
|
||||
@ -956,18 +955,18 @@ public class ML_DSA {
|
||||
//Parse s1
|
||||
int start = A_SEED_LEN + K_LEN + TR_LEN;
|
||||
int end = start + (32 * mlDsa_l * s1s2CoeffSize);
|
||||
int[][] s1 = new int[mlDsa_l][ML_DSA_N];
|
||||
int[][] s1 = integerMatrixAlloc(mlDsa_l, ML_DSA_N);
|
||||
bitUnpack(s1, sk, start, mlDsa_l, eta, s1s2CoeffSize);
|
||||
|
||||
//Parse s2
|
||||
start = end;
|
||||
end += 32 * s1s2CoeffSize * mlDsa_k;
|
||||
int[][] s2 = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] s2 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
bitUnpack(s2, sk, start, mlDsa_k, eta, s1s2CoeffSize);
|
||||
|
||||
//Parse t0
|
||||
start = end;
|
||||
int[][] t0 = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] t0 = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
bitUnpack(t0, sk, start, mlDsa_k, 1 << 12, T0_COEFF_SIZE);
|
||||
|
||||
return new ML_DSA_PrivateKey(rho, k, tr, s1, s2, t0);
|
||||
@ -1002,7 +1001,7 @@ public class ML_DSA {
|
||||
//Decode z
|
||||
int start = cSize;
|
||||
int end = start + zSize;
|
||||
int[][] z = new int[mlDsa_l][ML_DSA_N];
|
||||
int[][] z = integerMatrixAlloc(mlDsa_l, ML_DSA_N);
|
||||
bitUnpack(z, sig, start, mlDsa_l, gamma1, gamma1Bits + 1);
|
||||
|
||||
//Decode h
|
||||
@ -1092,7 +1091,12 @@ public class ML_DSA {
|
||||
}
|
||||
|
||||
private int[][][] generateA(byte[] seed) {
|
||||
int[][][] a = new int[mlDsa_k][mlDsa_l][];
|
||||
|
||||
// Manually do multidimensional array initialization for performance
|
||||
int[][][] a = new int[mlDsa_k][][];
|
||||
for (int i = 0; i < mlDsa_k; i++) {
|
||||
a[i] = new int[mlDsa_l][];
|
||||
}
|
||||
|
||||
int nrPar = 2;
|
||||
int rhoLen = seed.length;
|
||||
@ -1270,8 +1274,8 @@ public class ML_DSA {
|
||||
}
|
||||
|
||||
private int[][] highBits(int[][] input) {
|
||||
int[][] lowPart = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] highPart = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] lowPart = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
int[][] highPart = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
decompose(input, lowPart, highPart);
|
||||
return highPart;
|
||||
}
|
||||
@ -1297,7 +1301,7 @@ public class ML_DSA {
|
||||
private int[][] useHint(boolean[][] h, int[][] r) {
|
||||
int m = (ML_DSA_Q - 1) / (2*gamma2);
|
||||
int[][] lowPart = r;
|
||||
int[][] highPart = new int[mlDsa_k][ML_DSA_N];
|
||||
int[][] highPart = integerMatrixAlloc(mlDsa_k, ML_DSA_N);
|
||||
decompose(r, lowPart, highPart);
|
||||
|
||||
for (int i = 0; i < mlDsa_k; i++) {
|
||||
@ -1482,7 +1486,7 @@ public class ML_DSA {
|
||||
}
|
||||
|
||||
private int[][] vectorConstMul(int c, int[][] vec) {
|
||||
int[][] res = new int[vec.length][vec[0].length];
|
||||
int[][] res = integerMatrixAlloc(vec.length, vec[0].length);
|
||||
for (int i = 0; i < vec.length; i++) {
|
||||
for (int j = 0; j < vec[0].length; j++) {
|
||||
res[i][j] = montMul(c, toMont(vec[i][j]));
|
||||
@ -1496,7 +1500,7 @@ public class ML_DSA {
|
||||
// The coefficients in the output will be nonnegative and less than MONT_Q
|
||||
int[][] vectorAddPos(int[][] vec1, int[][] vec2) {
|
||||
int dim = vec1.length;
|
||||
int[][] result = new int[dim][ML_DSA_N];
|
||||
int[][] result = integerMatrixAlloc(dim, ML_DSA_N);
|
||||
for (int i = 0; i < dim; i++) {
|
||||
for (int m = 0; m < ML_DSA_N; m++) {
|
||||
int r = vec1[i][m] + vec2[i][m]; // -2 * MONT_Q < r < 2 * MONT_Q
|
||||
@ -1568,4 +1572,22 @@ public class ML_DSA {
|
||||
static int toMont(int a) {
|
||||
return montMul(a, MONT_R_SQUARE_MOD_Q);
|
||||
}
|
||||
|
||||
// For multidimensional array initialization, manually allocating each entry is
|
||||
// faster than doing the entire initialization in one go
|
||||
static boolean[][] booleanMatrixAlloc(int first, int second) {
|
||||
boolean[][] res = new boolean[first][];
|
||||
for (int i = 0; i < first; i++) {
|
||||
res[i] = new boolean[second];
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static int[][] integerMatrixAlloc(int first, int second) {
|
||||
int[][] res = new int[first][];
|
||||
for (int i = 0; i < first; i++) {
|
||||
res[i] = new int[second];
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user