Don't use init_buf in DTLS.

This machinery is so different between TLS and DTLS that there is no
sense in having them share structures. This switches us to maintaining
the full reassembled message in hm_fragment and get_message just lets
the caller read out of that when ready.

This removes the last direct handshake dependency on init_buf,
ssl3_hash_message.

Change-Id: I4eccfb6e6021116255daead5359a0aa3f4d5be7b
Reviewed-on: https://boringssl-review.googlesource.com/8667
Reviewed-by: Steven Valdez <svaldez@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/d1_both.c b/ssl/d1_both.c
index 6aa4cc6..3da1fd8 100644
--- a/ssl/d1_both.c
+++ b/ssl/d1_both.c
@@ -146,34 +146,50 @@
   if (frag == NULL) {
     return;
   }
-  OPENSSL_free(frag->fragment);
+  OPENSSL_free(frag->data);
   OPENSSL_free(frag->reassembly);
   OPENSSL_free(frag);
 }
 
-static hm_fragment *dtls1_hm_fragment_new(size_t frag_len) {
+static hm_fragment *dtls1_hm_fragment_new(const struct hm_header_st *msg_hdr) {
   hm_fragment *frag = OPENSSL_malloc(sizeof(hm_fragment));
   if (frag == NULL) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
     return NULL;
   }
   memset(frag, 0, sizeof(hm_fragment));
+  frag->type = msg_hdr->type;
+  frag->seq = msg_hdr->seq;
+  frag->msg_len = msg_hdr->msg_len;
 
-  /* If the handshake message is empty, |frag->fragment| and |frag->reassembly|
-   * are NULL. */
-  if (frag_len > 0) {
-    frag->fragment = OPENSSL_malloc(frag_len);
-    if (frag->fragment == NULL) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-      goto err;
-    }
+  /* Allocate space for the reassembled message and fill in the header. */
+  frag->data = OPENSSL_malloc(DTLS1_HM_HEADER_LENGTH + msg_hdr->msg_len);
+  if (frag->data == NULL) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+    goto err;
+  }
 
