Configure QUIC secrets inside set_{read,write}_state.

set_write_state flushes buffered handshake data, and we should finish
writing to a level before moving on to the next one.

I've moved the callback into set_{read,write}_state to ensure we still
update read_level and write_level after installing secrets, since that's
how we decide what level to write things and we should never write
alerts with keys we don't have. (I believe the only way this can come up
is if the QUIC callback itself fails, but it still seems prudent to
defer updating the levels.)

This does unfortunately mean a goofy secret_for_quic parameter, though
it is arguably more "correct" in that QUIC would ideally be a third
SSL_PROTOCOL_METHOD, rather than escape hatches over TLS. Probably a
cleaner abstraction would be for set_read_state and set_write_state to
take the secret and derive an SSLAEADContext internally.

Update-Note: See b/151142920#comment9
Change-Id: I4bbb76e15b5d95615ea643bccf796db87fae4989
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/40244
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc
index ae26de7..d179bae 100644
--- a/ssl/dtls_method.cc
+++ b/ssl/dtls_method.cc
@@ -78,7 +78,9 @@
 }
 
 static bool dtls1_set_read_state(SSL *ssl, ssl_encryption_level_t level,
-                                 UniquePtr<SSLAEADContext> aead_ctx) {
+                                 UniquePtr<SSLAEADContext> aead_ctx,
+                                 Span<const uint8_t> secret_for_quic) {
+  assert(secret_for_quic.empty());  // QUIC does not use DTLS.
   // Cipher changes are forbidden if the current epoch has leftover data.
   if (dtls_has_unprocessed_handshake_data(ssl)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESS_HANDSHAKE_DATA);
@@ -97,7 +99,9 @@
 }
 
 static bool dtls1_set_write_state(SSL *ssl, ssl_encryption_level_t level,
-                                  UniquePtr<SSLAEADContext> aead_ctx) {
+                                  UniquePtr<SSLAEADContext> aead_ctx,
+                                  Span<const uint8_t> secret_for_quic) {
+  assert(secret_for_quic.empty());  // QUIC does not use DTLS.
   ssl->d1->w_epoch++;
   OPENSSL_memcpy(ssl->d1->last_write_sequence, ssl->s3->write_sequence,
                  sizeof(ssl->s3->write_sequence));
diff --git a/ssl/internal.h b/ssl/internal.h
index 91036ae..ac3a844 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2140,15 +2140,19 @@
   // on_handshake_complete is called when the handshake is complete.
   void (*on_handshake_complete)(SSL *ssl);
   // set_read_state sets |ssl|'s read cipher state and level to |aead_ctx| and
-  // |level|. It returns true on success and false if changing the read state is
-  // forbidden at this point.
+  // |level|. In QUIC, |aead_ctx| is a placeholder object and |secret_for_quic|
+  // is the original secret. This function returns true on success and false on
+  // error.
   bool (*set_read_state)(SSL *ssl, ssl_encryption_level_t level,
-                         UniquePtr<SSLAEADContext> aead_ctx);
+                         UniquePtr<SSLAEADContext> aead_ctx,
+                         Span<const uint8_t> secret_for_quic);
   // set_write_state sets |ssl|'s write cipher state and level to |aead_ctx| and
-  // |level|. It returns true on success and false if changing the write state
-  // is forbidden at this point.
+  // |level|. In QUIC, |aead_ctx| is a placeholder object and |secret_for_quic|
+  // is the original secret. This function returns true on success and false on
+  // error.
   bool (*set_write_state)(SSL *ssl, ssl_encryption_level_t level,
-                          UniquePtr<SSLAEADContext> aead_ctx);
+                          UniquePtr<SSLAEADContext> aead_ctx,
+                          Span<const uint8_t> secret_for_quic);
 };
 
 // The following wrappers call |open_*| but handle |read_shutdown| correctly.
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 2293405..cc6fdee 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -4850,6 +4850,8 @@
     return !levels_[level].write_secret.empty();
   }
 
+  void AllowOutOfOrderWrites() { allow_out_of_order_writes_ = true; }
+
   bool SetReadSecret(ssl_encryption_level_t level, const SSL_CIPHER *cipher,
                      Span<const uint8_t> secret) {
     if (HasReadSecret(level)) {
@@ -4915,6 +4917,35 @@
                     << " write secret not yet configured";
       return false;
     }
+
+    // Although the levels are conceptually separate, BoringSSL finishes writing
+    // data from a previous level before installing keys for the next level.
+    if (!allow_out_of_order_writes_) {
+      switch (level) {
+        case ssl_encryption_early_data:
+          ADD_FAILURE() << "unexpected handshake data at early data level";
+          return false;
+        case ssl_encryption_initial:
+          if (!levels_[ssl_encryption_handshake].write_secret.empty()) {
+            ADD_FAILURE()
+                << LevelToString(level)
+                << " handshake data written after handshake keys installed";
+            return false;
+          }
+          OPENSSL_FALLTHROUGH;
+        case ssl_encryption_handshake:
+          if (!levels_[ssl_encryption_application].write_secret.empty()) {
+            ADD_FAILURE()
+                << LevelToString(level)
+                << " handshake data written after application keys installed";
+            return false;
+          }
+          OPENSSL_FALLTHROUGH;
+        case ssl_encryption_application:
+          break;
+      }
+    }
+
     levels_[level].write_data.insert(levels_[level].write_data.end(),
                                      data.begin(), data.end());
     return true;
@@ -4971,6 +5002,7 @@
   Role role_;
   MockQUICTransport *peer_ = nullptr;
 
+  bool allow_out_of_order_writes_ = false;
   bool has_alert_ = false;
   ssl_encryption_level_t alert_level_ = ssl_encryption_initial;
   uint8_t alert_ = 0;
@@ -5048,6 +5080,10 @@
            SSL_provide_quic_data(ssl, level, data.data(), data.size());
   }
 
