Forbid reusing HMAC key without reusing the hash function.

There's no good reason to do this, and it doesn't work; HMAC checks the length
of the key and runs it through the hash function if too long. The reuse occurs
after this check.

This allows us to shave 132 bytes off HMAC_CTX as this was the only reason it
ever stored the original key. It also slightly simplifies HMAC_Init_ex's
logic.

Change-Id: Ib56aabc3630b7178f1ee7c38ef6370c9638efbab
Reviewed-on: https://boringssl-review.googlesource.com/3733
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/hmac/hmac.c b/crypto/hmac/hmac.c
index 21479c2..eabfcd7 100644
--- a/crypto/hmac/hmac.c
+++ b/crypto/hmac/hmac.c
@@ -88,7 +88,6 @@
 
 void HMAC_CTX_init(HMAC_CTX *ctx) {
   ctx->md = NULL;
-  ctx->key_length = 0;
   EVP_MD_CTX_init(&ctx->i_ctx);
   EVP_MD_CTX_init(&ctx->o_ctx);
   EVP_MD_CTX_init(&ctx->md_ctx);
@@ -103,48 +102,44 @@
 
 int HMAC_Init_ex(HMAC_CTX *ctx, const void *key, size_t key_len,
                  const EVP_MD *md, ENGINE *impl) {
-  unsigned i, reset = 0;
-  uint8_t pad[HMAC_MAX_MD_CBLOCK];
-
-  if (md != NULL) {
-    if (ctx->md == NULL && key == NULL && ctx->key_length == 0) {
-      /* TODO(eroman): Change the API instead of this hack.
-       * If a key hasn't yet been assigned to the context, then default to using
-       * an all-zero key. This is to work around callers of
-       * HMAC_Init_ex(key=NULL, key_len=0) intending to set a zero-length key.
-       * Rather than resulting in uninitialized memory reads, it will
-       * predictably use a zero key. */
-      memset(ctx->key, 0, sizeof(ctx->key));
-    }
-    reset = 1;
-    ctx->md = md;
-  } else {
+  if (md == NULL) {
     md = ctx->md;
   }
 
-  if (key != NULL) {
+  /* If either |key| is non-NULL or |md| has changed, initialize with a new key
+   * rather than rewinding the previous one.
+   *
+   * TODO(davidben,eroman): Passing the previous |md| with a NULL |key| is
+   * ambiguous between using the empty key and reusing the previous key. There
+   * exist callers which intend the latter, but the former is an awkward edge
+   * case. Fix to API to avoid this. */
+  if (md != ctx->md || key != NULL) {
+    size_t i;
+    uint8_t pad[HMAC_MAX_MD_CBLOCK];
+    uint8_t key_block[HMAC_MAX_MD_CBLOCK];
+    unsigned key_block_len;
+
     size_t block_size = EVP_MD_block_size(md);
-    reset = 1;
-    assert(block_size <= sizeof(ctx->key));
+    assert(block_size <= sizeof(key_block));
     if (block_size < key_len) {
+      /* Long keys are hashed. */
       if (!EVP_DigestInit_ex(&ctx->md_ctx, md, impl) ||
           !EVP_DigestUpdate(&ctx->md_ctx, key, key_len) ||
-          !EVP_DigestFinal_ex(&(ctx->md_ctx), ctx->key, &ctx->key_length)) {
+          !EVP_DigestFinal_ex(&ctx->md_ctx, key_block, &key_block_len)) {
         goto err;
       }
     } else {
-      assert(key_len >= 0 && key_len <= sizeof(ctx->key));
-      memcpy(ctx->key, key, key_len);
-      ctx->key_length = key_len;
+      assert(key_len >= 0 && key_len <= sizeof(key_block));
+      memcpy(key_block, key, key_len);
+      key_block_len = (unsigned)key_len;
     }
-    if (ctx->key_length != HMAC_MAX_MD_CBLOCK) {
-      memset(&ctx->key[ctx->key_length], 0, sizeof(ctx->key) - ctx->key_length);
+    /* Keys are then padded with zeros. */
+    if (key_block_len != HMAC_MAX_MD_CBLOCK) {
+      memset(&key_block[key_block_len], 0, sizeof(key_block) - key_block_len);
     }
-  }
 
-  if (reset) {
     for (i = 0; i < HMAC_MAX_MD_CBLOCK; i++) {
-      pad[i] = 0x36 ^ ctx->key[i];
+      pad[i] = 0x36 ^ key_block[i];
     }
     if (!EVP_DigestInit_ex(&ctx->i_ctx, md, impl) ||
         !EVP_DigestUpdate(&ctx->i_ctx, pad, EVP_MD_block_size(md))) {
@@ -152,12 +147,14 @@
     }
 
     for (i = 0; i < HMAC_MAX_MD_CBLOCK; i++) {
-      pad[i] = 0x5c ^ ctx->key[i];
+      pad[i] = 0x5c ^ key_block[i];
     }
     if (!EVP_DigestInit_ex(&ctx->o_ctx, md, impl) ||
         !EVP_DigestUpdate(&ctx->o_ctx, pad, EVP_MD_block_size(md))) {
       goto err;
     }
+
+    ctx->md = md;
   }
 
   if (!EVP_MD_CTX_copy_ex(&ctx->md_ctx, &ctx->i_ctx)) {
@@ -200,8 +197,6 @@
     return 0;
   }
 
-  memcpy(dest->key, src->key, HMAC_MAX_MD_CBLOCK);
-  dest->key_length = src->key_length;
   dest->md = src->md;
   return 1;
 }
diff --git a/include/openssl/hmac.h b/include/openssl/hmac.h
index 6c34cdc..89cdf8f 100644
--- a/include/openssl/hmac.h
+++ b/include/openssl/hmac.h
@@ -94,9 +94,14 @@
 OPENSSL_EXPORT void HMAC_CTX_cleanup(HMAC_CTX *ctx);
 
 /* HMAC_Init_ex sets up an initialised |HMAC_CTX| to use |md| as the hash
- * function and |key| as the key. Any of |md| or |key| can be NULL, in which
- * case the previous value will be used. It returns one on success or zero
- * otherwise. */
+ * function and |key| as the key. For a non-initial call, |md| may be NULL, in
+ * which case the previous hash function will be used. If the hash function has
+ * not changed and |key| is NULL, |ctx| reuses the previous key. It returns one
+ * on success or zero otherwise.
+ *
+ * WARNING: NULL and empty keys are ambiguous on non-initial calls. Passing NULL
+ * |key| but repeating the previous |md| reuses the previous key rather than the
+ * empty key. */
 OPENSSL_EXPORT int HMAC_Init_ex(HMAC_CTX *ctx, const void *key, size_t key_len,
                                 const EVP_MD *md, ENGINE *impl);
 
@@ -152,8 +157,6 @@
   EVP_MD_CTX md_ctx;
   EVP_MD_CTX i_ctx;
   EVP_MD_CTX o_ctx;
-  unsigned int key_length;
-  unsigned char key[HMAC_MAX_MD_CBLOCK];
 } /* HMAC_CTX */;