Output a ClientHello during handoff.

This will allow edge servers to pass judgement on the ClientHello before
completing the handoff process. This also means that edge servers will
now enforce ClientHello well-formedness — previously that check didn't
occur until the handshaker tried to parse the handoff submission.

Change-Id: I9804ac0224632b4b4381c1a81f434d188e0b9376
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/35584
Commit-Queue: Adam Langley <agl@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 629f006..4240c29 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -4757,7 +4757,8 @@
 
 OPENSSL_EXPORT void SSL_CTX_set_handoff_mode(SSL_CTX *ctx, bool on);
 OPENSSL_EXPORT void SSL_set_handoff_mode(SSL *SSL, bool on);
-OPENSSL_EXPORT bool SSL_serialize_handoff(const SSL *ssl, CBB *out);
+OPENSSL_EXPORT bool SSL_serialize_handoff(const SSL *ssl, CBB *out,
+                                          SSL_CLIENT_HELLO *out_hello);
 OPENSSL_EXPORT bool SSL_decline_handoff(SSL *ssl);
 OPENSSL_EXPORT bool SSL_apply_handoff(SSL *ssl, Span<const uint8_t> handoff);
 OPENSSL_EXPORT bool SSL_serialize_handback(const SSL *ssl, CBB *out);
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index d82852d..f68cd1c 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -405,7 +405,7 @@
   return ssl_open_record_success;
 }
 