+  void AllowOutOfOrderWrites() {
+    allow_out_of_order_writes_ = true;
+  }
+
   bool CreateClientAndServer() {
     client_.reset(SSL_new(client_ctx_.get()));
     server_.reset(SSL_new(server_ctx_.get()));
@@ -5061,6 +5097,10 @@
     transport_.reset(new MockQUICTransportPair);
     ex_data_.Set(client_.get(), transport_->client());
     ex_data_.Set(server_.get(), transport_->server());
+    if (allow_out_of_order_writes_) {
+      transport_->client()->AllowOutOfOrderWrites();
+      transport_->server()->AllowOutOfOrderWrites();
+    }
     return true;
   }
 
@@ -5183,6 +5223,8 @@
 
   bssl::UniquePtr<SSL> client_;
   bssl::UniquePtr<SSL> server_;
+
+  bool allow_out_of_order_writes_ = false;
 };
 
 UnownedSSLExData<MockQUICTransport> QUICMethodTest::ex_data_;
@@ -5486,6 +5528,8 @@
 
 // Test buffering write data until explicit flushes.
 TEST_F(QUICMethodTest, Buffered) {
+  AllowOutOfOrderWrites();
+
   struct BufferedFlight {
     std::vector<uint8_t> data[kNumQUICLevels];
   };
@@ -5535,6 +5579,8 @@
 // EncryptedExtensions in a single chunk, BoringSSL notices and rejects this on
 // key change.
 TEST_F(QUICMethodTest, ExcessProvidedData) {
+  AllowOutOfOrderWrites();
+
   auto add_handshake_data = [](SSL *ssl, enum ssl_encryption_level_t level,
                                const uint8_t *data, size_t len) -> int {
     // Switch everything to the initial level.
@@ -5580,10 +5626,10 @@
   EXPECT_EQ(transport_->client()->alert_level(), ssl_encryption_handshake);
   EXPECT_EQ(transport_->client()->alert(), SSL_AD_UNEXPECTED_MESSAGE);
 
-  // Sanity-check client did get far enough to process the ServerHello and
-  // install keys.
-  EXPECT_TRUE(transport_->client()->HasReadSecret(ssl_encryption_handshake));
+  // Sanity-check handshake secrets. The error is discovered while setting the
+  // read secret, so only the write secret has been installed.
   EXPECT_TRUE(transport_->client()->HasWriteSecret(ssl_encryption_handshake));
+  EXPECT_FALSE(transport_->client()->HasReadSecret(ssl_encryption_handshake));
 }
 
 // Test that |SSL_provide_quic_data| will reject data at the wrong level.
diff --git a/ssl/t1_enc.cc b/ssl/t1_enc.cc
index 8091021..73b6544 100644
--- a/ssl/t1_enc.cc
+++ b/ssl/t1_enc.cc
@@ -237,11 +237,13 @@
 
   if (direction == evp_aead_open) {
     return ssl->method->set_read_state(ssl, ssl_encryption_application,
-                                       std::move(aead_ctx));
+                                       std::move(aead_ctx),
+                                       /*secret_for_quic=*/{});
   }
 
   return ssl->method->set_write_state(ssl, ssl_encryption_application,
-                                      std::move(aead_ctx));
+                                      std::move(aead_ctx),
+                                      /*secret_for_quic=*/{});
 }
 
 int tls1_change_cipher_state(SSL_HANDSHAKE *hs,
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index 7228471..b889ac2 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -79,8 +79,10 @@
     if (level == ssl_encryption_initial) {
       bssl::UniquePtr<SSLAEADContext> null_ctx =
           SSLAEADContext::CreateNullCipher(SSL_is_dtls(ssl));
-      if (!null_ctx || !ssl->method->set_write_state(ssl, ssl_encryption_initial,
-                                                     std::move(null_ctx))) {
+      if (!null_ctx ||
+          !ssl->method->set_write_state(ssl, ssl_encryption_initial,
+                                        std::move(null_ctx),
+                                        /*secret_for_quic=*/{})) {
         return false;
       }
       ssl->s3->aead_write_ctx->SetVersionIfNullCipher(ssl->version);
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index 3a2e4e5..69a5578 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -143,33 +143,13 @@
                            Span<const uint8_t> traffic_secret) {
   uint16_t version = ssl_session_protocol_version(session);
   UniquePtr<SSLAEADContext> traffic_aead;
+  Span<const uint8_t> secret_for_quic;
   if (ssl->quic_method != nullptr) {
-    // Pass the traffic secrets to QUIC.
-    if (direction == evp_aead_open) {
-      if (!ssl->quic_method->set_read_secret(ssl, level, session->cipher,
-                                             traffic_secret.data(),
-                                             traffic_secret.size())) {
-        return false;
-      }
-    } else {
-      if (!ssl->quic_method->set_write_secret(ssl, level, session->cipher,
-                                              traffic_secret.data(),
-                                              traffic_secret.size())) {
-        return false;
-      }
-    }
-
-    // QUIC only uses |ssl| for handshake messages, which never use early data
-    // keys, so we return installing anything. This avoids needing to have two
-    // secrets active at once in 0-RTT.
-    if (level == ssl_encryption_early_data) {
-      return true;
-    }
-
     // Install a placeholder SSLAEADContext so that SSL accessors work. The
     // encryption itself will be handled by the SSL_QUIC_METHOD.
     traffic_aead =
         SSLAEADContext::CreatePlaceholderForQUIC(version, session->cipher);
+    secret_for_quic = traffic_secret;
   } else {
     // Look up cipher suite properties.
     const EVP_AEAD *aead;
@@ -217,14 +197,16 @@
   }
 
   if (direction == evp_aead_open) {
-    if (!ssl->method->set_read_state(ssl, level, std::move(traffic_aead))) {
+    if (!ssl->method->set_read_state(ssl, level, std::move(traffic_aead),
+                                     secret_for_quic)) {
       return false;
     }
     OPENSSL_memmove(ssl->s3->read_traffic_secret, traffic_secret.data(),
                     traffic_secret.size());
     ssl->s3->read_traffic_secret_len = traffic_secret.size();
   } else {
-    if (!ssl->method->set_write_state(ssl, level, std::move(traffic_aead))) {
+    if (!ssl->method->set_write_state(ssl, level, std::move(traffic_aead),
+                                      secret_for_quic)) {
       return false;
     }
     OPENSSL_memmove(ssl->s3->write_traffic_secret, traffic_secret.data(),
diff --git a/ssl/tls_method.cc b/ssl/tls_method.cc
index 3868852..8165d1c 100644
--- a/ssl/tls_method.cc
+++ b/ssl/tls_method.cc
@@ -83,7 +83,8 @@
 }
 
 static bool tls_set_read_state(SSL *ssl, ssl_encryption_level_t level,
-                               UniquePtr<SSLAEADContext> aead_ctx) {
+                               UniquePtr<SSLAEADContext> aead_ctx,
+                               Span<const uint8_t> secret_for_quic) {
   // Cipher changes are forbidden if the current epoch has leftover data.
   if (tls_has_unprocessed_handshake_data(ssl)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESS_HANDSHAKE_DATA);
@@ -91,6 +92,21 @@
     return false;
   }
 
+  if (ssl->quic_method != nullptr) {
+    if (!ssl->quic_method->set_read_secret(ssl, level, aead_ctx->cipher(),
+                                           secret_for_quic.data(),
+                                           secret_for_quic.size())) {
+      return false;
+    }
+
+    // QUIC only uses |ssl| for handshake messages, which never use early data
+    // keys, so we return without installing anything. This avoids needing to
+    // have two secrets active at once in 0-RTT.
+    if (level == ssl_encryption_early_data) {
+      return true;
+    }
+  }
+
   OPENSSL_memset(ssl->s3->read_sequence, 0, sizeof(ssl->s3->read_sequence));
   ssl->s3->aead_read_ctx = std::move(aead_ctx);
   ssl->s3->read_level = level;
@@ -98,11 +114,27 @@
 }
 
 static bool tls_set_write_state(SSL *ssl, ssl_encryption_level_t level,
-                                UniquePtr<SSLAEADContext> aead_ctx) {
+                                UniquePtr<SSLAEADContext> aead_ctx,
+                                Span<const uint8_t> secret_for_quic) {
   if (!tls_flush_pending_hs_data(ssl)) {
     return false;
   }
 
+  if (ssl->quic_method != nullptr) {
+    if (!ssl->quic_method->set_write_secret(ssl, level, aead_ctx->cipher(),
+                                            secret_for_quic.data(),
+                                            secret_for_quic.size())) {
+      return false;
+    }
+
+    // QUIC only uses |ssl| for handshake messages, which never use early data
+    // keys, so we return without installing anything. This avoids needing to
+    // have two secrets active at once in 0-RTT.
+    if (level == ssl_encryption_early_data) {
+      return true;
+    }
+  }
+
   OPENSSL_memset(ssl->s3->write_sequence, 0, sizeof(ssl->s3->write_sequence));
   ssl->s3->aead_write_ctx = std::move(aead_ctx);
   ssl->s3->write_level = level;