8372044: Implementation Review Based on ML-KEM Security Considerations

Reviewed-by: weijun
This commit is contained in:
Ferenc Rakoczi 2026-05-12 13:07:36 +00:00 committed by Weijun Wang
parent 1ec74dca70
commit bcbde75d4e
2 changed files with 71 additions and 34 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2024, 2026, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -527,6 +527,8 @@ public final class ML_KEM {
} catch (DigestException e) {
// This should never happen.
throw new RuntimeException(e);
} finally {
mlKemH.reset();
}
// The 2nd 32-byte `z` is copied into decapsKey
System.arraycopy(kem_d_z, 32, decapsKey,
@ -562,7 +564,6 @@ public final class ML_KEM {
var randomCoins = Arrays.copyOfRange(kHatAndRandomCoins, 32, 64);
var cipherText = kPkeEncrypt(new K_PKE_EncryptionKey(encapsulationKey.keyBytes),
randomMessage, randomCoins);
Arrays.fill(randomCoins, (byte) 0);
byte[] sharedSecret = Arrays.copyOfRange(kHatAndRandomCoins, 0, 32);
Arrays.fill(kHatAndRandomCoins, (byte) 0);
@ -613,6 +614,7 @@ public final class ML_KEM {
var fakeResult = mlKemJ.digest();
var computedCipherText = kPkeEncrypt(
new K_PKE_EncryptionKey(encapsKeyBytes), mCandidate, coins);
Arrays.fill(mCandidate, (byte)0);
// The rest of this method implements the following in constant time
//
@ -648,9 +650,11 @@ public final class ML_KEM {
MessageDigest mlKemG;
SHAKE256 mlKemJ;
int cbdInputLen = 64 * mlKem_eta1;
byte[] cbdInput = new byte[cbdInputLen];
try {
mlKemG = MessageDigest.getInstance(HASH_G_NAME);
mlKemJ = new SHAKE256(64 * mlKem_eta1);
mlKemJ = new SHAKE256(cbdInputLen);
} catch (NoSuchAlgorithmException e) {
// This should never happen.
throw new RuntimeException(e);
@ -671,22 +675,26 @@ public final class ML_KEM {
int keyGenN = 0;
byte[] prfSeed = new byte[sigma.length + 1];
System.arraycopy(sigma, 0, prfSeed, 0, sigma.length);
byte[] cbdInput;
short[][] keyGenS = new short[mlKem_k][];
short[][] keyGenE = new short[mlKem_k][];
for (int i = 0; i < mlKem_k; i++) {
prfSeed[sigma.length] = (byte) (keyGenN++);
mlKemJ.update(prfSeed);
cbdInput = mlKemJ.digest();
keyGenS[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput);
}
for (int i = 0; i < mlKem_k; i++) {
prfSeed[sigma.length] = (byte) (keyGenN++);
mlKemJ.update(prfSeed);
cbdInput = mlKemJ.digest();
keyGenE[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput);
try {
for (int i = 0; i < mlKem_k; i++) {
prfSeed[sigma.length] = (byte) (keyGenN++);
mlKemJ.update(prfSeed);
mlKemJ.digest(cbdInput, 0, cbdInputLen);
keyGenS[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput);
}
for (int i = 0; i < mlKem_k; i++) {
prfSeed[sigma.length] = (byte) (keyGenN++);
mlKemJ.update(prfSeed);
mlKemJ.digest(cbdInput, 0, cbdInputLen);
keyGenE[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput);
}
} catch (DigestException e) {
throw new ProviderException("Internal error", e);
}
Arrays.fill(sigma, (byte)0);
Arrays.fill(cbdInput, (byte)0);
short[][] keyGenSHat = mlKemVectorNTT(keyGenS);
mlKemVectorReduce(keyGenSHat);
@ -700,7 +708,6 @@ public final class ML_KEM {
for (int i = 0; i < mlKem_k; i++) {
encodePoly12(keyGenTHat[i], pkEncoded, i * ((ML_KEM_N * 12) / 8));
encodePoly12(keyGenSHat[i], skEncoded, i * ((ML_KEM_N * 12) / 8));
Arrays.fill(keyGenEHat[i], (short) 0);
Arrays.fill(keyGenSHat[i], (short) 0);
}
System.arraycopy(rho, 0,
@ -723,39 +730,61 @@ public final class ML_KEM {
var encryptA = generateA(rho, true);
short[][] encryptR = new short[mlKem_k][];
short[][] encryptE1 = new short[mlKem_k][];
short[] encryptE2;
int encryptN = 0;
byte[] prfSeed = new byte[sigma.length + 1];
System.arraycopy(sigma, 0, prfSeed, 0, sigma.length);
Arrays.fill(sigma, (byte)0);
var kPkePRFeta1 = new SHAKE256(64 * mlKem_eta1);
var kPkePRFeta2 = new SHAKE256(64 * mlKem_eta2);
for (int i = 0; i < mlKem_k; i++) {
prfSeed[sigma.length] = (byte) (encryptN++);
kPkePRFeta1.update(prfSeed);
byte[] cbdInput = kPkePRFeta1.digest();
encryptR[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput);
}
for (int i = 0; i < mlKem_k; i++) {
prfSeed[sigma.length] = (byte) (encryptN++);
int cbdInput1Len = 64 * mlKem_eta1;
var kPkePRFeta1 = new SHAKE256(cbdInput1Len);
byte[] cbdInput1 = new byte[cbdInput1Len];
int cbdInput2Len = 64 * mlKem_eta2;
var kPkePRFeta2 = new SHAKE256(cbdInput2Len);
byte[] cbdInput2 = new byte[cbdInput2Len];
try {
for (int i = 0; i < mlKem_k; i++) {
prfSeed[sigma.length] = (byte) (encryptN++);
kPkePRFeta1.update(prfSeed);
kPkePRFeta1.digest(cbdInput1, 0, cbdInput1Len);
encryptR[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput1);
}
for (int i = 0; i < mlKem_k; i++) {
prfSeed[sigma.length] = (byte) (encryptN++);
kPkePRFeta2.update(prfSeed);
kPkePRFeta2.digest(cbdInput2, 0, cbdInput2Len);
encryptE1[i] = centeredBinomialDistribution(mlKem_eta2, cbdInput2);
}
prfSeed[sigma.length] = (byte) encryptN;
kPkePRFeta2.update(prfSeed);
byte[] cbdInput = kPkePRFeta2.digest();
encryptE1[i] = centeredBinomialDistribution(mlKem_eta2, cbdInput);
kPkePRFeta2.digest(cbdInput2, 0, cbdInput2Len);
encryptE2 = centeredBinomialDistribution(mlKem_eta2, cbdInput2);
} catch (DigestException e) {
throw new ProviderException("Internal error", e);
} finally {
kPkePRFeta1.reset();
kPkePRFeta2.reset();
Arrays.fill(prfSeed, (byte)0);
Arrays.fill(cbdInput1, (byte)0);
Arrays.fill(cbdInput2, (byte)0);
}
prfSeed[sigma.length] = (byte) encryptN;
kPkePRFeta2.reset();
kPkePRFeta2.update(prfSeed);
byte[] cbdInput = kPkePRFeta2.digest();
var encryptE2 = centeredBinomialDistribution(mlKem_eta2, cbdInput);
var encryptRHat = mlKemVectorNTT(encryptR);
var encryptUHat = mlKemMatrixVectorMuladd(encryptA, encryptRHat, zeroes);
var encryptU = mlKemVectorInverseNTT(encryptUHat);
encryptU = mlKemAddVec(encryptU, encryptE1);
for (int i = 0; i < mlKem_k; i++) {
Arrays.fill(encryptE1[i], (short)0);
}
var encryptVHat = mlKemVectorScalarMult(encryptTHat, encryptRHat);
var encryptV = mlKemInverseNTT(encryptVHat);
encryptV = mlKemAddPoly(encryptV, encryptE2, decompressDecode(message));
var encryptC1 = encodeVector(mlKem_du, compressVector10_11(encryptU, mlKem_du));
var encryptC2 = encodePoly(mlKem_dv, compressPoly4_5(encryptV, mlKem_dv));
Arrays.fill(encryptE2, (short)0);
Arrays.fill(encryptV, (short)0);
byte[] result = new byte[encryptC1.length + encryptC2.length];
System.arraycopy(encryptC1, 0,
@ -783,9 +812,11 @@ public final class ML_KEM {
Arrays.fill(decryptSHat[i], (short) 0);
}
decryptV = mlKemSubtractPoly(decryptV, decryptSU);
var result = encodeCompress(decryptV);
Arrays.fill(decryptSU, (short) 0);
Arrays.fill(decryptV, (short) 0);
return encodeCompress(decryptV);
return result;
}
/*

View File

@ -28,6 +28,7 @@ package sun.security.provider;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.nio.ByteOrder;
import java.security.DigestException;
import java.security.ProviderException;
import java.util.Arrays;
import java.util.Objects;
@ -481,6 +482,11 @@ public abstract class SHA3 extends DigestBase {
return engineDigest();
}
public int digest(byte[] out, int offs, int len)
throws DigestException {
return engineDigest(out, offs, len);
}
public void squeeze(byte[] output, int offset, int numBytes) {
implSqueeze(output, offset, numBytes);
}