Add tests for split handshakes.

This change adds a couple of focused tests to ssl_test.cc, but also
programmically duplicates many runner tests in a split-handshake mode.

Change-Id: I9dafc8a394581e5daf1318722e1015de82117fd9
Reviewed-on: https://boringssl-review.googlesource.com/25388
Commit-Queue: Adam Langley <agl@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/internal.h b/ssl/internal.h
index 937c9fe..b67637d 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -342,7 +342,7 @@
 };
 
 // CBBFinishArray behaves like |CBB_finish| but stores the result in an Array.
-bool CBBFinishArray(CBB *cbb, Array<uint8_t> *out);
+OPENSSL_EXPORT bool CBBFinishArray(CBB *cbb, Array<uint8_t> *out);
 
 
 // Protocol versions.
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index b2042ea..0f2a33c 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -1537,7 +1537,8 @@
 static bool ConnectClientAndServer(bssl::UniquePtr<SSL> *out_client,
                                    bssl::UniquePtr<SSL> *out_server,
                                    SSL_CTX *client_ctx, SSL_CTX *server_ctx,
-                                   const ClientConfig &config = ClientConfig()) {
+                                   const ClientConfig &config = ClientConfig(),
+                                   bool do_handshake = true) {
   bssl::UniquePtr<SSL> client(SSL_new(client_ctx)), server(SSL_new(server_ctx));
   if (!client || !server) {
     return false;
@@ -1561,7 +1562,7 @@
   SSL_set_bio(client.get(), bio1, bio1);
   SSL_set_bio(server.get(), bio2, bio2);
 
-  if (!CompleteHandshakes(client.get(), server.get())) {
+  if (do_handshake && !CompleteHandshakes(client.get(), server.get())) {
     return false;
   }
 
@@ -3923,6 +3924,139 @@
   EXPECT_TRUE(SSL_is_signature_algorithm_rsa_pss(SSL_SIGN_RSA_PSS_SHA384));
 }
 
+void MoveBIOs(SSL *dest, SSL *src) {
+  BIO *rbio = SSL_get_rbio(src);
+  BIO_up_ref(rbio);
+  SSL_set0_rbio(dest, rbio);
+
+  BIO *wbio = SSL_get_wbio(src);
+  BIO_up_ref(wbio);
+  SSL_set0_wbio(dest, wbio);
+
+  SSL_set0_rbio(src, nullptr);
+  SSL_set0_wbio(src, nullptr);
+}
+
+TEST(SSLTest, Handoff) {
+  bssl::UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(TLS_method()));
+  bssl::UniquePtr<SSL_CTX> server_ctx(SSL_CTX_new(TLS_method()));
+  bssl::UniquePtr<SSL_CTX> handshaker_ctx(SSL_CTX_new(TLS_method()));
+  ASSERT_TRUE(client_ctx);
+  ASSERT_TRUE(server_ctx);
+  ASSERT_TRUE(handshaker_ctx);
+
+  SSL_CTX_set_handoff_mode(server_ctx.get(), 1);
+  ASSERT_TRUE(SSL_CTX_set_max_proto_version(server_ctx.get(), TLS1_2_VERSION));
+  ASSERT_TRUE(
+      SSL_CTX_set_max_proto_version(handshaker_ctx.get(), TLS1_2_VERSION));
+
+  bssl::UniquePtr<X509> cert = GetTestCertificate();
+  bssl::UniquePtr<EVP_PKEY> key = GetTestKey();
+  ASSERT_TRUE(cert);
+  ASSERT_TRUE(key);
+  ASSERT_TRUE(SSL_CTX_use_certificate(handshaker_ctx.get(), cert.get()));
+  ASSERT_TRUE(SSL_CTX_use_PrivateKey(handshaker_ctx.get(), key.get()));
+
+  bssl::UniquePtr<SSL> client, server;
+  ASSERT_TRUE(ConnectClientAndServer(&client, &server, client_ctx.get(),
+                                     server_ctx.get(), ClientConfig(),
+                                     false /* don't handshake */));
+
+  int client_ret = SSL_do_handshake(client.get());
+  int client_err = SSL_get_error(client.get(), client_ret);
+  ASSERT_EQ(client_err, SSL_ERROR_WANT_READ);
+
+  int server_ret = SSL_do_handshake(server.get());
+  int server_err = SSL_get_error(server.get(), server_ret);
+  ASSERT_EQ(server_err, SSL_ERROR_HANDOFF);
+
+  ScopedCBB cbb;
+  Array<uint8_t> handoff;
+  ASSERT_TRUE(CBB_init(cbb.get(), 256));
+  ASSERT_TRUE(SSL_serialize_handoff(server.get(), cbb.get()));
+  ASSERT_TRUE(CBBFinishArray(cbb.get(), &handoff));
+
+  bssl::UniquePtr<SSL> handshaker(SSL_new(handshaker_ctx.get()));
+  ASSERT_TRUE(SSL_apply_handoff(handshaker.get(), handoff));
+
+  MoveBIOs(handshaker.get(), server.get());
+
+  int handshake_ret = SSL_do_handshake(handshaker.get());
+  int handshake_err = SSL_get_error(handshaker.get(), handshake_ret);
+  ASSERT_EQ(handshake_err, SSL_ERROR_WANT_READ);
+
+  ASSERT_TRUE(CompleteHandshakes(client.get(), handshaker.get()));
+
+  ScopedCBB cbb_handback;
+  Array<uint8_t> handback;
+  ASSERT_TRUE(CBB_init(cbb_handback.get(), 1024));
+  ASSERT_TRUE(SSL_serialize_handback(handshaker.get(), cbb_handback.get()));
+  ASSERT_TRUE(CBBFinishArray(cbb_handback.get(), &handback));
+
+  bssl::UniquePtr<SSL> server2(SSL_new(server_ctx.get()));
+  ASSERT_TRUE(SSL_apply_handback(server2.get(), handback));
+
+  MoveBIOs(server2.get(), handshaker.get());
+
+  uint8_t byte = 42;
+  EXPECT_EQ(SSL_write(client.get(), &byte, 1), 1);
+  EXPECT_EQ(SSL_read(server2.get(), &byte, 1), 1);
+  EXPECT_EQ(42, byte);
+
+  byte = 43;
+  EXPECT_EQ(SSL_write(server2.get(), &byte, 1), 1);
+  EXPECT_EQ(SSL_read(client.get(), &byte, 1), 1);
+  EXPECT_EQ(43, byte);
+}
+
+TEST(SSLTest, HandoffDeclined) {
+  bssl::UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(TLS_method()));
+  bssl::UniquePtr<SSL_CTX> server_ctx(SSL_CTX_new(TLS_method()));
+  ASSERT_TRUE(client_ctx);
+  ASSERT_TRUE(server_ctx);
+
+  SSL_CTX_set_handoff_mode(server_ctx.get(), 1);
+  ASSERT_TRUE(SSL_CTX_set_max_proto_version(server_ctx.get(), TLS1_2_VERSION));
+
+  bssl::UniquePtr<X509> cert = GetTestCertificate();
+  bssl::UniquePtr<EVP_PKEY> key = GetTestKey();
+  ASSERT_TRUE(cert);
+  ASSERT_TRUE(key);
+  ASSERT_TRUE(SSL_CTX_use_certificate(server_ctx.get(), cert.get()));
+  ASSERT_TRUE(SSL_CTX_use_PrivateKey(server_ctx.get(), key.get()));
+
+  bssl::UniquePtr<SSL> client, server;
+  ASSERT_TRUE(ConnectClientAndServer(&client, &server, client_ctx.get(),
+                                     server_ctx.get(), ClientConfig(),
+                                     false /* don't handshake */));
+
+  int client_ret = SSL_do_handshake(client.get());
+  int client_err = SSL_get_error(client.get(), client_ret);
+  ASSERT_EQ(client_err, SSL_ERROR_WANT_READ);
+
+  int server_ret = SSL_do_handshake(server.get());
+  int server_err = SSL_get_error(server.get(), server_ret);
+  ASSERT_EQ(server_err, SSL_ERROR_HANDOFF);
+
+  ScopedCBB cbb;
+  ASSERT_TRUE(CBB_init(cbb.get(), 256));
+  ASSERT_TRUE(SSL_serialize_handoff(server.get(), cbb.get()));
+
+  ASSERT_TRUE(SSL_decline_handoff(server.get()));
+
+  ASSERT_TRUE(CompleteHandshakes(client.get(), server.get()));
+
+  uint8_t byte = 42;
+  EXPECT_EQ(SSL_write(client.get(), &byte, 1), 1);
+  EXPECT_EQ(SSL_read(server.get(), &byte, 1), 1);
+  EXPECT_EQ(42, byte);
+
+  byte = 43;
+  EXPECT_EQ(SSL_write(server.get(), &byte, 1), 1);
+  EXPECT_EQ(SSL_read(client.get(), &byte, 1), 1);
+  EXPECT_EQ(43, byte);
+}
+
 // TODO(davidben): Convert this file to GTest properly.
 TEST(SSLTest, AllTests) {
   if (!TestSSL_SESSIONEncoding(kOpenSSLSession) ||
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index 6885b0f..5790dc3 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -150,6 +150,32 @@
   return (TestState *)SSL_get_ex_data(ssl, g_state_index);
 }
 
+static bool MoveExData(SSL *dest, SSL *src) {
+  TestState *state = GetTestState(src);
+  const TestConfig *config = GetTestConfig(src);
+  if (!SSL_set_ex_data(src, g_state_index, nullptr) ||
+      !SSL_set_ex_data(dest, g_state_index, state) ||
+      !SSL_set_ex_data(src, g_config_index, nullptr) ||
+      !SSL_set_ex_data(dest, g_config_index, (void *) config)) {
+    return false;
+  }
+
+  return true;
+}
+
+static void MoveBIOs(SSL *dest, SSL *src) {
+  BIO *rbio = SSL_get_rbio(src);
+  BIO_up_ref(rbio);
+  SSL_set0_rbio(dest, rbio);
+
+  BIO *wbio = SSL_get_wbio(src);
+  BIO_up_ref(wbio);
+  SSL_set0_wbio(dest, wbio);
+
+  SSL_set0_rbio(src, nullptr);
+  SSL_set0_wbio(src, nullptr);
+}
+
 static bool LoadCertificate(bssl::UniquePtr<X509> *out_x509,
                             bssl::UniquePtr<STACK_OF(X509)> *out_chain,
                             const std::string &file) {
@@ -1902,7 +1928,8 @@
   return fwrite(settings, settings_len, 1, file.get()) == 1;
 }
 
-static bool DoExchange(bssl::UniquePtr<SSL_SESSION> *out_session, SSL *ssl,
+static bool DoExchange(bssl::UniquePtr<SSL_SESSION> *out_session,
+                       bssl::UniquePtr<SSL> *ssl_uniqueptr,
                        const TestConfig *config, bool is_resume, bool is_retry);
 
 // DoConnection tests an SSL connection against the peer. On success, it returns
@@ -2155,7 +2182,7 @@
     SSL_set_connect_state(ssl.get());
   }
 
-  bool ret = DoExchange(out_session, ssl.get(), config, is_resume, false);
+  bool ret = DoExchange(out_session, &ssl, config, is_resume, false);
   if (!config->is_server && is_resume && config->expect_reject_early_data) {
     // We must have failed due to an early data rejection.
     if (ret) {
@@ -2189,7 +2216,8 @@
       return false;
     }
 
-    ret = DoExchange(out_session, ssl.get(), retry_config, is_resume, true);
+    assert(!config->handoff);
+    ret = DoExchange(out_session, &ssl, retry_config, is_resume, true);
   }
 
   if (!ret) {
@@ -2212,21 +2240,110 @@
   return true;
 }
 
-static bool DoExchange(bssl::UniquePtr<SSL_SESSION> *out_session, SSL *ssl,
+static bool HandoffReady(SSL *ssl, int ret) {
+  return ret < 0 && SSL_get_error(ssl, ret) == SSL_ERROR_HANDOFF;
+}
+
+static bool DoExchange(bssl::UniquePtr<SSL_SESSION> *out_session,
+                       bssl::UniquePtr<SSL> *ssl_uniqueptr,
                        const TestConfig *config, bool is_resume,
                        bool is_retry) {
   int ret;
+  SSL *ssl = ssl_uniqueptr->get();
+
   if (!config->implicit_handshake) {
+    if (config->handoff) {
+      bssl::UniquePtr<SSL_CTX> ctx_handoff(SSL_CTX_new(TLSv1_method()));
+      if (!ctx_handoff) {
+        return false;
+      }
+      SSL_CTX_set_handoff_mode(ctx_handoff.get(), 1);
+
+      bssl::UniquePtr<SSL> ssl_handoff(SSL_new(ctx_handoff.get()));
+      if (!ssl_handoff) {
+        return false;
+      }
+      SSL_set_accept_state(ssl_handoff.get());
+      if (!MoveExData(ssl_handoff.get(), ssl)) {
+        return false;
+      }
+      MoveBIOs(ssl_handoff.get(), ssl);
+
+      do {
+        ret = CheckIdempotentError("SSL_do_handshake", ssl_handoff.get(),
+                                   [&]() -> int {
+          return SSL_do_handshake(ssl_handoff.get());
+        });
+      } while (!HandoffReady(ssl_handoff.get(), ret) &&
+               config->async &&
+               RetryAsync(ssl_handoff.get(), ret));
+
+      if (!HandoffReady(ssl_handoff.get(), ret)) {
+        fprintf(stderr, "Handshake failed while waiting for handoff.\n");
+        return false;
+      }
+
+      bssl::ScopedCBB cbb;
+      bssl::Array<uint8_t> handoff;
+      if (!CBB_init(cbb.get(), 512) ||
+          !SSL_serialize_handoff(ssl_handoff.get(), cbb.get()) ||
+          !CBBFinishArray(cbb.get(), &handoff)) {
+        fprintf(stderr, "Handoff serialisation failed.\n");
+        return false;
+      }
+
+      MoveBIOs(ssl, ssl_handoff.get());
+      if (!MoveExData(ssl, ssl_handoff.get())) {
+        return false;
+      }
+
+      if (!SSL_apply_handoff(ssl, handoff)) {
+        fprintf(stderr, "Handoff application failed.\n");
+        return false;
+      }
+    }
+
     do {
       ret = CheckIdempotentError("SSL_do_handshake", ssl, [&]() -> int {
         return SSL_do_handshake(ssl);
       });
     } while (config->async && RetryAsync(ssl, ret));
+
     if (ret != 1 ||
         !CheckHandshakeProperties(ssl, is_resume, config)) {
       return false;
     }
 
+    if (config->handoff) {
+      bssl::ScopedCBB cbb;
+      bssl::Array<uint8_t> handback;
+      if (!CBB_init(cbb.get(), 512) ||
+          !SSL_serialize_handback(ssl, cbb.get()) ||
+          !CBBFinishArray(cbb.get(), &handback)) {
+        fprintf(stderr, "Handback serialisation failed.\n");
+        return false;
+      }
+
+      bssl::UniquePtr<SSL_CTX> ctx_handback(SSL_CTX_new(TLSv1_method()));
+      SSL_CTX_set_msg_callback(ctx_handback.get(), MessageCallback);
+      bssl::UniquePtr<SSL> ssl_handback(SSL_new(ctx_handback.get()));
+      if (!ssl_handback) {
+        return false;
+      }
+      if (!SSL_apply_handback(ssl_handback.get(), handback)) {
+        fprintf(stderr, "Applying handback failed.\n");
+        return false;
+      }
+
+      MoveBIOs(ssl_handback.get(), ssl);
+      if (!MoveExData(ssl_handback.get(), ssl)) {
+        return false;
+      }
+
+      *ssl_uniqueptr = std::move(ssl_handback);
+      ssl = ssl_uniqueptr->get();
+    }
+
     if (is_resume && !is_retry && !config->is_server &&
         config->expect_no_offer_early_data && SSL_in_early_data(ssl)) {
       fprintf(stderr, "Client unexpectedly offered early data.\n");
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index b782514..430e3d9 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -1455,6 +1455,36 @@
 	return ret
 }
 
+func convertToSplitHandshakeTests(tests []testCase) (splitHandshakeTests []testCase) {
+NextTest:
+	for _, test := range tests {
+		if test.protocol != tls ||
+			test.testType != serverTest ||
+			test.config.MaxVersion >= VersionTLS13 ||
+			test.config.MaxVersion < VersionTLS10 ||
+			(test.resumeConfig != nil && (test.resumeConfig.MaxVersion < VersionTLS10 || test.resumeConfig.MaxVersion >= VersionTLS13)) ||
+			strings.HasPrefix(test.name, "VersionNegotiation-") {
+			continue
+		}
+
+		for _, flag := range test.flags {
+			if flag == "-implicit-handshake" {
+				continue NextTest
+			}
+		}
+
+		shTest := test
+		shTest.name += "-Split"
+		shTest.flags = make([]string, len(test.flags), len(test.flags)+1)
+		copy(shTest.flags, test.flags)
+		shTest.flags = append(shTest.flags, "-handoff")
+
+		splitHandshakeTests = append(splitHandshakeTests, shTest)
+	}
+
+	return splitHandshakeTests
+}
+
 func addBasicTests() {
 	basicTests := []testCase{
 		{
@@ -14100,6 +14130,8 @@
 	addExtraHandshakeTests()
 	addOmitExtensionsTests()
 
+	testCases = append(testCases, convertToSplitHandshakeTests(testCases)...)
+
 	var wg sync.WaitGroup
 
 	statusChan := make(chan statusMsg, *numWorkers)
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index 516a9c9..1125aef 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -131,6 +131,7 @@
   { "-allow-false-start-without-alpn",
     &TestConfig::allow_false_start_without_alpn },
   { "-expect-draft-downgrade", &TestConfig::expect_draft_downgrade },
+  { "-handoff", &TestConfig::handoff },
 };
 
 const Flag<std::string> kStringFlags[] = {
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index cc1618a..8768654 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -152,6 +152,7 @@
   bool allow_false_start_without_alpn = false;
   bool expect_draft_downgrade = false;
   int dummy_pq_padding_len = 0;
+  bool handoff = false;
 };
 
 bool ParseConfig(int argc, char **argv, TestConfig *out_initial,