From a09b75abd00610df2b37f92a781af2cbb9a83675 Mon Sep 17 00:00:00 2001 From: Weijun Wang Date: Tue, 21 Dec 2010 17:35:47 +0800 Subject: [PATCH] 6996367: improve HandshakeHash Reviewed-by: xuelei --- .../sun/security/ssl/ClientHandshaker.java | 5 +- .../sun/security/ssl/HandshakeHash.java | 110 +++--------------- .../sun/security/ssl/ServerHandshaker.java | 3 +- 3 files changed, 16 insertions(+), 102 deletions(-) diff --git a/jdk/src/share/classes/sun/security/ssl/ClientHandshaker.java b/jdk/src/share/classes/sun/security/ssl/ClientHandshaker.java index af835db7899..81cdfd1e5b8 100644 --- a/jdk/src/share/classes/sun/security/ssl/ClientHandshaker.java +++ b/jdk/src/share/classes/sun/security/ssl/ClientHandshaker.java @@ -381,8 +381,7 @@ final class ClientHandshaker extends Handshaker { mesgVersion); } - handshakeHash.protocolDetermined( - mesgVersion.v >= ProtocolVersion.TLS12.v); + handshakeHash.protocolDetermined(mesgVersion); // Set protocolVersion and propagate to SSLSocket and the // Handshake streams @@ -1223,7 +1222,7 @@ final class ClientHandshaker extends Handshaker { // not follow the spec that HandshakeHash.reset() can be only be // called before protocolDetermined. // if (maxProtocolVersion.v < ProtocolVersion.TLS12.v) { - // handshakeHash.protocolDetermined(false); + // handshakeHash.protocolDetermined(maxProtocolVersion); // } // create the ClientHello message diff --git a/jdk/src/share/classes/sun/security/ssl/HandshakeHash.java b/jdk/src/share/classes/sun/security/ssl/HandshakeHash.java index 86ebb51970f..ffb8f322606 100644 --- a/jdk/src/share/classes/sun/security/ssl/HandshakeHash.java +++ b/jdk/src/share/classes/sun/security/ssl/HandshakeHash.java @@ -49,27 +49,27 @@ import java.util.Set; * * You need to obey these conventions when using this class: * - * 1. protocolDetermined(boolean isTLS12) should be called when the negotiated + * 1. protocolDetermined(version) should be called when the negotiated * protocol version is determined. * * 2. Before protocolDetermined() is called, only update(), reset(), * restrictCertificateVerifyAlgs(), setFinishedAlg(), and * setCertificateVerifyAlg() can be called. * - * 3. After protocolDetermined(*) is called. reset() cannot be called. + * 3. After protocolDetermined() is called, reset() cannot be called. * - * 4. After protocolDetermined(false) is called, getFinishedHash() and - * getCertificateVerifyHash() cannot be called. After protocolDetermined(true) - * is called, getMD5Clone() and getSHAClone() cannot be called. + * 4. After protocolDetermined() is called, if the version is pre-TLS 1.2, + * getFinishedHash() and getCertificateVerifyHash() cannot be called. Otherwise, + * getMD5Clone() and getSHAClone() cannot be called. * * 5. getMD5Clone() and getSHAClone() can only be called after - * protocolDetermined(false) is called. + * protocolDetermined() is called and version is pre-TLS 1.2. * * 6. getFinishedHash() and getCertificateVerifyHash() can only be called after - * all protocolDetermined(true), setCertificateVerifyAlg() and setFinishedAlg() - * have been called. If a CertificateVerify message is to be used, call - * setCertificateVerifyAlg() with the hash algorithm as the argument. - * Otherwise, you still must call setCertificateVerifyAlg(null) before + * all protocolDetermined(), setCertificateVerifyAlg() and setFinishedAlg() + * have been called and the version is TLS 1.2. If a CertificateVerify message + * is to be used, call setCertificateVerifyAlg() with the hash algorithm as the + * argument. Otherwise, you still must call setCertificateVerifyAlg(null) before * calculating any hash value. * * Suggestions: Call protocolDetermined(), restrictCertificateVerifyAlgs(), @@ -78,6 +78,7 @@ import java.util.Set; * Example: *
  * HandshakeHash hh = new HandshakeHash(...)
+ * hh.protocolDetermined(ProtocolVersion.TLS12);
  * hh.update(clientHelloBytes);
  * hh.setFinishedAlg("SHA-256");
  * hh.update(serverHelloBytes);
@@ -161,12 +162,12 @@ final class HandshakeHash {
     }
 
 
