Make low-level record errors idempotent.

Enough were to make record processing idempotent (we either consume a
record or we don't), but some errors would cause us to keep processing
records when we should get stuck.

This leaves errors in the layer between the record bits and the
handshake. I'm hoping that will be easier to resolve once they do not
depend on BIO, at which point the checks added in this CL may move
around.

Bug: 206
Change-Id: I6b177079388820335e25947c5bd736451780ab8f
Reviewed-on: https://boringssl-review.googlesource.com/21366
Commit-Queue: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Steven Valdez <svaldez@google.com>
diff --git a/ssl/dtls_record.cc b/ssl/dtls_record.cc
index 2bf1d42..eccc66b 100644
--- a/ssl/dtls_record.cc
+++ b/ssl/dtls_record.cc
@@ -174,22 +174,11 @@
   }
 }
 
-enum ssl_open_record_t dtls_open_record(SSL *ssl, uint8_t *out_type,
-                                        Span<uint8_t> *out,
-                                        size_t *out_consumed,
-                                        uint8_t *out_alert, Span<uint8_t> in) {
-  *out_consumed = 0;
-  switch (ssl->s3->read_shutdown) {
-    case ssl_shutdown_none:
-      break;
-    case ssl_shutdown_fatal_alert:
-      OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
-      *out_alert = 0;
-      return ssl_open_record_error;
-    case ssl_shutdown_close_notify:
-      return ssl_open_record_close_notify;
-  }
-
+static enum ssl_open_record_t do_dtls_open_record(SSL *ssl, uint8_t *out_type,
+                                                  Span<uint8_t> *out,
+                                                  size_t *out_consumed,
+                                                  uint8_t *out_alert,
+                                                  Span<uint8_t> in) {
   if (in.empty()) {
     return ssl_open_record_partial;
   }
@@ -278,6 +267,30 @@
   return ssl_open_record_success;
 }
 
+enum ssl_open_record_t dtls_open_record(SSL *ssl, uint8_t *out_type,
+                                        Span<uint8_t> *out,
+                                        size_t *out_consumed,
+                                        uint8_t *out_alert, Span<uint8_t> in) {
+  *out_consumed = 0;
+  switch (ssl->s3->read_shutdown) {
+    case ssl_shutdown_none:
+      break;
+    case ssl_shutdown_error:
+      ERR_restore_state(ssl->s3->read_error);
+      *out_alert = 0;
+      return ssl_open_record_error;
+    case ssl_shutdown_close_notify:
+      return ssl_open_record_close_notify;
+  }
+
+  enum ssl_open_record_t ret =
+      do_dtls_open_record(ssl, out_type, out, out_consumed, out_alert, in);
+  if (ret == ssl_open_record_error) {
+    ssl_set_read_error(ssl);
+  }
+  return ret;
+}
+
 static const SSLAEADContext *get_write_aead(const SSL *ssl,
                                             enum dtls1_use_epoch_t use_epoch) {
   if (use_epoch == dtls1_use_previous_epoch) {
diff --git a/ssl/internal.h b/ssl/internal.h
index edbf4eb..b034b66 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2130,7 +2130,7 @@
 enum ssl_shutdown_t {
   ssl_shutdown_none = 0,
   ssl_shutdown_close_notify = 1,
-  ssl_shutdown_fatal_alert = 2,
+  ssl_shutdown_error = 2,
 };
 
 struct SSL3_STATE {
@@ -2160,6 +2160,10 @@
   // write_shutdown is the shutdown state for the write half of the connection.
   enum ssl_shutdown_t write_shutdown;
 
+  // read_error, if |read_shutdown| is |ssl_shutdown_error|, is the error for
+  // the receive half of the connection.
+  ERR_SAVE_STATE *read_error;
+
   int alert_dispatch;
 
   int total_renegotiations;
@@ -2858,6 +2862,10 @@
 // ssl_reset_error_state resets state for |SSL_get_error|.
 void ssl_reset_error_state(SSL *ssl);
 
+// ssl_set_read_error sets |ssl|'s read half into an error state, saving the
+// current state of the error queue.
+void ssl_set_read_error(SSL* ssl);
+
 }  // namespace bssl
 
 
diff --git a/ssl/s3_lib.cc b/ssl/s3_lib.cc
index 3df8e1b..4d3cbb1 100644
--- a/ssl/s3_lib.cc
+++ b/ssl/s3_lib.cc
@@ -206,6 +206,7 @@
   ssl_read_buffer_clear(ssl);
   ssl_write_buffer_clear(ssl);
 
+  ERR_SAVE_STATE_free(ssl->s3->read_error);
   SSL_SESSION_free(ssl->s3->established_session);
   ssl_handshake_free(ssl->s3->hs);
   OPENSSL_free(ssl->s3->next_proto_negotiated);
@@ -215,7 +216,6 @@
   Delete(ssl->s3->aead_write_ctx);
   BUF_MEM_free(ssl->s3->pending_flight);
 
-  OPENSSL_cleanse(ssl->s3, sizeof *ssl->s3);
   OPENSSL_free(ssl->s3);
   ssl->s3 = NULL;
 }
diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc
index 71e1a08..0b3331c 100644
--- a/ssl/s3_pkt.cc
+++ b/ssl/s3_pkt.cc
@@ -497,7 +497,7 @@
   } else {
     assert(level == SSL3_AL_FATAL);
     assert(desc != SSL_AD_CLOSE_NOTIFY);
-    ssl->s3->write_shutdown = ssl_shutdown_fatal_alert;
+    ssl->s3->write_shutdown = ssl_shutdown_error;
   }
 
   ssl->s3->alert_dispatch = 1;
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index c975337..2f5374b 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -206,6 +206,12 @@
   ERR_clear_system_error();
 }
 
