Rewrite DTLS handshake message sending logic.

This fixes a number of bugs with the original logic:

- If handshake messages are fragmented and writes need to be retried, frag_off
  gets completely confused.

- The BIO_flush call didn't set rwstate, so it wasn't resumable at that point.

- The msg_callback call gets garbage because the fragment header would get
  scribbled over the handshake buffer.

The original logic was also extremely confusing with how it handles init_off.
(init_off gets rewound to make room for the fragment header.  Depending on
where you pause, resuming may or may not have already been rewound.)

For simplicity, just allocate a new buffer to assemble the fragment in and
avoid clobbering the old one. I don't think it's worth the complexity to
optimize that. If we want to optimize this sort of thing, not clobbering seems
better anyway because the message may need to be retransmitted. We could avoid
doing a copy when buffering the outgoing message for retransmission later.

We do still need to track how far we are in sending the current message via
init_off, so I haven't opted to disconnect this function from
init_{buf,off,num} yet.

Test the fix to the retry + fragment case by having the splitHandshake option
to the state machine tests, in DTLS, also clamp the MTU to force handshake
fragmentation.

Change-Id: I66f634d6c752ea63649db8ed2f898f9cc2b13908
Reviewed-on: https://boringssl-review.googlesource.com/6421
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/d1_both.c b/ssl/d1_both.c
index 2882320..a940af6 100644
--- a/ssl/d1_both.c
+++ b/ssl/d1_both.c
@@ -145,10 +145,6 @@
  * current one to buffer. */
 static const unsigned int kHandshakeBufferSize = 10;
 
