Add |SSL_key_update|.

This function allows a client to send a TLS 1.3 KeyUpdate message.

Change-Id: I69935253795a79d65a8c85b652378bf04b7058e2
Reviewed-on: https://boringssl-review.googlesource.com/c/33706
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 6898674..d588395 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -394,6 +394,26 @@
 // https://crbug.com/466303.
 OPENSSL_EXPORT int SSL_write(SSL *ssl, const void *buf, int num);
 
+// SSL_KEY_UPDATE_REQUESTED indicates that the peer should reply to a KeyUpdate
+// message with its own, thus updating traffic secrets for both directions on
+// the connection.
+#define SSL_KEY_UPDATE_REQUESTED 1
+
+// SSL_KEY_UPDATE_NOT_REQUESTED indicates that the peer should not reply with
+// it's own KeyUpdate message.
+#define SSL_KEY_UPDATE_NOT_REQUESTED 0
+
+// SSL_key_update queues a TLS 1.3 KeyUpdate message to be sent on |ssl|
+// if one is not already queued. The |request_type| argument must one of the
+// |SSL_KEY_UPDATE_*| values. This function requires that |ssl| have completed a
+// TLS >= 1.3 handshake. It returns one on success or zero on error.
+//
+// Note that this function does not _send_ the message itself. The next call to
+// |SSL_write| will cause the message to be sent. |SSL_write| may be called with
+// a zero length to flush a KeyUpdate message when no application data is
+// pending.
+OPENSSL_EXPORT int SSL_key_update(SSL *ssl, int request_type);
+
 // SSL_shutdown shuts down |ssl|. It runs in two stages. First, it sends
 // close_notify and returns zero or one on success or -1 on failure. Zero
 // indicates that close_notify was sent, but not received, and one additionally
diff --git a/ssl/internal.h b/ssl/internal.h
index bbce7ec..fa755d2 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1649,6 +1649,11 @@
 const char *tls13_client_handshake_state(SSL_HANDSHAKE *hs);
 const char *tls13_server_handshake_state(SSL_HANDSHAKE *hs);
 
+// tls13_add_key_update queues a KeyUpdate message on |ssl|. The
+// |update_requested| argument must be one of |SSL_KEY_UPDATE_REQUESTED| or
+// |SSL_KEY_UPDATE_NOT_REQUESTED|.
+bool tls13_add_key_update(SSL *ssl, int update_requested);
+
 // tls13_post_handshake processes a post-handshake message. It returns true on
 // success and false on failure.
 bool tls13_post_handshake(SSL *ssl, const SSLMessage &msg);
@@ -2506,10 +2511,6 @@
 // From RFC 8446, used in determining PSK modes.
 #define SSL_PSK_DHE_KE 0x1
 
-// From RFC 8446, used in determining whether to respond with a KeyUpdate.
-#define SSL_KEY_UPDATE_NOT_REQUESTED 0
-#define SSL_KEY_UPDATE_REQUESTED 1
-
 // kMaxEarlyDataAccepted is the advertised number of plaintext bytes of early
 // data that will be accepted. This value should be slightly below
 // kMaxEarlyDataSkipped in tls_record.c, which is measured in ciphertext.
diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc
index f0ae8a2..f1d67d0 100644
--- a/ssl/s3_pkt.cc
+++ b/ssl/s3_pkt.cc
@@ -232,25 +232,29 @@
     return -1;
   }
 
-  if (len == 0) {
-    return 0;
-  }
-
   if (!tls_flush_pending_hs_data(ssl)) {
     return -1;
   }
+
   size_t flight_len = 0;
   if (ssl->s3->pending_flight != nullptr) {
     flight_len =
         ssl->s3->pending_flight->length - ssl->s3->pending_flight_offset;
   }
 
