Acknowledge KeyUpdate messages.

Also remove TODO about post-handshake authentication. The only sensible
way to handle unexpected post-handshake authentication is a fatal error
(dropping them would cause a deadlock), and we treat all post-handshake
authentication as unexpected.

BUG=74

Change-Id: Ic92035b26ddcbcf25241262ce84bcc57b736b7a7
Reviewed-on: https://boringssl-review.googlesource.com/14744
Reviewed-by: Steven Valdez <svaldez@google.com>
Commit-Queue: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/internal.h b/ssl/internal.h
index 9a523d4..8d38cda 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1633,10 +1633,15 @@
    * handshake. */
   unsigned tlsext_channel_id_valid:1;
 
+  /* key_update_pending is one if we have a KeyUpdate acknowledgment
+   * outstanding. */
+  unsigned key_update_pending:1;
+
   uint8_t send_alert[2];
 
   /* pending_flight is the pending outgoing flight. This is used to flush each
-   * handshake flight in a single write. */
+   * handshake flight in a single write. |write_buffer| must be written out
+   * before this data. */
   BUF_MEM *pending_flight;
 
   /* pending_flight_offset is the number of bytes of |pending_flight| which have
diff --git a/ssl/s3_pkt.c b/ssl/s3_pkt.c
index fc21c2c..8e3613a 100644
--- a/ssl/s3_pkt.c
+++ b/ssl/s3_pkt.c
@@ -260,14 +260,6 @@
     return ssl3_write_pending(ssl, type, buf, len);
   }
 
-  /* The handshake flight buffer is mutually exclusive with application data.
-   *
-   * TODO(davidben): This will not be true when closure alerts use this. */
-  if (ssl->s3->pending_flight != NULL) {
-    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-    return -1;
-  }
-
   if (len > SSL3_RT_MAX_PLAIN_LENGTH) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return -1;
@@ -277,18 +269,47 @@
     return 0;
   }
 
+  size_t flight_len = 0;
+  if (ssl->s3->pending_flight != NULL) {
+    flight_len =
+        ssl->s3->pending_flight->length - ssl->s3->pending_flight_offset;
+  }
+
   size_t max_out = len + SSL_max_seal_overhead(ssl);
-  if (max_out < len) {
+  if (max_out < len || max_out + flight_len < max_out) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
     return -1;
   }
+  max_out += flight_len;
+
   uint8_t *out;
   size_t ciphertext_len;
