More reliably report handshake errors through SSL_write.

This CL fixes a couple of things. First, we never tested that SSL_write
refuses to write application data after a fatal alert, so add some tests
here. With those tests, we can revise some of this logic:

Next, this removes the write_shutdown check in SSL_write and instead
relies on the lower-level versions of the check in the write_app_data,
etc., hooks. This improves error-reporting on handshake errors:

We generally try to make SSL_do_handshake errors sticky, analogous to
handshakeErr in the Go implementation. SSL_write and SSL_read both
implicitly call SSL_do_handshake. Callers driving the two in parallel
will naturally call SSL_do_handshake twice. Since the error effectively
applies to both operations, we save and replay handshake errors
(hs->error).

Handshake errors typically come with sending alerts, which also sets
write_shutdown so we don't try to send more data over the channel.
Checking this early in SSL_write means we don't get a chance to replay
the handshake error. So this CL defers it, and the test ensures we still
ultimately get it right.

Finally, https://crbug.com/1078515 is a particular incarnation of this.
If the server enables 0-RTT and then reverts to TLS 1.2, clients need
to catch the error and retry. There, deferring the SSL_write check
isn't sufficient, because the can_early_write bit removes the write
path's dependency on the handshake, so we don't call into
SSL_do_handshake at all.

For now, I've made this error path clear can_early_write. I suspect
we want it to apply to all handshake errors, though it's weird because
the handshake error is effectively a read error in 0-RTT. We don't
currently replay record decryption failures at SSL_write, even though
those also send a fatal alert and thus break all subsequent writes.

Bug: chromium:1078515
Change-Id: Icdfae6a8f2e7c1b1c921068dca244795a670807f
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/48065
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index 1fb43fb..1f26242 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -720,6 +720,12 @@
   // an error code sooner. The caller may use this error code to implement the
   // fallback described in RFC 8446 appendix D.3.
   if (hs->early_data_offered) {
+    // Disconnect early writes. This ensures subsequent |SSL_write| calls query
+    // the handshake which, in turn, will replay the error code rather than fail
+    // at the |write_shutdown| check. See https://crbug.com/1078515.
+    // TODO(davidben): Should all handshake errors do this? What about record
+    // decryption failures?
+    hs->can_early_write = false;
     OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_VERSION_ON_EARLY_DATA);
     ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_PROTOCOL_VERSION);
     return ssl_hs_error;
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index 1211761..8b67cf8 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -1101,11 +1101,6 @@
     return -1;
   }
 
-  if (ssl->s3->write_shutdown != ssl_shutdown_none) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
-    return -1;
-  }
-
   int ret = 0;
   bool needs_handshake = false;
   do {
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index e2a41b8..ceb52de 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -1533,6 +1533,38 @@
   return true;
 }
 
+static bssl::UniquePtr<SSL_SESSION> g_last_session;
+
+static int SaveLastSession(SSL *ssl, SSL_SESSION *session) {
+  // Save the most recent session.
+  g_last_session.reset(session);
+  return 1;
+}
+
+static bssl::UniquePtr<SSL_SESSION> CreateClientSession(
+    SSL_CTX *client_ctx, SSL_CTX *server_ctx,
+    const ClientConfig &config = ClientConfig()) {
+  g_last_session = nullptr;
+  SSL_CTX_sess_set_new_cb(client_ctx, SaveLastSession);
+
+  // Connect client and server to get a session.
+  bssl::UniquePtr<SSL> client, server;
+  if (!ConnectClientAndServer(&client, &server, client_ctx, server_ctx,
+                              config) ||
+      !FlushNewSessionTickets(client.get(), server.get())) {
+    fprintf(stderr, "Failed to connect client and server.\n");
+    return nullptr;
+  }
+
+  SSL_CTX_sess_set_new_cb(client_ctx, nullptr);
+
+  if (!g_last_session) {
+    fprintf(stderr, "Client did not receive a session.\n");
+    return nullptr;
+  }
+  return std::move(g_last_session);
+}
+
 // Test that |SSL_get_client_CA_list| echoes back the configured parameter even
 // before configuring as a server.
 TEST(SSLTest, ClientCAList) {
@@ -2340,7 +2372,7 @@
   }
   ASSERT_TRUE(Connect());
 
