Simplify handshake hash handling.
Rather than support arbitrarily many handshake hashes in the general
case (which the PRF logic assumes is capped at two), special-case the
MD5/SHA1 two-hash combination and otherwise maintain a single rolling
hash.
Change-Id: Ide9475565b158f6839bb10b8b22f324f89399f92
Reviewed-on: https://boringssl-review.googlesource.com/5618
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/include/openssl/ssl3.h b/include/openssl/ssl3.h
index 7249f51..e07488d 100644
--- a/include/openssl/ssl3.h
+++ b/include/openssl/ssl3.h
@@ -387,9 +387,13 @@
/* handshake_buffer, if non-NULL, contains the handshake transcript. */
BUF_MEM *handshake_buffer;
- /* When set of handshake digests is determined, buffer is hashed and freed
- * and MD_CTX-es for all required digests are stored in this array */
- EVP_MD_CTX **handshake_dgst;
+ /* handshake_hash, if initialized with an |EVP_MD|, maintains the handshake
+ * hash. For TLS 1.1 and below, it is the SHA-1 half. */
+ EVP_MD_CTX handshake_hash;
+ /* handshake_md5, if initialized with an |EVP_MD|, maintains the MD5 half of
+ * the handshake hash for TLS 1.1 and below. */
+ EVP_MD_CTX handshake_md5;
+
/* this is set whenerver we see a change_cipher_spec message come in when we
* are not looking for one */
int change_cipher_spec;
diff --git a/ssl/internal.h b/ssl/internal.h
index a898422..4acd301 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -203,12 +203,9 @@
#define SSL_TLSV1_2 0x00000004L
/* Bits for |algorithm_prf| (handshake digest). */
-#define SSL_HANDSHAKE_MAC_MD5 0x10
-#define SSL_HANDSHAKE_MAC_SHA 0x20
-#define SSL_HANDSHAKE_MAC_SHA256 0x40
-#define SSL_HANDSHAKE_MAC_SHA384 0x80
-#define SSL_HANDSHAKE_MAC_DEFAULT \
- (SSL_HANDSHAKE_MAC_MD5 | SSL_HANDSHAKE_MAC_SHA)
+#define SSL_HANDSHAKE_MAC_DEFAULT 0x1
+#define SSL_HANDSHAKE_MAC_SHA256 0x2
+#define SSL_HANDSHAKE_MAC_SHA384 0x4
/* SSL_MAX_DIGEST is the number of digest types which exist. When adding a new
* one, update the table in ssl_cipher.c. */
@@ -229,11 +226,11 @@
size_t *out_fixed_iv_len,
const SSL_CIPHER *cipher, uint16_t version);
-/* ssl_get_handshake_digest looks up the |i|th handshake digest type and sets
- * |*out_mask| to the |SSL_HANDSHAKE_MAC_*| mask and |*out_md| to the
- * |EVP_MD|. It returns one on successs and zero if |i| >= |SSL_MAX_DIGEST|. */
-int ssl_get_handshake_digest(uint32_t *out_mask, const EVP_MD **out_md,
- size_t i);
+/* ssl_get_handshake_digest returns the |EVP_MD| corresponding to
+ * |algorithm_prf|. It returns SHA-1 for |SSL_HANDSHAKE_DEFAULT|. The caller is
+ * responsible for maintaining the additional MD5 digest and switching to
+ * SHA-256 in TLS 1.2. */
+const EVP_MD *ssl_get_handshake_digest(uint32_t algorithm_prf);
/* ssl_create_cipher_list evaluates |rule_str| according to the ciphers in
* |ssl_method|. It sets |*out_cipher_list| to a newly-allocated
diff --git a/ssl/s3_enc.c b/ssl/s3_enc.c
index f860609..ba9883b 100644
--- a/ssl/s3_enc.c
+++ b/ssl/s3_enc.c
@@ -242,58 +242,39 @@
return ssl->s3->handshake_buffer != NULL;
}
-int ssl3_init_handshake_hash(SSL *ssl) {
- int i;
- uint32_t mask;
- const EVP_MD *md;
-
- /* Allocate handshake_dgst array */
- ssl3_free_handshake_hash(ssl);
- ssl->s3->handshake_dgst = OPENSSL_malloc(SSL_MAX_DIGEST *
- sizeof(EVP_MD_CTX *));
- if (ssl->s3->handshake_dgst == NULL) {
- OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+/* init_digest_with_data calls |EVP_DigestInit_ex| on |ctx| with |md| and then
+ * writes the data in |buf| to it. */
+static int init_digest_with_data(EVP_MD_CTX *ctx, const EVP_MD *md,
+ const BUF_MEM *buf) {
+ if (!EVP_DigestInit_ex(ctx, md, NULL)) {
return 0;
}
- memset(ssl->s3->handshake_dgst, 0, SSL_MAX_DIGEST * sizeof(EVP_MD_CTX *));
+ EVP_DigestUpdate(ctx, buf->data, buf->length);
+ return 1;
+}
- /* Loop through bits of algorithm_prf field and create MD_CTX-es */
- for (i = 0; ssl_get_handshake_digest(&mask, &md, i); i++) {
- if ((mask & ssl_get_algorithm_prf(ssl)) && md) {
- ssl->s3->handshake_dgst[i] = EVP_MD_CTX_create();
- if (ssl->s3->handshake_dgst[i] == NULL) {
- OPENSSL_PUT_ERROR(SSL, ERR_LIB_EVP);
- return 0;
- }
- if (!EVP_DigestInit_ex(ssl->s3->handshake_dgst[i], md, NULL)) {
- EVP_MD_CTX_destroy(ssl->s3->handshake_dgst[i]);
- ssl->s3->handshake_dgst[i] = NULL;
- OPENSSL_PUT_ERROR(SSL, ERR_LIB_EVP);
- return 0;
- }
- EVP_DigestUpdate(ssl->s3->handshake_dgst[i],
- ssl->s3->handshake_buffer->data,
- ssl->s3->handshake_buffer->length);
- } else {
- ssl->s3->handshake_dgst[i] = NULL;
- }
+int ssl3_init_handshake_hash(SSL *ssl) {
+ ssl3_free_handshake_hash(ssl);
+
+ uint32_t algorithm_prf = ssl_get_algorithm_prf(ssl);
+ if (!init_digest_with_data(&ssl->s3->handshake_hash,
+ ssl_get_handshake_digest(algorithm_prf),
+ ssl->s3->handshake_buffer)) {
+ return 0;
+ }
+
+ if (algorithm_prf == SSL_HANDSHAKE_MAC_DEFAULT &&
+ !init_digest_with_data(&ssl->s3->handshake_md5, EVP_md5(),
+ ssl->s3->handshake_buffer)) {
+ return 0;
}
return 1;
}
void ssl3_free_handshake_hash(SSL *ssl) {
- int i;
- if (!ssl->s3->handshake_dgst) {
- return;
- }
- for (i = 0; i < SSL_MAX_DIGEST; i++) {
- if (ssl->s3->handshake_dgst[i]) {
- EVP_MD_CTX_destroy(ssl->s3->handshake_dgst[i]);
- }
- }
- OPENSSL_free(ssl->s3->handshake_dgst);
- ssl->s3->handshake_dgst = NULL;
+ EVP_MD_CTX_cleanup(&ssl->s3->handshake_hash);
+ EVP_MD_CTX_cleanup(&ssl->s3->handshake_md5);
}
void ssl3_free_handshake_buffer(SSL *ssl) {
@@ -317,13 +298,11 @@
memcpy(ssl->s3->handshake_buffer->data + new_len - in_len, in, in_len);
}
- if (ssl->s3->handshake_dgst != NULL) {
- int i;
- for (i = 0; i < SSL_MAX_DIGEST; i++) {
- if (ssl->s3->handshake_dgst[i] != NULL) {
- EVP_DigestUpdate(ssl->s3->handshake_dgst[i], in, in_len);
- }
- }
+ if (EVP_MD_CTX_md(&ssl->s3->handshake_hash) != NULL) {
+ EVP_DigestUpdate(&ssl->s3->handshake_hash, in, in_len);
+ }
+ if (EVP_MD_CTX_md(&ssl->s3->handshake_md5) != NULL) {
+ EVP_DigestUpdate(&ssl->s3->handshake_md5, in, in_len);
}
return 1;
}
@@ -356,24 +335,20 @@
int npad, n;
unsigned int i;
uint8_t md_buf[EVP_MAX_MD_SIZE];
- EVP_MD_CTX ctx, *d = NULL;
+ EVP_MD_CTX ctx;
+ const EVP_MD_CTX *ctx_template;
- /* Search for digest of specified type in the handshake_dgst array. */
- for (i = 0; i < SSL_MAX_DIGEST; i++) {
- if (s->s3->handshake_dgst[i] &&
- EVP_MD_CTX_type(s->s3->handshake_dgst[i]) == md_nid) {
- d = s->s3->handshake_dgst[i];
- break;
- }
- }
-
- if (!d) {
+ if (md_nid == NID_md5) {
+ ctx_template = &s->s3->handshake_md5;
+ } else if (md_nid == EVP_MD_CTX_type(&s->s3->handshake_hash)) {
+ ctx_template = &s->s3->handshake_hash;
+ } else {
OPENSSL_PUT_ERROR(SSL, SSL_R_NO_REQUIRED_DIGEST);
return 0;
}
EVP_MD_CTX_init(&ctx);
- if (!EVP_MD_CTX_copy_ex(&ctx, d)) {
+ if (!EVP_MD_CTX_copy_ex(&ctx, ctx_template)) {
EVP_MD_CTX_cleanup(&ctx);
OPENSSL_PUT_ERROR(SSL, ERR_LIB_EVP);
return 0;
diff --git a/ssl/s3_lib.c b/ssl/s3_lib.c
index 4ee68dd..2330fbd 100644
--- a/ssl/s3_lib.c
+++ b/ssl/s3_lib.c
@@ -152,6 +152,7 @@
#include <openssl/buf.h>
#include <openssl/dh.h>
+#include <openssl/digest.h>
#include <openssl/err.h>
#include <openssl/md5.h>
#include <openssl/mem.h>
@@ -201,6 +202,9 @@
}
memset(s3, 0, sizeof *s3);
+ EVP_MD_CTX_init(&s3->handshake_hash);
+ EVP_MD_CTX_init(&s3->handshake_md5);
+
s->s3 = s3;
/* Set the version to the highest supported version for TLS. This controls the
@@ -661,11 +665,10 @@
/* If we are using default SHA1+MD5 algorithms switch to new SHA256 PRF and
* handshake macs if required. */
uint32_t ssl_get_algorithm_prf(SSL *s) {
- static const uint32_t kMask = SSL_HANDSHAKE_MAC_DEFAULT;
- uint32_t alg2 = s->s3->tmp.new_cipher->algorithm_prf;
+ uint32_t algorithm_prf = s->s3->tmp.new_cipher->algorithm_prf;
if (s->enc_method->enc_flags & SSL_ENC_FLAG_SHA256_PRF &&
- (alg2 & kMask) == kMask) {
+ algorithm_prf == SSL_HANDSHAKE_MAC_DEFAULT) {
return SSL_HANDSHAKE_MAC_SHA256;
}
- return alg2;
+ return algorithm_prf;
}
diff --git a/ssl/ssl_cipher.c b/ssl/ssl_cipher.c
index 29824aa..a715409 100644
--- a/ssl/ssl_cipher.c
+++ b/ssl/ssl_cipher.c
@@ -469,18 +469,6 @@
static const size_t kCiphersLen = sizeof(kCiphers) / sizeof(kCiphers[0]);
-struct handshake_digest {
- uint32_t mask;
- const EVP_MD *(*md_func)(void);
-};
-
-static const struct handshake_digest ssl_handshake_digests[SSL_MAX_DIGEST] = {
- {SSL_HANDSHAKE_MAC_MD5, EVP_md5},
- {SSL_HANDSHAKE_MAC_SHA, EVP_sha1},
- {SSL_HANDSHAKE_MAC_SHA256, EVP_sha256},
- {SSL_HANDSHAKE_MAC_SHA384, EVP_sha384},
-};
-
#define CIPHER_ADD 1
#define CIPHER_KILL 2
#define CIPHER_DEL 3
@@ -718,14 +706,17 @@
}
}
-int ssl_get_handshake_digest(uint32_t *out_mask, const EVP_MD **out_md,
- size_t idx) {
- if (idx >= SSL_MAX_DIGEST) {
- return 0;
+const EVP_MD *ssl_get_handshake_digest(uint32_t algorithm_prf) {
+ switch (algorithm_prf) {
+ case SSL_HANDSHAKE_MAC_DEFAULT:
+ return EVP_sha1();
+ case SSL_HANDSHAKE_MAC_SHA256:
+ return EVP_sha256();
+ case SSL_HANDSHAKE_MAC_SHA384:
+ return EVP_sha384();
+ default:
+ return NULL;
}
- *out_mask = ssl_handshake_digests[idx].mask;
- *out_md = ssl_handshake_digests[idx].md_func();
- return 1;
}
#define ITEM_SEP(a) \
@@ -1456,27 +1447,24 @@
}
static const char *ssl_cipher_get_prf_name(const SSL_CIPHER *cipher) {
- if ((cipher->algorithm_prf & SSL_HANDSHAKE_MAC_DEFAULT) ==
- SSL_HANDSHAKE_MAC_DEFAULT) {
- /* Before TLS 1.2, the PRF component is the hash used in the HMAC, which is
- * only ever MD5 or SHA-1. */
- switch (cipher->algorithm_mac) {
- case SSL_MD5:
- return "MD5";
- case SSL_SHA1:
- return "SHA";
- default:
- assert(0);
- return "UNKNOWN";
- }
- } else if (cipher->algorithm_prf & SSL_HANDSHAKE_MAC_SHA256) {
- return "SHA256";
- } else if (cipher->algorithm_prf & SSL_HANDSHAKE_MAC_SHA384) {
- return "SHA384";
- } else {
- assert(0);
- return "UNKNOWN";
+ switch (cipher->algorithm_prf) {
+ case SSL_HANDSHAKE_MAC_DEFAULT:
+ /* Before TLS 1.2, the PRF component is the hash used in the HMAC, which is
+ * only ever MD5 or SHA-1. */
+ switch (cipher->algorithm_mac) {
+ case SSL_MD5:
+ return "MD5";
+ case SSL_SHA1:
+ return "SHA";
+ }
+ break;
+ case SSL_HANDSHAKE_MAC_SHA256:
+ return "SHA256";
+ case SSL_HANDSHAKE_MAC_SHA384:
+ return "SHA384";
}
+ assert(0);
+ return "UNKNOWN";
}
char *SSL_CIPHER_get_rfc_name(const SSL_CIPHER *cipher) {
diff --git a/ssl/t1_enc.c b/ssl/t1_enc.c
index febd54d..aa6095d 100644
--- a/ssl/t1_enc.c
+++ b/ssl/t1_enc.c
@@ -149,7 +149,7 @@
/* tls1_P_hash computes the TLS P_<hash> function as described in RFC 5246,
- * section 5. It writes |out_len| bytes to |out|, using |md| as the hash and
+ * section 5. It XORs |out_len| bytes to |out|, using |md| as the hash and
* |secret| as the secret. |seed1| through |seed3| are concatenated to form the
* seed parameter. It returns one on success and zero on failure. */
static int tls1_P_hash(uint8_t *out, size_t out_len, const EVP_MD *md,
@@ -188,26 +188,32 @@
goto err;
}
- if (out_len > chunk) {
- unsigned len;
- if (!HMAC_Final(&ctx, out, &len)) {
- goto err;
- }
- assert(len == chunk);
- out += len;
- out_len -= len;
- /* Calculate the next A1 value. */
- if (!HMAC_Final(&ctx_tmp, A1, &A1_len)) {
- goto err;
- }
- } else {
- /* Last chunk. */
- if (!HMAC_Final(&ctx, A1, &A1_len)) {
- goto err;
- }
- memcpy(out, A1, out_len);
+ unsigned len;
+ uint8_t hmac[EVP_MAX_MD_SIZE];
+ if (!HMAC_Final(&ctx, hmac, &len)) {
+ goto err;
+ }
+ assert(len == chunk);
+
+ /* XOR the result into |out|. */
+ if (len > out_len) {
+ len = out_len;
+ }
+ unsigned i;
+ for (i = 0; i < len; i++) {
+ out[i] ^= hmac[i];
+ }
+ out += len;
+ out_len -= len;
+
+ if (out_len == 0) {
break;
}
+
+ /* Calculate the next A1 value. */
+ if (!HMAC_Final(&ctx_tmp, A1, &A1_len)) {
+ goto err;
+ }
}
ret = 1;
@@ -224,62 +230,36 @@
size_t secret_len, const char *label, size_t label_len,
const uint8_t *seed1, size_t seed1_len,
const uint8_t *seed2, size_t seed2_len) {
- size_t idx, len, count, i;
- const uint8_t *S1;
- uint32_t m;
- const EVP_MD *md;
- int ret = 0;
- uint8_t *tmp;
if (out_len == 0) {
return 1;
}
- /* Allocate a temporary buffer. */
- tmp = OPENSSL_malloc(out_len);
- if (tmp == NULL) {
- OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+ memset(out, 0, out_len);
+
+ uint32_t algorithm_prf = ssl_get_algorithm_prf(s);
+ if (algorithm_prf == SSL_HANDSHAKE_MAC_DEFAULT) {
+ /* If using the MD5/SHA1 PRF, |secret| is partitioned between SHA-1 and
+ * MD5, MD5 first. */
+ size_t secret_half = secret_len - (secret_len / 2);
+ if (!tls1_P_hash(out, out_len, EVP_md5(), secret, secret_half,
+ (const uint8_t *)label, label_len, seed1, seed1_len, seed2,
+ seed2_len)) {
+ return 0;
+ }
+
+ /* Note that, if |secret_len| is odd, the two halves share a byte. */
+ secret = secret + (secret_len - secret_half);
+ secret_len = secret_half;
+ }
+
+ if (!tls1_P_hash(out, out_len, ssl_get_handshake_digest(algorithm_prf),
+ secret, secret_len, (const uint8_t *)label, label_len,
+ seed1, seed1_len, seed2, seed2_len)) {
return 0;
}
- /* Count number of digests and partition |secret| evenly. */
- count = 0;
- for (idx = 0; ssl_get_handshake_digest(&m, &md, idx); idx++) {
- if (m & ssl_get_algorithm_prf(s)) {
- count++;
- }
- }
- /* TODO(davidben): The only case where count isn't 1 is the old MD5/SHA-1
- * combination. The logic around multiple handshake digests can probably be
- * simplified. */
- assert(count == 1 || count == 2);
- len = secret_len / count;
- if (count == 1) {
- secret_len = 0;
- }
- S1 = secret;
- memset(out, 0, out_len);
- for (idx = 0; ssl_get_handshake_digest(&m, &md, idx); idx++) {
- if (m & ssl_get_algorithm_prf(s)) {
- /* If |count| is 2 and |secret_len| is odd, |secret| is partitioned into
- * two halves with an overlapping byte. */
- if (!tls1_P_hash(tmp, out_len, md, S1, len + (secret_len & 1),
- (const uint8_t *)label, label_len, seed1, seed1_len,
- seed2, seed2_len)) {
- goto err;
- }
- S1 += len;
- for (i = 0; i < out_len; i++) {
- out[i] ^= tmp[i];
- }
- }
- }
- ret = 1;
-
-err:
- OPENSSL_cleanse(tmp, out_len);
- OPENSSL_free(tmp);
- return ret;
+ return 1;
}
static int tls1_generate_key_block(SSL *s, uint8_t *out, size_t out_len) {
@@ -469,31 +449,50 @@
}
int tls1_cert_verify_mac(SSL *s, int md_nid, uint8_t *out) {
- unsigned int ret;
- EVP_MD_CTX ctx, *d = NULL;
- int i;
-
- for (i = 0; i < SSL_MAX_DIGEST; i++) {
- if (s->s3->handshake_dgst[i] &&
- EVP_MD_CTX_type(s->s3->handshake_dgst[i]) == md_nid) {
- d = s->s3->handshake_dgst[i];
- break;
- }
- }
-
- if (!d) {
+ const EVP_MD_CTX *ctx_template;
+ if (md_nid == NID_md5) {
+ ctx_template = &s->s3->handshake_md5;
+ } else if (md_nid == EVP_MD_CTX_type(&s->s3->handshake_hash)) {
+ ctx_template = &s->s3->handshake_hash;
+ } else {
OPENSSL_PUT_ERROR(SSL, SSL_R_NO_REQUIRED_DIGEST);
return 0;
}
+ EVP_MD_CTX ctx;
EVP_MD_CTX_init(&ctx);
- if (!EVP_MD_CTX_copy_ex(&ctx, d)) {
+ if (!EVP_MD_CTX_copy_ex(&ctx, ctx_template)) {
EVP_MD_CTX_cleanup(&ctx);
return 0;
}
+ unsigned ret;
EVP_DigestFinal_ex(&ctx, out, &ret);
EVP_MD_CTX_cleanup(&ctx);
+ return ret;
+}
+static int append_digest(const EVP_MD_CTX *ctx, uint8_t *out, size_t *out_len,
+ size_t max_out) {
+ int ret = 0;
+ EVP_MD_CTX ctx_copy;
+ EVP_MD_CTX_init(&ctx_copy);
+
+ if (EVP_MD_CTX_size(ctx) > max_out) {
+ OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL);
+ goto err;
+ }
+ unsigned len;
+ if (!EVP_MD_CTX_copy_ex(&ctx_copy, ctx) ||
+ !EVP_DigestFinal_ex(&ctx_copy, out, &len)) {
+ goto err;
+ }
+ assert(len == EVP_MD_CTX_size(ctx));
+
+ *out_len = len;
+ ret = 1;
+
+err:
+ EVP_MD_CTX_cleanup(&ctx_copy);
return ret;
}
@@ -503,44 +502,19 @@
* underlying digests so can be called multiple times and prior to the final
* update etc. */
int tls1_handshake_digest(SSL *s, uint8_t *out, size_t out_len) {
- const EVP_MD *md;
- EVP_MD_CTX ctx;
- int err = 0, len = 0;
- size_t i;
- uint32_t mask;
-
- EVP_MD_CTX_init(&ctx);
-
- for (i = 0; ssl_get_handshake_digest(&mask, &md, i); i++) {
- size_t hash_size;
- unsigned int digest_len;
- EVP_MD_CTX *hdgst = s->s3->handshake_dgst[i];
-
- if ((mask & ssl_get_algorithm_prf(s)) == 0) {
- continue;
- }
-
- hash_size = EVP_MD_size(md);
- if (!hdgst ||
- hash_size > out_len ||
- !EVP_MD_CTX_copy_ex(&ctx, hdgst) ||
- !EVP_DigestFinal_ex(&ctx, out, &digest_len) ||
- digest_len != hash_size /* internal error */) {
- err = 1;
- break;
- }
-
- out += digest_len;
- out_len -= digest_len;
- len += digest_len;
- }
-
- EVP_MD_CTX_cleanup(&ctx);
-
- if (err != 0) {
+ size_t md5_len = 0;
+ if (EVP_MD_CTX_md(&s->s3->handshake_md5) != NULL &&
+ !append_digest(&s->s3->handshake_md5, out, &md5_len, out_len)) {
return -1;
}
- return len;
+
+ size_t len;
+ if (!append_digest(&s->s3->handshake_hash, out + md5_len, &len,
+ out_len - md5_len)) {
+ return -1;
+ }
+
+ return (int)(md5_len + len);
}
int tls1_final_finish_mac(SSL *s, const char *str, int slen, uint8_t *out) {