-  if (!ssl_write_buffer_init(ssl, &out, max_out) ||
-      !tls_seal_record(ssl, out, &ciphertext_len, max_out, type, buf, len)) {
+  if (!ssl_write_buffer_init(ssl, &out, max_out)) {
     return -1;
   }
-  ssl_write_buffer_set_len(ssl, ciphertext_len);
+
+  /* Add any unflushed handshake data as a prefix. This may be a KeyUpdate
+   * acknowledgment or 0-RTT key change messages. |pending_flight| must be clear
+   * when data is added to |write_buffer| or it will be written in the wrong
+   * order. */
+  if (ssl->s3->pending_flight != NULL) {
+    OPENSSL_memcpy(
+        out, ssl->s3->pending_flight->data + ssl->s3->pending_flight_offset,
+        flight_len);
+    BUF_MEM_free(ssl->s3->pending_flight);
+    ssl->s3->pending_flight = NULL;
+    ssl->s3->pending_flight_offset = 0;
+  }
+
+  if (!tls_seal_record(ssl, out + flight_len, &ciphertext_len,
+                       max_out - flight_len, type, buf, len)) {
+    return -1;
+  }
+  ssl_write_buffer_set_len(ssl, flight_len + ciphertext_len);
+
+  /* Now that we've made progress on the connection, uncork KeyUpdate
+   * acknowledgments. */
+  ssl->s3->key_update_pending = 0;
 
   /* memorize arguments so that ssl3_write_pending can detect bad write retries
    * later */
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index 8e25e11..39b28e0 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -1305,7 +1305,8 @@
 
 // WriteAll writes |in_len| bytes from |in| to |ssl|, resolving any asynchronous
 // operations. It returns the result of the final |SSL_write| call.
-static int WriteAll(SSL *ssl, const uint8_t *in, size_t in_len) {
+static int WriteAll(SSL *ssl, const void *in_, size_t in_len) {
+  const uint8_t *in = reinterpret_cast<const uint8_t *>(in_);
   const TestConfig *config = GetTestConfig(ssl);
   int ret;
   do {
@@ -1969,22 +1970,23 @@
       }
     }
   } else {
+    static const char kInitialWrite[] = "hello";
+    bool pending_initial_write = false;
     if (config->read_with_unfinished_write) {
       if (!config->async) {
         fprintf(stderr, "-read-with-unfinished-write requires -async.\n");
         return false;
       }
 
-      int write_ret = SSL_write(ssl.get(),
-                          reinterpret_cast<const uint8_t *>("unfinished"), 10);
+      int write_ret =
+          SSL_write(ssl.get(), kInitialWrite, strlen(kInitialWrite));
       if (SSL_get_error(ssl.get(), write_ret) != SSL_ERROR_WANT_WRITE) {
         fprintf(stderr, "Failed to leave unfinished write.\n");
         return false;
       }
-    }
-    if (config->shim_writes_first) {
-      if (WriteAll(ssl.get(), reinterpret_cast<const uint8_t *>("hello"),
-                   5) < 0) {
+      pending_initial_write = true;
+    } else if (config->shim_writes_first) {
+      if (WriteAll(ssl.get(), kInitialWrite, strlen(kInitialWrite)) < 0) {
         return false;
       }
     }
@@ -2029,6 +2031,14 @@
           return false;
         }
 
+        // Clear the initial write, if unfinished.
+        if (pending_initial_write) {
+          if (WriteAll(ssl.get(), kInitialWrite, strlen(kInitialWrite)) < 0) {
+            return false;
+          }
+          pending_initial_write = false;
+        }
+
         for (int i = 0; i < n; i++) {
           buf[i] ^= 0xff;
         }
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index db17d98..0f54800 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -1335,6 +1335,10 @@
 	// SendServerHelloAsHelloRetryRequest, if true, causes the server to
 	// send ServerHello messages with a HelloRetryRequest type field.
 	SendServerHelloAsHelloRetryRequest bool
+
+	// RejectUnsolicitedKeyUpdate, if true, causes all unsolicited
+	// KeyUpdates from the peer to be rejected.
+	RejectUnsolicitedKeyUpdate bool
 }
 
 func (c *Config) serverInit() {
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 6cc1b5c..0eb64e7 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -852,20 +852,22 @@
 	default:
 		c.sendAlert(alertInternalError)
 		return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
-	case recordTypeHandshake, recordTypeChangeCipherSpec:
+	case recordTypeChangeCipherSpec:
 		if c.handshakeComplete {
 			c.sendAlert(alertInternalError)
-			return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete"))
+			return c.in.setErrorLocked(errors.New("tls: ChangeCipherSpec requested after handshake complete"))
 		}
 	case recordTypeApplicationData:
 		if !c.handshakeComplete && !c.config.Bugs.ExpectFalseStart && len(c.config.Bugs.ExpectHalfRTTData) == 0 && len(c.config.Bugs.ExpectEarlyData) == 0 {
 			c.sendAlert(alertInternalError)
 			return c.in.setErrorLocked(errors.New("tls: application data record requested before handshake complete"))
 		}
-	case recordTypeAlert:
-		// Looking for a close_notify. Note: unlike a real
-		// implementation, this is not tolerant of additional records.
-		// See the documentation for ExpectCloseNotify.
+	case recordTypeAlert, recordTypeHandshake:
+		// Looking for a close_notify or handshake message. Note: unlike
+		// a real implementation, this is not tolerant of additional
+		// records. See the documentation for ExpectCloseNotify.
+		// Post-handshake requests for handshake messages are allowed if
+		// the caller used ReadKeyUpdateACK.
 	}
 
 Again:
@@ -1497,6 +1499,9 @@
 	}
 
 	if keyUpdate, ok := msg.(*keyUpdateMsg); ok {
+		if c.config.Bugs.RejectUnsolicitedKeyUpdate {
+			return errors.New("tls: unexpected KeyUpdate message")
+		}
 		c.in.doKeyUpdate(c, false)
 		if keyUpdate.keyUpdateRequest == keyUpdateRequested {
 			c.keyUpdateRequested = true
@@ -1504,9 +1509,33 @@
 		return nil
 	}
 
-	// TODO(davidben): Add support for KeyUpdate.
 	c.sendAlert(alertUnexpectedMessage)
-	return alertUnexpectedMessage
+	return errors.New("tls: unexpected post-handshake message")
+}
+
+// Reads a KeyUpdate acknowledgment from the peer. There may not be any
+// application data records before the message.
+func (c *Conn) ReadKeyUpdateACK() error {
+	c.in.Lock()
+	defer c.in.Unlock()
+
+	msg, err := c.readHandshake()
+	if err != nil {
+		return err
+	}
+
+	keyUpdate, ok := msg.(*keyUpdateMsg)
+	if !ok {
+		c.sendAlert(alertUnexpectedMessage)
+		return errors.New("tls: unexpected message when reading KeyUpdate")
+	}
+
+	if keyUpdate.keyUpdateRequest != keyUpdateNotRequested {
+		return errors.New("tls: received invalid KeyUpdate message")
+	}
+
+	c.in.doKeyUpdate(c, false)
+	return nil
 }
 
 func (c *Conn) Renegotiate() error {
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 02824b8..2da0bc5 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -377,6 +377,10 @@
 	// shimWritesFirst controls whether the shim sends an initial "hello"
 	// message before doing a roundtrip with the runner.
 	shimWritesFirst bool
+	// readWithUnfinishedWrite behaves like shimWritesFirst, but the shim
+	// does not complete the write until responding to the first runner
+	// message.
+	readWithUnfinishedWrite bool
 	// shimShutsDown, if true, runs a test where the shim shuts down the
 	// connection immediately after the handshake rather than echoing
 	// messages from the runner.
@@ -678,36 +682,26 @@
 		}
 	}
 
