8373059: Test sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java should pass on Aarch64

Reviewed-by: weijun, vpaprotski
This commit is contained in:
Ferenc Rakoczi 2025-12-12 16:04:56 +00:00 committed by Weijun Wang
parent a99f340e1b
commit 6ec36d348b
2 changed files with 63 additions and 29 deletions

View File

@ -1555,7 +1555,7 @@ public class ML_DSA {
return res;
}
// precondition: -2^31 * MONT_Q <= a, b < 2^31, -2^31 < a * b < 2^31 * MONT_Q
// precondition: -2^31 <= a, b < 2^31, -2^31 * MONT_Q <= a * b < 2^31 * MONT_Q
// computes a * b * 2^-32 mod MONT_Q
// the result is greater than -MONT_Q and less than MONT_Q
// See e.g. Algorithm 3 in https://eprint.iacr.org/2018/039.pdf

View File

@ -38,16 +38,17 @@ import java.util.HexFormat;
*/
/*
* @test
* @comment This test should be reenabled on aarch64
* @requires os.simpleArch == "x64"
* @library /test/lib
* @key randomness
* @modules java.base/sun.security.provider:+open
* @run main ML_DSA_Intrinsic_Test
*/
// To run manually: java --add-opens java.base/sun.security.provider=ALL-UNNAMED --add-exports java.base/sun.security.provider=ALL-UNNAMED
// -XX:+UnlockDiagnosticVMOptions -XX:+UseDilithiumIntrinsics test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java
// To run manually:
// java --add-opens java.base/sun.security.provider=ALL-UNNAMED
// --add-exports java.base/sun.security.provider=ALL-UNNAMED
// -XX:+UnlockDiagnosticVMOptions -XX:+UseDilithiumIntrinsics
// test/jdk/sun/security/provider/pqc/ML_DSA_Intrinsic_Test.java
public class ML_DSA_Intrinsic_Test {
public static void main(String[] args) throws Throwable {
@ -104,9 +105,10 @@ public class ML_DSA_Intrinsic_Test {
m.setAccessible(true);
MethodHandle inverseNttJava = lookup.unreflect(m);
// Hint: if test fails, you can hardcode the seed to make the test more reproducible
Random rnd = new Random();
long seed = rnd.nextLong();
// Hint: if a test fails, it prints the seed, so you can hardcode
// it here to reproduce the failure
rnd.setSeed(seed);
//Note: it might be useful to increase this number during development of new intrinsics
final int repeat = 10000;
@ -117,32 +119,49 @@ public class ML_DSA_Intrinsic_Test {
int[] prod3 = new int[ML_DSA_N];
int[] prod4 = new int[ML_DSA_N];
for (int i = 0; i < repeat; i++) {
// Hint: if test fails, you can hardcode the seed to make the test more reproducible:
// rnd.setSeed(seed);
testMult(prod1, prod2, coeffs1, coeffs2, mult, multJava, rnd, seed, i);
testMult(prod1, prod2, coeffs1, coeffs2,
mult, multJava, rnd, seed, i);
testMultConst(prod1, prod2, multConst, multConstJava, rnd, seed, i);
testDecompose(prod1, prod2, prod3, prod4, coeffs1, coeffs2, decompose, decomposeJava, rnd, seed, i);
testDecompose(prod1, prod2, prod3, prod4, coeffs1, coeffs2,
decompose, decomposeJava, rnd, seed, i);
testAlmostNtt(coeffs1, coeffs2, almostNtt, almostNttJava, rnd, seed, i);
testInverseNtt(coeffs1, coeffs2, inverseNtt, inverseNttJava, rnd, seed, i);
}
System.out.println("Fuzz Success");
}
private static final int ML_DSA_N = 256;
public static void testMult(int[] prod1, int[] prod2, int[] coeffs1, int[] coeffs2,
public static void testMult(int[] prod1, int[] prod2,
int[] coeffs1, int[] coeffs2,
MethodHandle mult, MethodHandle multJava, Random rnd,
long seed, int i) throws Throwable {
for (int j = 0; j<ML_DSA_N; j++) {
coeffs1[j] = rnd.nextInt();
coeffs2[j] = rnd.nextInt();
// This method is always called with arrays whose elements are between
// -ML_DSA_Q and ML_DSA_Q, so we only test for these here (although
// both versions work fine with array element sizes that satisfy the
// montMul() preconditions in sun.security.provider.ML_DSA.java
for (int j = 0; j < ML_DSA_N; j++) {
coeffs1[j] = rnd.nextInt(2 * ML_DSA_Q) - ML_DSA_Q;
coeffs2[j] = rnd.nextInt(2 * ML_DSA_Q) - ML_DSA_Q;
}
mult.invoke(prod1, coeffs1, coeffs2);
multJava.invoke(prod2, coeffs1, coeffs2);
if (!Arrays.equals(prod1, prod2)) {
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result mult mismatch: " + formatOf(prod1) + " != " + formatOf(prod2));
// The Java version and the intrinsic version should not produce
// the exact same result (although usually they do), it is enough
// if the corresponding array elements are congruent modulo ML_DSA_Q
boolean modQequal = true;
for (int j = 0; j < ML_DSA_N; j++) {
if (prod1[j] != prod2[j]) {
modQequal &= (((prod1[j] - prod2[j]) % ML_DSA_Q) == 0);
}
}
if (!modQequal) {
throw new RuntimeException("[Seed " + seed + "@" + i
+ "] Result mult mismatch: "
+ formatOf(prod1) + "\n != " + formatOf(prod2));
}
}
}
@ -150,26 +169,30 @@ public class ML_DSA_Intrinsic_Test {
MethodHandle multConst, MethodHandle multConstJava, Random rnd,
long seed, int i) throws Throwable {
for (int j = 0; j<ML_DSA_N; j++) {
for (int j = 0; j < ML_DSA_N; j++) {
prod1[j] = prod2[j] = rnd.nextInt();
}
// Per Algorithm 3 in https://eprint.iacr.org/2018/039.pdf, one of the inputs is bound, which prevents overflows
int dilithium_q = 8380417;
int c = rnd.nextInt(dilithium_q);
// Per Algorithm 3 in https://eprint.iacr.org/2018/039.pdf,
// one of the inputs is bound, which prevents overflows
int c = rnd.nextInt(ML_DSA_Q);
multConst.invoke(prod1, c);
multConstJava.invoke(prod2, c);
if (!Arrays.equals(prod1, prod2)) {
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result multConst mismatch: " + formatOf(prod1) + " != " + formatOf(prod2));
throw new RuntimeException("[Seed " + seed + "@" + i
+ "] Result multConst mismatch: "
+ formatOf(prod1) + " != " + formatOf(prod2));
}
}
public static void testDecompose(int[] low1, int[] high1, int[] low2, int[] high2, int[] coeffs1, int[] coeffs2,
public static void testDecompose(int[] low1, int[] high1, int[] low2,
int[] high2, int[] coeffs1, int[] coeffs2,
MethodHandle decompose, MethodHandle decomposeJava, Random rnd,
long seed, int i) throws Throwable {
for (int j = 0; j<ML_DSA_N; j++) {
for (int j = 0; j < ML_DSA_N; j++) {
coeffs1[j] = coeffs2[j] = rnd.nextInt();
}
int gamma2 = 95232;
@ -182,18 +205,22 @@ public class ML_DSA_Intrinsic_Test {
decomposeJava.invoke(coeffs2, low2, high2, 2 * gamma2, multiplier);
if (!Arrays.equals(low1, low2)) {
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result low mismatch: " + formatOf(low1) + " != " + formatOf(low2));
throw new RuntimeException("[Seed " + seed + "@" + i
+ "] Result low mismatch: "
+ formatOf(low1) + " != " + formatOf(low2));
}
if (!Arrays.equals(high1, high2)) {
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result high mismatch: " + formatOf(high1) + " != " + formatOf(high2));
throw new RuntimeException("[Seed " + seed + "@" + i
+ "] Result high mismatch: "
+ formatOf(high1) + " != " + formatOf(high2));
}
}
public static void testAlmostNtt(int[] coeffs1, int[] coeffs2,
MethodHandle almostNtt, MethodHandle almostNttJava, Random rnd,
long seed, int i) throws Throwable {
for (int j = 0; j<ML_DSA_N; j++) {
for (int j = 0; j < ML_DSA_N; j++) {
coeffs1[j] = coeffs2[j] = rnd.nextInt();
}
@ -201,14 +228,16 @@ public class ML_DSA_Intrinsic_Test {
almostNttJava.invoke(coeffs2);
if (!Arrays.equals(coeffs1, coeffs2)) {
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result AlmostNtt mismatch: " + formatOf(coeffs1) + " != " + formatOf(coeffs2));
throw new RuntimeException("[Seed " + seed + "@" + i
+"] Result AlmostNtt mismatch: "
+ formatOf(coeffs1) + " != " + formatOf(coeffs2));
}
}
public static void testInverseNtt(int[] coeffs1, int[] coeffs2,
MethodHandle inverseNtt, MethodHandle inverseNttJava, Random rnd,
long seed, int i) throws Throwable {
for (int j = 0; j<ML_DSA_N; j++) {
for (int j = 0; j < ML_DSA_N; j++) {
coeffs1[j] = coeffs2[j] = rnd.nextInt();
}
@ -216,7 +245,9 @@ public class ML_DSA_Intrinsic_Test {
inverseNttJava.invoke(coeffs2);
if (!Arrays.equals(coeffs1, coeffs2)) {
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result InverseNtt mismatch: " + formatOf(coeffs1) + " != " + formatOf(coeffs2));
throw new RuntimeException("[Seed " + seed+ "@" + i
+"] Result InverseNtt mismatch: "
+ formatOf(coeffs1) + " != " + formatOf(coeffs2));
}
}
@ -230,6 +261,9 @@ public class ML_DSA_Intrinsic_Test {
}
// Copied constants from sun.security.provider.ML_DSA
private static final int ML_DSA_N = 256;
private static final int ML_DSA_Q = 8380417;
private static final int[] MONT_ZETAS_FOR_VECTOR_INVERSE_NTT = new int[]{
-1976782, 846154, -1400424, -3937738, 1362209, 48306, -3919660, 554416,
3545687, -1612842, 976891, -183443, 2286327, 420899, 2235985, 2939036,