Tidy up keyblock and CCS logic slightly.

Move the actual SSL_AEAD_CTX swap into the record layer. Also revise the
intermediate state we store between setup_key_block and
change_cipher_state. With SSL_AEAD_CTX_new abstracted out, keeping the
EVP_AEAD around doesn't make much sense. Just store enough to partition
the key block.

Change-Id: I773fb46a2cb78fa570f00c0a89339c15bbb1d719
Reviewed-on: https://boringssl-review.googlesource.com/6832
Reviewed-by: Adam Langley <alangley@gmail.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 27701b1..8f76c38 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -4057,13 +4057,12 @@
     uint8_t *certificate_types;
     size_t num_certificate_types;
 
-    int key_block_length;
     uint8_t *key_block;
+    uint8_t key_block_length;
 
-    const EVP_AEAD *new_aead;
     uint8_t new_mac_secret_len;
+    uint8_t new_key_len;
     uint8_t new_fixed_iv_len;
-    uint8_t new_variable_iv_len;
 
     /* Server-only: cert_request is true if a client certificate was
      * requested. */
diff --git a/ssl/internal.h b/ssl/internal.h
index 15f590a..1096cb7 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -436,6 +436,14 @@
                      uint8_t type, const uint8_t *in, size_t in_len,
                      enum dtls1_use_epoch_t use_epoch);
 
+/* ssl_set_read_state sets |ssl|'s read cipher state to |aead_ctx|. It takes
+ * ownership of |aead_ctx|. */
+void ssl_set_read_state(SSL *ssl, SSL_AEAD_CTX *aead_ctx);
+
+/* ssl_set_write_state sets |ssl|'s write cipher state to |aead_ctx|. It takes
+ * ownership of |aead_ctx|. */
+void ssl_set_write_state(SSL *ssl, SSL_AEAD_CTX *aead_ctx);
+
 
 /* Private key operations. */
 
diff --git a/ssl/s3_both.c b/ssl/s3_both.c
index 01a7e8c..fcbe78d 100644
--- a/ssl/s3_both.c
+++ b/ssl/s3_both.c
@@ -277,13 +277,6 @@
   return 0;
 }
 