-    void protocolDetermined(boolean isTLS12) {
+    void protocolDetermined(ProtocolVersion pv) {
 
         // Do not set again, will ignore
         if (version != -1) return;
 
-        version = isTLS12 ? 2 : 1;
+        version = pv.compareTo(ProtocolVersion.TLS12) >= 0 ? 2 : 1;
         switch (version) {
             case 1:
                 // initiate md5, sha and call update on saved array
@@ -310,91 +311,6 @@ final class HandshakeHash {
             throw new Error("BAD");
         }
     }
-
-    ////////////////////////////////////////////////////////////////
-    // TEST
-    ////////////////////////////////////////////////////////////////
-
-    public static void main(String[] args) throws Exception {
-        Test t = new Test();
-        t.test(null, "SHA-256");
-        t.test("", "SHA-256");
-        t.test("SHA-1", "SHA-256");
-        t.test("SHA-256", "SHA-256");
-        t.test("SHA-384", "SHA-256");
-        t.test("SHA-512", "SHA-256");
-        t.testSame("sha", "SHA-1");
-        t.testSame("SHA", "SHA-1");
-        t.testSame("SHA1", "SHA-1");
-        t.testSame("SHA-1", "SHA-1");
-        t.testSame("SHA256", "SHA-256");
-        t.testSame("SHA-256", "SHA-256");
-    }
-
-    static class Test {
-        void update(HandshakeHash hh, String s) {
-            hh.update(s.getBytes(), 0, s.length());
-        }
-        static byte[] digest(String alg, String data) throws Exception {
-            return MessageDigest.getInstance(alg).digest(data.getBytes());
-        }
-        static void equals(byte[] b1, byte[] b2) {
-            if (!Arrays.equals(b1, b2)) {
-                throw new RuntimeException("Bad");
-            }
-        }
-        void testSame(String a, String a2) {
-            System.out.println("testSame: " + a + " " + a2);
-            if (!HandshakeHash.normalizeAlgName(a).equals(a2)) {
-                throw new RuntimeException("Bad");
-            }
-        }
-        /**
-         * Special convention: when it's certain that CV will not be used at the
-         * very beginning, use null as cvAlg. If known at a late stage, use "".
-         */
-        void test(String cvAlg, String finAlg) throws Exception {
-            System.out.println("test: " + cvAlg + " " + finAlg);
-            byte[] cv = null, f1, f2;
-            HandshakeHash hh = new HandshakeHash(true, true, null);
-            if (cvAlg == null) {
-                hh.setCertificateVerifyAlg(cvAlg);
-            }
-
-            update(hh, "ClientHello,");
-            hh.reset();
-            update(hh, "ClientHellov2,");
-            hh.setFinishedAlg(finAlg);
-
-            // Useless calls
-            hh.setFinishedAlg("SHA-1");
-            hh.setFinishedAlg("SHA-512");
-
-            update(hh, "More,");
-            if (cvAlg != null) {
-                if (cvAlg.isEmpty()) cvAlg = null;
-                hh.setCertificateVerifyAlg(cvAlg);
-            }
-
-            // Useless calls
-            hh.setCertificateVerifyAlg("SHA-1");
-            hh.setCertificateVerifyAlg(null);
-
-            hh.protocolDetermined(true);
-
-            if (cvAlg != null) {
-                cv = hh.getAllHandshakeMessages();
-                equals(cv, "ClientHellov2,More,".getBytes());
-            }
-
-            update(hh, "FIN1,");
-            f1 = hh.getFinishedHash();
-            equals(f1, digest(finAlg, "ClientHellov2,More,FIN1,"));
-            update(hh, "FIN2,");
-            f2 = hh.getFinishedHash();
-            equals(f2, digest(finAlg, "ClientHellov2,More,FIN1,FIN2,"));
-        }
-    }
 }
 
 /**
diff --git a/jdk/src/share/classes/sun/security/ssl/ServerHandshaker.java b/jdk/src/share/classes/sun/security/ssl/ServerHandshaker.java
index c393314257b..ff1c8d7d5c9 100644
--- a/jdk/src/share/classes/sun/security/ssl/ServerHandshaker.java
+++ b/jdk/src/share/classes/sun/security/ssl/ServerHandshaker.java
@@ -424,8 +424,7 @@ final class ServerHandshaker extends Handshaker {
                 " not enabled or not supported");
         }
 
-        handshakeHash.protocolDetermined(
-            selectedVersion.v >= ProtocolVersion.TLS12.v);
+        handshakeHash.protocolDetermined(selectedVersion);
         setVersion(selectedVersion);
 
         m1.protocolVersion = protocolVersion;