Buffer up QUIC data within a level internally.

Avoid forcing the QUIC implementation to buffer this when we already have code
to do it. This also avoids QUIC implementations relying on this hook being
called for each individual message.

Change-Id: If2d70f045a25da1aa2b10fdae262cae331da06b1
Reviewed-on: https://boringssl-review.googlesource.com/c/32785
Commit-Queue: Steven Valdez <svaldez@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc
index 689dd1d..55e9aaa 100644
--- a/ssl/s3_both.cc
+++ b/ssl/s3_both.cc
@@ -184,56 +184,49 @@
 }
 
 bool ssl3_add_message(SSL *ssl, Array<uint8_t> msg) {
-  if (ssl->ctx->quic_method) {
-    if (!ssl->ctx->quic_method->add_message(ssl, ssl->s3->write_level,
-                                            msg.data(), msg.size())) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
-      return false;
+  // Pack handshake data into the minimal number of records. This avoids
+  // unnecessary encryption overhead, notably in TLS 1.3 where we send several
+  // encrypted messages in a row. For now, we do not do this for the null
+  // cipher. The benefit is smaller and there is a risk of breaking buggy
+  // implementations. Additionally, we tie this to draft-28 as a sanity check,
+  // on the off chance middleboxes have fixated on sizes.
+  //
+  // TODO(davidben): See if we can do this uniformly.
+  Span<const uint8_t> rest = msg;
+  if (ssl->ctx->quic_method == nullptr &&
+      (ssl->s3->aead_write_ctx->is_null_cipher() ||
+       ssl->version == TLS1_3_DRAFT23_VERSION)) {
+    while (!rest.empty()) {
+      Span<const uint8_t> chunk = rest.subspan(0, ssl->max_send_fragment);
+      rest = rest.subspan(chunk.size());
+
+      if (!add_record_to_flight(ssl, SSL3_RT_HANDSHAKE, chunk)) {
+        return false;
+      }
     }
   } else {
-    // Pack handshake data into the minimal number of records. This avoids
-    // unnecessary encryption overhead, notably in TLS 1.3 where we send several
-    // encrypted messages in a row. For now, we do not do this for the null
-    // cipher. The benefit is smaller and there is a risk of breaking buggy
-    // implementations. Additionally, we tie this to draft-28 as a sanity check,
-    // on the off chance middleboxes have fixated on sizes.
-    //
-    // TODO(davidben): See if we can do this uniformly.
-    Span<const uint8_t> rest = msg;
-    if (ssl->s3->aead_write_ctx->is_null_cipher() ||
-        ssl->version == TLS1_3_DRAFT23_VERSION) {
-      while (!rest.empty()) {
-        Span<const uint8_t> chunk = rest.subspan(0, ssl->max_send_fragment);
-        rest = rest.subspan(chunk.size());
-
-        if (!add_record_to_flight(ssl, SSL3_RT_HANDSHAKE, chunk)) {
-          return false;
-        }
+    while (!rest.empty()) {
+      // Flush if |pending_hs_data| is full.
+      if (ssl->s3->pending_hs_data &&
+          ssl->s3->pending_hs_data->length >= ssl->max_send_fragment &&
+          !tls_flush_pending_hs_data(ssl)) {
+        return false;
       }
-    } else {
-      while (!rest.empty()) {
-        // Flush if |pending_hs_data| is full.
-        if (ssl->s3->pending_hs_data &&
-            ssl->s3->pending_hs_data->length >= ssl->max_send_fragment &&
-            !tls_flush_pending_hs_data(ssl)) {
-          return false;
-        }
 
-        size_t pending_len =
-            ssl->s3->pending_hs_data ? ssl->s3->pending_hs_data->length : 0;
-        Span<const uint8_t> chunk =
-            rest.subspan(0, ssl->max_send_fragment - pending_len);
-        assert(!chunk.empty());
-        rest = rest.subspan(chunk.size());
+      size_t pending_len =
+          ssl->s3->pending_hs_data ? ssl->s3->pending_hs_data->length : 0;
+      Span<const uint8_t> chunk =
+          rest.subspan(0, ssl->max_send_fragment - pending_len);
+      assert(!chunk.empty());
+      rest = rest.subspan(chunk.size());
 
-        if (!ssl->s3->pending_hs_data) {
-          ssl->s3->pending_hs_data.reset(BUF_MEM_new());
-        }
-        if (!ssl->s3->pending_hs_data ||
-            !BUF_MEM_append(ssl->s3->pending_hs_data.get(), chunk.data(),
-                            chunk.size())) {
-          return false;
-        }
+      if (!ssl->s3->pending_hs_data) {
+        ssl->s3->pending_hs_data.reset(BUF_MEM_new());
+      }
+      if (!ssl->s3->pending_hs_data ||
+          !BUF_MEM_append(ssl->s3->pending_hs_data.get(), chunk.data(),
+                          chunk.size())) {
+        return false;
       }
     }
   }
@@ -249,16 +242,24 @@
 }
 
 bool tls_flush_pending_hs_data(SSL *ssl) {
-  if (!ssl->s3->pending_hs_data || ssl->s3->pending_hs_data->length == 0 ||
-      ssl->ctx->quic_method) {
+  if (!ssl->s3->pending_hs_data || ssl->s3->pending_hs_data->length == 0) {
     return true;
   }
 
   UniquePtr<BUF_MEM> pending_hs_data = std::move(ssl->s3->pending_hs_data);
-  return add_record_to_flight(
-      ssl, SSL3_RT_HANDSHAKE,
+  auto data =
       MakeConstSpan(reinterpret_cast<const uint8_t *>(pending_hs_data->data),
-                    pending_hs_data->length));
+                    pending_hs_data->length);
+  if (ssl->ctx->quic_method) {
+    if (!ssl->ctx->quic_method->add_handshake_data(ssl, ssl->s3->write_level,
+                                                   data.data(), data.size())) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
+      return false;
+    }
+    return true;
+  }
+
+  return add_record_to_flight(ssl, SSL3_RT_HANDSHAKE, data);
 }
 
 bool ssl3_add_change_cipher_spec(SSL *ssl) {
@@ -280,6 +281,10 @@
 }
 
 int ssl3_flush_flight(SSL *ssl) {
+  if (!tls_flush_pending_hs_data(ssl)) {
+    return -1;
+  }
+
   if (ssl->ctx->quic_method) {
     if (ssl->s3->write_shutdown != ssl_shutdown_none) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
@@ -292,10 +297,6 @@
     }
   }
 
-  if (!tls_flush_pending_hs_data(ssl)) {
-    return -1;
-  }
-
   if (ssl->s3->pending_flight == nullptr) {
     return 1;
   }
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index c237809..4792560 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -4650,8 +4650,9 @@
                                                        write_key, key_len);
   }
 
-  static int AddMessageCallback(SSL *ssl, enum ssl_encryption_level_t level,
-                                const uint8_t *data, size_t len) {
+  static int AddHandshakeDataCallback(SSL *ssl,
+                                      enum ssl_encryption_level_t level,
+                                      const uint8_t *data, size_t len) {
     EXPECT_EQ(level, SSL_quic_write_level(ssl));
     return TransportFromSSL(ssl)->WriteHandshakeData(level,
                                                      MakeConstSpan(data, len));
@@ -4681,7 +4682,7 @@
 TEST_F(QUICMethodTest, Basic) {
   const SSL_QUIC_METHOD quic_method = {
       SetEncryptionSecretsCallback,
-      AddMessageCallback,
+      AddHandshakeDataCallback,
       FlushFlightCallback,
       SendAlertCallback,
   };
@@ -4732,7 +4733,7 @@
 TEST_F(QUICMethodTest, Async) {
   const SSL_QUIC_METHOD quic_method = {
       SetEncryptionSecretsCallback,
-      AddMessageCallback,
+      AddHandshakeDataCallback,
       FlushFlightCallback,
       SendAlertCallback,
   };
@@ -4793,8 +4794,8 @@
   };
   static UnownedSSLExData<BufferedFlight> buffered_flights;
 
-  auto add_message = [](SSL *ssl, enum ssl_encryption_level_t level,
-                        const uint8_t *data, size_t len) -> int {
+  auto add_handshake_data = [](SSL *ssl, enum ssl_encryption_level_t level,
+                               const uint8_t *data, size_t len) -> int {
     BufferedFlight *flight = buffered_flights.Get(ssl);
     flight->data[level].insert(flight->data[level].end(), data, data + len);
     return 1;
@@ -4817,7 +4818,7 @@
 
   const SSL_QUIC_METHOD quic_method = {
     SetEncryptionSecretsCallback,
-    add_message,
+    add_handshake_data,
     flush_flight,
     SendAlertCallback,
   };
@@ -4862,8 +4863,8 @@
 // EncryptedExtensions in a single chunk, BoringSSL notices and rejects this on
 // key change.
 TEST_F(QUICMethodTest, ExcessProvidedData) {
-  auto add_message = [](SSL *ssl, enum ssl_encryption_level_t level,
-                        const uint8_t *data, size_t len) -> int {
+  auto add_handshake_data = [](SSL *ssl, enum ssl_encryption_level_t level,
+                               const uint8_t *data, size_t len) -> int {
     // Switch everything to the initial level.
     return TransportFromSSL(ssl)->WriteHandshakeData(ssl_encryption_initial,
                                                      MakeConstSpan(data, len));
@@ -4871,7 +4872,7 @@
 
   const SSL_QUIC_METHOD quic_method = {
       SetEncryptionSecretsCallback,
-      add_message,
+      add_handshake_data,
       FlushFlightCallback,
       SendAlertCallback,
   };
@@ -4891,8 +4892,8 @@
   // encryption.
   ASSERT_EQ(ssl_encryption_initial, SSL_quic_read_level(client_.get()));
 
-  // |add_message| incorrectly wrote everything at the initial level, so this
-  // queues up ServerHello through Finished in one chunk.
+  // |add_handshake_data| incorrectly wrote everything at the initial level, so
+  // this queues up ServerHello through Finished in one chunk.
   ASSERT_TRUE(ProvideHandshakeData(client_.get()));
 
   // The client reads ServerHello successfully, but then rejects the buffered
@@ -4917,7 +4918,7 @@
 TEST_F(QUICMethodTest, ProvideWrongLevel) {
   const SSL_QUIC_METHOD quic_method = {
       SetEncryptionSecretsCallback,
-      AddMessageCallback,
+      AddHandshakeDataCallback,
       FlushFlightCallback,
       SendAlertCallback,
   };
@@ -4962,7 +4963,7 @@
 TEST_F(QUICMethodTest, TooMuchData) {
   const SSL_QUIC_METHOD quic_method = {
       SetEncryptionSecretsCallback,
-      AddMessageCallback,
+      AddHandshakeDataCallback,
       FlushFlightCallback,
       SendAlertCallback,
   };