-/* for these 2 messages, we need to
- * ssl->enc_read_ctx      re-init
- * ssl->s3->read_sequence   zero
- * ssl->s3->read_mac_secret   re-init
- * ssl->session->read_sym_enc   assign
- * ssl->session->read_compression assign
- * ssl->session->read_hash    assign */
 int ssl3_send_change_cipher_spec(SSL *ssl, int a, int b) {
   if (ssl->state == a) {
     *((uint8_t *)ssl->init_buf->data) = SSL3_MT_CCS;
diff --git a/ssl/t1_enc.c b/ssl/t1_enc.c
index 0f1d683..7dd6f0e 100644
--- a/ssl/t1_enc.c
+++ b/ssl/t1_enc.c
@@ -277,63 +277,28 @@
    * or a server reading a client's ChangeCipherSpec. */
   const char use_client_keys = which == SSL3_CHANGE_CIPHER_CLIENT_WRITE ||
                                which == SSL3_CHANGE_CIPHER_SERVER_READ;
-  const uint8_t *client_write_mac_secret, *server_write_mac_secret, *mac_secret;
-  const uint8_t *client_write_key, *server_write_key, *key;
-  const uint8_t *client_write_iv, *server_write_iv, *iv;
-  const EVP_AEAD *aead = ssl->s3->tmp.new_aead;
-  size_t key_len, iv_len, mac_secret_len;
-  const uint8_t *key_data;
 
-  /* Reset sequence number to zero. */
-  if (is_read) {
-    if (SSL_IS_DTLS(ssl)) {
-      ssl->d1->r_epoch++;
-      memset(&ssl->d1->bitmap, 0, sizeof(ssl->d1->bitmap));
-    }
-    memset(ssl->s3->read_sequence, 0, sizeof(ssl->s3->read_sequence));
-  } else {
-    if (SSL_IS_DTLS(ssl)) {
-      ssl->d1->w_epoch++;
-      memcpy(ssl->d1->last_write_sequence, ssl->s3->write_sequence,
-             sizeof(ssl->s3->write_sequence));
-    }
-    memset(ssl->s3->write_sequence, 0, sizeof(ssl->s3->write_sequence));
-  }
+  size_t mac_secret_len = ssl->s3->tmp.new_mac_secret_len;
+  size_t key_len = ssl->s3->tmp.new_key_len;
+  size_t iv_len = ssl->s3->tmp.new_fixed_iv_len;
+  assert((mac_secret_len + key_len + iv_len) * 2 ==
+         ssl->s3->tmp.key_block_length);
 
-  mac_secret_len = ssl->s3->tmp.new_mac_secret_len;
-  iv_len = ssl->s3->tmp.new_fixed_iv_len;
-
-  if (aead == NULL) {
-    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-    return 0;
-  }
-
-  key_len = EVP_AEAD_key_length(aead);
-  if (mac_secret_len > 0) {
-    /* For "stateful" AEADs (i.e. compatibility with pre-AEAD cipher
-     * suites) the key length reported by |EVP_AEAD_key_length| will
-     * include the MAC and IV key bytes. */
-    if (key_len < mac_secret_len + iv_len) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-      return 0;
-    }
-    key_len -= mac_secret_len + iv_len;
-  }
-
-  key_data = ssl->s3->tmp.key_block;
-  client_write_mac_secret = key_data;
+  const uint8_t *key_data = ssl->s3->tmp.key_block;
+  const uint8_t *client_write_mac_secret = key_data;
   key_data += mac_secret_len;
-  server_write_mac_secret = key_data;
+  const uint8_t *server_write_mac_secret = key_data;
   key_data += mac_secret_len;
-  client_write_key = key_data;
+  const uint8_t *client_write_key = key_data;
   key_data += key_len;
-  server_write_key = key_data;
+  const uint8_t *server_write_key = key_data;
   key_data += key_len;
-  client_write_iv = key_data;
+  const uint8_t *client_write_iv = key_data;
   key_data += iv_len;
-  server_write_iv = key_data;
+  const uint8_t *server_write_iv = key_data;
   key_data += iv_len;
 
+  const uint8_t *mac_secret, *key, *iv;
   if (use_client_keys) {
     mac_secret = client_write_mac_secret;
     key = client_write_key;
@@ -344,50 +309,37 @@
     iv = server_write_iv;
   }
 
-  if (key_data - ssl->s3->tmp.key_block != ssl->s3->tmp.key_block_length) {
-    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+  SSL_AEAD_CTX *aead_ctx = SSL_AEAD_CTX_new(
+      is_read ? evp_aead_open : evp_aead_seal,
+      ssl3_version_from_wire(ssl, ssl->version), ssl->s3->tmp.new_cipher, key,
+      key_len, mac_secret, mac_secret_len, iv, iv_len);
+  if (aead_ctx == NULL) {
     return 0;
   }
 
   if (is_read) {
-    SSL_AEAD_CTX_free(ssl->aead_read_ctx);
-    ssl->aead_read_ctx = SSL_AEAD_CTX_new(
-        evp_aead_open, ssl3_version_from_wire(ssl, ssl->version),
-        ssl->s3->tmp.new_cipher, key, key_len, mac_secret, mac_secret_len, iv,
-        iv_len);
-    return ssl->aead_read_ctx != NULL;
+    ssl_set_read_state(ssl, aead_ctx);
+  } else {
+    ssl_set_write_state(ssl, aead_ctx);
   }
-
-  SSL_AEAD_CTX_free(ssl->aead_write_ctx);
-  ssl->aead_write_ctx = SSL_AEAD_CTX_new(
-      evp_aead_seal, ssl3_version_from_wire(ssl, ssl->version),
-      ssl->s3->tmp.new_cipher, key, key_len, mac_secret, mac_secret_len, iv,
-      iv_len);
-  return ssl->aead_write_ctx != NULL;
+  return 1;
 }
 
 int tls1_setup_key_block(SSL *ssl) {
-  uint8_t *p;
-  const EVP_AEAD *aead = NULL;
-  int ret = 0;
-  size_t mac_secret_len, fixed_iv_len, variable_iv_len, key_len;
-  size_t key_block_len;
-
   if (ssl->s3->tmp.key_block_length != 0) {
     return 1;
   }
 
-  if (ssl->session->cipher == NULL) {
-    goto cipher_unavailable_err;
-  }
-
-  if (!ssl_cipher_get_evp_aead(&aead, &mac_secret_len, &fixed_iv_len,
+  const EVP_AEAD *aead = NULL;
+  size_t mac_secret_len, fixed_iv_len;
+  if (ssl->session->cipher == NULL ||
+      !ssl_cipher_get_evp_aead(&aead, &mac_secret_len, &fixed_iv_len,
                                ssl->session->cipher,
                                ssl3_version_from_wire(ssl, ssl->version))) {
-    goto cipher_unavailable_err;
+    OPENSSL_PUT_ERROR(SSL, SSL_R_CIPHER_OR_HASH_UNAVAILABLE);
+    return 0;
   }
-  key_len = EVP_AEAD_key_length(aead);
-  variable_iv_len = EVP_AEAD_nonce_length(aead);
+  size_t key_len = EVP_AEAD_key_length(aead);
   if (mac_secret_len > 0) {
     /* For "stateful" AEADs (i.e. compatibility with pre-AEAD cipher suites) the
      * key length reported by |EVP_AEAD_key_length| will include the MAC key
@@ -397,50 +349,36 @@
       return 0;
     }
     key_len -= mac_secret_len + fixed_iv_len;
-  } else {
-    /* The nonce is split into a fixed portion and a variable portion. */
-    if (variable_iv_len < fixed_iv_len) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-      return 0;
-    }
-    variable_iv_len -= fixed_iv_len;
   }
 
   assert(mac_secret_len < 256);
+  assert(key_len < 256);
   assert(fixed_iv_len < 256);
-  assert(variable_iv_len < 256);
 
-  ssl->s3->tmp.new_aead = aead;
   ssl->s3->tmp.new_mac_secret_len = (uint8_t)mac_secret_len;
+  ssl->s3->tmp.new_key_len = (uint8_t)key_len;
   ssl->s3->tmp.new_fixed_iv_len = (uint8_t)fixed_iv_len;
-  ssl->s3->tmp.new_variable_iv_len = (uint8_t)variable_iv_len;
 
-  key_block_len = key_len + mac_secret_len + fixed_iv_len;
+  size_t key_block_len = mac_secret_len + key_len + fixed_iv_len;
   key_block_len *= 2;
 
   ssl3_cleanup_key_block(ssl);
 
-  p = (uint8_t *)OPENSSL_malloc(key_block_len);
-  if (p == NULL) {
+  uint8_t *keyblock = (uint8_t *)OPENSSL_malloc(key_block_len);
+  if (keyblock == NULL) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-    goto err;
+    return 0;
   }
 
-  ssl->s3->tmp.key_block_length = key_block_len;
-  ssl->s3->tmp.key_block = p;
-
-  if (!tls1_generate_key_block(ssl, p, key_block_len)) {
-    goto err;
+  if (!tls1_generate_key_block(ssl, keyblock, key_block_len)) {
+    OPENSSL_free(keyblock);
+    return 0;
   }
 
-  ret = 1;
-
-err:
-  return ret;
-
-cipher_unavailable_err:
-  OPENSSL_PUT_ERROR(SSL, SSL_R_CIPHER_OR_HASH_UNAVAILABLE);
-  return 0;
+  assert(key_block_len < 256);
+  ssl->s3->tmp.key_block_length = (uint8_t)key_block_len;
+  ssl->s3->tmp.key_block = keyblock;
+  return 1;
 }
 
 int tls1_cert_verify_mac(SSL *ssl, int md_nid, uint8_t *out) {
diff --git a/ssl/tls_record.c b/ssl/tls_record.c
index e3a413b..0f29321 100644
--- a/ssl/tls_record.c
+++ b/ssl/tls_record.c
@@ -109,6 +109,7 @@
 #include <openssl/ssl.h>
 
 #include <assert.h>
+#include <string.h>
 
 #include <openssl/bytestring.h>
 #include <openssl/err.h>
@@ -352,3 +353,26 @@
   *out_len += frag_len;
   return 1;
 }
+
+void ssl_set_read_state(SSL *ssl, SSL_AEAD_CTX *aead_ctx) {
+  if (SSL_IS_DTLS(ssl)) {
+    ssl->d1->r_epoch++;
+    memset(&ssl->d1->bitmap, 0, sizeof(ssl->d1->bitmap));
+  }
+  memset(ssl->s3->read_sequence, 0, sizeof(ssl->s3->read_sequence));
+
+  SSL_AEAD_CTX_free(ssl->aead_read_ctx);
+  ssl->aead_read_ctx = aead_ctx;
+}
+
+void ssl_set_write_state(SSL *ssl, SSL_AEAD_CTX *aead_ctx) {
+  if (SSL_IS_DTLS(ssl)) {
+    ssl->d1->w_epoch++;
+    memcpy(ssl->d1->last_write_sequence, ssl->s3->write_sequence,
+           sizeof(ssl->s3->write_sequence));
+  }
+  memset(ssl->s3->write_sequence, 0, sizeof(ssl->s3->write_sequence));
+
+  SSL_AEAD_CTX_free(ssl->aead_write_ctx);
+  ssl->aead_write_ctx = aead_ctx;
+}