Allow configuring QUIC method per-connection

This allows sharing SSL_CTX between TCP and QUIC connections, such that
common settings can be configured without having to duplicate the
context.

Change-Id: Ie920e7f2a772dd6c6c7b63fdac243914ac5b7b26
Reviewed-on: https://boringssl-review.googlesource.com/c/33904
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index c128605..a011e0f 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -3170,6 +3170,12 @@
 OPENSSL_EXPORT int SSL_CTX_set_quic_method(SSL_CTX *ctx,
                                            const SSL_QUIC_METHOD *quic_method);
 
+// SSL_set_quic_method configures the QUIC hooks. This should only be
+// configured with a minimum version of TLS 1.3. |quic_method| must remain valid
+// for the lifetime of |ssl|. It returns one on success and zero on error.
+OPENSSL_EXPORT int SSL_set_quic_method(SSL *ssl,
+                                       const SSL_QUIC_METHOD *quic_method);
+
 
 // Early data.
 //
diff --git a/ssl/handshake.cc b/ssl/handshake.cc
index 058a793..091ed44 100644
--- a/ssl/handshake.cc
+++ b/ssl/handshake.cc
@@ -544,7 +544,7 @@
       case ssl_hs_read_server_hello:
       case ssl_hs_read_message:
       case ssl_hs_read_change_cipher_spec: {
-        if (ssl->ctx->quic_method) {
+        if (ssl->quic_method) {
           hs->wait = ssl_hs_ok;
           // The change cipher spec is omitted in QUIC.
           if (hs->wait != ssl_hs_read_change_cipher_spec) {
diff --git a/ssl/internal.h b/ssl/internal.h
index 07e1b89..1116bad 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -3170,6 +3170,9 @@
   uint32_t max_cert_list = 0;
   bssl::UniquePtr<char> hostname;
 
+  // quic_method is the method table corresponding to the QUIC hooks.
+  const SSL_QUIC_METHOD *quic_method = nullptr;
+
   // renegotiate_mode controls how peer renegotiation attempts are handled.
   ssl_renegotiate_mode_t renegotiate_mode = ssl_renegotiate_never;
 
diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc
index 02bc3bb..aec6cae 100644
--- a/ssl/s3_both.cc
+++ b/ssl/s3_both.cc
@@ -192,7 +192,7 @@
   //
   // TODO(davidben): See if we can do this uniformly.
   Span<const uint8_t> rest = msg;
-  if (ssl->ctx->quic_method == nullptr &&
+  if (ssl->quic_method == nullptr &&
       ssl->s3->aead_write_ctx->is_null_cipher()) {
     while (!rest.empty()) {
       Span<const uint8_t> chunk = rest.subspan(0, ssl->max_send_fragment);
@@ -248,9 +248,9 @@
   auto data =
       MakeConstSpan(reinterpret_cast<const uint8_t *>(pending_hs_data->data),
                     pending_hs_data->length);
-  if (ssl->ctx->quic_method) {
-    if (!ssl->ctx->quic_method->add_handshake_data(ssl, ssl->s3->write_level,
-                                                   data.data(), data.size())) {
+  if (ssl->quic_method) {
+    if (!ssl->quic_method->add_handshake_data(ssl, ssl->s3->write_level,
+                                              data.data(), data.size())) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
       return false;
     }
@@ -267,7 +267,7 @@
     return false;
   }
 
-  if (!ssl->ctx->quic_method &&
+  if (!ssl->quic_method &&
       !add_record_to_flight(ssl, SSL3_RT_CHANGE_CIPHER_SPEC,
                             kChangeCipherSpec)) {
     return false;
@@ -283,13 +283,13 @@
     return -1;
   }
 
-  if (ssl->ctx->quic_method) {
+  if (ssl->quic_method) {
     if (ssl->s3->write_shutdown != ssl_shutdown_none) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
       return -1;
     }
 
-    if (!ssl->ctx->quic_method->flush_flight(ssl)) {
+    if (!ssl->quic_method->flush_flight(ssl)) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
       return -1;
     }
diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc
index f1d67d0..abc6798 100644
--- a/ssl/s3_pkt.cc
+++ b/ssl/s3_pkt.cc
@@ -410,9 +410,9 @@
 }
 
 int ssl3_dispatch_alert(SSL *ssl) {
-  if (ssl->ctx->quic_method) {
-    if (!ssl->ctx->quic_method->send_alert(ssl, ssl->s3->write_level,
-                                           ssl->s3->send_alert[1])) {
+  if (ssl->quic_method) {
+    if (!ssl->quic_method->send_alert(ssl, ssl->s3->write_level,
+                                      ssl->s3->send_alert[1])) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
       return 0;
     }
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index bbc3758..f3d92a6 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -714,6 +714,7 @@
   ssl->config->ocsp_stapling_enabled = ctx->ocsp_stapling_enabled;
   ssl->config->handoff = ctx->handoff;
   ssl->config->ignore_tls13_downgrade = ctx->ignore_tls13_downgrade;
+  ssl->quic_method = ctx->quic_method;
 
   if (!ssl->method->ssl_new(ssl.get()) ||
       !ssl->ctx->x509_method->ssl_new(ssl->s3->hs.get())) {
@@ -850,7 +851,7 @@
 
 int SSL_provide_quic_data(SSL *ssl, enum ssl_encryption_level_t level,
                           const uint8_t *data, size_t len) {
-  if (ssl->ctx->quic_method == nullptr) {
+  if (ssl->quic_method == nullptr) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
     return 0;
   }
@@ -1077,7 +1078,7 @@
 }
 
 int SSL_peek(SSL *ssl, void *buf, int num) {
-  if (ssl->ctx->quic_method != nullptr) {
+  if (ssl->quic_method != nullptr) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
     return 0;
   }
@@ -1098,7 +1099,7 @@
 int SSL_write(SSL *ssl, const void *buf, int num) {
   ssl_reset_error_state(ssl);
 
-  if (ssl->ctx->quic_method != nullptr) {
+  if (ssl->quic_method != nullptr) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
     return 0;
   }
@@ -1342,7 +1343,7 @@
       return SSL_ERROR_HANDBACK;
 
     case SSL_READING: {
-      if (ssl->ctx->quic_method) {
+      if (ssl->quic_method) {
         return SSL_ERROR_WANT_READ;
       }
       BIO *bio = SSL_get_rbio(ssl);
@@ -2459,6 +2460,14 @@
   return 1;
 }
 
+int SSL_set_quic_method(SSL *ssl, const SSL_QUIC_METHOD *quic_method) {
+  if (ssl->method->is_dtls) {
+    return 0;
+  }
+  ssl->quic_method = quic_method;
+  return 1;
+}
+
 int SSL_get_ex_new_index(long argl, void *argp, CRYPTO_EX_unused *unused,
                          CRYPTO_EX_dup *dup_unused, CRYPTO_EX_free *free_func) {
   int index;
diff --git a/ssl/ssl_versions.cc b/ssl/ssl_versions.cc
index 39540f1..e6dbc8d 100644
--- a/ssl/ssl_versions.cc
+++ b/ssl/ssl_versions.cc
@@ -192,7 +192,7 @@
   uint16_t max_version = hs->config->conf_max_version;
 
   // QUIC requires TLS 1.3.
-  if (hs->ssl->ctx->quic_method && min_version < TLS1_3_VERSION) {
+  if (hs->ssl->quic_method && min_version < TLS1_3_VERSION) {
     min_version = TLS1_3_VERSION;
   }
 
diff --git a/ssl/tls13_both.cc b/ssl/tls13_both.cc
index 605942a..7674d99 100644
--- a/ssl/tls13_both.cc
+++ b/ssl/tls13_both.cc
@@ -655,7 +655,7 @@
 bool tls13_post_handshake(SSL *ssl, const SSLMessage &msg) {
   if (msg.type == SSL3_MT_KEY_UPDATE) {
     ssl->s3->key_update_count++;
-    if (ssl->ctx->quic_method != nullptr ||
+    if (ssl->quic_method != nullptr ||
         ssl->s3->key_update_count > kMaxKeyUpdates) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_TOO_MANY_KEY_UPDATES);
       ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index a59686f..7353561 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -144,7 +144,7 @@
   }
 
   UniquePtr<SSLAEADContext> traffic_aead;
-  if (ssl->ctx->quic_method == nullptr) {
+  if (ssl->quic_method == nullptr) {
     // Look up cipher suite properties.
     const EVP_AEAD *aead;
     size_t discard;
@@ -237,16 +237,16 @@
   }
   ssl->s3->early_exporter_secret_len = hs->hash_len;
 
-  if (ssl->ctx->quic_method != nullptr) {
+  if (ssl->quic_method != nullptr) {
     if (ssl->server) {
-      if (!ssl->ctx->quic_method->set_encryption_secrets(
+      if (!ssl->quic_method->set_encryption_secrets(
               ssl, ssl_encryption_early_data, nullptr, hs->early_traffic_secret,
               hs->hash_len)) {
         OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
         return false;
       }
     } else {
-      if (!ssl->ctx->quic_method->set_encryption_secrets(
+      if (!ssl->quic_method->set_encryption_secrets(
               ssl, ssl_encryption_early_data, hs->early_traffic_secret, nullptr,
               hs->hash_len)) {
         OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
@@ -273,16 +273,16 @@
     return false;
   }
 
-  if (ssl->ctx->quic_method != nullptr) {
+  if (ssl->quic_method != nullptr) {
     if (ssl->server) {
-      if (!ssl->ctx->quic_method->set_encryption_secrets(
+      if (!ssl->quic_method->set_encryption_secrets(
               ssl, ssl_encryption_handshake, hs->client_handshake_secret,
               hs->server_handshake_secret, hs->hash_len)) {
         OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
         return false;
       }
     } else {
-      if (!ssl->ctx->quic_method->set_encryption_secrets(
+      if (!ssl->quic_method->set_encryption_secrets(
               ssl, ssl_encryption_handshake, hs->server_handshake_secret,
               hs->client_handshake_secret, hs->hash_len)) {
         OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
@@ -314,16 +314,16 @@
     return false;
   }
 
-  if (ssl->ctx->quic_method != nullptr) {
+  if (ssl->quic_method != nullptr) {
     if (ssl->server) {
-      if (!ssl->ctx->quic_method->set_encryption_secrets(
+      if (!ssl->quic_method->set_encryption_secrets(
               ssl, ssl_encryption_application, hs->client_traffic_secret_0,
               hs->server_traffic_secret_0, hs->hash_len)) {
         OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
         return false;
       }
     } else {
-      if (!ssl->ctx->quic_method->set_encryption_secrets(
+      if (!ssl->quic_method->set_encryption_secrets(
               ssl, ssl_encryption_application, hs->server_traffic_secret_0,
               hs->client_traffic_secret_0, hs->hash_len)) {
         OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);