Push alert handling down into the record functions.

Alert handling is more-or-less identical across all contexts. Push it down from
read_bytes into the low-level record functions. This also deduplicates the code
shared between TLS and DTLS.

Now the only type mismatch managed by read_bytes is if we get handshake data in
read_app_data.

Change-Id: Ia8331897b304566e66d901899cfbf31d2870194e
Reviewed-on: https://boringssl-review.googlesource.com/8124
Reviewed-by: Steven Valdez <svaldez@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/d1_pkt.c b/ssl/d1_pkt.c
index c3f03bc..68e6a4d 100644
--- a/ssl/d1_pkt.c
+++ b/ssl/d1_pkt.c
@@ -144,8 +144,8 @@
 
   /* Read a new packet if there is no unconsumed one. */
   if (ssl_read_buffer_len(ssl) == 0) {
-    int ret = ssl_read_buffer_extend_to(ssl, 0 /* unused */);
-    if (ret < 0 && dtls1_is_timer_expired(ssl)) {
+    int read_ret = ssl_read_buffer_extend_to(ssl, 0 /* unused */);
+    if (read_ret < 0 && dtls1_is_timer_expired(ssl)) {
       /* For blocking BIOs, retransmits must be handled internally. */
       int timeout_ret = DTLSv1_handle_timeout(ssl);
       if (timeout_ret <= 0) {
@@ -153,8 +153,8 @@
       }
       goto again;
     }
-    if (ret <= 0) {
-      return ret;
+    if (read_ret <= 0) {
+      return read_ret;
     }
   }
   assert(ssl_read_buffer_len(ssl) > 0);
@@ -169,11 +169,16 @@
   size_t max_out = ssl_read_buffer_len(ssl) - ssl_record_prefix_len(ssl);
   uint8_t type, alert;
   size_t len, consumed;
-  switch (dtls_open_record(ssl, &type, out, &len, &consumed, &alert, max_out,
-                           ssl_read_buffer(ssl), ssl_read_buffer_len(ssl))) {
-    case ssl_open_record_success:
-      ssl_read_buffer_consume(ssl, consumed);
+  enum ssl_open_record_t open_ret =
+      dtls_open_record(ssl, &type, out, &len, &consumed, &alert, max_out,
+                       ssl_read_buffer(ssl), ssl_read_buffer_len(ssl));
+  ssl_read_buffer_consume(ssl, consumed);
+  switch (open_ret) {
+    case ssl_open_record_partial:
+      /* Impossible in DTLS. */
+      break;
 
+    case ssl_open_record_success:
       if (len > 0xffff) {
         OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
         return -1;
@@ -186,16 +191,17 @@
       return 1;
 
     case ssl_open_record_discard:
-      ssl_read_buffer_consume(ssl, consumed);
       goto again;
 
+    case ssl_open_record_close_notify:
+      return 0;
+
+    case ssl_open_record_fatal_alert:
+      return -1;
+
     case ssl_open_record_error:
       ssl3_send_alert(ssl, SSL3_AL_FATAL, alert);
       return -1;
-
-    case ssl_open_record_partial:
-      /* Impossible in DTLS. */
-      break;
   }
 
   assert(0);
@@ -310,49 +316,6 @@
 
   /* If we get here, then type != rr->type. */
 
-  /* If an alert record, process the alert. */
-  if (rr->type == SSL3_RT_ALERT) {
-    /* Alerts records may not contain fragmented or multiple alerts. */
-    if (rr->length != 2) {
-      al = SSL_AD_DECODE_ERROR;
-      OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_ALERT);
-      goto f_err;
-    }
-
-    ssl_do_msg_callback(ssl, 0 /* read */, ssl->version, SSL3_RT_ALERT,
-                        rr->data, 2);
-
-    const uint8_t alert_level = rr->data[0];
-    const uint8_t alert_descr = rr->data[1];
-    rr->length -= 2;
-    rr->data += 2;
-
-    uint16_t alert = (alert_level << 8) | alert_descr;
-    ssl_do_info_callback(ssl, SSL_CB_READ_ALERT, alert);
-
-    if (alert_level == SSL3_AL_WARNING) {
-      if (alert_descr == SSL_AD_CLOSE_NOTIFY) {
-        ssl->s3->recv_shutdown = ssl_shutdown_close_notify;
-        return 0;
-      }
-    } else if (alert_level == SSL3_AL_FATAL) {
-      char tmp[16];
-
-      OPENSSL_PUT_ERROR(SSL, SSL_AD_REASON_OFFSET + alert_descr);
-      BIO_snprintf(tmp, sizeof tmp, "%d", alert_descr);
-      ERR_add_error_data(2, "SSL alert number ", tmp);
-      ssl->s3->recv_shutdown = ssl_shutdown_fatal_alert;
-      SSL_CTX_remove_session(ssl->ctx, ssl->session);
-      return 0;
-    } else {
-      al = SSL_AD_ILLEGAL_PARAMETER;
-      OPENSSL_PUT_ERROR(SSL, SSL_R_UNKNOWN_ALERT_TYPE);
-      goto f_err;
-    }
-
-    goto start;
-  }
-
   /* Cross-epoch records are discarded, but we may receive out-of-order
    * application data between ChangeCipherSpec and Finished or a ChangeCipherSpec
    * before the appropriate point in the handshake. Those must be silently
diff --git a/ssl/dtls_record.c b/ssl/dtls_record.c
index 71d7ba5..94dfb28 100644
--- a/ssl/dtls_record.c
+++ b/ssl/dtls_record.c
@@ -175,6 +175,8 @@
     SSL *ssl, uint8_t *out_type, uint8_t *out, size_t *out_len,
     size_t *out_consumed, uint8_t *out_alert, size_t max_out, const uint8_t *in,
     size_t in_len) {
+  *out_consumed = 0;
+
   CBS cbs;
   CBS_init(&cbs, in, in_len);
 
@@ -224,6 +226,7 @@
     *out_consumed = in_len - CBS_len(&cbs);
     return ssl_open_record_discard;
   }
+  *out_consumed = in_len - CBS_len(&cbs);
 
   /* Check the plaintext length. */
   if (plaintext_len > SSL3_RT_MAX_PLAIN_LENGTH) {
@@ -237,9 +240,14 @@
   /* TODO(davidben): Limit the number of empty records as in TLS? This is only
    * useful if we also limit discarded packets. */
 
+  if (type == SSL3_RT_ALERT) {
+    return ssl_process_alert(ssl, out_alert, out, plaintext_len);
+  }
+
+  ssl->s3->warning_alert_count = 0;
+
   *out_type = type;
   *out_len = plaintext_len;
-  *out_consumed = in_len - CBS_len(&cbs);
   return ssl_open_record_success;
 }
 
diff --git a/ssl/internal.h b/ssl/internal.h
index b35ccc5..13e7935 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -365,25 +365,32 @@
   ssl_open_record_success,
   ssl_open_record_discard,
   ssl_open_record_partial,
+  ssl_open_record_close_notify,
+  ssl_open_record_fatal_alert,
   ssl_open_record_error,
 };
 
 /* tls_open_record decrypts a record from |in|.
  *
- * On success, it returns |ssl_open_record_success|. It sets |*out_type| to the
- * record type, |*out_len| to the plaintext length, and writes the record body
- * to |out|. It sets |*out_consumed| to the number of bytes of |in| consumed.
- * Note that |*out_len| may be zero.
- *
- * If a record was successfully processed but should be discarded, it returns
- * |ssl_open_record_discard| and sets |*out_consumed| to the number of bytes
- * consumed.
- *
  * If the input did not contain a complete record, it returns
  * |ssl_open_record_partial|. It sets |*out_consumed| to the total number of
  * bytes necessary. It is guaranteed that a successful call to |tls_open_record|
  * will consume at least that many bytes.
  *
+ * Otherwise, it sets |*out_consumed| to the number of bytes of input
+ * consumed. Note that input may be consumed on all return codes if a record was
+ * decrypted.
+ *
+ * On success, it returns |ssl_open_record_success|. It sets |*out_type| to the
+ * record type, |*out_len| to the plaintext length, and writes the record body
+ * to |out|. Note that |*out_len| may be zero.
+ *
+ * If a record was successfully processed but should be discarded, it returns
+ * |ssl_open_record_discard|.
+ *
+ * If a record was successfully processed but is a close_notify or fatal alert,
+ * it returns |ssl_open_record_close_notify| or |ssl_open_record_fatal_alert|.
+ *
  * On failure, it returns |ssl_open_record_error| and sets |*out_alert| to an
  * alert to emit.
  *
@@ -448,6 +455,13 @@
  * ownership of |aead_ctx|. */
 void ssl_set_write_state(SSL *ssl, SSL_AEAD_CTX *aead_ctx);
 
+/* ssl_process_alert processes |in| as an alert and updates |ssl|'s shutdown
+ * state. It returns one of |ssl_open_record_discard|, |ssl_open_record_error|,
+ * |ssl_open_record_close_notify|, or |ssl_open_record_fatal_alert| as
+ * appropriate. */
+enum ssl_open_record_t ssl_process_alert(SSL *ssl, uint8_t *out_alert,
+                                         const uint8_t *in, size_t in_len);
+
 
 /* Private key operations. */
 
diff --git a/ssl/s3_pkt.c b/ssl/s3_pkt.c
index 4b5138e..04d41be 100644
--- a/ssl/s3_pkt.c
+++ b/ssl/s3_pkt.c
@@ -123,15 +123,10 @@
 
 static int do_ssl3_write(SSL *ssl, int type, const uint8_t *buf, unsigned len);
 
-/* kMaxWarningAlerts is the number of consecutive warning alerts that will be
- * processed. */
-static const uint8_t kMaxWarningAlerts = 4;
-
 /* ssl3_get_record reads a new input record. On success, it places it in
  * |ssl->s3->rrec| and returns one. Otherwise it returns <= 0 on error or if
  * more data is needed. */
 static int ssl3_get_record(SSL *ssl) {
-  int ret;
 again:
   switch (ssl->s3->recv_shutdown) {
     case ssl_shutdown_none:
@@ -144,9 +139,9 @@
   }
 
   /* Ensure the buffer is large enough to decrypt in-place. */
-  ret = ssl_read_buffer_extend_to(ssl, ssl_record_prefix_len(ssl));
-  if (ret <= 0) {
-    return ret;
+  int read_ret = ssl_read_buffer_extend_to(ssl, ssl_record_prefix_len(ssl));
+  if (read_ret <= 0) {
+    return read_ret;
   }
   assert(ssl_read_buffer_len(ssl) >= ssl_record_prefix_len(ssl));
 
@@ -154,11 +149,21 @@
   size_t max_out = ssl_read_buffer_len(ssl) - ssl_record_prefix_len(ssl);
   uint8_t type, alert;
   size_t len, consumed;
-  switch (tls_open_record(ssl, &type, out, &len, &consumed, &alert, max_out,
-                          ssl_read_buffer(ssl), ssl_read_buffer_len(ssl))) {
-    case ssl_open_record_success:
-      ssl_read_buffer_consume(ssl, consumed);
+  enum ssl_open_record_t open_ret =
+      tls_open_record(ssl, &type, out, &len, &consumed, &alert, max_out,
+                      ssl_read_buffer(ssl), ssl_read_buffer_len(ssl));
+  if (open_ret != ssl_open_record_partial) {
+    ssl_read_buffer_consume(ssl, consumed);
+  }
+  switch (open_ret) {
+    case ssl_open_record_partial:
+      read_ret = ssl_read_buffer_extend_to(ssl, consumed);
+      if (read_ret <= 0) {
+        return read_ret;
+      }
+      goto again;
 
+    case ssl_open_record_success:
       if (len > 0xffff) {
         OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
         return -1;
@@ -170,16 +175,14 @@
       rr->data = out;
       return 1;
 
-    case ssl_open_record_partial:
-      ret = ssl_read_buffer_extend_to(ssl, consumed);
-      if (ret <= 0) {
-        return ret;
-      }
+    case ssl_open_record_discard:
       goto again;
 
-    case ssl_open_record_discard:
-      ssl_read_buffer_consume(ssl, consumed);
-      goto again;
+    case ssl_open_record_close_notify:
+      return 0;
+
+    case ssl_open_record_fatal_alert:
+      return -1;
 
     case ssl_open_record_error:
       ssl3_send_alert(ssl, SSL3_AL_FATAL, alert);
@@ -402,8 +405,6 @@
   /* we now have a packet which can be read and processed */
 
   if (type != 0 && type == rr->type) {
-    ssl->s3->warning_alert_count = 0;
-
     /* Discard empty records. */
     if (rr->length == 0) {
       goto start;
@@ -495,56 +496,6 @@
     goto start;
   }
 
-  /* If an alert record, process the alert. */
-  if (rr->type == SSL3_RT_ALERT) {
-    /* Alerts records may not contain fragmented or multiple alerts. */
-    if (rr->length != 2) {
-      al = SSL_AD_DECODE_ERROR;
-      OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_ALERT);
-      goto f_err;
-    }
-
-    ssl_do_msg_callback(ssl, 0 /* read */, ssl->version, SSL3_RT_ALERT,
-                        rr->data, 2);
-
-    const uint8_t alert_level = rr->data[0];
-    const uint8_t alert_descr = rr->data[1];
-    rr->length -= 2;
-    rr->data += 2;
-
-    uint16_t alert = (alert_level << 8) | alert_descr;
-    ssl_do_info_callback(ssl, SSL_CB_READ_ALERT, alert);
-
-    if (alert_level == SSL3_AL_WARNING) {
-      if (alert_descr == SSL_AD_CLOSE_NOTIFY) {
-        ssl->s3->recv_shutdown = ssl_shutdown_close_notify;
-        return 0;
-      }
-
-      ssl->s3->warning_alert_count++;
-      if (ssl->s3->warning_alert_count > kMaxWarningAlerts) {
-        al = SSL_AD_UNEXPECTED_MESSAGE;
-        OPENSSL_PUT_ERROR(SSL, SSL_R_TOO_MANY_WARNING_ALERTS);
-        goto f_err;
-      }
-    } else if (alert_level == SSL3_AL_FATAL) {
-      char tmp[16];
-
-      OPENSSL_PUT_ERROR(SSL, SSL_AD_REASON_OFFSET + alert_descr);
-      BIO_snprintf(tmp, sizeof(tmp), "%d", alert_descr);
-      ERR_add_error_data(2, "SSL alert number ", tmp);
-      ssl->s3->recv_shutdown = ssl_shutdown_fatal_alert;
-      SSL_CTX_remove_session(ssl->ctx, ssl->session);
-      return 0;
-    } else {
-      al = SSL_AD_ILLEGAL_PARAMETER;
-      OPENSSL_PUT_ERROR(SSL, SSL_R_UNKNOWN_ALERT_TYPE);
-      goto f_err;
-    }
-
-    goto start;
-  }
-
   if (type == 0) {
     /* This may only occur from read_close_notify. */
     assert(ssl->s3->send_shutdown == ssl_shutdown_close_notify);
diff --git a/ssl/ssl_buffer.c b/ssl/ssl_buffer.c
index efa2208..df814fa 100644
--- a/ssl/ssl_buffer.c
+++ b/ssl/ssl_buffer.c
@@ -182,14 +182,13 @@
   SSL3_BUFFER *buf = &ssl->s3->read_buffer;
 
   consume_buffer(buf, len);
-  if (!SSL_IS_DTLS(ssl)) {
-    /* The TLS stack never reads beyond the current record, so there will never
-     * be unconsumed data. If read-ahead is ever reimplemented,
-     * |ssl_read_buffer_discard| will require a |memcpy| to shift the excess
-     * back to the front of the buffer, to ensure there is enough space for the
-     * next record. */
-     assert(buf->len == 0);
-  }
+
+  /* The TLS stack never reads beyond the current record, so there will never be
+   * unconsumed data. If read-ahead is ever reimplemented,
+   * |ssl_read_buffer_discard| will require a |memcpy| to shift the excess back
+   * to the front of the buffer, to ensure there is enough space for the next
+   * record. */
+  assert(SSL_IS_DTLS(ssl) || len == 0 || buf->len == 0);
 }
 
 void ssl_read_buffer_discard(SSL *ssl) {
diff --git a/ssl/tls_record.c b/ssl/tls_record.c
index 7036c87..869831c 100644
--- a/ssl/tls_record.c
+++ b/ssl/tls_record.c
@@ -113,6 +113,7 @@
 
 #include <openssl/bytestring.h>
 #include <openssl/err.h>
+#include <openssl/mem.h>
 
 #include "internal.h"
 
@@ -123,6 +124,10 @@
  * forever. */
 static const uint8_t kMaxEmptyRecords = 32;
 
+/* kMaxWarningAlerts is the number of consecutive warning alerts that will be
+ * processed. */
+static const uint8_t kMaxWarningAlerts = 4;
+
 /* ssl_needs_record_splitting returns one if |ssl|'s current outgoing cipher
  * state needs record-splitting and zero otherwise. */
 static int ssl_needs_record_splitting(const SSL *ssl) {
@@ -191,6 +196,8 @@
     SSL *ssl, uint8_t *out_type, uint8_t *out, size_t *out_len,
     size_t *out_consumed, uint8_t *out_alert, size_t max_out, const uint8_t *in,
     size_t in_len) {
+  *out_consumed = 0;
+
   CBS cbs;
   CBS_init(&cbs, in, in_len);
 
@@ -238,6 +245,8 @@
     *out_alert = SSL_AD_BAD_RECORD_MAC;
     return ssl_open_record_error;
   }
+  *out_consumed = in_len - CBS_len(&cbs);
+
   if (!ssl_record_sequence_update(ssl->s3->read_sequence, 8)) {
     *out_alert = SSL_AD_INTERNAL_ERROR;
     return ssl_open_record_error;
@@ -281,9 +290,14 @@
     ssl->s3->empty_record_count = 0;
   }
 
+  if (type == SSL3_RT_ALERT) {
+    return ssl_process_alert(ssl, out_alert, out, plaintext_len);
+  }
+
+  ssl->s3->warning_alert_count = 0;
+
   *out_type = type;
   *out_len = plaintext_len;
-  *out_consumed = in_len - CBS_len(&cbs);
   return ssl_open_record_success;
 }
 
@@ -417,3 +431,51 @@
   SSL_AEAD_CTX_free(ssl->s3->aead_write_ctx);
   ssl->s3->aead_write_ctx = aead_ctx;
 }
+
+enum ssl_open_record_t ssl_process_alert(SSL *ssl, uint8_t *out_alert,
+                                         const uint8_t *in, size_t in_len) {
+  /* Alerts records may not contain fragmented or multiple alerts. */
+  if (in_len != 2) {
+    *out_alert = SSL_AD_DECODE_ERROR;
+    OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_ALERT);
+    return ssl_open_record_error;
+  }
+
+  ssl_do_msg_callback(ssl, 0 /* read */, ssl->version, SSL3_RT_ALERT, in, in_len);
+
+  const uint8_t alert_level = in[0];
+  const uint8_t alert_descr = in[1];
+
+  uint16_t alert = (alert_level << 8) | alert_descr;
+  ssl_do_info_callback(ssl, SSL_CB_READ_ALERT, alert);
+
+  if (alert_level == SSL3_AL_WARNING) {
+    if (alert_descr == SSL_AD_CLOSE_NOTIFY) {
+      ssl->s3->recv_shutdown = ssl_shutdown_close_notify;
+      return ssl_open_record_close_notify;
+    }
+
+    ssl->s3->warning_alert_count++;
+    if (ssl->s3->warning_alert_count > kMaxWarningAlerts) {
+      *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
+      OPENSSL_PUT_ERROR(SSL, SSL_R_TOO_MANY_WARNING_ALERTS);
+      return ssl_open_record_error;
+    }
+    return ssl_open_record_discard;
+  }
+
+  if (alert_level == SSL3_AL_FATAL) {
+    ssl->s3->recv_shutdown = ssl_shutdown_fatal_alert;
+    SSL_CTX_remove_session(ssl->ctx, ssl->session);
+
+    char tmp[16];
+    OPENSSL_PUT_ERROR(SSL, SSL_AD_REASON_OFFSET + alert_descr);
+    BIO_snprintf(tmp, sizeof(tmp), "%d", alert_descr);
+    ERR_add_error_data(2, "SSL alert number ", tmp);
+    return ssl_open_record_fatal_alert;
+  }
+
+  *out_alert = SSL_AD_ILLEGAL_PARAMETER;
+  OPENSSL_PUT_ERROR(SSL, SSL_R_UNKNOWN_ALERT_TYPE);
+  return ssl_open_record_error;
+}