-  size_t max_out = len + SSL_max_seal_overhead(ssl);
-  if (max_out < len || max_out + flight_len < max_out) {
-    OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
-    return -1;
+  size_t max_out = flight_len;
+  if (len > 0) {
+    const size_t max_ciphertext_len = len + SSL_max_seal_overhead(ssl);
+    if (max_ciphertext_len < len || max_out + max_ciphertext_len < max_out) {
+      OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
+      return -1;
+    }
+    max_out += max_ciphertext_len;
   }
-  max_out += flight_len;
+
+  if (max_out == 0) {
+    return 0;
+  }
 
   if (!buf->EnsureCap(flight_len + ssl_seal_align_prefix_len(ssl), max_out)) {
     return -1;
@@ -270,12 +274,14 @@
     buf->DidWrite(flight_len);
   }
 
-  size_t ciphertext_len;
-  if (!tls_seal_record(ssl, buf->remaining().data(), &ciphertext_len,
-                       buf->remaining().size(), type, in, len)) {
-    return -1;
+  if (len > 0) {
+    size_t ciphertext_len;
+    if (!tls_seal_record(ssl, buf->remaining().data(), &ciphertext_len,
+                         buf->remaining().size(), type, in, len)) {
+      return -1;
+    }
+    buf->DidWrite(ciphertext_len);
   }
-  buf->DidWrite(ciphertext_len);
 
   // Now that we've made progress on the connection, uncork KeyUpdate
   // acknowledgments.
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index ceeba89..a3c25f3 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -1135,6 +1135,37 @@
   return ret;
 }
 
+int SSL_key_update(SSL *ssl, int request_type) {
+  ssl_reset_error_state(ssl);
+
+  if (ssl->do_handshake == NULL) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_UNINITIALIZED);
+    return 0;
+  }
+
+  if (ssl->ctx->quic_method != nullptr) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+    return 0;
+  }
+
+  if (!ssl->s3->initial_handshake_complete) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_HANDSHAKE_NOT_COMPLETE);
+    return 0;
+  }
+
+  if (ssl_protocol_version(ssl) < TLS1_3_VERSION) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SSL_VERSION);
+    return 0;
+  }
+
+  if (!ssl->s3->key_update_pending &&
+      !tls13_add_key_update(ssl, request_type)) {
+    return 0;
+  }
+
+  return 1;
+}
+
 int SSL_shutdown(SSL *ssl) {
   ssl_reset_error_state(ssl);
 
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 8d01c03..1f09156 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -4358,6 +4358,35 @@
   EXPECT_EQ(1u, server->config->supported_group_list.size());
 }
 