+void ssl_set_read_error(SSL* ssl) {
+  ssl->s3->read_shutdown = ssl_shutdown_error;
+  ERR_SAVE_STATE_free(ssl->s3->read_error);
+  ssl->s3->read_error = ERR_save_state();
+}
+
 int ssl_can_write(const SSL *ssl) {
   return !SSL_in_init(ssl) || ssl->s3->hs->can_early_write;
 }
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 4da109d..f475b73 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -3801,6 +3801,109 @@
   EXPECT_EQ(SSL_R_DECODE_ERROR, ERR_GET_REASON(ERR_peek_error()));
 }
 
+// Test that alerts during a handshake are sticky.
+TEST_P(SSLVersionTest, StickyErrorHandshake_Alert) {
+  UniquePtr<SSL_CTX> ctx = CreateContext();
+  ASSERT_TRUE(ctx);
+  UniquePtr<SSL> ssl(SSL_new(ctx.get()));
+  ASSERT_TRUE(ssl);
+  SSL_set_accept_state(ssl.get());
+
+  if (is_dtls()) {
+    static const uint8_t kHandshakeFailureDTLS[] = {
+        0x15, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00,
+        0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x28};
+    SSL_set0_rbio(ssl.get(), BIO_new_mem_buf(kHandshakeFailureDTLS,
+                                             sizeof(kHandshakeFailureDTLS)));
+  } else {
+    static const uint8_t kHandshakeFailureTLS[] = {0x15, 0x03, 0x01, 0x00,
+                                                   0x02, 0x02, 0x28};
+    SSL_set0_rbio(ssl.get(), BIO_new_mem_buf(kHandshakeFailureTLS,
+                                             sizeof(kHandshakeFailureTLS)));
+  }
+  SSL_set0_wbio(ssl.get(), BIO_new(BIO_s_mem()));
+
+  int ret = SSL_do_handshake(ssl.get());
+  EXPECT_NE(1, ret);
+  EXPECT_EQ(SSL_ERROR_SSL, SSL_get_error(ssl.get(), ret));
+  EXPECT_EQ(ERR_LIB_SSL, ERR_GET_LIB(ERR_peek_error()));
+  EXPECT_EQ(SSL_R_SSLV3_ALERT_HANDSHAKE_FAILURE,
+            ERR_GET_REASON(ERR_peek_error()));
+  ERR_clear_error();
+
+  // Driving the handshake again does not consume more records.
+  ret = SSL_do_handshake(ssl.get());
+  EXPECT_NE(1, ret);
+  EXPECT_EQ(SSL_ERROR_SSL, SSL_get_error(ssl.get(), ret));
+  EXPECT_EQ(ERR_LIB_SSL, ERR_GET_LIB(ERR_peek_error()));
+  EXPECT_EQ(SSL_R_SSLV3_ALERT_HANDSHAKE_FAILURE,
+            ERR_GET_REASON(ERR_peek_error()));
+}
+
+// Test that a bad record header causes a sticky error.
+TEST_P(SSLVersionTest, StickyErrorRead_BadRecordHeader) {
+  // Bad record headers in DTLS are discarded.
+  if (is_dtls()) {
+    return;
+  }
+
+  ASSERT_TRUE(Connect());
+
+  // Inject a record with invalid version into the stream.
+  static const uint8_t kBadRecord[] = {0x16, 0x00, 0x00, 0x00, 0x00};
+  SSL_set0_rbio(server_.get(), BIO_new_mem_buf(kBadRecord, sizeof(kBadRecord)));
+
+  // The bad header should be rejected.
+  char buf[5];
+  int ret = SSL_read(server_.get(), buf, sizeof(buf));
+  EXPECT_EQ(-1, ret);
+  EXPECT_EQ(SSL_ERROR_SSL, SSL_get_error(server_.get(), ret));
+  EXPECT_EQ(ERR_LIB_SSL, ERR_GET_LIB(ERR_peek_error()));
+  EXPECT_EQ(SSL_R_WRONG_VERSION_NUMBER, ERR_GET_REASON(ERR_peek_error()));
+  ERR_clear_error();
+
+  // It should continue to be rejected on a retry.
+  ret = SSL_read(server_.get(), buf, sizeof(buf));
+  EXPECT_EQ(-1, ret);
+  EXPECT_EQ(SSL_ERROR_SSL, SSL_get_error(server_.get(), ret));
+  EXPECT_EQ(ERR_LIB_SSL, ERR_GET_LIB(ERR_peek_error()));
+  EXPECT_EQ(SSL_R_WRONG_VERSION_NUMBER, ERR_GET_REASON(ERR_peek_error()));
+}
+
+// Test that a bad encrypted record causes a sticky error.
+TEST_P(SSLVersionTest, StickyErrorRead_BadCiphertext) {
+  // Bad ciphertext in DTLS is discarded.
+  if (is_dtls()) {
+    return;
+  }
+  ASSERT_TRUE(Connect());
+
+  // Inject a record with invalid version into the stream.
+  uint16_t record_version =
+      version() >= TLS1_3_VERSION ? TLS1_VERSION : version();
+  uint8_t record[] = {SSL3_RT_APPLICATION_DATA,
+                      static_cast<uint8_t>(record_version >> 8),
+                      static_cast<uint8_t>(record_version),
+                      0x00,
+                      0x01,
+                      0x42};
+  SSL_set0_rbio(server_.get(), BIO_new_mem_buf(record, sizeof(record)));
+
+  // The bad record should be rejected.
+  char buf[5];
+  int ret = SSL_read(server_.get(), buf, sizeof(buf));
+  EXPECT_EQ(-1, ret);
+  EXPECT_EQ(SSL_ERROR_SSL, SSL_get_error(server_.get(), ret));
+  uint32_t err = ERR_get_error();
+  ERR_clear_error();
+
+  // It should continue to be rejected on a retry with the same error.
+  ret = SSL_read(server_.get(), buf, sizeof(buf));
+  EXPECT_EQ(-1, ret);
+  EXPECT_EQ(SSL_ERROR_SSL, SSL_get_error(server_.get(), ret));
+  EXPECT_EQ(err, ERR_peek_error());
+}
+
 TEST_P(SSLVersionTest, SSLPending) {
   UniquePtr<SSL> ssl(SSL_new(client_ctx_.get()));
   ASSERT_TRUE(ssl);
diff --git a/ssl/tls_record.cc b/ssl/tls_record.cc
index 44a04d9..12597a6 100644
--- a/ssl/tls_record.cc
+++ b/ssl/tls_record.cc
@@ -187,21 +187,11 @@
   return ret;
 }
 