-	if test.shimWritesFirst {
-		var buf [5]byte
-		_, err := io.ReadFull(tlsConn, buf[:])
-		if err != nil {
-			return err
-		}
-		if string(buf[:]) != "hello" {
-			return fmt.Errorf("bad initial message")
-		}
-	}
-
-	for i := 0; i < test.sendKeyUpdates; i++ {
-		if err := tlsConn.SendKeyUpdate(test.keyUpdateRequest); err != nil {
-			return err
-		}
-	}
-
-	for i := 0; i < test.sendEmptyRecords; i++ {
-		tlsConn.Write(nil)
-	}
-
-	for i := 0; i < test.sendWarningAlerts; i++ {
-		tlsConn.SendAlert(alertLevelWarning, alertUnexpectedMessage)
-	}
-
 	if test.sendHalfHelloRequest {
 		tlsConn.SendHalfHelloRequest()
 	}
 
+	shimPrefixPending := test.shimWritesFirst || test.readWithUnfinishedWrite
 	if test.renegotiate > 0 {
+		// If readWithUnfinishedWrite is set, the shim prefix will be
+		// available later.
+		if shimPrefixPending && !test.readWithUnfinishedWrite {
+			var buf [5]byte
+			_, err := io.ReadFull(tlsConn, buf[:])
+			if err != nil {
+				return err
+			}
+			if string(buf[:]) != "hello" {
+				return fmt.Errorf("bad initial message")
+			}
+			shimPrefixPending = false
+		}
+
 		if test.renegotiateCiphers != nil {
 			config.CipherSuites = test.renegotiateCiphers
 		}
@@ -745,12 +739,6 @@
 	}
 
 	for j := 0; j < messageCount; j++ {
-		testMessage := make([]byte, messageLen)
-		for i := range testMessage {
-			testMessage[i] = 0x42 ^ byte(j)
-		}
-		tlsConn.Write(testMessage)
-
 		for i := 0; i < test.sendKeyUpdates; i++ {
 			tlsConn.SendKeyUpdate(test.keyUpdateRequest)
 		}
@@ -763,11 +751,38 @@
 			tlsConn.SendAlert(alertLevelWarning, alertUnexpectedMessage)
 		}
 
+		testMessage := make([]byte, messageLen)
+		for i := range testMessage {
+			testMessage[i] = 0x42 ^ byte(j)
+		}
+		tlsConn.Write(testMessage)
+
+		// Consume the shim prefix if needed.
+		if shimPrefixPending {
+			var buf [5]byte
+			_, err := io.ReadFull(tlsConn, buf[:])
+			if err != nil {
+				return err
+			}
+			if string(buf[:]) != "hello" {
+				return fmt.Errorf("bad initial message")
+			}
+			shimPrefixPending = false
+		}
+
 		if test.shimShutsDown || test.expectMessageDropped {
 			// The shim will not respond.
 			continue
 		}
 
