Tidy up handshake digest logic.

Use SSL_SESSION_get_digest instead of the lower level function where
applicable. Also, remove the failure case (Ivan Maidanski points out in
https://android-review.googlesource.com/c/337852/1/src/ssl/t1_enc.c that
this unreachable codepath is a memory leak) by passing in an SSL_CIPHER
to make it more locally obvious that other values are impossible.

Change-Id: Ie624049d47ab0d24f32b405390d6251c7343d7d6
Reviewed-on: https://boringssl-review.googlesource.com/19024
Commit-Queue: David Benjamin <davidben@google.com>
Commit-Queue: Steven Valdez <svaldez@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index aa1524b..dd09797 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -967,8 +967,8 @@
                    CBS_len(&session_id));
   }
 
-  const SSL_CIPHER *c = SSL_get_cipher_by_value(cipher_suite);
-  if (c == NULL) {
+  const SSL_CIPHER *cipher = SSL_get_cipher_by_value(cipher_suite);
+  if (cipher == NULL) {
     /* unknown cipher */
     OPENSSL_PUT_ERROR(SSL, SSL_R_UNKNOWN_CIPHER_RETURNED);
     ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
@@ -978,10 +978,10 @@
   /* The cipher must be allowed in the selected version and enabled. */
   uint32_t mask_a, mask_k;
   ssl_get_client_disabled(ssl, &mask_a, &mask_k);
-  if ((c->algorithm_mkey & mask_k) || (c->algorithm_auth & mask_a) ||
-      SSL_CIPHER_get_min_version(c) > ssl3_protocol_version(ssl) ||
-      SSL_CIPHER_get_max_version(c) < ssl3_protocol_version(ssl) ||
-      !sk_SSL_CIPHER_find(SSL_get_ciphers(ssl), NULL, c)) {
+  if ((cipher->algorithm_mkey & mask_k) || (cipher->algorithm_auth & mask_a) ||
+      SSL_CIPHER_get_min_version(cipher) > ssl3_protocol_version(ssl) ||
+      SSL_CIPHER_get_max_version(cipher) < ssl3_protocol_version(ssl) ||
+      !sk_SSL_CIPHER_find(SSL_get_ciphers(ssl), NULL, cipher)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_CIPHER_RETURNED);
     ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
     return -1;
@@ -993,7 +993,7 @@
       ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
       return -1;
     }
-    if (ssl->session->cipher != c) {
+    if (ssl->session->cipher != cipher) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_OLD_SESSION_CIPHER_NOT_RETURNED);
       ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
       return -1;
@@ -1006,13 +1006,13 @@
       return -1;
     }
   } else {
-    hs->new_session->cipher = c;
+    hs->new_session->cipher = cipher;
   }
-  hs->new_cipher = c;
+  hs->new_cipher = cipher;
 
   /* Now that the cipher is known, initialize the handshake hash and hash the
    * ServerHello. */