+  CBB cbb;
+  if (!CBB_init_fixed(&cbb, frag->data, DTLS1_HM_HEADER_LENGTH) ||
+      !CBB_add_u8(&cbb, msg_hdr->type) ||
+      !CBB_add_u24(&cbb, msg_hdr->msg_len) ||
+      !CBB_add_u16(&cbb, msg_hdr->seq) ||
+      !CBB_add_u24(&cbb, 0 /* frag_off */) ||
+      !CBB_add_u24(&cbb, msg_hdr->msg_len) ||
+      !CBB_finish(&cbb, NULL, NULL)) {
+    CBB_cleanup(&cbb);
+    OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+    goto err;
+  }
+
+  /* If the handshake message is empty, |frag->reassembly| is NULL. */
+  if (msg_hdr->msg_len > 0) {
     /* Initialize reassembly bitmask. */
-    if (frag_len + 7 < frag_len) {
+    if (msg_hdr->msg_len + 7 < msg_hdr->msg_len) {
       OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
       goto err;
     }
-    size_t bitmask_len = (frag_len + 7) / 8;
+    size_t bitmask_len = (msg_hdr->msg_len + 7) / 8;
     frag->reassembly = OPENSSL_malloc(bitmask_len);
     if (frag->reassembly == NULL) {
       OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
@@ -202,7 +218,7 @@
 static void dtls1_hm_fragment_mark(hm_fragment *frag, size_t start,
                                    size_t end) {
   size_t i;
-  size_t msg_len = frag->msg_header.msg_len;
+  size_t msg_len = frag->msg_len;
 
   if (frag->reassembly == NULL || start > end || end > msg_len) {
     assert(0);
@@ -238,14 +254,25 @@
   frag->reassembly = NULL;
 }
 
-/* dtls1_is_next_message_complete returns one if the next handshake message is
- * complete and zero otherwise. */
-static int dtls1_is_next_message_complete(SSL *ssl) {
+/* dtls1_is_current_message_complete returns one if the current handshake
+ * message is complete and zero otherwise. */
+static int dtls1_is_current_message_complete(SSL *ssl) {
   hm_fragment *frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
                                                  SSL_MAX_HANDSHAKE_FLIGHT];
   return frag != NULL && frag->reassembly == NULL;
 }
 
+/* dtls1_pop_message removes the current handshake message, which must be
+ * complete, and advances to the next one. */
+static void dtls1_pop_message(SSL *ssl) {
+  assert(dtls1_is_current_message_complete(ssl));
+
+  size_t index = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
+  dtls1_hm_fragment_free(ssl->d1->incoming_messages[index]);
+  ssl->d1->incoming_messages[index] = NULL;
+  ssl->d1->handshake_read_seq++;
+}
+
 /* dtls1_get_incoming_message returns the incoming message corresponding to
  * |msg_hdr|. If none exists, it creates a new one and inserts it in the
  * queue. Otherwise, it checks |msg_hdr| is consistent with the existing one. It
@@ -260,11 +287,11 @@
   size_t idx = msg_hdr->seq % SSL_MAX_HANDSHAKE_FLIGHT;
   hm_fragment *frag = ssl->d1->incoming_messages[idx];
   if (frag != NULL) {
-    assert(frag->msg_header.seq == msg_hdr->seq);
+    assert(frag->seq == msg_hdr->seq);
     /* The new fragment must be compatible with the previous fragments from this
      * message. */
-    if (frag->msg_header.type != msg_hdr->type ||
-        frag->msg_header.msg_len != msg_hdr->msg_len) {
+    if (frag->type != msg_hdr->type ||
+        frag->msg_len != msg_hdr->msg_len) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_FRAGMENT_MISMATCH);
       ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
       return NULL;
@@ -273,11 +300,10 @@
   }
 
   /* This is the first fragment from this message. */
-  frag = dtls1_hm_fragment_new(msg_hdr->msg_len);
+  frag = dtls1_hm_fragment_new(msg_hdr);
   if (frag == NULL) {
     return NULL;
   }
-  memcpy(&frag->msg_header, msg_hdr, sizeof(*msg_hdr));
   ssl->d1->incoming_messages[idx] = frag;
   return frag;
 }
@@ -353,7 +379,7 @@
     if (frag == NULL) {
       return -1;
     }
-    assert(frag->msg_header.msg_len == msg_len);
+    assert(frag->msg_len == msg_len);
 
     if (frag->reassembly == NULL) {
       /* The message is already assembled. */
@@ -362,7 +388,8 @@
     assert(msg_len > 0);
 
     /* Copy the body into the fragment. */
-    memcpy(frag->fragment + frag_off, CBS_data(&body), CBS_len(&body));
+    memcpy(frag->data + DTLS1_HM_HEADER_LENGTH + frag_off, CBS_data(&body),
+           CBS_len(&body));
     dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len);
   }
 
@@ -376,101 +403,63 @@
  * arrive in fragments. */
 long dtls1_get_message(SSL *ssl, int msg_type,
                        enum ssl_hash_message_t hash_message, int *ok) {
-  hm_fragment *frag = NULL;
-  int al;
+  *ok = 0;
 
-  /* s3->tmp is used to store messages that are unexpected, caused
-   * by the absence of an optional handshake message */
   if (ssl->s3->tmp.reuse_message) {
     /* A ssl_dont_hash_message call cannot be combined with reuse_message; the
      * ssl_dont_hash_message would have to have been applied to the previous
      * call. */
     assert(hash_message == ssl_hash_message);
+    assert(dtls1_is_current_message_complete(ssl));
+
     ssl->s3->tmp.reuse_message = 0;
-    if (msg_type >= 0 && ssl->s3->tmp.message_type != msg_type) {
-      al = SSL_AD_UNEXPECTED_MESSAGE;
-      OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
-      goto f_err;
-    }
-    *ok = 1;
-    assert(ssl->init_buf->length >= DTLS1_HM_HEADER_LENGTH);
-    ssl->init_msg = (uint8_t *)ssl->init_buf->data + DTLS1_HM_HEADER_LENGTH;
-    ssl->init_num = (int)ssl->init_buf->length - DTLS1_HM_HEADER_LENGTH;
-    return ssl->init_num;
+    hash_message = ssl_dont_hash_message;
+  } else if (dtls1_is_current_message_complete(ssl)) {
+    dtls1_pop_message(ssl);
   }
 
-  /* Process handshake records until the next message is ready. */
-  while (!dtls1_is_next_message_complete(ssl)) {
+  /* Process handshake records until the current message is ready. */
+  while (!dtls1_is_current_message_complete(ssl)) {
     int ret = dtls1_process_handshake_record(ssl);
     if (ret <= 0) {
-      *ok = 0;
       return ret;
     }
   }
 
-  /* Pop an entry from the ring buffer. */
-  frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
-                                    SSL_MAX_HANDSHAKE_FLIGHT];
-  ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
-                             SSL_MAX_HANDSHAKE_FLIGHT] = NULL;
-
+  hm_fragment *frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
+                                                 SSL_MAX_HANDSHAKE_FLIGHT];
   assert(frag != NULL);
   assert(frag->reassembly == NULL);