-  // Shut down half the connection. SSL_shutdown will return 0 to signal only
+  // Shut down half the connection. |SSL_shutdown| will return 0 to signal only
   // one side has shut down.
   ASSERT_EQ(SSL_shutdown(client_.get()), 0);
 
@@ -2361,6 +2393,217 @@
   EXPECT_EQ(SSL_shutdown(client_.get()), 1);
 }
 
+// Test that, after calling |SSL_shutdown|, |SSL_write| fails.
+TEST_P(SSLVersionTest, WriteAfterShutdown) {
+  ASSERT_TRUE(Connect());
+
+  for (SSL *ssl : {client_.get(), server_.get()}) {
+    SCOPED_TRACE(SSL_is_server(ssl) ? "server" : "client");
+
+    bssl::UniquePtr<BIO> mem(BIO_new(BIO_s_mem()));
+    ASSERT_TRUE(mem);
+    SSL_set0_wbio(ssl, bssl::UpRef(mem).release());
+
+    // Shut down half the connection. |SSL_shutdown| will return 0 to signal
+    // only one side has shut down.
+    ASSERT_EQ(SSL_shutdown(ssl), 0);
+
+    // |ssl| should have written an alert to the transport.
+    const uint8_t *unused;
+    size_t len;
+    ASSERT_TRUE(BIO_mem_contents(mem.get(), &unused, &len));
+    EXPECT_NE(0u, len);
+    EXPECT_TRUE(BIO_reset(mem.get()));
+
+    // Writing should fail.
+    EXPECT_EQ(-1, SSL_write(ssl, "a", 1));
+
+    // Nothing should be written to the transport.
+    ASSERT_TRUE(BIO_mem_contents(mem.get(), &unused, &len));
+    EXPECT_EQ(0u, len);
+  }
+}
+
+// Test that, after sending a fatal alert in a failed |SSL_read|, |SSL_write|
+// fails.
+TEST_P(SSLVersionTest, WriteAfterReadSentFatalAlert) {
+  // Decryption failures are not fatal in DTLS.
+  if (is_dtls()) {
+    return;
+  }
+
+  ASSERT_TRUE(Connect());
+
+  // Save the write |BIO|s as the test will overwrite them.
+  bssl::UniquePtr<BIO> client_wbio = bssl::UpRef(SSL_get_wbio(client_.get()));
+  bssl::UniquePtr<BIO> server_wbio = bssl::UpRef(SSL_get_wbio(server_.get()));
+
+  for (bool test_server : {false, true}) {
+    SCOPED_TRACE(test_server ? "server" : "client");
+    SSL *ssl = test_server ? server_.get() : client_.get();
+    BIO *other_wbio = test_server ? client_wbio.get() : server_wbio.get();
+
+    bssl::UniquePtr<BIO> mem(BIO_new(BIO_s_mem()));
+    ASSERT_TRUE(mem);
+    SSL_set0_wbio(ssl, bssl::UpRef(mem).release());
+
+    // Read an invalid record from the peer.
+    static const uint8_t kInvalidRecord[] = "invalid record";
+    EXPECT_EQ(int{sizeof(kInvalidRecord)},
+              BIO_write(other_wbio, kInvalidRecord, sizeof(kInvalidRecord)));
+    char buf[256];
+    EXPECT_EQ(-1, SSL_read(ssl, buf, sizeof(buf)));
+
+    // |ssl| should have written an alert to the transport.
+    const uint8_t *unused;
+    size_t len;
+    ASSERT_TRUE(BIO_mem_contents(mem.get(), &unused, &len));
+    EXPECT_NE(0u, len);
+    EXPECT_TRUE(BIO_reset(mem.get()));
+
+    // Writing should fail.
+    EXPECT_EQ(-1, SSL_write(ssl, "a", 1));
+
+    // Nothing should be written to the transport.
+    ASSERT_TRUE(BIO_mem_contents(mem.get(), &unused, &len));
+    EXPECT_EQ(0u, len);
+  }
+}
+
+// Test that, after sending a fatal alert from the handshake, |SSL_write| fails.
+TEST_P(SSLVersionTest, WriteAfterHandshakeSentFatalAlert) {
+  for (bool test_server : {false, true}) {
+    SCOPED_TRACE(test_server ? "server" : "client");
+
+    bssl::UniquePtr<SSL> ssl(
+        SSL_new(test_server ? server_ctx_.get() : client_ctx_.get()));
+    ASSERT_TRUE(ssl);
+    if (test_server) {
+      SSL_set_accept_state(ssl.get());
+    } else {
+      SSL_set_connect_state(ssl.get());
+    }
+
+    std::vector<uint8_t> invalid;
+    if (is_dtls()) {
+      // In DTLS, invalid records are discarded. To cause the handshake to fail,
+      // use a valid handshake record with invalid contents.
+      invalid.push_back(SSL3_RT_HANDSHAKE);
+      invalid.push_back(DTLS1_VERSION >> 8);
+      invalid.push_back(DTLS1_VERSION & 0xff);
+      // epoch and sequence_number
+      for (int i = 0; i < 8; i++) {
+        invalid.push_back(0);
+      }
+      // A one-byte fragment is invalid.
+      invalid.push_back(0);
+      invalid.push_back(1);
+      // Arbitrary contents.
+      invalid.push_back(0);
+    } else {
+      invalid = {'i', 'n', 'v', 'a', 'l', 'i', 'd'};
+    }
+    bssl::UniquePtr<BIO> rbio(
+        BIO_new_mem_buf(invalid.data(), invalid.size()));
+    ASSERT_TRUE(rbio);
+    SSL_set0_rbio(ssl.get(), rbio.release());
+
+    bssl::UniquePtr<BIO> mem(BIO_new(BIO_s_mem()));
+    ASSERT_TRUE(mem);
+    SSL_set0_wbio(ssl.get(), bssl::UpRef(mem).release());
+
+    // The handshake should fail.
+    EXPECT_EQ(-1, SSL_do_handshake(ssl.get()));
+    EXPECT_EQ(SSL_ERROR_SSL, SSL_get_error(ssl.get(), -1));
+    uint32_t err = ERR_get_error();
+
+    // |ssl| should have written an alert (and, in the client's case, a
+    // ClientHello) to the transport.
+    const uint8_t *unused;
+    size_t len;
+    ASSERT_TRUE(BIO_mem_contents(mem.get(), &unused, &len));
+    EXPECT_NE(0u, len);
+    EXPECT_TRUE(BIO_reset(mem.get()));
+
+    // Writing should fail, with the same error as the handshake.
+    EXPECT_EQ(-1, SSL_write(ssl.get(), "a", 1));
+    EXPECT_EQ(SSL_ERROR_SSL, SSL_get_error(ssl.get(), -1));
+    EXPECT_EQ(err, ERR_get_error());
+
+    // Nothing should be written to the transport.
+    ASSERT_TRUE(BIO_mem_contents(mem.get(), &unused, &len));
+    EXPECT_EQ(0u, len);
+  }
+}
+
+// Test that, after seeing TLS 1.2 in response to early data, |SSL_write|
+// continues to report |SSL_R_WRONG_VERSION_ON_EARLY_DATA|. See
+// https://crbug.com/1078515.
+TEST(SSLTest, WriteAfterWrongVersionOnEarlyData) {
+  // Set up some 0-RTT-enabled contexts.
+  bssl::UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(TLS_method()));
+  bssl::UniquePtr<SSL_CTX> server_ctx =
+      CreateContextWithTestCertificate(TLS_method());
+  ASSERT_TRUE(client_ctx);
+  ASSERT_TRUE(server_ctx);
+  SSL_CTX_set_early_data_enabled(client_ctx.get(), 1);
+  SSL_CTX_set_early_data_enabled(server_ctx.get(), 1);
+  SSL_CTX_set_session_cache_mode(client_ctx.get(), SSL_SESS_CACHE_BOTH);
+  SSL_CTX_set_session_cache_mode(server_ctx.get(), SSL_SESS_CACHE_BOTH);
+
+  // Get an early-data-capable session.
+  bssl::UniquePtr<SSL_SESSION> session =
+      CreateClientSession(client_ctx.get(), server_ctx.get());
+  ASSERT_TRUE(session);
+  EXPECT_TRUE(SSL_SESSION_early_data_capable(session.get()));
+
+  // Offer the session to the server, but now the server speaks TLS 1.2.
+  bssl::UniquePtr<SSL> client, server;
+  ASSERT_TRUE(CreateClientAndServer(&client, &server, client_ctx.get(),
+                                    server_ctx.get()));
+  SSL_set_session(client.get(), session.get());
+  EXPECT_TRUE(SSL_set_max_proto_version(server.get(), TLS1_2_VERSION));
+
+  // The client handshake initially succeeds in the early data state.
+  EXPECT_EQ(1, SSL_do_handshake(client.get()));
+  EXPECT_TRUE(SSL_in_early_data(client.get()));
+
+  // The server processes the ClientHello and negotiates TLS 1.2.
+  EXPECT_EQ(-1, SSL_do_handshake(server.get()));
+  EXPECT_EQ(SSL_ERROR_WANT_READ, SSL_get_error(server.get(), -1));
+  EXPECT_EQ(TLS1_2_VERSION, SSL_version(server.get()));
+
+  // Capture the client's output.
+  bssl::UniquePtr<BIO> mem(BIO_new(BIO_s_mem()));
+  ASSERT_TRUE(mem);
+  SSL_set0_wbio(client.get(), bssl::UpRef(mem).release());
+
+  // The client processes the ServerHello and fails.
+  EXPECT_EQ(-1, SSL_do_handshake(client.get()));
+  EXPECT_EQ(SSL_ERROR_SSL, SSL_get_error(client.get(), -1));
+  uint32_t err = ERR_get_error();
+  EXPECT_EQ(ERR_LIB_SSL, ERR_GET_LIB(err));
+  EXPECT_EQ(SSL_R_WRONG_VERSION_ON_EARLY_DATA, ERR_GET_REASON(err));
+
+  // The client should have written an alert to the transport.
+  const uint8_t *unused;
+  size_t len;
+  ASSERT_TRUE(BIO_mem_contents(mem.get(), &unused, &len));
+  EXPECT_NE(0u, len);
+  EXPECT_TRUE(BIO_reset(mem.get()));
+
+  // Writing should fail, with the same error as the handshake.
+  EXPECT_EQ(-1, SSL_write(client.get(), "a", 1));
+  EXPECT_EQ(SSL_ERROR_SSL, SSL_get_error(client.get(), -1));
+  err = ERR_get_error();
+  EXPECT_EQ(ERR_LIB_SSL, ERR_GET_LIB(err));
+  EXPECT_EQ(SSL_R_WRONG_VERSION_ON_EARLY_DATA, ERR_GET_REASON(err));
+
+  // Nothing should be written to the transport.
+  ASSERT_TRUE(BIO_mem_contents(mem.get(), &unused, &len));
+  EXPECT_EQ(0u, len);
+}
+
 TEST(SSLTest, SessionDuplication) {
   bssl::UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(TLS_method()));
   bssl::UniquePtr<SSL_CTX> server_ctx =
