ssl_test: test early data with split handshakes.

This helps to clarify where SSL_set_early_data_enabled() needs to be
called: in the shim tests it was being set everywhere, which concealed
the fact that the |enable_early_data| bit was not being set by
SSL_apply_handback().

Change-Id: I35bfdc6dd43f4fa07ef79eb02e4624b59fcdda5e
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/39385
Commit-Queue: Matt Braithwaite <mab@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 6fdab84..0105a8b 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -1574,7 +1574,7 @@
         client_err != SSL_ERROR_WANT_READ &&
         client_err != SSL_ERROR_WANT_WRITE &&
         client_err != SSL_ERROR_PENDING_TICKET) {
-      fprintf(stderr, "Client error: %d\n", client_err);
+      fprintf(stderr, "Client error: %s\n", SSL_error_description(client_err));
       return false;
     }
 
@@ -1584,7 +1584,7 @@
         server_err != SSL_ERROR_WANT_READ &&
         server_err != SSL_ERROR_WANT_WRITE &&
         server_err != SSL_ERROR_PENDING_TICKET) {
-      fprintf(stderr, "Server error: %d\n", server_err);
+      fprintf(stderr, "Server error: %s\n", SSL_error_description(server_err));
       return false;
     }
 
@@ -1630,6 +1630,7 @@
 struct ClientConfig {
   SSL_SESSION *session = nullptr;
   std::string servername;
+  bool early_data = false;
 };
 
 static bool ConnectClientAndServer(bssl::UniquePtr<SSL> *out_client,
@@ -1642,6 +1643,9 @@
   if (!client || !server) {
     return false;
   }
+  if (config.early_data) {
+    SSL_set_early_data_enabled(client.get(), 1);
+  }
   SSL_set_connect_state(client.get());
   SSL_set_accept_state(server.get());
 
@@ -4172,9 +4176,6 @@
   SSL_CTX_set_session_cache_mode(client_ctx.get(), SSL_SESS_CACHE_CLIENT);
   SSL_CTX_sess_set_new_cb(client_ctx.get(), SaveLastSession);
   SSL_CTX_set_handoff_mode(server_ctx.get(), 1);
-  ASSERT_TRUE(SSL_CTX_set_max_proto_version(server_ctx.get(), TLS1_2_VERSION));
-  ASSERT_TRUE(
-      SSL_CTX_set_max_proto_version(handshaker_ctx.get(), TLS1_2_VERSION));
   uint8_t keys[48];
   SSL_CTX_get_tlsext_ticket_keys(server_ctx.get(), &keys, sizeof(keys));
   SSL_CTX_set_tlsext_ticket_keys(handshaker_ctx.get(), &keys, sizeof(keys));
@@ -4186,70 +4187,97 @@
   ASSERT_TRUE(SSL_CTX_use_certificate(handshaker_ctx.get(), cert.get()));
   ASSERT_TRUE(SSL_CTX_use_PrivateKey(handshaker_ctx.get(), key.get()));
 
-  for (int i = 0; i < 2; ++i) {
-    bssl::UniquePtr<SSL> client, server;
-    bool is_resume = i > 0;
-    auto config = ClientConfig();
-    if (is_resume) {
-      ASSERT_TRUE(g_last_session);
-      config.session = g_last_session.get();
+  for (bool early_data : {false, true}) {
+    SCOPED_TRACE(early_data);
+    for (bool is_resume : {false, true}) {
+      SCOPED_TRACE(is_resume);
+      bssl::UniquePtr<SSL> client, server;
+      auto config = ClientConfig();
+      config.early_data = early_data;
+      if (is_resume) {
+        ASSERT_TRUE(g_last_session);
+        config.session = g_last_session.get();
+      }
+      if (is_resume && config.early_data) {
+        EXPECT_GT(g_last_session->ticket_max_early_data, 0u);
+      }
+      ASSERT_TRUE(ConnectClientAndServer(&client, &server, client_ctx.get(),
+                                         server_ctx.get(), config,
+                                         false /* don't handshake */));
+
+      int client_ret = SSL_do_handshake(client.get());
+      int client_err = SSL_get_error(client.get(), client_ret);
+
+      uint8_t byte_written;
+      if (config.early_data && is_resume) {
+        ASSERT_EQ(client_err, 0);
+        EXPECT_TRUE(SSL_in_early_data(client.get()));
+        // Attempt to write early data.
+        byte_written = 43;
+        EXPECT_EQ(SSL_write(client.get(), &byte_written, 1), 1);
+      } else {
+        ASSERT_EQ(client_err, SSL_ERROR_WANT_READ);
+      }
+
+      int server_ret = SSL_do_handshake(server.get());
+      int server_err = SSL_get_error(server.get(), server_ret);
+      ASSERT_EQ(server_err, SSL_ERROR_HANDOFF);
+
+      ScopedCBB cbb;
+      Array<uint8_t> handoff;
+      SSL_CLIENT_HELLO hello;
+      ASSERT_TRUE(CBB_init(cbb.get(), 256));
+      ASSERT_TRUE(SSL_serialize_handoff(server.get(), cbb.get(), &hello));
+      ASSERT_TRUE(CBBFinishArray(cbb.get(), &handoff));
+
+      bssl::UniquePtr<SSL> handshaker(SSL_new(handshaker_ctx.get()));
+      // Note split handshakes determines 0-RTT support, for both the current
+      // handshake and newly-issued tickets, entirely by |handshaker|. There is
+      // no need to call |SSL_set_early_data_enabled| on |server|.
+      SSL_set_early_data_enabled(handshaker.get(), 1);
+      ASSERT_TRUE(SSL_apply_handoff(handshaker.get(), handoff));
+
+      MoveBIOs(handshaker.get(), server.get());
+
+      int handshake_ret = SSL_do_handshake(handshaker.get());
+      int handshake_err = SSL_get_error(handshaker.get(), handshake_ret);
+      ASSERT_EQ(handshake_err, SSL_ERROR_HANDBACK);
+
+      // Double-check that additional calls to |SSL_do_handshake| continue
+      // to get |SSL_ERRROR_HANDBACK|.
+      handshake_ret = SSL_do_handshake(handshaker.get());
+      handshake_err = SSL_get_error(handshaker.get(), handshake_ret);
+      ASSERT_EQ(handshake_err, SSL_ERROR_HANDBACK);
+
+      ScopedCBB cbb_handback;
+      Array<uint8_t> handback;
+      ASSERT_TRUE(CBB_init(cbb_handback.get(), 1024));
+      ASSERT_TRUE(SSL_serialize_handback(handshaker.get(), cbb_handback.get()));
+      ASSERT_TRUE(CBBFinishArray(cbb_handback.get(), &handback));
+
+      bssl::UniquePtr<SSL> server2(SSL_new(server_ctx.get()));
+      ASSERT_TRUE(SSL_apply_handback(server2.get(), handback));
+
+      MoveBIOs(server2.get(), handshaker.get());
+      ASSERT_TRUE(CompleteHandshakes(client.get(), server2.get()));
+      EXPECT_EQ(is_resume, SSL_session_reused(client.get()));
+
+      if (config.early_data && is_resume) {
+        // In this case, one byte of early data has already been written above.
+        EXPECT_TRUE(SSL_early_data_accepted(client.get()));
+      } else {
+        byte_written = 42;
+        EXPECT_EQ(SSL_write(client.get(), &byte_written, 1), 1);
+      }
+      uint8_t byte;
+      EXPECT_EQ(SSL_read(server2.get(), &byte, 1), 1);
+      EXPECT_EQ(byte_written, byte);
+
+      byte = 44;
+      EXPECT_EQ(SSL_write(server2.get(), &byte, 1), 1);
+      EXPECT_EQ(SSL_read(client.get(), &byte, 1), 1);
+      EXPECT_EQ(44, byte);
     }
-    ASSERT_TRUE(ConnectClientAndServer(&client, &server, client_ctx.get(),
-                                       server_ctx.get(), config,
-                                       false /* don't handshake */));
-
-    int client_ret = SSL_do_handshake(client.get());
-    int client_err = SSL_get_error(client.get(), client_ret);
-    ASSERT_EQ(client_err, SSL_ERROR_WANT_READ);
-
-    int server_ret = SSL_do_handshake(server.get());
-    int server_err = SSL_get_error(server.get(), server_ret);
-    ASSERT_EQ(server_err, SSL_ERROR_HANDOFF);
-
-    ScopedCBB cbb;
-    Array<uint8_t> handoff;
-    SSL_CLIENT_HELLO hello;
-    ASSERT_TRUE(CBB_init(cbb.get(), 256));
-    ASSERT_TRUE(SSL_serialize_handoff(server.get(), cbb.get(), &hello));
-    ASSERT_TRUE(CBBFinishArray(cbb.get(), &handoff));
-
-    bssl::UniquePtr<SSL> handshaker(SSL_new(handshaker_ctx.get()));
-    ASSERT_TRUE(SSL_apply_handoff(handshaker.get(), handoff));
-
-    MoveBIOs(handshaker.get(), server.get());
-
-    int handshake_ret = SSL_do_handshake(handshaker.get());
-    int handshake_err = SSL_get_error(handshaker.get(), handshake_ret);
-    ASSERT_EQ(handshake_err, SSL_ERROR_HANDBACK);
-
-    // Double-check that additional calls to |SSL_do_handshake| continue
-    // to get |SSL_ERRROR_HANDBACK|.
-    handshake_ret = SSL_do_handshake(handshaker.get());
-    handshake_err = SSL_get_error(handshaker.get(), handshake_ret);
-    ASSERT_EQ(handshake_err, SSL_ERROR_HANDBACK);
-
-    ScopedCBB cbb_handback;
-    Array<uint8_t> handback;
-    ASSERT_TRUE(CBB_init(cbb_handback.get(), 1024));
-    ASSERT_TRUE(SSL_serialize_handback(handshaker.get(), cbb_handback.get()));
-    ASSERT_TRUE(CBBFinishArray(cbb_handback.get(), &handback));
-
-    bssl::UniquePtr<SSL> server2(SSL_new(server_ctx.get()));
-    ASSERT_TRUE(SSL_apply_handback(server2.get(), handback));
-
-    MoveBIOs(server2.get(), handshaker.get());
-    ASSERT_TRUE(CompleteHandshakes(client.get(), server2.get()));
-    EXPECT_EQ(is_resume, SSL_session_reused(client.get()));
-
-    uint8_t byte = 42;
-    EXPECT_EQ(SSL_write(client.get(), &byte, 1), 1);
-    EXPECT_EQ(SSL_read(server2.get(), &byte, 1), 1);
-    EXPECT_EQ(42, byte);
-
-    byte = 43;
-    EXPECT_EQ(SSL_write(server2.get(), &byte, 1), 1);
-    EXPECT_EQ(SSL_read(client.get(), &byte, 1), 1);
-    EXPECT_EQ(43, byte);
   }
 }