-static void dtls1_fix_message_header(SSL *s, unsigned long frag_off,
-                                     unsigned long frag_len);
-static unsigned char *dtls1_write_message_header(SSL *s, unsigned char *p);
-
 static hm_fragment *dtls1_hm_fragment_new(size_t frag_len, int reassembly) {
   hm_fragment *frag = OPENSSL_malloc(sizeof(hm_fragment));
   if (frag == NULL) {
@@ -268,14 +264,34 @@
   assert(ssl->d1->mtu >= dtls1_min_mtu());
 }
 
+/* dtls1_max_record_size returns the maximum record body length that may be
+ * written without exceeding the MTU. It accounts for any buffering installed on
+ * the write BIO. If no record may be written, it returns zero. */
+static size_t dtls1_max_record_size(SSL *ssl) {
+  size_t ret = ssl->d1->mtu;
+
+  size_t overhead = ssl_max_seal_overhead(ssl);
+  if (ret <= overhead) {
+    return 0;
+  }
+  ret -= overhead;
+
+  size_t pending = BIO_wpending(SSL_get_wbio(ssl));
+  if (ret <= pending) {
+    return 0;
+  }
+  ret -= pending;
+
+  return ret;
+}
+
 static int dtls1_write_change_cipher_spec(SSL *ssl,
                                           enum dtls1_use_epoch_t use_epoch) {
   dtls1_update_mtu(ssl);
 
   /* During the handshake, wbio is buffered to pack messages together. Flush the
    * buffer if the ChangeCipherSpec would not fit in a packet. */
-  if (BIO_wpending(SSL_get_wbio(ssl)) + ssl_max_seal_overhead(ssl) + 1 >
-      ssl->d1->mtu) {
+  if (dtls1_max_record_size(ssl) == 0) {
     ssl->rwstate = SSL_WRITING;
     int ret = BIO_flush(SSL_get_wbio(ssl));
     if (ret <= 0) {
@@ -300,103 +316,96 @@
   return 1;
 }
 
-int dtls1_do_handshake_write(SSL *s, enum dtls1_use_epoch_t use_epoch) {
-  int ret;
-  int curr_mtu;
-  unsigned int len, frag_off;
+int dtls1_do_handshake_write(SSL *ssl, enum dtls1_use_epoch_t use_epoch) {
+  dtls1_update_mtu(ssl);
 
-  dtls1_update_mtu(s);
-
-  if (s->init_off == 0) {
-    assert(s->init_num ==
-           (int)s->d1->w_msg_hdr.msg_len + DTLS1_HM_HEADER_LENGTH);
+  int ret = -1;
+  CBB cbb;
+  CBB_zero(&cbb);
+  /* Allocate a temporary buffer to hold the message fragments to avoid
+   * clobbering the message. */
+  uint8_t *buf = OPENSSL_malloc(ssl->d1->mtu);
+  if (buf == NULL) {
+    goto err;
   }
 
-  /* Determine the maximum overhead of the current cipher. */
-  size_t max_overhead = SSL_AEAD_CTX_max_overhead(s->aead_write_ctx);
+  /* Consume the message header. Fragments will have different headers
+   * prepended. */
+  if (ssl->init_off == 0) {
+    ssl->init_off += DTLS1_HM_HEADER_LENGTH;
+    ssl->init_num -= DTLS1_HM_HEADER_LENGTH;
+  }
+  assert(ssl->init_off >= DTLS1_HM_HEADER_LENGTH);
 
-  frag_off = 0;
-  while (s->init_num) {
-    /* Account for data in the buffering BIO; multiple records may be packed
-     * into a single packet during the handshake.
-     *
-     * TODO(davidben): This is buggy; if the MTU is larger than the buffer size,
-     * the large record will be split across two packets. Moreover, in that
-     * case, the |dtls1_write_bytes| call may not return synchronously. This
-     * will break on retry as the |s->init_off| and |s->init_num| adjustment
-     * will run a second time. */
-    curr_mtu = s->d1->mtu - BIO_wpending(SSL_get_wbio(s)) -
-        DTLS1_RT_HEADER_LENGTH - max_overhead;
-
-    if (curr_mtu <= DTLS1_HM_HEADER_LENGTH) {
-      /* Flush the buffer and continue with a fresh packet.
-       *
-       * TODO(davidben): If |BIO_flush| is not synchronous and requires multiple
-       * calls to |dtls1_do_write|, |frag_off| will be wrong. */
-      ret = BIO_flush(SSL_get_wbio(s));
-      if (ret <= 0) {
-        return ret;
+  do {
+    /* During the handshake, wbio is buffered to pack messages together. Flush
+     * the buffer if there isn't enough room to make progress. */
+    if (dtls1_max_record_size(ssl) < DTLS1_HM_HEADER_LENGTH + 1) {
+      ssl->rwstate = SSL_WRITING;
+      int flush_ret = BIO_flush(SSL_get_wbio(ssl));
+      if (flush_ret <= 0) {
+        ret = flush_ret;
+        goto err;
       }
-      assert(BIO_wpending(SSL_get_wbio(s)) == 0);
-      curr_mtu = s->d1->mtu - DTLS1_RT_HEADER_LENGTH - max_overhead;
+      ssl->rwstate = SSL_NOTHING;
+      assert(BIO_wpending(SSL_get_wbio(ssl)) == 0);
     }
 
-    /* If this isn't the first fragment, reserve space to prepend a new fragment
-     * header. This will override the body of a previous fragment. */
-    if (s->init_off != 0) {
-      assert(s->init_off > DTLS1_HM_HEADER_LENGTH);
-      s->init_off -= DTLS1_HM_HEADER_LENGTH;
-      s->init_num += DTLS1_HM_HEADER_LENGTH;
-    }
-
-    if (curr_mtu <= DTLS1_HM_HEADER_LENGTH) {
+    size_t todo = dtls1_max_record_size(ssl);
+    if (todo < DTLS1_HM_HEADER_LENGTH + 1) {
       /* To make forward progress, the MTU must, at minimum, fit the handshake
        * header and one byte of handshake body. */
       OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL);
-      return -1;
+      goto err;
+    }
+    todo -= DTLS1_HM_HEADER_LENGTH;
+
+    if (todo > (size_t)ssl->init_num) {
+      todo = ssl->init_num;
+    }
+    if (todo >= (1u << 24)) {
+      todo = (1u << 24) - 1;
     }
 
-    if (s->init_num > curr_mtu) {
-      len = curr_mtu;
-    } else {
-      len = s->init_num;
-    }
-    assert(len >= DTLS1_HM_HEADER_LENGTH);
-
-    dtls1_fix_message_header(s, frag_off, len - DTLS1_HM_HEADER_LENGTH);
-    dtls1_write_message_header(
-        s, (uint8_t *)&s->init_buf->data[s->init_off]);
-
-    ret = dtls1_write_bytes(s, SSL3_RT_HANDSHAKE,
-                            &s->init_buf->data[s->init_off], len, use_epoch);
-    if (ret < 0) {
-      return -1;
+    size_t len;
+    if (!CBB_init_fixed(&cbb, buf, ssl->d1->mtu) ||
+        !CBB_add_u8(&cbb, ssl->d1->w_msg_hdr.type) ||
+        !CBB_add_u24(&cbb, ssl->d1->w_msg_hdr.msg_len) ||
+        !CBB_add_u16(&cbb, ssl->d1->w_msg_hdr.seq) ||
+        !CBB_add_u24(&cbb, ssl->init_off - DTLS1_HM_HEADER_LENGTH) ||
+        !CBB_add_u24(&cbb, todo) ||
+        !CBB_add_bytes(
+            &cbb, (const uint8_t *)ssl->init_buf->data + ssl->init_off, todo) ||
+        !CBB_finish(&cbb, NULL, &len)) {
+      OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+      goto err;
     }
 
-    /* bad if this assert fails, only part of the handshake message got sent.
-     * But why would this happen? */
-    assert(len == (unsigned int)ret);
-
-    if (ret == s->init_num) {
-      if (s->msg_callback) {
-        /* TODO(davidben): At this point, |s->init_buf->data| has been clobbered
-         * already. */
-        s->msg_callback(1, s->version, SSL3_RT_HANDSHAKE, s->init_buf->data,
-                        (size_t)(s->init_off + s->init_num), s,
-                        s->msg_callback_arg);
-      }
-
-      s->init_off = 0; /* done writing this message */
-      s->init_num = 0;
-
-      return 1;
+    int write_ret = dtls1_write_bytes(ssl, SSL3_RT_HANDSHAKE, buf, len,
+                                      use_epoch);
+    if (write_ret <= 0) {
+      ret = write_ret;
+      goto err;
     }
-    s->init_off += ret;
-    s->init_num -= ret;
-    frag_off += (ret -= DTLS1_HM_HEADER_LENGTH);
+    ssl->init_off += todo;
+    ssl->init_num -= todo;
+  } while (ssl->init_num > 0);
+
+  if (ssl->msg_callback != NULL) {
+    ssl->msg_callback(
+        1 /* write */, ssl->version, SSL3_RT_HANDSHAKE, ssl->init_buf->data,
+        (size_t)(ssl->init_off + ssl->init_num), ssl, ssl->msg_callback_arg);
   }
 
-  return 0;
+  ssl->init_off = 0;
+  ssl->init_num = 0;
+
+  ret = 1;
+
+err:
+  CBB_cleanup(&cbb);
+  OPENSSL_free(buf);
+  return ret;
 }
 
 /* dtls1_is_next_message_complete returns one if the next handshake message is
@@ -855,27 +864,6 @@
   msg_hdr->frag_len = frag_len;
 }
 
-static void dtls1_fix_message_header(SSL *s, unsigned long frag_off,
-                                     unsigned long frag_len) {
-  struct hm_header_st *msg_hdr = &s->d1->w_msg_hdr;
-
-  msg_hdr->frag_off = frag_off;
-  msg_hdr->frag_len = frag_len;
-}
-
-static uint8_t *dtls1_write_message_header(SSL *s, uint8_t *p) {
-  struct hm_header_st *msg_hdr = &s->d1->w_msg_hdr;
-
-  *p++ = msg_hdr->type;
-  l2n3(msg_hdr->msg_len, p);
-
-  s2n(msg_hdr->seq, p);
-  l2n3(msg_hdr->frag_off, p);
-  l2n3(msg_hdr->frag_len, p);
-
-  return p;
-}
-
 unsigned int dtls1_min_mtu(void) {
   return kMinMTU;
 }
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index ad8e12a..158f082 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -2932,27 +2932,25 @@
 		})
 	}
 
-	var suffix string
-	var flags []string
-	var maxHandshakeRecordLength int
-	if protocol == dtls {
-		suffix = "-DTLS"
-	}
-	if async {
-		suffix += "-Async"
-		flags = append(flags, "-async")
-	} else {
-		suffix += "-Sync"
-	}
-	if splitHandshake {
-		suffix += "-SplitHandshakeRecords"
-		maxHandshakeRecordLength = 1
-	}
 	for _, test := range tests {
 		test.protocol = protocol
-		test.name += suffix
-		test.config.Bugs.MaxHandshakeRecordLength = maxHandshakeRecordLength
-		test.flags = append(test.flags, flags...)
+		if protocol == dtls {
+			test.name += "-DTLS"
+		}
+		if async {
+			test.name += "-Async"
+			test.flags = append(test.flags, "-async")
+		} else {
+			test.name += "-Sync"
+		}
+		if splitHandshake {
+			test.name += "-SplitHandshakeRecords"
+			test.config.Bugs.MaxHandshakeRecordLength = 1
+			if protocol == dtls {
+				test.config.Bugs.MaxPacketLength = 256
+				test.flags = append(test.flags, "-mtu", "256")
+			}
+		}
 		testCases = append(testCases, test)
 	}
 }