-  assert(ssl->d1->handshake_read_seq == frag->msg_header.seq);
-
-  ssl->d1->handshake_read_seq++;
-
-  /* Reconstruct the assembled message. */
-  CBB cbb;
-  CBB_zero(&cbb);
-  if (!BUF_MEM_reserve(ssl->init_buf, (size_t)frag->msg_header.msg_len +
-                                          DTLS1_HM_HEADER_LENGTH) ||
-      !CBB_init_fixed(&cbb, (uint8_t *)ssl->init_buf->data,
-                      ssl->init_buf->max) ||
-      !CBB_add_u8(&cbb, frag->msg_header.type) ||
-      !CBB_add_u24(&cbb, frag->msg_header.msg_len) ||
-      !CBB_add_u16(&cbb, frag->msg_header.seq) ||
-      !CBB_add_u24(&cbb, 0 /* frag_off */) ||
-      !CBB_add_u24(&cbb, frag->msg_header.msg_len) ||
-      !CBB_add_bytes(&cbb, frag->fragment, frag->msg_header.msg_len) ||
-      !CBB_finish(&cbb, NULL, &ssl->init_buf->length)) {
-    CBB_cleanup(&cbb);
-    OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-    goto err;
-  }
-  assert(ssl->init_buf->length ==
-         (size_t)frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH);
+  assert(ssl->d1->handshake_read_seq == frag->seq);
 
   /* TODO(davidben): This function has a lot of implicit outputs. Simplify the
    * |ssl_get_message| API. */
-  ssl->s3->tmp.message_type = frag->msg_header.type;
-  ssl->init_msg = (uint8_t *)ssl->init_buf->data + DTLS1_HM_HEADER_LENGTH;
-  ssl->init_num = frag->msg_header.msg_len;
+  ssl->s3->tmp.message_type = frag->type;
+  ssl->init_msg = frag->data + DTLS1_HM_HEADER_LENGTH;
+  ssl->init_num = frag->msg_len;
 
   if (msg_type >= 0 && ssl->s3->tmp.message_type != msg_type) {
-    al = SSL_AD_UNEXPECTED_MESSAGE;
+    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
     OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
-    goto f_err;
+    return -1;
   }
-  if (hash_message == ssl_hash_message && !ssl3_hash_current_message(ssl)) {
-    goto err;
+  if (hash_message == ssl_hash_message && !dtls1_hash_current_message(ssl)) {
+    return -1;
   }
 
   ssl_do_msg_callback(ssl, 0 /* read */, ssl->version, SSL3_RT_HANDSHAKE,
-                      ssl->init_buf->data,
-                      ssl->init_num + DTLS1_HM_HEADER_LENGTH);
-
-  dtls1_hm_fragment_free(frag);
-
+                      frag->data, ssl->init_num + DTLS1_HM_HEADER_LENGTH);
   *ok = 1;
   return ssl->init_num;
+}
 
