Tidy up handshake digest logic.
Use SSL_SESSION_get_digest instead of the lower level function where
applicable. Also, remove the failure case (Ivan Maidanski points out in
https://android-review.googlesource.com/c/337852/1/src/ssl/t1_enc.c that
this unreachable codepath is a memory leak) by passing in an SSL_CIPHER
to make it more locally obvious that other values are impossible.
Change-Id: Ie624049d47ab0d24f32b405390d6251c7343d7d6
Reviewed-on: https://boringssl-review.googlesource.com/19024
Commit-Queue: David Benjamin <davidben@google.com>
Commit-Queue: Steven Valdez <svaldez@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index aa1524b..dd09797 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -967,8 +967,8 @@
CBS_len(&session_id));
}
- const SSL_CIPHER *c = SSL_get_cipher_by_value(cipher_suite);
- if (c == NULL) {
+ const SSL_CIPHER *cipher = SSL_get_cipher_by_value(cipher_suite);
+ if (cipher == NULL) {
/* unknown cipher */
OPENSSL_PUT_ERROR(SSL, SSL_R_UNKNOWN_CIPHER_RETURNED);
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
@@ -978,10 +978,10 @@
/* The cipher must be allowed in the selected version and enabled. */
uint32_t mask_a, mask_k;
ssl_get_client_disabled(ssl, &mask_a, &mask_k);
- if ((c->algorithm_mkey & mask_k) || (c->algorithm_auth & mask_a) ||
- SSL_CIPHER_get_min_version(c) > ssl3_protocol_version(ssl) ||
- SSL_CIPHER_get_max_version(c) < ssl3_protocol_version(ssl) ||
- !sk_SSL_CIPHER_find(SSL_get_ciphers(ssl), NULL, c)) {
+ if ((cipher->algorithm_mkey & mask_k) || (cipher->algorithm_auth & mask_a) ||
+ SSL_CIPHER_get_min_version(cipher) > ssl3_protocol_version(ssl) ||
+ SSL_CIPHER_get_max_version(cipher) < ssl3_protocol_version(ssl) ||
+ !sk_SSL_CIPHER_find(SSL_get_ciphers(ssl), NULL, cipher)) {
OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_CIPHER_RETURNED);
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
return -1;
@@ -993,7 +993,7 @@
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
return -1;
}
- if (ssl->session->cipher != c) {
+ if (ssl->session->cipher != cipher) {
OPENSSL_PUT_ERROR(SSL, SSL_R_OLD_SESSION_CIPHER_NOT_RETURNED);
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
return -1;
@@ -1006,13 +1006,13 @@
return -1;
}
} else {
- hs->new_session->cipher = c;
+ hs->new_session->cipher = cipher;
}
- hs->new_cipher = c;
+ hs->new_cipher = cipher;
/* Now that the cipher is known, initialize the handshake hash and hash the
* ServerHello. */
- if (!hs->transcript.InitHash(ssl3_protocol_version(ssl), c->algorithm_prf) ||
+ if (!hs->transcript.InitHash(ssl3_protocol_version(ssl), hs->new_cipher) ||
!ssl_hash_message(hs, msg)) {
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
return -1;
diff --git a/ssl/handshake_server.cc b/ssl/handshake_server.cc
index 6cccff9..2d5b85e 100644
--- a/ssl/handshake_server.cc
+++ b/ssl/handshake_server.cc
@@ -917,8 +917,7 @@
/* Now that all parameters are known, initialize the handshake hash and hash
* the ClientHello. */
- if (!hs->transcript.InitHash(ssl3_protocol_version(ssl),
- hs->new_cipher->algorithm_prf) ||
+ if (!hs->transcript.InitHash(ssl3_protocol_version(ssl), hs->new_cipher) ||
!ssl_hash_message(hs, msg)) {
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
return -1;
diff --git a/ssl/internal.h b/ssl/internal.h
index e323049..931ac82 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -338,10 +338,10 @@
size_t *out_fixed_iv_len, const SSL_CIPHER *cipher,
uint16_t version, int is_dtls);
-/* ssl_get_handshake_digest returns the |EVP_MD| corresponding to
- * |algorithm_prf| and the |version|. */
-const EVP_MD *ssl_get_handshake_digest(uint32_t algorithm_prf,
- uint16_t version);
+/* ssl_get_handshake_digest returns the |EVP_MD| corresponding to |version| and
+ * |cipher|. */
+const EVP_MD *ssl_get_handshake_digest(uint16_t version,
+ const SSL_CIPHER *cipher);
/* ssl_create_cipher_list evaluates |rule_str| according to the ciphers in
* |ssl_method|. It sets |*out_cipher_list| to a newly-allocated
@@ -397,7 +397,7 @@
* the handshake transcript. Subsequent calls to |Update| will update the
* rolling hash. It returns one on success and zero on failure. It is an error
* to call this function after the handshake buffer is released. */
- bool InitHash(uint16_t version, int algorithm_prf);
+ bool InitHash(uint16_t version, const SSL_CIPHER *cipher);
const uint8_t *buffer_data() const {
return reinterpret_cast<const uint8_t *>(buffer_->data);
diff --git a/ssl/ssl_cipher.cc b/ssl/ssl_cipher.cc
index f1a215f..fbcabd5 100644
--- a/ssl/ssl_cipher.cc
+++ b/ssl/ssl_cipher.cc
@@ -742,9 +742,9 @@
return 1;
}
-const EVP_MD *ssl_get_handshake_digest(uint32_t algorithm_prf,
- uint16_t version) {
- switch (algorithm_prf) {
+const EVP_MD *ssl_get_handshake_digest(uint16_t version,
+ const SSL_CIPHER *cipher) {
+ switch (cipher->algorithm_prf) {
case SSL_HANDSHAKE_MAC_DEFAULT:
return version >= TLS1_2_VERSION ? EVP_sha256() : EVP_md5_sha1();
case SSL_HANDSHAKE_MAC_SHA256:
@@ -752,6 +752,7 @@
case SSL_HANDSHAKE_MAC_SHA384:
return EVP_sha384();
default:
+ assert(0);
return NULL;
}
}
diff --git a/ssl/ssl_session.cc b/ssl/ssl_session.cc
index 1830723..a1c21dc 100644
--- a/ssl/ssl_session.cc
+++ b/ssl/ssl_session.cc
@@ -368,8 +368,8 @@
}
const EVP_MD *SSL_SESSION_get_digest(const SSL_SESSION *session) {
- return ssl_get_handshake_digest(session->cipher->algorithm_prf,
- SSL_SESSION_protocol_version(session));
+ return ssl_get_handshake_digest(SSL_SESSION_protocol_version(session),
+ session->cipher);
}
int ssl_get_new_session(SSL_HANDSHAKE *hs, int is_server) {
diff --git a/ssl/ssl_transcript.cc b/ssl/ssl_transcript.cc
index 4a00d0f..2dfaf76 100644
--- a/ssl/ssl_transcript.cc
+++ b/ssl/ssl_transcript.cc
@@ -178,8 +178,8 @@
return true;
}
-bool SSLTranscript::InitHash(uint16_t version, int algorithm_prf) {
- const EVP_MD *md = ssl_get_handshake_digest(algorithm_prf, version);
+bool SSLTranscript::InitHash(uint16_t version, const SSL_CIPHER *cipher) {
+ const EVP_MD *md = ssl_get_handshake_digest(version, cipher);
/* To support SSL 3.0's Finished and CertificateVerify constructions,
* EVP_md5_sha1() is split into MD5 and SHA-1 halves. When SSL 3.0 is removed,
diff --git a/ssl/t1_enc.cc b/ssl/t1_enc.cc
index 2349df0..d4a6ee9 100644
--- a/ssl/t1_enc.cc
+++ b/ssl/t1_enc.cc
@@ -483,25 +483,19 @@
}
int SSL_generate_key_block(const SSL *ssl, uint8_t *out, size_t out_len) {
+ const SSL_SESSION *session = SSL_get_session(ssl);
if (ssl3_protocol_version(ssl) == SSL3_VERSION) {
- return ssl3_prf(out, out_len, SSL_get_session(ssl)->master_key,
- SSL_get_session(ssl)->master_key_length,
- TLS_MD_KEY_EXPANSION_CONST, TLS_MD_KEY_EXPANSION_CONST_SIZE,
- ssl->s3->server_random, SSL3_RANDOM_SIZE,
- ssl->s3->client_random, SSL3_RANDOM_SIZE);
+ return ssl3_prf(out, out_len, session->master_key,
+ session->master_key_length, TLS_MD_KEY_EXPANSION_CONST,
+ TLS_MD_KEY_EXPANSION_CONST_SIZE, ssl->s3->server_random,
+ SSL3_RANDOM_SIZE, ssl->s3->client_random, SSL3_RANDOM_SIZE);
}
- const EVP_MD *digest = ssl_get_handshake_digest(
- SSL_get_session(ssl)->cipher->algorithm_prf, ssl3_protocol_version(ssl));
- if (digest == NULL) {
- OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
- return 0;
- }
- return tls1_prf(digest, out, out_len, SSL_get_session(ssl)->master_key,
- SSL_get_session(ssl)->master_key_length,
- TLS_MD_KEY_EXPANSION_CONST, TLS_MD_KEY_EXPANSION_CONST_SIZE,
- ssl->s3->server_random, SSL3_RANDOM_SIZE,
- ssl->s3->client_random, SSL3_RANDOM_SIZE);
+ const EVP_MD *digest = SSL_SESSION_get_digest(session);
+ return tls1_prf(digest, out, out_len, session->master_key,
+ session->master_key_length, TLS_MD_KEY_EXPANSION_CONST,
+ TLS_MD_KEY_EXPANSION_CONST_SIZE, ssl->s3->server_random,
+ SSL3_RANDOM_SIZE, ssl->s3->client_random, SSL3_RANDOM_SIZE);
}
int SSL_export_keying_material(SSL *ssl, uint8_t *out, size_t out_len,
@@ -545,15 +539,11 @@
OPENSSL_memcpy(seed + 2 * SSL3_RANDOM_SIZE + 2, context, context_len);
}
- const EVP_MD *digest = ssl_get_handshake_digest(
- SSL_get_session(ssl)->cipher->algorithm_prf, ssl3_protocol_version(ssl));
- if (digest == NULL) {
- OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
- return 0;
- }
- int ret = tls1_prf(digest, out, out_len, SSL_get_session(ssl)->master_key,
- SSL_get_session(ssl)->master_key_length, label, label_len,
- seed, seed_len, NULL, 0);
+ const SSL_SESSION *session = SSL_get_session(ssl);
+ const EVP_MD *digest = SSL_SESSION_get_digest(session);
+ int ret = tls1_prf(digest, out, out_len, session->master_key,
+ session->master_key_length, label, label_len, seed,
+ seed_len, NULL, 0);
OPENSSL_free(seed);
return ret;
}
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index f744cf8..39e80be 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -33,8 +33,8 @@
namespace bssl {
static int init_key_schedule(SSL_HANDSHAKE *hs, uint16_t version,
- int algorithm_prf) {
- if (!hs->transcript.InitHash(version, algorithm_prf)) {
+ const SSL_CIPHER *cipher) {
+ if (!hs->transcript.InitHash(version, cipher)) {
return 0;
}
@@ -47,8 +47,7 @@
}
int tls13_init_key_schedule(SSL_HANDSHAKE *hs) {
- if (!init_key_schedule(hs, ssl3_protocol_version(hs->ssl),
- hs->new_cipher->algorithm_prf)) {
+ if (!init_key_schedule(hs, ssl3_protocol_version(hs->ssl), hs->new_cipher)) {
return 0;
}
@@ -59,7 +58,7 @@
int tls13_init_early_key_schedule(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
return init_key_schedule(hs, SSL_SESSION_protocol_version(ssl->session),
- ssl->session->cipher->algorithm_prf);
+ ssl->session->cipher);
}
int tls13_advance_key_schedule(SSL_HANDSHAKE *hs, const uint8_t *in,
@@ -243,9 +242,6 @@
"application traffic secret";
int tls13_rotate_traffic_key(SSL *ssl, enum evp_aead_direction_t direction) {
- const EVP_MD *digest = ssl_get_handshake_digest(
- SSL_get_session(ssl)->cipher->algorithm_prf, ssl3_protocol_version(ssl));
-
uint8_t *secret;
size_t secret_len;
if (direction == evp_aead_open) {
@@ -256,6 +252,7 @@
secret_len = ssl->s3->write_traffic_secret_len;
}
+ const EVP_MD *digest = SSL_SESSION_get_digest(SSL_get_session(ssl));
if (!hkdf_expand_label(secret, digest, secret, secret_len,
(const uint8_t *)kTLS13LabelApplicationTraffic,
strlen(kTLS13LabelApplicationTraffic), NULL, 0,
@@ -323,15 +320,14 @@
const char *label, size_t label_len,
const uint8_t *context, size_t context_len,
int use_context) {
- const EVP_MD *digest = ssl_get_handshake_digest(
- SSL_get_session(ssl)->cipher->algorithm_prf, ssl3_protocol_version(ssl));
-
const uint8_t *hash = NULL;
size_t hash_len = 0;
if (use_context) {
hash = context;
hash_len = context_len;
}
+
+ const EVP_MD *digest = SSL_SESSION_get_digest(SSL_get_session(ssl));
return hkdf_expand_label(out, digest, ssl->s3->exporter_secret,
ssl->s3->exporter_secret_len, (const uint8_t *)label,
label_len, hash, hash_len, out_len);