Simplify handshake hash handling. Rather than support arbitrarily many handshake hashes in the general case (which the PRF logic assumes is capped at two), special-case the MD5/SHA1 two-hash combination and otherwise maintain a single rolling hash. Change-Id: Ide9475565b158f6839bb10b8b22f324f89399f92 Reviewed-on: https://boringssl-review.googlesource.com/5618 Reviewed-by: Adam Langley <agl@google.com>
diff --git a/include/openssl/ssl3.h b/include/openssl/ssl3.h index 7249f51..e07488d 100644 --- a/include/openssl/ssl3.h +++ b/include/openssl/ssl3.h
@@ -387,9 +387,13 @@ /* handshake_buffer, if non-NULL, contains the handshake transcript. */ BUF_MEM *handshake_buffer; - /* When set of handshake digests is determined, buffer is hashed and freed - * and MD_CTX-es for all required digests are stored in this array */ - EVP_MD_CTX **handshake_dgst; + /* handshake_hash, if initialized with an |EVP_MD|, maintains the handshake + * hash. For TLS 1.1 and below, it is the SHA-1 half. */ + EVP_MD_CTX handshake_hash; + /* handshake_md5, if initialized with an |EVP_MD|, maintains the MD5 half of + * the handshake hash for TLS 1.1 and below. */ + EVP_MD_CTX handshake_md5; + /* this is set whenerver we see a change_cipher_spec message come in when we * are not looking for one */ int change_cipher_spec;
diff --git a/ssl/internal.h b/ssl/internal.h index a898422..4acd301 100644 --- a/ssl/internal.h +++ b/ssl/internal.h
@@ -203,12 +203,9 @@ #define SSL_TLSV1_2 0x00000004L /* Bits for |algorithm_prf| (handshake digest). */ -#define SSL_HANDSHAKE_MAC_MD5 0x10 -#define SSL_HANDSHAKE_MAC_SHA 0x20 -#define SSL_HANDSHAKE_MAC_SHA256 0x40 -#define SSL_HANDSHAKE_MAC_SHA384 0x80 -#define SSL_HANDSHAKE_MAC_DEFAULT \ - (SSL_HANDSHAKE_MAC_MD5 | SSL_HANDSHAKE_MAC_SHA) +#define SSL_HANDSHAKE_MAC_DEFAULT 0x1 +#define SSL_HANDSHAKE_MAC_SHA256 0x2 +#define SSL_HANDSHAKE_MAC_SHA384 0x4 /* SSL_MAX_DIGEST is the number of digest types which exist. When adding a new * one, update the table in ssl_cipher.c. */ @@ -229,11 +226,11 @@ size_t *out_fixed_iv_len, const SSL_CIPHER *cipher, uint16_t version); -/* ssl_get_handshake_digest looks up the |i|th handshake digest type and sets - * |*out_mask| to the |SSL_HANDSHAKE_MAC_*| mask and |*out_md| to the - * |EVP_MD|. It returns one on successs and zero if |i| >= |SSL_MAX_DIGEST|. */ -int ssl_get_handshake_digest(uint32_t *out_mask, const EVP_MD **out_md, - size_t i); +/* ssl_get_handshake_digest returns the |EVP_MD| corresponding to + * |algorithm_prf|. It returns SHA-1 for |SSL_HANDSHAKE_DEFAULT|. The caller is + * responsible for maintaining the additional MD5 digest and switching to + * SHA-256 in TLS 1.2. */ +const EVP_MD *ssl_get_handshake_digest(uint32_t algorithm_prf); /* ssl_create_cipher_list evaluates |rule_str| according to the ciphers in * |ssl_method|. It sets |*out_cipher_list| to a newly-allocated
diff --git a/ssl/s3_enc.c b/ssl/s3_enc.c index f860609..ba9883b 100644 --- a/ssl/s3_enc.c +++ b/ssl/s3_enc.c
@@ -242,58 +242,39 @@ return ssl->s3->handshake_buffer != NULL; } -int ssl3_init_handshake_hash(SSL *ssl) { - int i; - uint32_t mask; - const EVP_MD *md; - - /* Allocate handshake_dgst array */ - ssl3_free_handshake_hash(ssl); - ssl->s3->handshake_dgst = OPENSSL_malloc(SSL_MAX_DIGEST * - sizeof(EVP_MD_CTX *)); - if (ssl->s3->handshake_dgst == NULL) { - OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE); +/* init_digest_with_data calls |EVP_DigestInit_ex| on |ctx| with |md| and then + * writes the data in |buf| to it. */ +static int init_digest_with_data(EVP_MD_CTX *ctx, const EVP_MD *md, + const BUF_MEM *buf) { + if (!EVP_DigestInit_ex(ctx, md, NULL)) { return 0; } - memset(ssl->s3->handshake_dgst, 0, SSL_MAX_DIGEST * sizeof(EVP_MD_CTX *)); + EVP_DigestUpdate(ctx, buf->data, buf->length); + return 1; +} - /* Loop through bits of algorithm_prf field and create MD_CTX-es */ - for (i = 0; ssl_get_handshake_digest(&mask, &md, i); i++) { - if ((mask & ssl_get_algorithm_prf(ssl)) && md) { - ssl->s3->handshake_dgst[i] = EVP_MD_CTX_create(); - if (ssl->s3->handshake_dgst[i] == NULL) { - OPENSSL_PUT_ERROR(SSL, ERR_LIB_EVP); - return 0; - } - if (!EVP_DigestInit_ex(ssl->s3->handshake_dgst[i], md, NULL)) { - EVP_MD_CTX_destroy(ssl->s3->handshake_dgst[i]); - ssl->s3->handshake_dgst[i] = NULL; - OPENSSL_PUT_ERROR(SSL, ERR_LIB_EVP); - return 0; - } - EVP_DigestUpdate(ssl->s3->handshake_dgst[i], - ssl->s3->handshake_buffer->data, - ssl->s3->handshake_buffer->length); - } else { - ssl->s3->handshake_dgst[i] = NULL; - } +int ssl3_init_handshake_hash(SSL *ssl) { + ssl3_free_handshake_hash(ssl); + + uint32_t algorithm_prf = ssl_get_algorithm_prf(ssl); + if (!init_digest_with_data(&ssl->s3->handshake_hash, + ssl_get_handshake_digest(algorithm_prf), + ssl->s3->handshake_buffer)) { + return 0; + } + + if (algorithm_prf == SSL_HANDSHAKE_MAC_DEFAULT && + !init_digest_with_data(&ssl->s3->handshake_md5, EVP_md5(), + ssl->s3->handshake_buffer)) { + return 0; } return 1; } void ssl3_free_handshake_hash(SSL *ssl) { - int i; - if (!ssl->s3->handshake_dgst) { - return; - } - for (i = 0; i < SSL_MAX_DIGEST; i++) { - if (ssl->s3->handshake_dgst[i]) { - EVP_MD_CTX_destroy(ssl->s3->handshake_dgst[i]); - } - } - OPENSSL_free(ssl->s3->handshake_dgst); - ssl->s3->handshake_dgst = NULL; + EVP_MD_CTX_cleanup(&ssl->s3->handshake_hash); + EVP_MD_CTX_cleanup(&ssl->s3->handshake_md5); } void ssl3_free_handshake_buffer(SSL *ssl) { @@ -317,13 +298,11 @@ memcpy(ssl->s3->handshake_buffer->data + new_len - in_len, in, in_len); } - if (ssl->s3->handshake_dgst != NULL) { - int i; - for (i = 0; i < SSL_MAX_DIGEST; i++) { - if (ssl->s3->handshake_dgst[i] != NULL) { - EVP_DigestUpdate(ssl->s3->handshake_dgst[i], in, in_len); - } - } + if (EVP_MD_CTX_md(&ssl->s3->handshake_hash) != NULL) { + EVP_DigestUpdate(&ssl->s3->handshake_hash, in, in_len); + } + if (EVP_MD_CTX_md(&ssl->s3->handshake_md5) != NULL) { + EVP_DigestUpdate(&ssl->s3->handshake_md5, in, in_len); } return 1; } @@ -356,24 +335,20 @@ int npad, n; unsigned int i; uint8_t md_buf[EVP_MAX_MD_SIZE]; - EVP_MD_CTX ctx, *d = NULL; + EVP_MD_CTX ctx; + const EVP_MD_CTX *ctx_template; - /* Search for digest of specified type in the handshake_dgst array. */ - for (i = 0; i < SSL_MAX_DIGEST; i++) { - if (s->s3->handshake_dgst[i] && - EVP_MD_CTX_type(s->s3->handshake_dgst[i]) == md_nid) { - d = s->s3->handshake_dgst[i]; - break; - } - } - - if (!d) { + if (md_nid == NID_md5) { + ctx_template = &s->s3->handshake_md5; + } else if (md_nid == EVP_MD_CTX_type(&s->s3->handshake_hash)) { + ctx_template = &s->s3->handshake_hash; + } else { OPENSSL_PUT_ERROR(SSL, SSL_R_NO_REQUIRED_DIGEST); return 0; } EVP_MD_CTX_init(&ctx); - if (!EVP_MD_CTX_copy_ex(&ctx, d)) { + if (!EVP_MD_CTX_copy_ex(&ctx, ctx_template)) { EVP_MD_CTX_cleanup(&ctx); OPENSSL_PUT_ERROR(SSL, ERR_LIB_EVP); return 0;
diff --git a/ssl/s3_lib.c b/ssl/s3_lib.c index 4ee68dd..2330fbd 100644 --- a/ssl/s3_lib.c +++ b/ssl/s3_lib.c
@@ -152,6 +152,7 @@ #include <openssl/buf.h> #include <openssl/dh.h> +#include <openssl/digest.h> #include <openssl/err.h> #include <openssl/md5.h> #include <openssl/mem.h> @@ -201,6 +202,9 @@ } memset(s3, 0, sizeof *s3); + EVP_MD_CTX_init(&s3->handshake_hash); + EVP_MD_CTX_init(&s3->handshake_md5); + s->s3 = s3; /* Set the version to the highest supported version for TLS. This controls the @@ -661,11 +665,10 @@ /* If we are using default SHA1+MD5 algorithms switch to new SHA256 PRF and * handshake macs if required. */ uint32_t ssl_get_algorithm_prf(SSL *s) { - static const uint32_t kMask = SSL_HANDSHAKE_MAC_DEFAULT; - uint32_t alg2 = s->s3->tmp.new_cipher->algorithm_prf; + uint32_t algorithm_prf = s->s3->tmp.new_cipher->algorithm_prf; if (s->enc_method->enc_flags & SSL_ENC_FLAG_SHA256_PRF && - (alg2 & kMask) == kMask) { + algorithm_prf == SSL_HANDSHAKE_MAC_DEFAULT) { return SSL_HANDSHAKE_MAC_SHA256; } - return alg2; + return algorithm_prf; }
diff --git a/ssl/ssl_cipher.c b/ssl/ssl_cipher.c index 29824aa..a715409 100644 --- a/ssl/ssl_cipher.c +++ b/ssl/ssl_cipher.c
@@ -469,18 +469,6 @@ static const size_t kCiphersLen = sizeof(kCiphers) / sizeof(kCiphers[0]); -struct handshake_digest { - uint32_t mask; - const EVP_MD *(*md_func)(void); -}; - -static const struct handshake_digest ssl_handshake_digests[SSL_MAX_DIGEST] = { - {SSL_HANDSHAKE_MAC_MD5, EVP_md5}, - {SSL_HANDSHAKE_MAC_SHA, EVP_sha1}, - {SSL_HANDSHAKE_MAC_SHA256, EVP_sha256}, - {SSL_HANDSHAKE_MAC_SHA384, EVP_sha384}, -}; - #define CIPHER_ADD 1 #define CIPHER_KILL 2 #define CIPHER_DEL 3 @@ -718,14 +706,17 @@ } } -int ssl_get_handshake_digest(uint32_t *out_mask, const EVP_MD **out_md, - size_t idx) { - if (idx >= SSL_MAX_DIGEST) { - return 0; +const EVP_MD *ssl_get_handshake_digest(uint32_t algorithm_prf) { + switch (algorithm_prf) { + case SSL_HANDSHAKE_MAC_DEFAULT: + return EVP_sha1(); + case SSL_HANDSHAKE_MAC_SHA256: + return EVP_sha256(); + case SSL_HANDSHAKE_MAC_SHA384: + return EVP_sha384(); + default: + return NULL; } - *out_mask = ssl_handshake_digests[idx].mask; - *out_md = ssl_handshake_digests[idx].md_func(); - return 1; } #define ITEM_SEP(a) \ @@ -1456,27 +1447,24 @@ } static const char *ssl_cipher_get_prf_name(const SSL_CIPHER *cipher) { - if ((cipher->algorithm_prf & SSL_HANDSHAKE_MAC_DEFAULT) == - SSL_HANDSHAKE_MAC_DEFAULT) { - /* Before TLS 1.2, the PRF component is the hash used in the HMAC, which is - * only ever MD5 or SHA-1. */ - switch (cipher->algorithm_mac) { - case SSL_MD5: - return "MD5"; - case SSL_SHA1: - return "SHA"; - default: - assert(0); - return "UNKNOWN"; - } - } else if (cipher->algorithm_prf & SSL_HANDSHAKE_MAC_SHA256) { - return "SHA256"; - } else if (cipher->algorithm_prf & SSL_HANDSHAKE_MAC_SHA384) { - return "SHA384"; - } else { - assert(0); - return "UNKNOWN"; + switch (cipher->algorithm_prf) { + case SSL_HANDSHAKE_MAC_DEFAULT: + /* Before TLS 1.2, the PRF component is the hash used in the HMAC, which is + * only ever MD5 or SHA-1. */ + switch (cipher->algorithm_mac) { + case SSL_MD5: + return "MD5"; + case SSL_SHA1: + return "SHA"; + } + break; + case SSL_HANDSHAKE_MAC_SHA256: + return "SHA256"; + case SSL_HANDSHAKE_MAC_SHA384: + return "SHA384"; } + assert(0); + return "UNKNOWN"; } char *SSL_CIPHER_get_rfc_name(const SSL_CIPHER *cipher) {
diff --git a/ssl/t1_enc.c b/ssl/t1_enc.c index febd54d..aa6095d 100644 --- a/ssl/t1_enc.c +++ b/ssl/t1_enc.c
@@ -149,7 +149,7 @@ /* tls1_P_hash computes the TLS P_<hash> function as described in RFC 5246, - * section 5. It writes |out_len| bytes to |out|, using |md| as the hash and + * section 5. It XORs |out_len| bytes to |out|, using |md| as the hash and * |secret| as the secret. |seed1| through |seed3| are concatenated to form the * seed parameter. It returns one on success and zero on failure. */ static int tls1_P_hash(uint8_t *out, size_t out_len, const EVP_MD *md, @@ -188,26 +188,32 @@ goto err; } - if (out_len > chunk) { - unsigned len; - if (!HMAC_Final(&ctx, out, &len)) { - goto err; - } - assert(len == chunk); - out += len; - out_len -= len; - /* Calculate the next A1 value. */ - if (!HMAC_Final(&ctx_tmp, A1, &A1_len)) { - goto err; - } - } else { - /* Last chunk. */ - if (!HMAC_Final(&ctx, A1, &A1_len)) { - goto err; - } - memcpy(out, A1, out_len); + unsigned len; + uint8_t hmac[EVP_MAX_MD_SIZE]; + if (!HMAC_Final(&ctx, hmac, &len)) { + goto err; + } + assert(len == chunk); + + /* XOR the result into |out|. */ + if (len > out_len) { + len = out_len; + } + unsigned i; + for (i = 0; i < len; i++) { + out[i] ^= hmac[i]; + } + out += len; + out_len -= len; + + if (out_len == 0) { break; } + + /* Calculate the next A1 value. */ + if (!HMAC_Final(&ctx_tmp, A1, &A1_len)) { + goto err; + } } ret = 1; @@ -224,62 +230,36 @@ size_t secret_len, const char *label, size_t label_len, const uint8_t *seed1, size_t seed1_len, const uint8_t *seed2, size_t seed2_len) { - size_t idx, len, count, i; - const uint8_t *S1; - uint32_t m; - const EVP_MD *md; - int ret = 0; - uint8_t *tmp; if (out_len == 0) { return 1; } - /* Allocate a temporary buffer. */ - tmp = OPENSSL_malloc(out_len); - if (tmp == NULL) { - OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE); + memset(out, 0, out_len); + + uint32_t algorithm_prf = ssl_get_algorithm_prf(s); + if (algorithm_prf == SSL_HANDSHAKE_MAC_DEFAULT) { + /* If using the MD5/SHA1 PRF, |secret| is partitioned between SHA-1 and + * MD5, MD5 first. */ + size_t secret_half = secret_len - (secret_len / 2); + if (!tls1_P_hash(out, out_len, EVP_md5(), secret, secret_half, + (const uint8_t *)label, label_len, seed1, seed1_len, seed2, + seed2_len)) { + return 0; + } + + /* Note that, if |secret_len| is odd, the two halves share a byte. */ + secret = secret + (secret_len - secret_half); + secret_len = secret_half; + } + + if (!tls1_P_hash(out, out_len, ssl_get_handshake_digest(algorithm_prf), + secret, secret_len, (const uint8_t *)label, label_len, + seed1, seed1_len, seed2, seed2_len)) { return 0; } - /* Count number of digests and partition |secret| evenly. */ - count = 0; - for (idx = 0; ssl_get_handshake_digest(&m, &md, idx); idx++) { - if (m & ssl_get_algorithm_prf(s)) { - count++; - } - } - /* TODO(davidben): The only case where count isn't 1 is the old MD5/SHA-1 - * combination. The logic around multiple handshake digests can probably be - * simplified. */ - assert(count == 1 || count == 2); - len = secret_len / count; - if (count == 1) { - secret_len = 0; - } - S1 = secret; - memset(out, 0, out_len); - for (idx = 0; ssl_get_handshake_digest(&m, &md, idx); idx++) { - if (m & ssl_get_algorithm_prf(s)) { - /* If |count| is 2 and |secret_len| is odd, |secret| is partitioned into - * two halves with an overlapping byte. */ - if (!tls1_P_hash(tmp, out_len, md, S1, len + (secret_len & 1), - (const uint8_t *)label, label_len, seed1, seed1_len, - seed2, seed2_len)) { - goto err; - } - S1 += len; - for (i = 0; i < out_len; i++) { - out[i] ^= tmp[i]; - } - } - } - ret = 1; - -err: - OPENSSL_cleanse(tmp, out_len); - OPENSSL_free(tmp); - return ret; + return 1; } static int tls1_generate_key_block(SSL *s, uint8_t *out, size_t out_len) { @@ -469,31 +449,50 @@ } int tls1_cert_verify_mac(SSL *s, int md_nid, uint8_t *out) { - unsigned int ret; - EVP_MD_CTX ctx, *d = NULL; - int i; - - for (i = 0; i < SSL_MAX_DIGEST; i++) { - if (s->s3->handshake_dgst[i] && - EVP_MD_CTX_type(s->s3->handshake_dgst[i]) == md_nid) { - d = s->s3->handshake_dgst[i]; - break; - } - } - - if (!d) { + const EVP_MD_CTX *ctx_template; + if (md_nid == NID_md5) { + ctx_template = &s->s3->handshake_md5; + } else if (md_nid == EVP_MD_CTX_type(&s->s3->handshake_hash)) { + ctx_template = &s->s3->handshake_hash; + } else { OPENSSL_PUT_ERROR(SSL, SSL_R_NO_REQUIRED_DIGEST); return 0; } + EVP_MD_CTX ctx; EVP_MD_CTX_init(&ctx); - if (!EVP_MD_CTX_copy_ex(&ctx, d)) { + if (!EVP_MD_CTX_copy_ex(&ctx, ctx_template)) { EVP_MD_CTX_cleanup(&ctx); return 0; } + unsigned ret; EVP_DigestFinal_ex(&ctx, out, &ret); EVP_MD_CTX_cleanup(&ctx); + return ret; +} +static int append_digest(const EVP_MD_CTX *ctx, uint8_t *out, size_t *out_len, + size_t max_out) { + int ret = 0; + EVP_MD_CTX ctx_copy; + EVP_MD_CTX_init(&ctx_copy); + + if (EVP_MD_CTX_size(ctx) > max_out) { + OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL); + goto err; + } + unsigned len; + if (!EVP_MD_CTX_copy_ex(&ctx_copy, ctx) || + !EVP_DigestFinal_ex(&ctx_copy, out, &len)) { + goto err; + } + assert(len == EVP_MD_CTX_size(ctx)); + + *out_len = len; + ret = 1; + +err: + EVP_MD_CTX_cleanup(&ctx_copy); return ret; } @@ -503,44 +502,19 @@ * underlying digests so can be called multiple times and prior to the final * update etc. */ int tls1_handshake_digest(SSL *s, uint8_t *out, size_t out_len) { - const EVP_MD *md; - EVP_MD_CTX ctx; - int err = 0, len = 0; - size_t i; - uint32_t mask; - - EVP_MD_CTX_init(&ctx); - - for (i = 0; ssl_get_handshake_digest(&mask, &md, i); i++) { - size_t hash_size; - unsigned int digest_len; - EVP_MD_CTX *hdgst = s->s3->handshake_dgst[i]; - - if ((mask & ssl_get_algorithm_prf(s)) == 0) { - continue; - } - - hash_size = EVP_MD_size(md); - if (!hdgst || - hash_size > out_len || - !EVP_MD_CTX_copy_ex(&ctx, hdgst) || - !EVP_DigestFinal_ex(&ctx, out, &digest_len) || - digest_len != hash_size /* internal error */) { - err = 1; - break; - } - - out += digest_len; - out_len -= digest_len; - len += digest_len; - } - - EVP_MD_CTX_cleanup(&ctx); - - if (err != 0) { + size_t md5_len = 0; + if (EVP_MD_CTX_md(&s->s3->handshake_md5) != NULL && + !append_digest(&s->s3->handshake_md5, out, &md5_len, out_len)) { return -1; } - return len; + + size_t len; + if (!append_digest(&s->s3->handshake_hash, out + md5_len, &len, + out_len - md5_len)) { + return -1; + } + + return (int)(md5_len + len); } int tls1_final_finish_mac(SSL *s, const char *str, int slen, uint8_t *out) {