diff --git a/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java b/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java index 56a119893a7..6564f40545a 100644 --- a/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java +++ b/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2024, 2026, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -527,6 +527,8 @@ public final class ML_KEM { } catch (DigestException e) { // This should never happen. throw new RuntimeException(e); + } finally { + mlKemH.reset(); } // The 2nd 32-byte `z` is copied into decapsKey System.arraycopy(kem_d_z, 32, decapsKey, @@ -562,7 +564,6 @@ public final class ML_KEM { var randomCoins = Arrays.copyOfRange(kHatAndRandomCoins, 32, 64); var cipherText = kPkeEncrypt(new K_PKE_EncryptionKey(encapsulationKey.keyBytes), randomMessage, randomCoins); - Arrays.fill(randomCoins, (byte) 0); byte[] sharedSecret = Arrays.copyOfRange(kHatAndRandomCoins, 0, 32); Arrays.fill(kHatAndRandomCoins, (byte) 0); @@ -613,6 +614,7 @@ public final class ML_KEM { var fakeResult = mlKemJ.digest(); var computedCipherText = kPkeEncrypt( new K_PKE_EncryptionKey(encapsKeyBytes), mCandidate, coins); + Arrays.fill(mCandidate, (byte)0); // The rest of this method implements the following in constant time // @@ -648,9 +650,11 @@ public final class ML_KEM { MessageDigest mlKemG; SHAKE256 mlKemJ; + int cbdInputLen = 64 * mlKem_eta1; + byte[] cbdInput = new byte[cbdInputLen]; try { mlKemG = MessageDigest.getInstance(HASH_G_NAME); - mlKemJ = new SHAKE256(64 * mlKem_eta1); + mlKemJ = new SHAKE256(cbdInputLen); } catch (NoSuchAlgorithmException e) { // This should never happen. throw new RuntimeException(e); @@ -671,22 +675,26 @@ public final class ML_KEM { int keyGenN = 0; byte[] prfSeed = new byte[sigma.length + 1]; System.arraycopy(sigma, 0, prfSeed, 0, sigma.length); - byte[] cbdInput; short[][] keyGenS = new short[mlKem_k][]; short[][] keyGenE = new short[mlKem_k][]; - for (int i = 0; i < mlKem_k; i++) { - prfSeed[sigma.length] = (byte) (keyGenN++); - mlKemJ.update(prfSeed); - cbdInput = mlKemJ.digest(); - keyGenS[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput); - } - for (int i = 0; i < mlKem_k; i++) { - prfSeed[sigma.length] = (byte) (keyGenN++); - mlKemJ.update(prfSeed); - cbdInput = mlKemJ.digest(); - keyGenE[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput); + try { + for (int i = 0; i < mlKem_k; i++) { + prfSeed[sigma.length] = (byte) (keyGenN++); + mlKemJ.update(prfSeed); + mlKemJ.digest(cbdInput, 0, cbdInputLen); + keyGenS[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput); + } + for (int i = 0; i < mlKem_k; i++) { + prfSeed[sigma.length] = (byte) (keyGenN++); + mlKemJ.update(prfSeed); + mlKemJ.digest(cbdInput, 0, cbdInputLen); + keyGenE[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput); + } + } catch (DigestException e) { + throw new ProviderException("Internal error", e); } Arrays.fill(sigma, (byte)0); + Arrays.fill(cbdInput, (byte)0); short[][] keyGenSHat = mlKemVectorNTT(keyGenS); mlKemVectorReduce(keyGenSHat); @@ -700,7 +708,6 @@ public final class ML_KEM { for (int i = 0; i < mlKem_k; i++) { encodePoly12(keyGenTHat[i], pkEncoded, i * ((ML_KEM_N * 12) / 8)); encodePoly12(keyGenSHat[i], skEncoded, i * ((ML_KEM_N * 12) / 8)); - Arrays.fill(keyGenEHat[i], (short) 0); Arrays.fill(keyGenSHat[i], (short) 0); } System.arraycopy(rho, 0, @@ -723,39 +730,61 @@ public final class ML_KEM { var encryptA = generateA(rho, true); short[][] encryptR = new short[mlKem_k][]; short[][] encryptE1 = new short[mlKem_k][]; + short[] encryptE2; int encryptN = 0; byte[] prfSeed = new byte[sigma.length + 1]; System.arraycopy(sigma, 0, prfSeed, 0, sigma.length); + Arrays.fill(sigma, (byte)0); - var kPkePRFeta1 = new SHAKE256(64 * mlKem_eta1); - var kPkePRFeta2 = new SHAKE256(64 * mlKem_eta2); - for (int i = 0; i < mlKem_k; i++) { - prfSeed[sigma.length] = (byte) (encryptN++); - kPkePRFeta1.update(prfSeed); - byte[] cbdInput = kPkePRFeta1.digest(); - encryptR[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput); - } - for (int i = 0; i < mlKem_k; i++) { - prfSeed[sigma.length] = (byte) (encryptN++); + int cbdInput1Len = 64 * mlKem_eta1; + var kPkePRFeta1 = new SHAKE256(cbdInput1Len); + byte[] cbdInput1 = new byte[cbdInput1Len]; + int cbdInput2Len = 64 * mlKem_eta2; + var kPkePRFeta2 = new SHAKE256(cbdInput2Len); + byte[] cbdInput2 = new byte[cbdInput2Len]; + try { + for (int i = 0; i < mlKem_k; i++) { + prfSeed[sigma.length] = (byte) (encryptN++); + kPkePRFeta1.update(prfSeed); + kPkePRFeta1.digest(cbdInput1, 0, cbdInput1Len); + encryptR[i] = centeredBinomialDistribution(mlKem_eta1, cbdInput1); + } + for (int i = 0; i < mlKem_k; i++) { + prfSeed[sigma.length] = (byte) (encryptN++); + kPkePRFeta2.update(prfSeed); + kPkePRFeta2.digest(cbdInput2, 0, cbdInput2Len); + encryptE1[i] = centeredBinomialDistribution(mlKem_eta2, cbdInput2); + } + prfSeed[sigma.length] = (byte) encryptN; kPkePRFeta2.update(prfSeed); - byte[] cbdInput = kPkePRFeta2.digest(); - encryptE1[i] = centeredBinomialDistribution(mlKem_eta2, cbdInput); + kPkePRFeta2.digest(cbdInput2, 0, cbdInput2Len); + encryptE2 = centeredBinomialDistribution(mlKem_eta2, cbdInput2); + } catch (DigestException e) { + throw new ProviderException("Internal error", e); + } finally { + kPkePRFeta1.reset(); + kPkePRFeta2.reset(); + Arrays.fill(prfSeed, (byte)0); + Arrays.fill(cbdInput1, (byte)0); + Arrays.fill(cbdInput2, (byte)0); } - prfSeed[sigma.length] = (byte) encryptN; - kPkePRFeta2.reset(); - kPkePRFeta2.update(prfSeed); - byte[] cbdInput = kPkePRFeta2.digest(); - var encryptE2 = centeredBinomialDistribution(mlKem_eta2, cbdInput); var encryptRHat = mlKemVectorNTT(encryptR); var encryptUHat = mlKemMatrixVectorMuladd(encryptA, encryptRHat, zeroes); var encryptU = mlKemVectorInverseNTT(encryptUHat); encryptU = mlKemAddVec(encryptU, encryptE1); + + for (int i = 0; i < mlKem_k; i++) { + Arrays.fill(encryptE1[i], (short)0); + } + var encryptVHat = mlKemVectorScalarMult(encryptTHat, encryptRHat); var encryptV = mlKemInverseNTT(encryptVHat); encryptV = mlKemAddPoly(encryptV, encryptE2, decompressDecode(message)); var encryptC1 = encodeVector(mlKem_du, compressVector10_11(encryptU, mlKem_du)); var encryptC2 = encodePoly(mlKem_dv, compressPoly4_5(encryptV, mlKem_dv)); + Arrays.fill(encryptE2, (short)0); + Arrays.fill(encryptV, (short)0); byte[] result = new byte[encryptC1.length + encryptC2.length]; System.arraycopy(encryptC1, 0, @@ -783,9 +812,11 @@ public final class ML_KEM { Arrays.fill(decryptSHat[i], (short) 0); } decryptV = mlKemSubtractPoly(decryptV, decryptSU); + var result = encodeCompress(decryptV); Arrays.fill(decryptSU, (short) 0); + Arrays.fill(decryptV, (short) 0); - return encodeCompress(decryptV); + return result; } /* diff --git a/src/java.base/share/classes/sun/security/provider/SHA3.java b/src/java.base/share/classes/sun/security/provider/SHA3.java index 0578645c1cd..5eafface0d3 100644 --- a/src/java.base/share/classes/sun/security/provider/SHA3.java +++ b/src/java.base/share/classes/sun/security/provider/SHA3.java @@ -28,6 +28,7 @@ package sun.security.provider; import java.lang.invoke.MethodHandles; import java.lang.invoke.VarHandle; import java.nio.ByteOrder; +import java.security.DigestException; import java.security.ProviderException; import java.util.Arrays; import java.util.Objects; @@ -481,6 +482,11 @@ public abstract class SHA3 extends DigestBase { return engineDigest(); } + public int digest(byte[] out, int offs, int len) + throws DigestException { + return engineDigest(out, offs, len); + } + public void squeeze(byte[] output, int offset, int numBytes) { implSqueeze(output, offset, numBytes); }