+TEST(SSLTest, ZeroSizedWiteFlushesHandshakeMessages) {
+  // If there are pending handshake mesages, an |SSL_write| of zero bytes should
+  // flush them.
+  bssl::UniquePtr<SSL_CTX> server_ctx(SSL_CTX_new(TLS_method()));
+  EXPECT_TRUE(SSL_CTX_set_max_proto_version(server_ctx.get(), TLS1_3_VERSION));
+  EXPECT_TRUE(SSL_CTX_set_min_proto_version(server_ctx.get(), TLS1_3_VERSION));
+  bssl::UniquePtr<X509> cert = GetTestCertificate();
+  bssl::UniquePtr<EVP_PKEY> key = GetTestKey();
+  ASSERT_TRUE(cert);
+  ASSERT_TRUE(key);
+  ASSERT_TRUE(SSL_CTX_use_certificate(server_ctx.get(), cert.get()));
+  ASSERT_TRUE(SSL_CTX_use_PrivateKey(server_ctx.get(), key.get()));
+
+  bssl::UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(TLS_method()));
+  EXPECT_TRUE(SSL_CTX_set_max_proto_version(client_ctx.get(), TLS1_3_VERSION));
+  EXPECT_TRUE(SSL_CTX_set_min_proto_version(client_ctx.get(), TLS1_3_VERSION));
+
+  bssl::UniquePtr<SSL> client, server;
+  ASSERT_TRUE(ConnectClientAndServer(&client, &server, client_ctx.get(),
+                                     server_ctx.get()));
+
+  BIO *client_wbio = SSL_get_wbio(client.get());
+  EXPECT_EQ(0u, BIO_wpending(client_wbio));
+  EXPECT_TRUE(SSL_key_update(client.get(), SSL_KEY_UPDATE_NOT_REQUESTED));
+  EXPECT_EQ(0u, BIO_wpending(client_wbio));
+  EXPECT_EQ(0, SSL_write(client.get(), nullptr, 0));
+  EXPECT_NE(0u, BIO_wpending(client_wbio));
+}
+
 TEST_P(SSLVersionTest, VerifyBeforeCertRequest) {
   // Configure the server to request client certificates.
   SSL_CTX_set_custom_verify(
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index 77ed796..62db076 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -981,6 +981,12 @@
           pending_initial_write = false;
         }
 
+        if (config->key_update &&
+            !SSL_key_update(ssl, SSL_KEY_UPDATE_NOT_REQUESTED)) {
+          fprintf(stderr, "SSL_key_update failed.\n");
+          return false;
+        }
+
         for (int i = 0; i < n; i++) {
           buf[i] ^= 0xff;
         }
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index b6b6ffa..816ffca 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -106,6 +106,7 @@
 	pendingFragments [][]byte // pending outgoing handshake fragments.
 	pendingPacket    []byte   // pending outgoing packet.
 
+	keyUpdateSeen      bool
 	keyUpdateRequested bool
 	seenOneByteRecord  bool
 
@@ -1635,6 +1636,8 @@
 	}
 
 	if keyUpdate, ok := msg.(*keyUpdateMsg); ok {
+		c.keyUpdateSeen = true
+
 		if c.config.Bugs.RejectUnsolicitedKeyUpdate {
 			return errors.New("tls: unexpected KeyUpdate message")
 		}
@@ -1665,7 +1668,7 @@
 	keyUpdate, ok := msg.(*keyUpdateMsg)
 	if !ok {
 		c.sendAlert(alertUnexpectedMessage)
-		return errors.New("tls: unexpected message when reading KeyUpdate")
+		return fmt.Errorf("tls: unexpected message (%T) when reading KeyUpdate", msg)
 	}
 
 	if keyUpdate.keyUpdateRequest != keyUpdateNotRequested {
@@ -1708,13 +1711,12 @@
 				// Soft error, like EAGAIN
 				return 0, err
 			}
-			if c.hand.Len() > 0 {
+			for c.hand.Len() > 0 {
 				// We received handshake bytes, indicating a
 				// post-handshake message.
 				if err := c.handlePostHandshakeMessage(); err != nil {
 					return 0, err
 				}
-				continue
 			}
 		}
 		if err := c.in.err; err != nil {
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index b5cc0a7..d64e95f 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -473,6 +473,11 @@
 	sendKeyUpdates int
 	// keyUpdateRequest is the KeyUpdateRequest value to send in KeyUpdate messages.
 	keyUpdateRequest byte
+	// expectUnsolicitedKeyUpdate makes the test expect a one or more KeyUpdate
+	// messages while reading data from the shim. Don't use this in combination
+	// with any of the fields that send a KeyUpdate otherwise any received
+	// KeyUpdate might not be as unsolicited as expected.
+	expectUnsolicitedKeyUpdate bool
 	// expectMessageDropped, if true, means the test message is expected to
 	// be dropped by the client rather than echoed back.
 	expectMessageDropped bool
@@ -951,6 +956,10 @@
 				return fmt.Errorf("bad reply contents at byte %d; got %q and wanted %q", i, buf, testMessage)
 			}
 		}
+
+		if seen := tlsConn.keyUpdateSeen; seen != test.expectUnsolicitedKeyUpdate {
+			return fmt.Errorf("keyUpdateSeen (%t) != expectUnsolicitedKeyUpdate", seen)
+		}
 	}
 
 	return nil
