Don't call ssl3_read_message from ssl3_read_app_data.

With this change, it should now always be the case that rr->length is
zero on entry to ssl3_read_message. This will let us detach everything
but application data from rr. This pushes some init_buf invariants down
into tls_open_record so we don't need to maintain them everywhere.

Change-Id: I206747434e0a9603eea7d19664734fd16fa2de8e
Reviewed-on: https://boringssl-review.googlesource.com/21524
Commit-Queue: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Steven Valdez <svaldez@google.com>
diff --git a/ssl/internal.h b/ssl/internal.h
index 577b382..60e69f9 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1001,6 +1001,10 @@
 // dtls_clear_incoming_messages releases all buffered incoming messages.
 void dtls_clear_incoming_messages(SSL *ssl);
 
+// tls_can_accept_handshake_data returns whether |ssl| is able to accept more
+// data into handshake buffer.
+bool tls_can_accept_handshake_data(const SSL *ssl, uint8_t *out_alert);
+
 // tls_has_unprocessed_handshake_data returns whether there is buffered
 // handshake data that has not been consumed by |get_message|.
 bool tls_has_unprocessed_handshake_data(const SSL *ssl);
diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc
index 376018d..7fc843e 100644
--- a/ssl/s3_both.cc
+++ b/ssl/s3_both.cc
@@ -464,6 +464,26 @@
   return true;
 }
 
+bool tls_can_accept_handshake_data(const SSL *ssl, uint8_t *out_alert) {
+  // If there is a complete message, the caller must have consumed it first.
+  SSLMessage msg;
+  size_t bytes_needed;
+  if (parse_message(ssl, &msg, &bytes_needed)) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+    *out_alert = SSL_AD_INTERNAL_ERROR;
+    return false;
+  }
+
+  // Enforce the limit so the peer cannot force us to buffer 16MB.
+  if (bytes_needed > 4 + ssl_max_handshake_message_len(ssl)) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
+    *out_alert = SSL_AD_ILLEGAL_PARAMETER;
+    return false;
+  }
+
+  return true;
+}
+
 bool tls_has_unprocessed_handshake_data(const SSL *ssl) {
   size_t msg_len = 0;
   if (ssl->s3->has_message) {
@@ -478,20 +498,6 @@
 }
 
 int ssl3_read_message(SSL *ssl) {
-  SSLMessage msg;
-  size_t bytes_needed;
-  if (parse_message(ssl, &msg, &bytes_needed)) {
-    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-    return -1;
-  }
-
-  // Enforce the limit so the peer cannot force us to buffer 16MB.
-  if (bytes_needed > 4 + ssl_max_handshake_message_len(ssl)) {
-    ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
-    OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
-    return -1;
-  }
-
   // Re-create the handshake buffer if needed.
   if (ssl->init_buf == NULL) {
     ssl->init_buf = BUF_MEM_new();
diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc
index 0b3331c..2718fde 100644
--- a/ssl/s3_pkt.cc
+++ b/ssl/s3_pkt.cc
@@ -365,22 +365,20 @@
   SSL3_RECORD *rr = &ssl->s3->rrec;
 
   for (;;) {
-    // A previous iteration may have read a partial handshake message. Do not
-    // allow more app data in that case.
-    int has_hs_data = ssl->init_buf != NULL && ssl->init_buf->length > 0;
-
     // Get new packet if necessary.
-    if (rr->length == 0 && !has_hs_data) {
+    if (rr->length == 0) {
       int ret = ssl3_get_record(ssl);
       if (ret <= 0) {
         return ret;
       }
     }
 
-    if (has_hs_data || rr->type == SSL3_RT_HANDSHAKE) {
+    const bool is_early_data_read = ssl->server && SSL_in_early_data(ssl);
+
+    if (rr->type == SSL3_RT_HANDSHAKE) {
       // If reading 0-RTT data, reject handshake data. 0-RTT data is terminated
       // by an alert.
-      if (SSL_in_init(ssl)) {
+      if (is_early_data_read) {
         OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
         ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
         return -1;
@@ -395,20 +393,19 @@
         return -1;
       }
 
-      // Parse post-handshake handshake messages.
-      int ret = ssl3_read_message(ssl);
-      if (ret <= 0) {
-        return ret;
+      if (ssl->init_buf == NULL) {
+        ssl->init_buf = BUF_MEM_new();
+      }
+      if (ssl->init_buf == NULL ||
+          !BUF_MEM_append(ssl->init_buf, rr->data, rr->length)) {
+        return -1;
       }
       *out_got_handshake = true;
+      rr->length = 0;
+      ssl_read_buffer_discard(ssl);
       return -1;
     }
 
-    const int is_early_data_read = ssl->server &&
-                                   ssl->s3->hs != NULL &&
-                                   ssl->s3->hs->can_early_read &&
-                                   ssl_protocol_version(ssl) >= TLS1_3_VERSION;
-
     // Handle the end_of_early_data alert.
     if (rr->type == SSL3_RT_ALERT &&
         rr->length == 2 &&
@@ -457,8 +454,7 @@
     }
   }
 
-  if (rr->type != SSL3_RT_CHANGE_CIPHER_SPEC ||
-      tls_has_unprocessed_handshake_data(ssl)) {
+  if (rr->type != SSL3_RT_CHANGE_CIPHER_SPEC) {
     ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
     OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
     return -1;
diff --git a/ssl/tls_record.cc b/ssl/tls_record.cc
index 12597a6..2a28859 100644
--- a/ssl/tls_record.cc
+++ b/ssl/tls_record.cc
@@ -192,6 +192,12 @@
                                                  size_t *out_consumed,
                                                  uint8_t *out_alert,
                                                  Span<uint8_t> in) {
+  // If there is an unprocessed handshake message or we are already buffering
+  // too much, stop before decrypting another handshake record.
+  if (!tls_can_accept_handshake_data(ssl, out_alert)) {
+    return ssl_open_record_error;
+  }
+
   CBS cbs = CBS(in);
 
   // Decode the record header.
@@ -321,6 +327,14 @@
     return ssl_process_alert(ssl, out_alert, *out);
   }
 
+  // Handshake messages may not interleave with any other record type.
+  if (type != SSL3_RT_HANDSHAKE &&
+      tls_has_unprocessed_handshake_data(ssl)) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
+    *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
+    return ssl_open_record_error;
+  }
+
   ssl->s3->warning_alert_count = 0;
 
   *out_type = type;