-bool dtls1_get_message(SSL *ssl, SSLMessage *out) {
+bool dtls1_get_message(const SSL *ssl, SSLMessage *out) {
   if (!dtls1_is_current_message_complete(ssl)) {
     return false;
   }
diff --git a/ssl/handoff.cc b/ssl/handoff.cc
index f9dbd13..0928015 100644
--- a/ssl/handoff.cc
+++ b/ssl/handoff.cc
@@ -49,7 +49,8 @@
   return CBB_flush(out);
 }
 
-bool SSL_serialize_handoff(const SSL *ssl, CBB *out) {
+bool SSL_serialize_handoff(const SSL *ssl, CBB *out,
+                           SSL_CLIENT_HELLO *out_hello) {
   const SSL3_STATE *const s3 = ssl->s3;
   if (!ssl->server ||
       s3->hs == nullptr ||
@@ -58,6 +59,7 @@
   }
 
   CBB seq;
+  SSLMessage msg;
   Span<const uint8_t> transcript = s3->hs->transcript.buffer();
   if (!CBB_add_asn1(out, &seq, CBS_ASN1_SEQUENCE) ||
       !CBB_add_asn1_uint64(&seq, kHandoffVersion) ||
@@ -66,7 +68,9 @@
                                  reinterpret_cast<uint8_t *>(s3->hs_buf->data),
                                  s3->hs_buf->length) ||
       !serialize_features(&seq) ||
-      !CBB_flush(out)) {
+      !CBB_flush(out) ||
+      !ssl->method->get_message(ssl, &msg) ||
+      !ssl_client_hello_init(ssl, out_hello, msg)) {
     return false;
   }
 
diff --git a/ssl/handshake_server.cc b/ssl/handshake_server.cc
index cb4e9d1..4622ad0 100644
--- a/ssl/handshake_server.cc
+++ b/ssl/handshake_server.cc
@@ -515,10 +515,6 @@
     return ssl_hs_error;
   }
 
-  if (hs->config->handoff) {
-    return ssl_hs_handoff;
-  }
-
   SSL_CLIENT_HELLO client_hello;
   if (!ssl_client_hello_init(ssl, &client_hello, msg)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
@@ -526,6 +522,10 @@
     return ssl_hs_error;
   }
 
+  if (hs->config->handoff) {
+    return ssl_hs_handoff;
+  }
+
   // Run the early callback.
   if (ssl->ctx->select_certificate_cb != NULL) {
     switch (ssl->ctx->select_certificate_cb(&client_hello)) {
diff --git a/ssl/internal.h b/ssl/internal.h
index 16b2866..ee2952a 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1081,7 +1081,7 @@
 void ssl_do_info_callback(const SSL *ssl, int type, int value);
 
 // ssl_do_msg_callback calls |ssl|'s message callback, if set.
-void ssl_do_msg_callback(SSL *ssl, int is_write, int content_type,
+void ssl_do_msg_callback(const SSL *ssl, int is_write, int content_type,
                          Span<const uint8_t> in);
 
 
@@ -1798,7 +1798,7 @@
 
 // ClientHello functions.
 
-bool ssl_client_hello_init(SSL *ssl, SSL_CLIENT_HELLO *out,
+bool ssl_client_hello_init(const SSL *ssl, SSL_CLIENT_HELLO *out,
                            const SSLMessage &msg);
 
 bool ssl_client_hello_get_extension(const SSL_CLIENT_HELLO *client_hello,
@@ -1958,7 +1958,7 @@
   void (*ssl_free)(SSL *ssl);
   // get_message sets |*out| to the current handshake message and returns true
   // if one has been received. It returns false if more input is needed.
-  bool (*get_message)(SSL *ssl, SSLMessage *out);
+  bool (*get_message)(const SSL *ssl, SSLMessage *out);
   // next_message is called to release the current handshake message.
   void (*next_message)(SSL *ssl);
   // Use the |ssl_open_handshake| wrapper.
@@ -2675,7 +2675,7 @@
 void ssl_update_cache(SSL_HANDSHAKE *hs, int mode);
 
 int ssl_send_alert(SSL *ssl, int level, int desc);
-bool ssl3_get_message(SSL *ssl, SSLMessage *out);
+bool ssl3_get_message(const SSL *ssl, SSLMessage *out);
 ssl_open_record_t ssl3_open_handshake(SSL *ssl, size_t *out_consumed,
                                       uint8_t *out_alert, Span<uint8_t> in);
 void ssl3_next_message(SSL *ssl);
@@ -2741,7 +2741,7 @@
 bool dtls1_new(SSL *ssl);
 void dtls1_free(SSL *ssl);
 
-bool dtls1_get_message(SSL *ssl, SSLMessage *out);
+bool dtls1_get_message(const SSL *ssl, SSLMessage *out);
 ssl_open_record_t dtls1_open_handshake(SSL *ssl, size_t *out_consumed,
                                        uint8_t *out_alert, Span<uint8_t> in);
 void dtls1_next_message(SSL *ssl);
diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc
index aec6cae..27e9454 100644
--- a/ssl/s3_both.cc
+++ b/ssl/s3_both.cc
@@ -494,7 +494,7 @@
   return true;
 }
 
-bool ssl3_get_message(SSL *ssl, SSLMessage *out) {
+bool ssl3_get_message(const SSL *ssl, SSLMessage *out) {
   size_t unused;
   if (!parse_message(ssl, out, &unused)) {
     return false;
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index d3e76d0..f9910f7 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -378,7 +378,7 @@
   }
 }
 
-void ssl_do_msg_callback(SSL *ssl, int is_write, int content_type,
+void ssl_do_msg_callback(const SSL *ssl, int is_write, int content_type,
                          Span<const uint8_t> in) {
   if (ssl->msg_callback == NULL) {
     return;
@@ -399,8 +399,8 @@
       version = SSL_version(ssl);
   }
 
-  ssl->msg_callback(is_write, version, content_type, in.data(), in.size(), ssl,
-                    ssl->msg_callback_arg);
+  ssl->msg_callback(is_write, version, content_type, in.data(), in.size(),
+                    const_cast<SSL *>(ssl), ssl->msg_callback_arg);
 }
 
 void ssl_get_current_time(const SSL *ssl, struct OPENSSL_timeval *out_clock) {
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index cd6f389..d01b649 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -4050,8 +4050,9 @@
 
   ScopedCBB cbb;
   Array<uint8_t> handoff;
+  SSL_CLIENT_HELLO hello;
   ASSERT_TRUE(CBB_init(cbb.get(), 256));
-  ASSERT_TRUE(SSL_serialize_handoff(server.get(), cbb.get()));
+  ASSERT_TRUE(SSL_serialize_handoff(server.get(), cbb.get(), &hello));
   ASSERT_TRUE(CBBFinishArray(cbb.get(), &handoff));
 
   bssl::UniquePtr<SSL> handshaker(SSL_new(handshaker_ctx.get()));
@@ -4122,8 +4123,9 @@
   ASSERT_EQ(server_err, SSL_ERROR_HANDOFF);
 
   ScopedCBB cbb;
+  SSL_CLIENT_HELLO hello;
   ASSERT_TRUE(CBB_init(cbb.get(), 256));
-  ASSERT_TRUE(SSL_serialize_handoff(server.get(), cbb.get()));
+  ASSERT_TRUE(SSL_serialize_handoff(server.get(), cbb.get(), &hello));
 
   ASSERT_TRUE(SSL_decline_handoff(server.get()));
 
diff --git a/ssl/t1_lib.cc b/ssl/t1_lib.cc
index c0452dc..87f1888 100644
--- a/ssl/t1_lib.cc
+++ b/ssl/t1_lib.cc
@@ -199,10 +199,10 @@
   return true;
 }
 
-bool ssl_client_hello_init(SSL *ssl, SSL_CLIENT_HELLO *out,
+bool ssl_client_hello_init(const SSL *ssl, SSL_CLIENT_HELLO *out,
                            const SSLMessage &msg) {
   OPENSSL_memset(out, 0, sizeof(*out));
-  out->ssl = ssl;
+  out->ssl = const_cast<SSL *>(ssl);
   out->client_hello = CBS_data(&msg.body);
   out->client_hello_len = CBS_len(&msg.body);
 
diff --git a/ssl/test/handshake_util.cc b/ssl/test/handshake_util.cc
index a36b41a..afead7f 100644
--- a/ssl/test/handshake_util.cc
+++ b/ssl/test/handshake_util.cc
@@ -413,8 +413,9 @@
   }
 
   ScopedCBB cbb;
+  SSL_CLIENT_HELLO hello;
   if (!CBB_init(cbb.get(), 512) ||
-      !SSL_serialize_handoff(ssl, cbb.get()) ||
+      !SSL_serialize_handoff(ssl, cbb.get(), &hello) ||
       !writer->WriteHandoff({CBB_data(cbb.get()), CBB_len(cbb.get())}) ||
       !SerializeContextState(ssl->ctx.get(), cbb.get()) ||
       !GetTestState(ssl)->Serialize(cbb.get())) {