mirror of
https://github.com/openjdk/jdk.git
synced 2026-05-16 08:29:34 +00:00
8372044: Implementation Review Based on ML-KEM Security Considerations
Reviewed-by: weijun
This commit is contained in:
parent
1ec74dca70
commit
bcbde75d4e
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 2024, 2026, Oracle and/or its affiliates. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
@ -527,6 +527,8 @@ public final class ML_KEM {
|
||||
} catch (DigestException e) {
|
||||
// This should never happen.
|
||||
throw new RuntimeException(e);
|
||||
} finally {
|
||||
mlKemH.reset();
|
||||
}
|
||||
// The 2nd 32-byte `z` is copied into decapsKey
|
||||
System.arraycopy(kem_d_z, 32, decapsKey,
|
||||
@ -562,7 +564,6 @@ public final class ML_KEM {
|
||||
var randomCoins = Arrays.copyOfRange(kHatAndRandomCoins, 32, 64);
|
||||
var cipherText = kPkeEncrypt(new K_PKE_EncryptionKey(encapsulationKey.keyBytes),
|
||||
randomMessage, randomCoins);
|
||||
Arrays.fill(randomCoins, (byte) 0);
|
||||
byte[] sharedSecret = Arrays.copyOfRange(kHatAndRandomCoins, 0, 32);
|
||||
Arrays.fill(kHatAndRandomCoins, (byte) 0);
|
||||
|
||||
@ -613,6 +614,7 @@ public final class ML_KEM {
|
||||
var fakeResult = mlKemJ.digest();
|
||||
var computedCipherText = kPkeEncrypt(
|
||||
new K_PKE_EncryptionKey(encapsKeyBytes), mCandidate, coins);
|
||||
Arrays.fill(mCandidate, (byte)0);
|
||||
|
||||
// The rest of this method implements the following in constant time
|
||||
//
|
||||
@ -648,9 +650,11 @@ public final class ML_KEM {
|
||||
|
||||
MessageDigest mlKemG;
|
||||
SHAKE256 mlKemJ;
|
||||
int cbdInputLen = 64 * mlKem_eta1;
|
||||
byte[] cbdInput = new byte[cbdInputLen];
|
||||
try {
|
||||
mlKemG = MessageDigest.getInstance(HASH_G_NAME);
|
||||
mlKemJ = new SHAKE256(64 * mlKem_eta1);
|
||||
mlKemJ = new SHAKE256(cbdInputLen);
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
// This should never happen.
|
||||
throw new RuntimeException(e);
|
||||
@ -671,22 +675,26 @@ public final class ML_KEM {
|
||||
int keyGenN = 0;
|
||||
byte[] prfSeed = new byte[sigma.length + 1];
|
||||
System.arraycopy(sigma, 0, prfSeed, 0, sigma.length);
|
||||
byte[] cbdInput;
|
||||
short[][] keyGenS = new short[mlKem_k][];
|
||||
short[][] keyGenE = new short[mlKem_k][];
|
||||
for (int i = 0; i < mlKem_k; i++) {
|
||||
prfSeed[sigma.length] = (byte) (keyGenN++);
|
||||
mlKemJ.update(prfSeed);
|
||||
cbdInput = mlKemJ.digest();
|
||||
keyGenS[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput);
|
||||
}
|
||||
for (int i = 0; i < mlKem_k; i++) {
|
||||
prfSeed[sigma.length] = (byte) (keyGenN++);
|
||||
mlKemJ.update(prfSeed);
|
||||
cbdInput = mlKemJ.digest();
|
||||
keyGenE[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput);
|
||||
try {
|
||||
for (int i = 0; i < mlKem_k; i++) {
|
||||
prfSeed[sigma.length] = (byte) (keyGenN++);
|
||||
mlKemJ.update(prfSeed);
|
||||
mlKemJ.digest(cbdInput, 0, cbdInputLen);
|
||||
keyGenS[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput);
|
||||
}
|
||||
for (int i = 0; i < mlKem_k; i++) {
|
||||
prfSeed[sigma.length] = (byte) (keyGenN++);
|
||||
mlKemJ.update(prfSeed);
|
||||
mlKemJ.digest(cbdInput, 0, cbdInputLen);
|
||||
keyGenE[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput);
|
||||
}
|
||||
} catch (DigestException e) {
|
||||
throw new ProviderException("Internal error", e);
|
||||
}
|
||||
Arrays.fill(sigma, (byte)0);
|
||||
Arrays.fill(cbdInput, (byte)0);
|
||||
|
||||
short[][] keyGenSHat = mlKemVectorNTT(keyGenS);
|
||||
mlKemVectorReduce(keyGenSHat);
|
||||
@ -700,7 +708,6 @@ public final class ML_KEM {
|
||||
for (int i = 0; i < mlKem_k; i++) {
|
||||
encodePoly12(keyGenTHat[i], pkEncoded, i * ((ML_KEM_N * 12) / 8));
|
||||
encodePoly12(keyGenSHat[i], skEncoded, i * ((ML_KEM_N * 12) / 8));
|
||||
Arrays.fill(keyGenEHat[i], (short) 0);
|
||||
Arrays.fill(keyGenSHat[i], (short) 0);
|
||||
}
|
||||
System.arraycopy(rho, 0,
|
||||
@ -723,39 +730,61 @@ public final class ML_KEM {
|
||||
var encryptA = generateA(rho, true);
|
||||
short[][] encryptR = new short[mlKem_k][];
|
||||
short[][] encryptE1 = new short[mlKem_k][];
|
||||
short[] encryptE2;
|
||||
int encryptN = 0;
|
||||
byte[] prfSeed = new byte[sigma.length + 1];
|
||||
System.arraycopy(sigma, 0, prfSeed, 0, sigma.length);
|
||||
Arrays.fill(sigma, (byte)0);
|
||||
|
||||
var kPkePRFeta1 = new SHAKE256(64 * mlKem_eta1);
|
||||
var kPkePRFeta2 = new SHAKE256(64 * mlKem_eta2);
|
||||
for (int i = 0; i < mlKem_k; i++) {
|
||||
prfSeed[sigma.length] = (byte) (encryptN++);
|
||||
kPkePRFeta1.update(prfSeed);
|
||||
byte[] cbdInput = kPkePRFeta1.digest();
|
||||
encryptR[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput);
|
||||
}
|
||||
for (int i = 0; i < mlKem_k; i++) {
|
||||
prfSeed[sigma.length] = (byte) (encryptN++);
|
||||
int cbdInput1Len = 64 * mlKem_eta1;
|
||||
var kPkePRFeta1 = new SHAKE256(cbdInput1Len);
|
||||
byte[] cbdInput1 = new byte[cbdInput1Len];
|
||||
int cbdInput2Len = 64 * mlKem_eta2;
|
||||
var kPkePRFeta2 = new SHAKE256(cbdInput2Len);
|
||||
byte[] cbdInput2 = new byte[cbdInput2Len];
|
||||
try {
|
||||
for (int i = 0; i < mlKem_k; i++) {
|
||||
prfSeed[sigma.length] = (byte) (encryptN++);
|
||||
kPkePRFeta1.update(prfSeed);
|
||||
kPkePRFeta1.digest(cbdInput1, 0, cbdInput1Len);
|
||||
encryptR[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput1);
|
||||
}
|
||||
for (int i = 0; i < mlKem_k; i++) {
|
||||
prfSeed[sigma.length] = (byte) (encryptN++);
|
||||
kPkePRFeta2.update(prfSeed);
|
||||
kPkePRFeta2.digest(cbdInput2, 0, cbdInput2Len);
|
||||
encryptE1[i] = centeredBinomialDistribution(mlKem_eta2, cbdInput2);
|
||||
}
|
||||
prfSeed[sigma.length] = (byte) encryptN;
|
||||
kPkePRFeta2.update(prfSeed);
|
||||
byte[] cbdInput = kPkePRFeta2.digest();
|
||||
encryptE1[i] = centeredBinomialDistribution(mlKem_eta2, cbdInput);
|
||||
kPkePRFeta2.digest(cbdInput2, 0, cbdInput2Len);
|
||||
encryptE2 = centeredBinomialDistribution(mlKem_eta2, cbdInput2);
|
||||
} catch (DigestException e) {
|
||||
throw new ProviderException("Internal error", e);
|
||||
} finally {
|
||||
kPkePRFeta1.reset();
|
||||
kPkePRFeta2.reset();
|
||||
Arrays.fill(prfSeed, (byte)0);
|
||||
Arrays.fill(cbdInput1, (byte)0);
|
||||
Arrays.fill(cbdInput2, (byte)0);
|
||||
}
|
||||
prfSeed[sigma.length] = (byte) encryptN;
|
||||
kPkePRFeta2.reset();
|
||||
kPkePRFeta2.update(prfSeed);
|
||||
byte[] cbdInput = kPkePRFeta2.digest();
|
||||
var encryptE2 = centeredBinomialDistribution(mlKem_eta2, cbdInput);
|
||||
|
||||
var encryptRHat = mlKemVectorNTT(encryptR);
|
||||
var encryptUHat = mlKemMatrixVectorMuladd(encryptA, encryptRHat, zeroes);
|
||||
var encryptU = mlKemVectorInverseNTT(encryptUHat);
|
||||
encryptU = mlKemAddVec(encryptU, encryptE1);
|
||||
|
||||
for (int i = 0; i < mlKem_k; i++) {
|
||||
Arrays.fill(encryptE1[i], (short)0);
|
||||
}
|
||||
|
||||
var encryptVHat = mlKemVectorScalarMult(encryptTHat, encryptRHat);
|
||||
var encryptV = mlKemInverseNTT(encryptVHat);
|
||||
encryptV = mlKemAddPoly(encryptV, encryptE2, decompressDecode(message));
|
||||
var encryptC1 = encodeVector(mlKem_du, compressVector10_11(encryptU, mlKem_du));
|
||||
var encryptC2 = encodePoly(mlKem_dv, compressPoly4_5(encryptV, mlKem_dv));
|
||||
Arrays.fill(encryptE2, (short)0);
|
||||
Arrays.fill(encryptV, (short)0);
|
||||
|
||||
byte[] result = new byte[encryptC1.length + encryptC2.length];
|
||||
System.arraycopy(encryptC1, 0,
|
||||
@ -783,9 +812,11 @@ public final class ML_KEM {
|
||||
Arrays.fill(decryptSHat[i], (short) 0);
|
||||
}
|
||||
decryptV = mlKemSubtractPoly(decryptV, decryptSU);
|
||||
var result = encodeCompress(decryptV);
|
||||
Arrays.fill(decryptSU, (short) 0);
|
||||
Arrays.fill(decryptV, (short) 0);
|
||||
|
||||
return encodeCompress(decryptV);
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@ -28,6 +28,7 @@ package sun.security.provider;
|
||||
import java.lang.invoke.MethodHandles;
|
||||
import java.lang.invoke.VarHandle;
|
||||
import java.nio.ByteOrder;
|
||||
import java.security.DigestException;
|
||||
import java.security.ProviderException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Objects;
|
||||
@ -481,6 +482,11 @@ public abstract class SHA3 extends DigestBase {
|
||||
return engineDigest();
|
||||
}
|
||||
|
||||
public int digest(byte[] out, int offs, int len)
|
||||
throws DigestException {
|
||||
return engineDigest(out, offs, len);
|
||||
}
|
||||
|
||||
public void squeeze(byte[] output, int offset, int numBytes) {
|
||||
implSqueeze(output, offset, numBytes);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user