-f_err:
-  ssl3_send_alert(ssl, SSL3_AL_FATAL, al);
-err:
-  dtls1_hm_fragment_free(frag);
-  *ok = 0;
-  return -1;
+int dtls1_hash_current_message(SSL *ssl) {
+  assert(dtls1_is_current_message_complete(ssl));
+
+  hm_fragment *frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
+                                                 SSL_MAX_HANDSHAKE_FLIGHT];
+  return ssl3_update_handshake_hash(ssl, frag->data,
+                                    DTLS1_HM_HEADER_LENGTH + frag->msg_len);
 }
 
 void dtls_clear_incoming_messages(SSL *ssl) {
diff --git a/ssl/dtls_method.c b/ssl/dtls_method.c
index f6376bb..09c7d40 100644
--- a/ssl/dtls_method.c
+++ b/ssl/dtls_method.c
@@ -91,26 +91,10 @@
 }
 
 static int dtls1_begin_handshake(SSL *ssl) {
-  if (ssl->init_buf != NULL) {
-    return 1;
-  }
-
-  BUF_MEM *buf = BUF_MEM_new();
-  if (buf == NULL || !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
-    BUF_MEM_free(buf);
-    return 0;
-  }
-
-  ssl->init_buf = buf;
-  ssl->init_num = 0;
   return 1;
 }
 
 static void dtls1_finish_handshake(SSL *ssl) {
-  BUF_MEM_free(ssl->init_buf);
-  ssl->init_buf = NULL;
-  ssl->init_num = 0;
-
   ssl->d1->handshake_read_seq = 0;
   ssl->d1->handshake_write_seq = 0;
   dtls_clear_incoming_messages(ssl);
@@ -127,6 +111,7 @@
     dtls1_begin_handshake,
     dtls1_finish_handshake,
     dtls1_get_message,
+    dtls1_hash_current_message,
     dtls1_read_app_data,
     dtls1_read_change_cipher_spec,
     dtls1_read_close_notify,
diff --git a/ssl/handshake_server.c b/ssl/handshake_server.c
index 375f0e3..6dea88e 100644
--- a/ssl/handshake_server.c
+++ b/ssl/handshake_server.c
@@ -1764,7 +1764,7 @@
   /* The handshake buffer is no longer necessary, and we may hash the current
    * message.*/
   ssl3_free_handshake_buffer(ssl);
-  if (!ssl3_hash_current_message(ssl)) {
+  if (!ssl->method->hash_current_message(ssl)) {
     goto err;
   }
 
@@ -1837,7 +1837,7 @@
   }
   assert(channel_id_hash_len == SHA256_DIGEST_LENGTH);
 
-  if (!ssl3_hash_current_message(ssl)) {
+  if (!ssl->method->hash_current_message(ssl)) {
     return -1;
   }
 
diff --git a/ssl/internal.h b/ssl/internal.h
index 2e4cb46..bdc1230 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -836,6 +836,10 @@
   void (*finish_handshake)(SSL *ssl);
   long (*ssl_get_message)(SSL *ssl, int msg_type,
                           enum ssl_hash_message_t hash_message, int *ok);
+  /* hash_current_message incorporates the current handshake message into the
+   * handshake hash. It returns one on success and zero on allocation
+   * failure. */
+  int (*hash_current_message)(SSL *ssl);
   int (*read_app_data)(SSL *ssl, uint8_t *buf, int len, int peek);
   int (*read_change_cipher_spec)(SSL *ssl);
   void (*read_close_notify)(SSL *ssl);
@@ -897,9 +901,19 @@
   uint32_t frag_len;
 };
 
+/* An hm_fragment is an incoming DTLS message, possibly not yet assembled. */
 typedef struct hm_fragment_st {
-  struct hm_header_st msg_header;
-  uint8_t *fragment;
+  /* type is the type of the message. */
+  uint8_t type;
+  /* seq is the sequence number of this message. */
+  uint16_t seq;
+  /* msg_len is the length of the message body. */
+  uint32_t msg_len;
+  /* data is a pointer to the message, including message header. It has length
+   * |DTLS1_HM_HEADER_LENGTH| + |msg_len|. */
+  uint8_t *data;
+  /* reassembly is a bitmask of |msg_len| bits corresponding to which parts of
+   * the message have been received. It is NULL if the message is complete. */
   uint8_t *reassembly;
 } hm_fragment;
 
@@ -1013,9 +1027,6 @@
 int ssl3_send_alert(SSL *ssl, int level, int desc);
 long ssl3_get_message(SSL *ssl, int msg_type,
                       enum ssl_hash_message_t hash_message, int *ok);
-
-/* ssl3_hash_current_message incorporates the current handshake message into the
- * handshake hash. It returns one on success and zero on allocation failure. */
 int ssl3_hash_current_message(SSL *ssl);
 
 /* ssl3_cert_verify_hash writes the SSL 3.0 CertificateVerify hash into the
@@ -1096,6 +1107,7 @@
 
 long dtls1_get_message(SSL *ssl, int mt, enum ssl_hash_message_t hash_message,
                        int *ok);
+int dtls1_hash_current_message(SSL *ssl);
 int dtls1_dispatch_alert(SSL *ssl);
 
 /* ssl_is_wbio_buffered returns one if |ssl|'s write BIO is buffered and zero
diff --git a/ssl/s3_both.c b/ssl/s3_both.c
index b1947c9..42ec70e 100644
--- a/ssl/s3_both.c
+++ b/ssl/s3_both.c
@@ -274,7 +274,7 @@
 
   /* Snapshot the finished hash before incorporating the new message. */
   ssl3_take_mac(ssl);
-  if (!ssl3_hash_current_message(ssl)) {
+  if (!ssl->method->hash_current_message(ssl)) {
     goto err;
   }
 
@@ -606,11 +606,8 @@
 }
 
 int ssl3_hash_current_message(SSL *ssl) {
-  /* The handshake header (different size between DTLS and TLS) is included in
-   * the hash. */
-  size_t header_len = ssl->init_msg - (uint8_t *)ssl->init_buf->data;
   return ssl3_update_handshake_hash(ssl, (uint8_t *)ssl->init_buf->data,
-                                    ssl->init_num + header_len);
+                                    ssl->init_buf->length);
 }
 
 int ssl_verify_alarm_type(long type) {
diff --git a/ssl/tls_method.c b/ssl/tls_method.c
index dab5c47..17905a9 100644
--- a/ssl/tls_method.c
+++ b/ssl/tls_method.c
@@ -100,6 +100,7 @@
     ssl3_begin_handshake,
     ssl3_finish_handshake,
     ssl3_get_message,
+    ssl3_hash_current_message,
     ssl3_read_app_data,
     ssl3_read_change_cipher_spec,
     ssl3_read_close_notify,