+		// Process the KeyUpdate ACK. However many KeyUpdates the runner
+		// sends, the shim should respond only once.
+		if test.sendKeyUpdates > 0 && test.keyUpdateRequest == keyUpdateRequested {
+			if err := tlsConn.ReadKeyUpdateACK(); err != nil {
+				return err
+			}
+		}
+
 		buf := make([]byte, len(testMessage))
 		if test.protocol == dtls {
 			bufTmp := make([]byte, len(buf)+1)
@@ -940,6 +955,10 @@
 		flags = append(flags, "-shim-writes-first")
 	}
 
+	if test.readWithUnfinishedWrite {
+		flags = append(flags, "-read-with-unfinished-write")
+	}
+
 	if test.shimShutsDown {
 		flags = append(flags, "-shim-shuts-down")
 	}
@@ -2301,6 +2320,38 @@
 			expectedError:    ":DECODE_ERROR:",
 		},
 		{
+			// Test that KeyUpdates are acknowledged properly.
+			name: "KeyUpdate-RequestACK",
+			config: Config{
+				MaxVersion: VersionTLS13,
+				Bugs: ProtocolBugs{
+					RejectUnsolicitedKeyUpdate: true,
+				},
+			},
+			// Test the shim receiving many KeyUpdates in a row.
+			sendKeyUpdates:   5,
+			messageCount:     5,
+			keyUpdateRequest: keyUpdateRequested,
+		},
+		{
+			// Test that KeyUpdates are acknowledged properly if the
+			// peer's KeyUpdate is discovered while a write is
+			// pending.
+			name: "KeyUpdate-RequestACK-UnfinishedWrite",
+			config: Config{
+				MaxVersion: VersionTLS13,
+				Bugs: ProtocolBugs{
+					RejectUnsolicitedKeyUpdate: true,
+				},
+			},
+			// Test the shim receiving many KeyUpdates in a row.
+			sendKeyUpdates:          5,
+			messageCount:            5,
+			keyUpdateRequest:        keyUpdateRequested,
+			readWithUnfinishedWrite: true,
+			flags: []string{"-async"},
+		},
+		{
 			name: "SendSNIWarningAlert",
 			config: Config{
 				MaxVersion: VersionTLS12,
@@ -6553,11 +6604,11 @@
 		config: Config{
 			MaxVersion: VersionTLS12,
 		},
-		renegotiate: 1,
+		renegotiate:             1,
+		readWithUnfinishedWrite: true,
 		flags: []string{
 			"-async",
 			"-renegotiate-freely",
-			"-read-with-unfinished-write",
 		},
 		shouldFail:    true,
 		expectedError: ":NO_RENEGOTIATION:",
diff --git a/ssl/tls13_both.c b/ssl/tls13_both.c
index be3c1b7..f44933f 100644
--- a/ssl/tls13_both.c
+++ b/ssl/tls13_both.c
@@ -625,9 +625,30 @@
     return 0;
   }
 
-  /* TODO(svaldez): Send KeyUpdate if |key_update_request| is
-   * |SSL_KEY_UPDATE_REQUESTED|. */
-  return tls13_rotate_traffic_key(ssl, evp_aead_open);
+  if (!tls13_rotate_traffic_key(ssl, evp_aead_open)) {
+    return 0;
+  }
+
+  /* Acknowledge the KeyUpdate */
+  if (key_update_request == SSL_KEY_UPDATE_REQUESTED &&
+      !ssl->s3->key_update_pending) {
+    CBB cbb, body;
+    if (!ssl->method->init_message(ssl, &cbb, &body, SSL3_MT_KEY_UPDATE) ||
+        !CBB_add_u8(&body, SSL_KEY_UPDATE_NOT_REQUESTED) ||
+        !ssl_add_message_cbb(ssl, &cbb) ||
+        !tls13_rotate_traffic_key(ssl, evp_aead_seal)) {
+      CBB_cleanup(&cbb);
+      return 0;
+    }
+
+    /* Suppress KeyUpdate acknowledgments until this change is written to the
+     * wire. This prevents us from accumulating write obligations when read and
+     * write progress at different rates. See draft-ietf-tls-tls13-18, section
+     * 4.5.3. */
+    ssl->s3->key_update_pending = 1;
+  }
+
+  return 1;
 }
 
 int tls13_post_handshake(SSL *ssl) {
@@ -649,8 +670,6 @@
     return tls13_process_new_session_ticket(ssl);
   }
 
-  // TODO(svaldez): Handle post-handshake authentication.
-
   ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
   OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
   return 0;