-enum ssl_open_record_t tls_open_record(SSL *ssl, uint8_t *out_type,
-                                       Span<uint8_t> *out, size_t *out_consumed,
-                                       uint8_t *out_alert, Span<uint8_t> in) {
-  *out_consumed = 0;
-  switch (ssl->s3->read_shutdown) {
-    case ssl_shutdown_none:
-      break;
-    case ssl_shutdown_fatal_alert:
-      OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
-      *out_alert = 0;
-      return ssl_open_record_error;
-    case ssl_shutdown_close_notify:
-      return ssl_open_record_close_notify;
-  }
-
+static enum ssl_open_record_t do_tls_open_record(SSL *ssl, uint8_t *out_type,
+                                                 Span<uint8_t> *out,
+                                                 size_t *out_consumed,
+                                                 uint8_t *out_alert,
+                                                 Span<uint8_t> in) {
   CBS cbs = CBS(in);
 
   // Decode the record header.
@@ -351,6 +341,29 @@
   return ssl_open_record_discard;
 }
 
+enum ssl_open_record_t tls_open_record(SSL *ssl, uint8_t *out_type,
+                                       Span<uint8_t> *out, size_t *out_consumed,
+                                       uint8_t *out_alert, Span<uint8_t> in) {
+  *out_consumed = 0;
+  switch (ssl->s3->read_shutdown) {
+    case ssl_shutdown_none:
+      break;
+    case ssl_shutdown_error:
+      ERR_restore_state(ssl->s3->read_error);
+      *out_alert = 0;
+      return ssl_open_record_error;
+    case ssl_shutdown_close_notify:
+      return ssl_open_record_close_notify;
+  }
+
+  enum ssl_open_record_t ret =
+      do_tls_open_record(ssl, out_type, out, out_consumed, out_alert, in);
+  if (ret == ssl_open_record_error) {
+    ssl_set_read_error(ssl);
+  }
+  return ret;
+}
+
 static int do_seal_record(SSL *ssl, uint8_t *out_prefix, uint8_t *out,
                           uint8_t *out_suffix, uint8_t type, const uint8_t *in,
                           const size_t in_len) {
@@ -567,12 +580,8 @@
   }
 
   if (alert_level == SSL3_AL_FATAL) {
-    ssl->s3->read_shutdown = ssl_shutdown_fatal_alert;
-
-    char tmp[16];
     OPENSSL_PUT_ERROR(SSL, SSL_AD_REASON_OFFSET + alert_descr);
-    BIO_snprintf(tmp, sizeof(tmp), "%d", alert_descr);
-    ERR_add_error_data(2, "SSL alert number ", tmp);
+    ERR_add_error_dataf("SSL alert number %d", alert_descr);
     *out_alert = 0;  // No alert to send back to the peer.
     return ssl_open_record_error;
   }