-  if (!hs->transcript.InitHash(ssl3_protocol_version(ssl), c->algorithm_prf) ||
+  if (!hs->transcript.InitHash(ssl3_protocol_version(ssl), hs->new_cipher) ||
       !ssl_hash_message(hs, msg)) {
     ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
     return -1;
diff --git a/ssl/handshake_server.cc b/ssl/handshake_server.cc
index 6cccff9..2d5b85e 100644
--- a/ssl/handshake_server.cc
+++ b/ssl/handshake_server.cc
@@ -917,8 +917,7 @@
 
   /* Now that all parameters are known, initialize the handshake hash and hash
    * the ClientHello. */
-  if (!hs->transcript.InitHash(ssl3_protocol_version(ssl),
-                               hs->new_cipher->algorithm_prf) ||
+  if (!hs->transcript.InitHash(ssl3_protocol_version(ssl), hs->new_cipher) ||
       !ssl_hash_message(hs, msg)) {
     ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
     return -1;
diff --git a/ssl/internal.h b/ssl/internal.h
index e323049..931ac82 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -338,10 +338,10 @@
                             size_t *out_fixed_iv_len, const SSL_CIPHER *cipher,
                             uint16_t version, int is_dtls);
 
-/* ssl_get_handshake_digest returns the |EVP_MD| corresponding to
- * |algorithm_prf| and the |version|. */
-const EVP_MD *ssl_get_handshake_digest(uint32_t algorithm_prf,
-                                       uint16_t version);
+/* ssl_get_handshake_digest returns the |EVP_MD| corresponding to |version| and
+ * |cipher|. */
+const EVP_MD *ssl_get_handshake_digest(uint16_t version,
+                                       const SSL_CIPHER *cipher);
 
 /* ssl_create_cipher_list evaluates |rule_str| according to the ciphers in
  * |ssl_method|. It sets |*out_cipher_list| to a newly-allocated
@@ -397,7 +397,7 @@
    * the handshake transcript. Subsequent calls to |Update| will update the
    * rolling hash. It returns one on success and zero on failure. It is an error
    * to call this function after the handshake buffer is released. */
-  bool InitHash(uint16_t version, int algorithm_prf);
+  bool InitHash(uint16_t version, const SSL_CIPHER *cipher);
 
   const uint8_t *buffer_data() const {
     return reinterpret_cast<const uint8_t *>(buffer_->data);
diff --git a/ssl/ssl_cipher.cc b/ssl/ssl_cipher.cc
index f1a215f..fbcabd5 100644
--- a/ssl/ssl_cipher.cc
+++ b/ssl/ssl_cipher.cc
@@ -742,9 +742,9 @@
   return 1;
 }
 
-const EVP_MD *ssl_get_handshake_digest(uint32_t algorithm_prf,
-                                       uint16_t version) {
-  switch (algorithm_prf) {
+const EVP_MD *ssl_get_handshake_digest(uint16_t version,
+                                       const SSL_CIPHER *cipher) {
+  switch (cipher->algorithm_prf) {
     case SSL_HANDSHAKE_MAC_DEFAULT:
       return version >= TLS1_2_VERSION ? EVP_sha256() : EVP_md5_sha1();
     case SSL_HANDSHAKE_MAC_SHA256:
@@ -752,6 +752,7 @@
     case SSL_HANDSHAKE_MAC_SHA384:
       return EVP_sha384();
     default:
+      assert(0);
       return NULL;
   }
 }
diff --git a/ssl/ssl_session.cc b/ssl/ssl_session.cc
index 1830723..a1c21dc 100644
--- a/ssl/ssl_session.cc
+++ b/ssl/ssl_session.cc
@@ -368,8 +368,8 @@
 }
 
 const EVP_MD *SSL_SESSION_get_digest(const SSL_SESSION *session) {
-  return ssl_get_handshake_digest(session->cipher->algorithm_prf,
-                                  SSL_SESSION_protocol_version(session));
+  return ssl_get_handshake_digest(SSL_SESSION_protocol_version(session),
+                                  session->cipher);
 }
 
 int ssl_get_new_session(SSL_HANDSHAKE *hs, int is_server) {
diff --git a/ssl/ssl_transcript.cc b/ssl/ssl_transcript.cc
index 4a00d0f..2dfaf76 100644
--- a/ssl/ssl_transcript.cc
+++ b/ssl/ssl_transcript.cc
@@ -178,8 +178,8 @@
   return true;
 }
 
-bool SSLTranscript::InitHash(uint16_t version, int algorithm_prf) {
-  const EVP_MD *md = ssl_get_handshake_digest(algorithm_prf, version);
+bool SSLTranscript::InitHash(uint16_t version, const SSL_CIPHER *cipher) {
+  const EVP_MD *md = ssl_get_handshake_digest(version, cipher);
 
   /* To support SSL 3.0's Finished and CertificateVerify constructions,
    * EVP_md5_sha1() is split into MD5 and SHA-1 halves. When SSL 3.0 is removed,
diff --git a/ssl/t1_enc.cc b/ssl/t1_enc.cc
index 2349df0..d4a6ee9 100644
--- a/ssl/t1_enc.cc
+++ b/ssl/t1_enc.cc
@@ -483,25 +483,19 @@
 }
 
 int SSL_generate_key_block(const SSL *ssl, uint8_t *out, size_t out_len) {
+  const SSL_SESSION *session = SSL_get_session(ssl);
   if (ssl3_protocol_version(ssl) == SSL3_VERSION) {
-    return ssl3_prf(out, out_len, SSL_get_session(ssl)->master_key,
-                    SSL_get_session(ssl)->master_key_length,
-                    TLS_MD_KEY_EXPANSION_CONST, TLS_MD_KEY_EXPANSION_CONST_SIZE,
-                    ssl->s3->server_random, SSL3_RANDOM_SIZE,
-                    ssl->s3->client_random, SSL3_RANDOM_SIZE);
+    return ssl3_prf(out, out_len, session->master_key,
+                    session->master_key_length, TLS_MD_KEY_EXPANSION_CONST,
+                    TLS_MD_KEY_EXPANSION_CONST_SIZE, ssl->s3->server_random,
+                    SSL3_RANDOM_SIZE, ssl->s3->client_random, SSL3_RANDOM_SIZE);
   }
 
-  const EVP_MD *digest = ssl_get_handshake_digest(
-      SSL_get_session(ssl)->cipher->algorithm_prf, ssl3_protocol_version(ssl));
-  if (digest == NULL) {
-    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-    return 0;
-  }
-  return tls1_prf(digest, out, out_len, SSL_get_session(ssl)->master_key,
-                  SSL_get_session(ssl)->master_key_length,
-                  TLS_MD_KEY_EXPANSION_CONST, TLS_MD_KEY_EXPANSION_CONST_SIZE,
-                  ssl->s3->server_random, SSL3_RANDOM_SIZE,
-                  ssl->s3->client_random, SSL3_RANDOM_SIZE);
+  const EVP_MD *digest = SSL_SESSION_get_digest(session);
+  return tls1_prf(digest, out, out_len, session->master_key,
+                  session->master_key_length, TLS_MD_KEY_EXPANSION_CONST,
+                  TLS_MD_KEY_EXPANSION_CONST_SIZE, ssl->s3->server_random,
+                  SSL3_RANDOM_SIZE, ssl->s3->client_random, SSL3_RANDOM_SIZE);
 }
 
 int SSL_export_keying_material(SSL *ssl, uint8_t *out, size_t out_len,
@@ -545,15 +539,11 @@
     OPENSSL_memcpy(seed + 2 * SSL3_RANDOM_SIZE + 2, context, context_len);
   }
 
-  const EVP_MD *digest = ssl_get_handshake_digest(
-      SSL_get_session(ssl)->cipher->algorithm_prf, ssl3_protocol_version(ssl));
-  if (digest == NULL) {
-    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-    return 0;
-  }
-  int ret = tls1_prf(digest, out, out_len, SSL_get_session(ssl)->master_key,
-                     SSL_get_session(ssl)->master_key_length, label, label_len,
-                     seed, seed_len, NULL, 0);
+  const SSL_SESSION *session = SSL_get_session(ssl);
+  const EVP_MD *digest = SSL_SESSION_get_digest(session);
+  int ret = tls1_prf(digest, out, out_len, session->master_key,
+                     session->master_key_length, label, label_len, seed,
+                     seed_len, NULL, 0);
   OPENSSL_free(seed);
   return ret;
 }
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index f744cf8..39e80be 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -33,8 +33,8 @@
 namespace bssl {
 
 static int init_key_schedule(SSL_HANDSHAKE *hs, uint16_t version,
-                              int algorithm_prf) {
-  if (!hs->transcript.InitHash(version, algorithm_prf)) {
+                             const SSL_CIPHER *cipher) {
+  if (!hs->transcript.InitHash(version, cipher)) {
     return 0;
   }
 
@@ -47,8 +47,7 @@
 }
 
 int tls13_init_key_schedule(SSL_HANDSHAKE *hs) {
-  if (!init_key_schedule(hs, ssl3_protocol_version(hs->ssl),
-                         hs->new_cipher->algorithm_prf)) {
+  if (!init_key_schedule(hs, ssl3_protocol_version(hs->ssl), hs->new_cipher)) {
     return 0;
   }
 
@@ -59,7 +58,7 @@
 int tls13_init_early_key_schedule(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
   return init_key_schedule(hs, SSL_SESSION_protocol_version(ssl->session),
-                           ssl->session->cipher->algorithm_prf);
+                           ssl->session->cipher);
 }
 
 int tls13_advance_key_schedule(SSL_HANDSHAKE *hs, const uint8_t *in,
@@ -243,9 +242,6 @@
     "application traffic secret";
 
 int tls13_rotate_traffic_key(SSL *ssl, enum evp_aead_direction_t direction) {
-  const EVP_MD *digest = ssl_get_handshake_digest(
-      SSL_get_session(ssl)->cipher->algorithm_prf, ssl3_protocol_version(ssl));
-
   uint8_t *secret;
   size_t secret_len;
   if (direction == evp_aead_open) {
@@ -256,6 +252,7 @@
     secret_len = ssl->s3->write_traffic_secret_len;
   }
 
+  const EVP_MD *digest = SSL_SESSION_get_digest(SSL_get_session(ssl));
   if (!hkdf_expand_label(secret, digest, secret, secret_len,
                          (const uint8_t *)kTLS13LabelApplicationTraffic,
                          strlen(kTLS13LabelApplicationTraffic), NULL, 0,
@@ -323,15 +320,14 @@
                                  const char *label, size_t label_len,
                                  const uint8_t *context, size_t context_len,
                                  int use_context) {
-  const EVP_MD *digest = ssl_get_handshake_digest(
-      SSL_get_session(ssl)->cipher->algorithm_prf, ssl3_protocol_version(ssl));
-
   const uint8_t *hash = NULL;
   size_t hash_len = 0;
   if (use_context) {
     hash = context;
     hash_len = context_len;
   }
+
+  const EVP_MD *digest = SSL_SESSION_get_digest(SSL_get_session(ssl));
   return hkdf_expand_label(out, digest, ssl->s3->exporter_secret,
                            ssl->s3->exporter_secret_len, (const uint8_t *)label,
                            label_len, hash, hash_len, out_len);