Don't use dtls1_read_bytes to read messages.

This was probably the worst offender of them all as read_bytes is the wrong
abstraction to begin with. Note this is a slight change in how processing a
record works. Rather than reading one fragment at a time, we process all
fragments in a record and return. The intent here is so that all records are
processed atomically since the connection eventually will not be able to retain
a buffer holding the record.

This loses a ton of (though not quite all yet) those a2b macros.

Change-Id: Ibe4bbcc33c496328de08d272457d2282c411b38b
Reviewed-on: https://boringssl-review.googlesource.com/8176
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/d1_both.c b/ssl/d1_both.c
index 11c9b65..093bb69 100644
--- a/ssl/d1_both.c
+++ b/ssl/d1_both.c
@@ -415,24 +415,6 @@
          frag->reassembly == NULL;
 }
 
-/* dtls1_discard_fragment_body discards a handshake fragment body of length
- * |frag_len|. It returns one on success and zero on error.
- *
- * TODO(davidben): This function will go away when ssl_read_bytes is gone from
- * the DTLS side. */
-static int dtls1_discard_fragment_body(SSL *ssl, size_t frag_len) {
-  uint8_t discard[256];
-  while (frag_len > 0) {
-    size_t chunk = frag_len < sizeof(discard) ? frag_len : sizeof(discard);
-    int ret = dtls1_read_bytes(ssl, SSL3_RT_HANDSHAKE, discard, chunk, 0);
-    if (ret != (int) chunk) {
-      return 0;
-    }
-    frag_len -= chunk;
-  }
-  return 1;
-}
-
 /* dtls1_get_buffered_message returns the buffered message corresponding to
  * |msg_hdr|. If none exists, it creates a new one and inserts it in the
  * queue. Otherwise, it checks |msg_hdr| is consistent with the existing one. It
@@ -478,75 +460,92 @@
   return frag;
 }
 
-/* dtls1_process_fragment reads a handshake fragment and processes it. It
- * returns one if a fragment was successfully processed and 0 or -1 on error. */
-static int dtls1_process_fragment(SSL *ssl) {
-  /* Read handshake message header. */
-  uint8_t header[DTLS1_HM_HEADER_LENGTH];
-  int ret = dtls1_read_bytes(ssl, SSL3_RT_HANDSHAKE, header,
-                             DTLS1_HM_HEADER_LENGTH, 0);
-  if (ret <= 0) {
-    return ret;
+/* dtls1_process_handshake_record reads a handshake record and processes it. It
+ * returns one if the record was successfully processed and 0 or -1 on error. */
+static int dtls1_process_handshake_record(SSL *ssl) {
+  SSL3_RECORD *rr = &ssl->s3->rrec;
+
+start:
+  if (rr->length == 0) {
+    int ret = dtls1_get_record(ssl);
+    if (ret <= 0) {
+      return ret;
+    }
   }
-  if (ret != DTLS1_HM_HEADER_LENGTH) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
+
+  /* Cross-epoch records are discarded, but we may receive out-of-order
+   * application data between ChangeCipherSpec and Finished or a ChangeCipherSpec
+   * before the appropriate point in the handshake. Those must be silently
+   * discarded.
+   *
+   * However, only allow the out-of-order records in the correct epoch.
+   * Application data must come in the encrypted epoch, and ChangeCipherSpec in
+   * the unencrypted epoch (we never renegotiate). Other cases fall through and
+   * fail with a fatal error. */
+  if ((rr->type == SSL3_RT_APPLICATION_DATA &&
+       ssl->s3->aead_read_ctx != NULL) ||
+      (rr->type == SSL3_RT_CHANGE_CIPHER_SPEC &&
+       ssl->s3->aead_read_ctx == NULL)) {
+    rr->length = 0;
+    goto start;
+  }
+
+  if (rr->type != SSL3_RT_HANDSHAKE) {
     ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
+    OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
     return -1;
   }
 
-  /* Parse the message fragment header. */
-  struct hm_header_st msg_hdr;
-  dtls1_get_message_header(header, &msg_hdr);
+  CBS cbs;
+  CBS_init(&cbs, rr->data, rr->length);
 
-  /* TODO(davidben): dtls1_read_bytes is the wrong abstraction for DTLS. There
-   * should be no need to reach into |ssl->s3->rrec.length|. */
-  const size_t frag_off = msg_hdr.frag_off;
-  const size_t frag_len = msg_hdr.frag_len;
-  const size_t msg_len = msg_hdr.msg_len;
-  if (frag_off > msg_len || frag_off + frag_len < frag_off ||
-      frag_off + frag_len > msg_len ||
-      msg_len > ssl_max_handshake_message_len(ssl) ||
-      frag_len > ssl->s3->rrec.length) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
-    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
-    return -1;
-  }
-
-  if (msg_hdr.seq < ssl->d1->handshake_read_seq ||
-      msg_hdr.seq > (unsigned)ssl->d1->handshake_read_seq +
-                    kHandshakeBufferSize) {
-    /* Ignore fragments from the past, or ones too far in the future. */
-    if (!dtls1_discard_fragment_body(ssl, frag_len)) {
+  while (CBS_len(&cbs) > 0) {
+    /* Read a handshake fragment. */
+    struct hm_header_st msg_hdr;
+    CBS body;
+    if (!dtls1_parse_fragment(&cbs, &msg_hdr, &body)) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD);
+      ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
       return -1;
     }
-    return 1;
-  }
 
-  hm_fragment *frag = dtls1_get_buffered_message(ssl, &msg_hdr);
-  if (frag == NULL) {
-    return -1;
-  }
-  assert(frag->msg_header.msg_len == msg_len);
-
-  if (frag->reassembly == NULL) {
-    /* The message is already assembled. */
-    if (!dtls1_discard_fragment_body(ssl, frag_len)) {
+    const size_t frag_off = msg_hdr.frag_off;
+    const size_t frag_len = msg_hdr.frag_len;
+    const size_t msg_len = msg_hdr.msg_len;
+    if (frag_off > msg_len || frag_off + frag_len < frag_off ||
+        frag_off + frag_len > msg_len ||
+        msg_len > ssl_max_handshake_message_len(ssl)) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
+      ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
       return -1;
     }
-    return 1;
-  }
-  assert(msg_len > 0);
 
-  /* Read the body of the fragment. */
-  ret = dtls1_read_bytes(ssl, SSL3_RT_HANDSHAKE, frag->fragment + frag_off,
-                         frag_len, 0);
-  if (ret != (int) frag_len) {
-    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
-    return -1;
-  }
-  dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len);
+    if (msg_hdr.seq < ssl->d1->handshake_read_seq ||
+        msg_hdr.seq >
+            (unsigned)ssl->d1->handshake_read_seq + kHandshakeBufferSize) {
+      /* Ignore fragments from the past, or ones too far in the future. */
+      continue;
+    }
 
+    hm_fragment *frag = dtls1_get_buffered_message(ssl, &msg_hdr);
+    if (frag == NULL) {
+      return -1;
+    }
+    assert(frag->msg_header.msg_len == msg_len);
+
+    if (frag->reassembly == NULL) {
+      /* The message is already assembled. */
+      continue;
+    }
+    assert(msg_len > 0);
+
+    /* Copy the body into the fragment. */
+    memcpy(frag->fragment + frag_off, CBS_data(&body), CBS_len(&body));
+    dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len);
+  }
+
+  rr->length = 0;
+  ssl_read_buffer_discard(ssl);
   return 1;
 }
 
