Adding handling for KeyUpdate post-handshake message.

BUG=74

Change-Id: I72d52c1fbc3413e940dddbc0b20c7f22459da693
Reviewed-on: https://boringssl-review.googlesource.com/8981
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/internal.h b/ssl/internal.h
index 14265f9..fe8bbf5 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -823,6 +823,10 @@
  * zero on error. */
 int tls13_set_handshake_traffic(SSL *ssl);
 
+/* tls13_rotate_traffic_key derives the next read or write traffic secret. It
+ * returns one on success and zero on error. */
+int tls13_rotate_traffic_key(SSL *ssl, enum evp_aead_direction_t direction);
+
 /* tls13_derive_traffic_secret_0 derives the initial application data traffic
  * secret based on the handshake transcripts and |master_secret|. It returns one
  * on success and zero on error. */
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index 141f160..dfd5b30 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -885,6 +885,10 @@
 	// message. This only makes sense for a server.
 	SendHelloRequestBeforeEveryHandshakeMessage bool
 
+	// SendKeyUpdateBeforeEveryAppDataRecord, if true, causes a KeyUpdate
+	// handshake message to be sent before each application data record.
+	SendKeyUpdateBeforeEveryAppDataRecord bool
+
 	// RequireDHPublicValueLen causes a fatal error if the length (in
 	// bytes) of the server's Diffie-Hellman public value is not equal to
 	// this.
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index d01643c..703908a 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -1325,6 +1325,10 @@
 		c.flushHandshake()
 	}
 
+	if c.config.Bugs.SendKeyUpdateBeforeEveryAppDataRecord {
+		c.sendKeyUpdateLocked()
+	}
+
 	// SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext
 	// attack when using block mode ciphers due to predictable IVs.
 	// This can be prevented by splitting each Application Data
@@ -1394,7 +1398,7 @@
 	}
 
 	if _, ok := msg.(*keyUpdateMsg); ok {
-		c.in.doKeyUpdate(c, true)
+		c.in.doKeyUpdate(c, false)
 		return nil
 	}
 
@@ -1704,6 +1708,6 @@
 	if err := c.flushHandshake(); err != nil {
 		return err
 	}
-	c.out.doKeyUpdate(c, false)
+	c.out.doKeyUpdate(c, true)
 	return nil
 }
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 48d8703..ccd25d4 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -2166,6 +2166,16 @@
 			shouldFail:    true,
 			expectedError: ":WRONG_VERSION_NUMBER:",
 		},
+		{
+			testType: clientTest,
+			name:     "KeyUpdate",
+			config: Config{
+				MaxVersion: VersionTLS13,
+				Bugs: ProtocolBugs{
+					SendKeyUpdateBeforeEveryAppDataRecord: true,
+				},
+			},
+		},
 	}
 	testCases = append(testCases, basicTests...)
 }
diff --git a/ssl/tls13_both.c b/ssl/tls13_both.c
index 1f60453..15f9a29 100644
--- a/ssl/tls13_both.c
+++ b/ssl/tls13_both.c
@@ -453,7 +453,22 @@
   return 1;
 }
 
+static int tls13_receive_key_update(SSL *ssl) {
+  if (ssl->init_num != 0) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
+    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
+    return 0;
+  }
+
+  // TODO(svaldez): Send KeyUpdate.
+  return tls13_rotate_traffic_key(ssl, evp_aead_open);
+}
+
 int tls13_post_handshake(SSL *ssl) {
+  if (ssl->s3->tmp.message_type == SSL3_MT_KEY_UPDATE) {
+    return tls13_receive_key_update(ssl);
+  }
+
   if (ssl->s3->tmp.message_type == SSL3_MT_NEW_SESSION_TICKET &&
       !ssl->server) {
     // TODO(svaldez): Handle NewSessionTicket.
@@ -461,7 +476,6 @@
   }
 
   // TODO(svaldez): Handle post-handshake authentication.
-  // TODO(svaldez): Handle KeyUpdate.
 
   ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
   OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
diff --git a/ssl/tls13_enc.c b/ssl/tls13_enc.c
index 1f4fe21..70b041a 100644
--- a/ssl/tls13_enc.c
+++ b/ssl/tls13_enc.c
@@ -176,7 +176,7 @@
   const EVP_MD *digest = ssl_get_handshake_digest(ssl_get_algorithm_prf(ssl));
   size_t mac_secret_len, fixed_iv_len;
   if (!ssl_cipher_get_evp_aead(&aead, &mac_secret_len, &fixed_iv_len,
-                               ssl->s3->new_session->cipher,
+                               SSL_get_session(ssl)->cipher,
                                ssl3_protocol_version(ssl))) {
     return 0;
   }
@@ -207,7 +207,7 @@
   }
 
   SSL_AEAD_CTX *traffic_aead = SSL_AEAD_CTX_new(
-      direction, ssl3_protocol_version(ssl), ssl->s3->new_session->cipher, key,
+      direction, ssl3_protocol_version(ssl), SSL_get_session(ssl)->cipher, key,
       key_len, NULL, 0, iv, iv_len);
   if (traffic_aead == NULL) {
     return 0;
@@ -225,10 +225,10 @@
 
   /* Save the traffic secret. */
   if (direction == evp_aead_open) {
-    memcpy(ssl->s3->read_traffic_secret, traffic_secret, traffic_secret_len);
+    memmove(ssl->s3->read_traffic_secret, traffic_secret, traffic_secret_len);
     ssl->s3->read_traffic_secret_len = traffic_secret_len;
   } else {
-    memcpy(ssl->s3->write_traffic_secret, traffic_secret, traffic_secret_len);
+    memmove(ssl->s3->write_traffic_secret, traffic_secret, traffic_secret_len);
     ssl->s3->write_traffic_secret_len = traffic_secret_len;
   }
 
@@ -267,6 +267,29 @@
                         hs->hash_len);
 }
 
+int tls13_rotate_traffic_key(SSL *ssl, enum evp_aead_direction_t direction) {
+  const EVP_MD *digest = ssl_get_handshake_digest(ssl_get_algorithm_prf(ssl));
+
+  uint8_t *secret;
+  size_t secret_len;
+  if (direction == evp_aead_open) {
+    secret = ssl->s3->read_traffic_secret;
+    secret_len = ssl->s3->read_traffic_secret_len;
+  } else {
+    secret = ssl->s3->write_traffic_secret;
+    secret_len = ssl->s3->write_traffic_secret_len;
+  }
+
+  if (!hkdf_expand_label(secret, digest, secret, secret_len,
+                         (const uint8_t *)kTLS13LabelApplicationTraffic,
+                         strlen(kTLS13LabelApplicationTraffic), NULL, 0,
+                         secret_len)) {
+    return 0;
+  }
+
+  return tls13_set_traffic_key(ssl, type_data, direction, secret, secret_len);
+}
+
 static const char kTLS13LabelExporter[] = "exporter master secret";
 static const char kTLS13LabelResumption[] = "resumption master secret";