@@ -2826,7 +2835,7 @@
 			expectedError: ":WRONG_VERSION_NUMBER:",
 		},
 		{
-			name: "KeyUpdate-Client",
+			name: "KeyUpdate-ToClient",
 			config: Config{
 				MaxVersion: VersionTLS13,
 			},
@@ -2835,7 +2844,7 @@
 		},
 		{
 			testType: serverTest,
-			name:     "KeyUpdate-Server",
+			name:     "KeyUpdate-ToServer",
 			config: Config{
 				MaxVersion: VersionTLS13,
 			},
@@ -2843,6 +2852,23 @@
 			keyUpdateRequest: keyUpdateNotRequested,
 		},
 		{
+			name: "KeyUpdate-FromClient",
+			config: Config{
+				MaxVersion: VersionTLS13,
+			},
+			expectUnsolicitedKeyUpdate: true,
+			flags:                      []string{"-key-update"},
+		},
+		{
+			testType: serverTest,
+			name:     "KeyUpdate-FromServer",
+			config: Config{
+				MaxVersion: VersionTLS13,
+			},
+			expectUnsolicitedKeyUpdate: true,
+			flags:                      []string{"-key-update"},
+		},
+		{
 			name: "KeyUpdate-InvalidRequestMode",
 			config: Config{
 				MaxVersion: VersionTLS13,
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index bed0501..b88d0ae 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -148,6 +148,7 @@
   { "-jdk11-workaround", &TestConfig::jdk11_workaround },
   { "-server-preference", &TestConfig::server_preference },
   { "-export-traffic-secrets", &TestConfig::export_traffic_secrets },
+  { "-key-update", &TestConfig::key_update },
 };
 
 const Flag<std::string> kStringFlags[] = {
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index 0d0753e..5d5eb5a 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -172,6 +172,7 @@
   bool jdk11_workaround = false;
   bool server_preference = false;
   bool export_traffic_secrets = false;
+  bool key_update = false;
 
   int argc;
   char **argv;
diff --git a/ssl/tls13_both.cc b/ssl/tls13_both.cc
index 6baeaf7..f6e359c 100644
--- a/ssl/tls13_both.cc
+++ b/ssl/tls13_both.cc
@@ -607,6 +607,25 @@
   return true;
 }
 
+bool tls13_add_key_update(SSL *ssl, int update_requested) {
+  ScopedCBB cbb;
+  CBB body_cbb;
+  if (!ssl->method->init_message(ssl, cbb.get(), &body_cbb,
+                                 SSL3_MT_KEY_UPDATE) ||
+      !CBB_add_u8(&body_cbb, update_requested) ||
+      !ssl_add_message_cbb(ssl, cbb.get()) ||
+      !tls13_rotate_traffic_key(ssl, evp_aead_seal)) {
+    return false;
+  }
+
+  // Suppress KeyUpdate acknowledgments until this change is written to the
+  // wire. This prevents us from accumulating write obligations when read and
+  // write progress at different rates. See RFC 8446, section 4.6.3.
+  ssl->s3->key_update_pending = true;
+
+  return true;
+}
+
 static bool tls13_receive_key_update(SSL *ssl, const SSLMessage &msg) {
   CBS body = msg.body;
   uint8_t key_update_request;
@@ -625,21 +644,9 @@
 
   // Acknowledge the KeyUpdate
   if (key_update_request == SSL_KEY_UPDATE_REQUESTED &&
-      !ssl->s3->key_update_pending) {
-    ScopedCBB cbb;
-    CBB body_cbb;
-    if (!ssl->method->init_message(ssl, cbb.get(), &body_cbb,
-                                   SSL3_MT_KEY_UPDATE) ||
-        !CBB_add_u8(&body_cbb, SSL_KEY_UPDATE_NOT_REQUESTED) ||
-        !ssl_add_message_cbb(ssl, cbb.get()) ||
-        !tls13_rotate_traffic_key(ssl, evp_aead_seal)) {
-      return false;
-    }
-
-    // Suppress KeyUpdate acknowledgments until this change is written to the
-    // wire. This prevents us from accumulating write obligations when read and
-    // write progress at different rates. See RFC 8446, section 4.6.3.
-    ssl->s3->key_update_pending = true;
+      !ssl->s3->key_update_pending &&
+      !tls13_add_key_update(ssl, SSL_KEY_UPDATE_NOT_REQUESTED)) {
+    return false;
   }
 
   return true;