@@ -2698,38 +2941,6 @@
   }
 }
 
-static bssl::UniquePtr<SSL_SESSION> g_last_session;
-
-static int SaveLastSession(SSL *ssl, SSL_SESSION *session) {
-  // Save the most recent session.
-  g_last_session.reset(session);
-  return 1;
-}
-
-static bssl::UniquePtr<SSL_SESSION> CreateClientSession(
-    SSL_CTX *client_ctx, SSL_CTX *server_ctx,
-    const ClientConfig &config = ClientConfig()) {
-  g_last_session = nullptr;
-  SSL_CTX_sess_set_new_cb(client_ctx, SaveLastSession);
-
-  // Connect client and server to get a session.
-  bssl::UniquePtr<SSL> client, server;
-  if (!ConnectClientAndServer(&client, &server, client_ctx, server_ctx,
-                              config) ||
-      !FlushNewSessionTickets(client.get(), server.get())) {
-    fprintf(stderr, "Failed to connect client and server.\n");
-    return nullptr;
-  }
-
-  SSL_CTX_sess_set_new_cb(client_ctx, nullptr);
-
-  if (!g_last_session) {
-    fprintf(stderr, "Client did not receive a session.\n");
-    return nullptr;
-  }
-  return std::move(g_last_session);
-}
-
 static void ExpectSessionReused(SSL_CTX *client_ctx, SSL_CTX *server_ctx,
                                 SSL_SESSION *session, bool want_reused) {
   bssl::UniquePtr<SSL> client, server;