@@ -579,9 +578,9 @@
     return ssl->init_num;
   }
 
-  /* Process fragments until one is found. */
+  /* Process handshake records until the next message is ready. */
   while (!dtls1_is_next_message_complete(ssl)) {
-    int ret = dtls1_process_fragment(ssl);
+    int ret = dtls1_process_handshake_record(ssl);
     if (ret <= 0) {
       *ok = 0;
       return ret;
@@ -835,13 +834,18 @@
   return kMinMTU;
 }
 
-void dtls1_get_message_header(uint8_t *data,
-                              struct hm_header_st *msg_hdr) {
-  memset(msg_hdr, 0x00, sizeof(struct hm_header_st));
-  msg_hdr->type = *(data++);
-  n2l3(data, msg_hdr->msg_len);
+int dtls1_parse_fragment(CBS *cbs, struct hm_header_st *out_hdr,
+                         CBS *out_body) {
+  memset(out_hdr, 0x00, sizeof(struct hm_header_st));
 
-  n2s(data, msg_hdr->seq);
-  n2l3(data, msg_hdr->frag_off);
-  n2l3(data, msg_hdr->frag_len);
+  if (!CBS_get_u8(cbs, &out_hdr->type) ||
+      !CBS_get_u24(cbs, &out_hdr->msg_len) ||
+      !CBS_get_u16(cbs, &out_hdr->seq) ||
+      !CBS_get_u24(cbs, &out_hdr->frag_off) ||
+      !CBS_get_u24(cbs, &out_hdr->frag_len) ||
+      !CBS_get_bytes(cbs, out_body, out_hdr->frag_len)) {
+    return 0;
+  }
+
+  return 1;
 }
diff --git a/ssl/d1_pkt.c b/ssl/d1_pkt.c
index d0a884a..c67a7ae 100644
--- a/ssl/d1_pkt.c
+++ b/ssl/d1_pkt.c
@@ -116,6 +116,7 @@
 
 #include <openssl/bio.h>
 #include <openssl/buf.h>
+#include <openssl/bytestring.h>
 #include <openssl/mem.h>
 #include <openssl/evp.h>
 #include <openssl/err.h>
@@ -127,10 +128,7 @@
 static int do_dtls1_write(SSL *ssl, int type, const uint8_t *buf,
                           unsigned int len, enum dtls1_use_epoch_t use_epoch);
 
-/* dtls1_get_record reads a new input record. On success, it places it in
- * |ssl->s3->rrec| and returns one. Otherwise it returns <= 0 on error or if
- * more data is needed. */
-static int dtls1_get_record(SSL *ssl) {
+int dtls1_get_record(SSL *ssl) {
 again:
   switch (ssl->s3->recv_shutdown) {
     case ssl_shutdown_none:
@@ -258,10 +256,7 @@
 }
 
 /* Return up to 'len' payload bytes received in 'type' records.
- * 'type' is one of the following:
- *
- *   -  SSL3_RT_HANDSHAKE (when dtls1_get_message calls us)
- *   -  SSL3_RT_APPLICATION_DATA (when dtls1_read_app_data calls us)
+ * 'type' must be SSL3_RT_APPLICATION_DATA (when dtls1_read_app_data calls us).
  *
  * If we don't have stored data to work from, read a DTLS record first (possibly
  * multiple records if we still don't have anything to return).
@@ -273,8 +268,7 @@
   unsigned int n;
   SSL3_RECORD *rr;
 
-  if ((type != SSL3_RT_APPLICATION_DATA && type != SSL3_RT_HANDSHAKE) ||
-      (peek && type != SSL3_RT_APPLICATION_DATA)) {
+  if (type != SSL3_RT_APPLICATION_DATA) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return -1;
   }
@@ -327,35 +321,18 @@
 
   /* If we get here, then type != rr->type. */
 
-  /* Cross-epoch records are discarded, but we may receive out-of-order
-   * application data between ChangeCipherSpec and Finished or a ChangeCipherSpec
-   * before the appropriate point in the handshake. Those must be silently
-   * discarded.
-   *
-   * However, only allow the out-of-order records in the correct epoch.
-   * Application data must come in the encrypted epoch, and ChangeCipherSpec in
-   * the unencrypted epoch (we never renegotiate). Other cases fall through and
-   * fail with a fatal error. */
-  if ((rr->type == SSL3_RT_APPLICATION_DATA &&
-       ssl->s3->aead_read_ctx != NULL) ||
-      (rr->type == SSL3_RT_CHANGE_CIPHER_SPEC &&
-       ssl->s3->aead_read_ctx == NULL)) {
-    rr->length = 0;
-    goto start;
-  }
-
   if (rr->type == SSL3_RT_HANDSHAKE) {
-    assert(type == SSL3_RT_APPLICATION_DATA);
     /* Parse the first fragment header to determine if this is a pre-CCS or
      * post-CCS handshake record. DTLS resets handshake message numbers on each
      * handshake, so renegotiations and retransmissions are ambiguous. */
-    if (rr->length < DTLS1_HM_HEADER_LENGTH) {
+    CBS cbs, body;
+    struct hm_header_st msg_hdr;
+    CBS_init(&cbs, rr->data, rr->length);
+    if (!dtls1_parse_fragment(&cbs, &msg_hdr, &body)) {
       al = SSL_AD_DECODE_ERROR;
       OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD);
       goto f_err;
     }
-    struct hm_header_st msg_hdr;
-    dtls1_get_message_header(rr->data, &msg_hdr);
 
     if (msg_hdr.type == SSL3_MT_FINISHED) {
       if (msg_hdr.frag_off == 0) {
diff --git a/ssl/internal.h b/ssl/internal.h
index 4856969..1b2c544 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -704,100 +704,17 @@
  *
  * Functions below here haven't been touched up and may be underdocumented. */
 
-#define c2l(c, l)                                                            \
-  (l = ((unsigned long)(*((c)++))), l |= (((unsigned long)(*((c)++))) << 8), \
-   l |= (((unsigned long)(*((c)++))) << 16),                                 \
-   l |= (((unsigned long)(*((c)++))) << 24))
-
-/* NOTE - c is not incremented as per c2l */
-#define c2ln(c, l1, l2, n)                       \
-  {                                              \
-    c += n;                                      \
-    l1 = l2 = 0;                                 \
-    switch (n) {                                 \
-      case 8:                                    \
-        l2 = ((unsigned long)(*(--(c)))) << 24;  \
-      case 7:                                    \
-        l2 |= ((unsigned long)(*(--(c)))) << 16; \
-      case 6:                                    \
-        l2 |= ((unsigned long)(*(--(c)))) << 8;  \
-      case 5:                                    \
-        l2 |= ((unsigned long)(*(--(c))));       \
-      case 4:                                    \
-        l1 = ((unsigned long)(*(--(c)))) << 24;  \
-      case 3:                                    \
-        l1 |= ((unsigned long)(*(--(c)))) << 16; \
-      case 2:                                    \
-        l1 |= ((unsigned long)(*(--(c)))) << 8;  \
-      case 1:                                    \
-        l1 |= ((unsigned long)(*(--(c))));       \
-    }                                            \
-  }
-
-#define l2c(l, c)                            \
-  (*((c)++) = (uint8_t)(((l)) & 0xff),       \
-   *((c)++) = (uint8_t)(((l) >> 8) & 0xff),  \
-   *((c)++) = (uint8_t)(((l) >> 16) & 0xff), \
-   *((c)++) = (uint8_t)(((l) >> 24) & 0xff))
-
-#define n2l(c, l)                          \
-  (l = ((unsigned long)(*((c)++))) << 24,  \
-   l |= ((unsigned long)(*((c)++))) << 16, \
-   l |= ((unsigned long)(*((c)++))) << 8, l |= ((unsigned long)(*((c)++))))
-
 #define l2n(l, c)                            \
   (*((c)++) = (uint8_t)(((l) >> 24) & 0xff), \
    *((c)++) = (uint8_t)(((l) >> 16) & 0xff), \
    *((c)++) = (uint8_t)(((l) >> 8) & 0xff),  \
    *((c)++) = (uint8_t)(((l)) & 0xff))
 
-#define l2n8(l, c)                           \
-  (*((c)++) = (uint8_t)(((l) >> 56) & 0xff), \
-   *((c)++) = (uint8_t)(((l) >> 48) & 0xff), \
-   *((c)++) = (uint8_t)(((l) >> 40) & 0xff), \
-   *((c)++) = (uint8_t)(((l) >> 32) & 0xff), \
-   *((c)++) = (uint8_t)(((l) >> 24) & 0xff), \
-   *((c)++) = (uint8_t)(((l) >> 16) & 0xff), \
-   *((c)++) = (uint8_t)(((l) >> 8) & 0xff),  \
-   *((c)++) = (uint8_t)(((l)) & 0xff))
-
-/* NOTE - c is not incremented as per l2c */
-#define l2cn(l1, l2, c, n)                               \
-  {                                                      \
-    c += n;                                              \
-    switch (n) {                                         \
-      case 8:                                            \
-        *(--(c)) = (uint8_t)(((l2) >> 24) & 0xff); \
-      case 7:                                            \
-        *(--(c)) = (uint8_t)(((l2) >> 16) & 0xff); \
-      case 6:                                            \
-        *(--(c)) = (uint8_t)(((l2) >> 8) & 0xff);  \
-      case 5:                                            \
-        *(--(c)) = (uint8_t)(((l2)) & 0xff);       \
-      case 4:                                            \
-        *(--(c)) = (uint8_t)(((l1) >> 24) & 0xff); \
-      case 3:                                            \
-        *(--(c)) = (uint8_t)(((l1) >> 16) & 0xff); \
-      case 2:                                            \
-        *(--(c)) = (uint8_t)(((l1) >> 8) & 0xff);  \
-      case 1:                                            \
-        *(--(c)) = (uint8_t)(((l1)) & 0xff);       \
-    }                                                    \
-  }
-
-#define n2s(c, s) \
-  ((s = (((unsigned int)(c[0])) << 8) | (((unsigned int)(c[1])))), c += 2)
-
 #define s2n(s, c)                              \
   ((c[0] = (uint8_t)(((s) >> 8) & 0xff), \
     c[1] = (uint8_t)(((s)) & 0xff)),     \
    c += 2)
 
-#define n2l3(c, l)                                                         \
-  ((l = (((unsigned long)(c[0])) << 16) | (((unsigned long)(c[1])) << 8) | \
-        (((unsigned long)(c[2])))),                                        \
-   c += 3)
-
 #define l2n3(l, c)                              \
   ((c[0] = (uint8_t)(((l) >> 16) & 0xff), \
     c[1] = (uint8_t)(((l) >> 8) & 0xff),  \
@@ -1126,6 +1043,12 @@
 int ssl3_handshake_write(SSL *ssl);
 
 int dtls1_do_handshake_write(SSL *ssl, enum dtls1_use_epoch_t use_epoch);
+
+/* dtls1_get_record reads a new input record. On success, it places it in
+ * |ssl->s3->rrec| and returns one. Otherwise it returns <= 0 on error or if
+ * more data is needed. */
+int dtls1_get_record(SSL *ssl);
+
 int dtls1_read_app_data(SSL *ssl, uint8_t *buf, int len, int peek);
 int dtls1_read_change_cipher_spec(SSL *ssl);
 void dtls1_read_close_notify(SSL *ssl);
@@ -1143,7 +1066,8 @@
 int dtls1_buffer_message(SSL *ssl);
 int dtls1_retransmit_buffered_messages(SSL *ssl);
 void dtls1_clear_record_buffer(SSL *ssl);
-void dtls1_get_message_header(uint8_t *data, struct hm_header_st *msg_hdr);
+int dtls1_parse_fragment(CBS *cbs, struct hm_header_st *out_hdr,
+                         CBS *out_body);
 int dtls1_check_timeout_num(SSL *ssl);
 int dtls1_set_handshake_header(SSL *ssl, int type, unsigned long len);
 int dtls1_handshake_write(SSL *ssl);
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 1b0bf54..71d14d2 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -1772,7 +1772,7 @@
 				},
 			},
 			shouldFail:    true,
-			expectedError: ":UNEXPECTED_MESSAGE:",
+			expectedError: ":BAD_HANDSHAKE_RECORD:",
 		},
 		{
 			protocol: dtls,
@@ -1783,7 +1783,7 @@
 				},
 			},
 			shouldFail:    true,
-			expectedError: ":EXCESSIVE_MESSAGE_SIZE:",
+			expectedError: ":BAD_HANDSHAKE_RECORD:",
 		},
 		{
 			protocol: dtls,
@@ -1794,7 +1794,7 @@
 				},
 			},
 			shouldFail:    true,
-			expectedError: ":EXCESSIVE_MESSAGE_SIZE:",
+			expectedError: ":BAD_HANDSHAKE_RECORD:",
 		},
 		{
 			protocol: dtls,