Separate CCS and handshake writing in DTLS.

They run through completely different logic as only handshake is fragmented.
This'll make it easier to rewrite the handshake logic in a follow-up.

Change-Id: I9515feafc06bf069b261073873966e72fcbe13cb
Reviewed-on: https://boringssl-review.googlesource.com/6420
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 79d7205..465d44f 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -3750,8 +3750,15 @@
 
   BIO *rbio; /* used by SSL_read */
   BIO *wbio; /* used by SSL_write */
-  BIO *bbio; /* used during session-id reuse to concatenate
-              * messages */
+
+  /* bbio, if non-NULL, is a buffer placed in front of |wbio| to pack handshake
+   * messages within one flight into a single |BIO_write|.
+   *
+   * TODO(davidben): This does not work right for DTLS. It assumes the MTU is
+   * smaller than the buffer size so that the buffer's internal flushing never
+   * kicks in. It also doesn't kick in for DTLS retransmission. Replace this
+   * with a better mechanism. */
+  BIO *bbio;
 
   int (*handshake_func)(SSL *);
 
diff --git a/ssl/d1_both.c b/ssl/d1_both.c
index 4d550e9..2882320 100644
--- a/ssl/d1_both.c
+++ b/ssl/d1_both.c
@@ -251,29 +251,63 @@
   frag->reassembly = NULL;
 }
 
-/* send s->init_buf in records of type 'type' (SSL3_RT_HANDSHAKE or
- * SSL3_RT_CHANGE_CIPHER_SPEC) */
-int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch) {
+static void dtls1_update_mtu(SSL *ssl) {
+  /* TODO(davidben): What is this code doing and do we need it? */
+  if (ssl->d1->mtu < dtls1_min_mtu() &&
+      !(SSL_get_options(ssl) & SSL_OP_NO_QUERY_MTU)) {
+    long mtu = BIO_ctrl(SSL_get_wbio(ssl), BIO_CTRL_DGRAM_QUERY_MTU, 0, NULL);
+    if (mtu >= 0 && mtu <= (1 << 30) && (unsigned)mtu >= dtls1_min_mtu()) {
+      ssl->d1->mtu = (unsigned)mtu;
+    } else {
+      ssl->d1->mtu = kDefaultMTU;
+      BIO_ctrl(SSL_get_wbio(ssl), BIO_CTRL_DGRAM_SET_MTU, ssl->d1->mtu, NULL);
+    }
+  }
+
+  /* The MTU should be above the minimum now. */
+  assert(ssl->d1->mtu >= dtls1_min_mtu());
+}
+
+static int dtls1_write_change_cipher_spec(SSL *ssl,
+                                          enum dtls1_use_epoch_t use_epoch) {
+  dtls1_update_mtu(ssl);
+
+  /* During the handshake, wbio is buffered to pack messages together. Flush the
+   * buffer if the ChangeCipherSpec would not fit in a packet. */
+  if (BIO_wpending(SSL_get_wbio(ssl)) + ssl_max_seal_overhead(ssl) + 1 >
+      ssl->d1->mtu) {
+    ssl->rwstate = SSL_WRITING;
+    int ret = BIO_flush(SSL_get_wbio(ssl));
+    if (ret <= 0) {
+      return ret;
+    }
+    ssl->rwstate = SSL_NOTHING;
+  }
+
+  static const uint8_t kChangeCipherSpec[1] = {SSL3_MT_CCS};
+  int ret = dtls1_write_bytes(ssl, SSL3_RT_CHANGE_CIPHER_SPEC, kChangeCipherSpec,
+                              sizeof(kChangeCipherSpec), use_epoch);
+  if (ret <= 0) {
+    return ret;
+  }
+
+  if (ssl->msg_callback != NULL) {
+    ssl->msg_callback(1 /* write */, ssl->version, SSL3_RT_CHANGE_CIPHER_SPEC,
+                      kChangeCipherSpec, sizeof(kChangeCipherSpec), ssl,
+                      ssl->msg_callback_arg);
+  }
+
+  return 1;
+}
+
+int dtls1_do_handshake_write(SSL *s, enum dtls1_use_epoch_t use_epoch) {
   int ret;
   int curr_mtu;
   unsigned int len, frag_off;
 
-  /* AHA!  Figure out the MTU, and stick to the right size */
-  if (s->d1->mtu < dtls1_min_mtu() &&
-      !(SSL_get_options(s) & SSL_OP_NO_QUERY_MTU)) {
-    long mtu = BIO_ctrl(SSL_get_wbio(s), BIO_CTRL_DGRAM_QUERY_MTU, 0, NULL);
-    if (mtu >= 0 && mtu <= (1 << 30) && (unsigned)mtu >= dtls1_min_mtu()) {
-      s->d1->mtu = (unsigned)mtu;
-    } else {
-      s->d1->mtu = kDefaultMTU;
-      BIO_ctrl(SSL_get_wbio(s), BIO_CTRL_DGRAM_SET_MTU, s->d1->mtu, NULL);
-    }
-  }
+  dtls1_update_mtu(s);
 
-  /* should have something reasonable now */
-  assert(s->d1->mtu >= dtls1_min_mtu());
-
-  if (s->init_off == 0 && type == SSL3_RT_HANDSHAKE) {
+  if (s->init_off == 0) {
     assert(s->init_num ==
            (int)s->d1->w_msg_hdr.msg_len + DTLS1_HM_HEADER_LENGTH);
   }
@@ -307,45 +341,34 @@
       curr_mtu = s->d1->mtu - DTLS1_RT_HEADER_LENGTH - max_overhead;
     }
 
-    /* XDTLS: this function is too long.  split out the CCS part */
-    if (type == SSL3_RT_HANDSHAKE) {
-      /* If this isn't the first fragment, reserve space to prepend a new
-       * fragment header. This will override the body of a previous fragment. */
-      if (s->init_off != 0) {
-        assert(s->init_off > DTLS1_HM_HEADER_LENGTH);
-        s->init_off -= DTLS1_HM_HEADER_LENGTH;
-        s->init_num += DTLS1_HM_HEADER_LENGTH;
-      }
-
-      if (curr_mtu <= DTLS1_HM_HEADER_LENGTH) {
-        /* To make forward progress, the MTU must, at minimum, fit the handshake
-         * header and one byte of handshake body. */
-        OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL);
-        return -1;
-      }
-
-      if (s->init_num > curr_mtu) {
-        len = curr_mtu;
-      } else {
-        len = s->init_num;
-      }
-      assert(len >= DTLS1_HM_HEADER_LENGTH);
-
-      dtls1_fix_message_header(s, frag_off, len - DTLS1_HM_HEADER_LENGTH);
-      dtls1_write_message_header(
-          s, (uint8_t *)&s->init_buf->data[s->init_off]);
-    } else {
-      assert(type == SSL3_RT_CHANGE_CIPHER_SPEC);
-      /* ChangeCipherSpec cannot be fragmented. */
-      if (s->init_num > curr_mtu) {
-        OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL);
-        return -1;
-      }
-      len = s->init_num;
+    /* If this isn't the first fragment, reserve space to prepend a new fragment
+     * header. This will override the body of a previous fragment. */
+    if (s->init_off != 0) {
+      assert(s->init_off > DTLS1_HM_HEADER_LENGTH);
+      s->init_off -= DTLS1_HM_HEADER_LENGTH;
+      s->init_num += DTLS1_HM_HEADER_LENGTH;
     }
 
-    ret = dtls1_write_bytes(s, type, &s->init_buf->data[s->init_off], len,
-                            use_epoch);
+    if (curr_mtu <= DTLS1_HM_HEADER_LENGTH) {
+      /* To make forward progress, the MTU must, at minimum, fit the handshake
+       * header and one byte of handshake body. */
+      OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL);
+      return -1;
+    }
+
+    if (s->init_num > curr_mtu) {
+      len = curr_mtu;
+    } else {
+      len = s->init_num;
+    }
+    assert(len >= DTLS1_HM_HEADER_LENGTH);
+
+    dtls1_fix_message_header(s, frag_off, len - DTLS1_HM_HEADER_LENGTH);
+    dtls1_write_message_header(
+        s, (uint8_t *)&s->init_buf->data[s->init_off]);
+
+    ret = dtls1_write_bytes(s, SSL3_RT_HANDSHAKE,
+                            &s->init_buf->data[s->init_off], len, use_epoch);
     if (ret < 0) {
       return -1;
     }
@@ -356,7 +379,9 @@
 
     if (ret == s->init_num) {
       if (s->msg_callback) {
-        s->msg_callback(1, s->version, type, s->init_buf->data,
+        /* TODO(davidben): At this point, |s->init_buf->data| has been clobbered
+         * already. */
+        s->msg_callback(1, s->version, SSL3_RT_HANDSHAKE, s->init_buf->data,
                         (size_t)(s->init_off + s->init_num), s,
                         s->msg_callback_arg);
       }
@@ -644,37 +669,6 @@
   return -1;
 }
 
-/* for these 2 messages, we need to
- * ssl->enc_read_ctx			re-init
- * ssl->s3->read_sequence		zero
- * ssl->s3->read_mac_secret		re-init
- * ssl->session->read_sym_enc		assign
- * ssl->session->read_compression	assign
- * ssl->session->read_hash		assign */
-int dtls1_send_change_cipher_spec(SSL *s, int a, int b) {
-  uint8_t *p;
-
-  if (s->state == a) {
-    p = (uint8_t *)s->init_buf->data;
-    *p++ = SSL3_MT_CCS;
-    s->d1->handshake_write_seq = s->d1->next_handshake_write_seq;
-    s->init_num = DTLS1_CCS_HEADER_LENGTH;
-
-    s->init_off = 0;
-
-    dtls1_set_message_header(s, SSL3_MT_CCS, 0, s->d1->handshake_write_seq, 0,
-                             0);
-
-    /* buffer the message to handle re-xmits */
-    dtls1_buffer_message(s, 1);
-
-    s->state = b;
-  }
-
-  /* SSL3_ST_CW_CHANGE_B */
-  return dtls1_do_write(s, SSL3_RT_CHANGE_CIPHER_SPEC, dtls1_use_current_epoch);
-}
-
 int dtls1_read_failed(SSL *s, int code) {
   if (code > 0) {
     assert(0);
@@ -696,7 +690,9 @@
   return DTLSv1_handle_timeout(s);
 }
 
-int dtls1_get_queue_priority(unsigned short seq, int is_ccs) {
+static uint16_t dtls1_get_queue_priority(uint16_t seq, int is_ccs) {
+  assert(seq * 2 >= seq);
+
   /* The index of the retransmission queue actually is the message sequence
    * number, since the queue only contains messages of a single handshake.
    * However, the ChangeCipherSpec has no message sequence number and so using
@@ -709,27 +705,6 @@
 }
 
 static int dtls1_retransmit_message(SSL *s, hm_fragment *frag) {
-  int ret;
-  /* XDTLS: for now assuming that read/writes are blocking */
-  unsigned long header_length;
-
-  /* assert(s->init_num == 0);
-     assert(s->init_off == 0); */
-
-  if (frag->msg_header.is_ccs) {
-    header_length = DTLS1_CCS_HEADER_LENGTH;
-  } else {
-    header_length = DTLS1_HM_HEADER_LENGTH;
-  }
-
-  memcpy(s->init_buf->data, frag->fragment,
-         frag->msg_header.msg_len + header_length);
-  s->init_num = frag->msg_header.msg_len + header_length;
-
-  dtls1_set_message_header(s, frag->msg_header.type,
-                           frag->msg_header.msg_len, frag->msg_header.seq,
-                           0, frag->msg_header.frag_len);
-
   /* DTLS renegotiation is unsupported, so only epochs 0 (NULL cipher) and 1
    * (negotiated cipher) exist. */
   assert(s->d1->w_epoch == 0 || s->d1->w_epoch == 1);
@@ -739,10 +714,24 @@
     use_epoch = dtls1_use_previous_epoch;
   }
 
-  ret = dtls1_do_write(s, frag->msg_header.is_ccs ? SSL3_RT_CHANGE_CIPHER_SPEC
-                                                  : SSL3_RT_HANDSHAKE,
-                       use_epoch);
+  /* TODO(davidben): This cannot handle non-blocking writes. */
+  int ret;
+  if (frag->msg_header.is_ccs) {
+    ret = dtls1_write_change_cipher_spec(s, use_epoch);
+  } else {
+    /* Restore the message body.
+     * TODO(davidben): Make this less stateful. */
+    memcpy(s->init_buf->data, frag->fragment,
+           frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH);
+    s->init_num = frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH;
 
+    dtls1_set_message_header(s, frag->msg_header.type,
+                             frag->msg_header.msg_len, frag->msg_header.seq,
+                             0, frag->msg_header.frag_len);
+    ret = dtls1_do_handshake_write(s, use_epoch);
+  }
+
+  /* TODO(davidben): Check return value? */
   (void)BIO_flush(SSL_get_wbio(s));
   return ret;
 }
@@ -763,46 +752,65 @@
   return 1;
 }
 
-int dtls1_buffer_message(SSL *s, int is_ccs) {
-  pitem *item;
-  hm_fragment *frag;
-  uint8_t seq64be[8];
+/* dtls1_buffer_change_cipher_spec adds a ChangeCipherSpec to the current
+ * handshake flight, ordered just before the handshake message numbered
+ * |seq|. */
+static int dtls1_buffer_change_cipher_spec(SSL *ssl, uint16_t seq) {
+  hm_fragment *frag = dtls1_hm_fragment_new(0 /* frag_len */,
+                                            0 /* no reassembly */);
+  if (frag == NULL) {
+    return 0;
+  }
+  frag->msg_header.is_ccs = 1;
+  frag->msg_header.epoch = ssl->d1->w_epoch;
 
+  uint16_t priority = dtls1_get_queue_priority(seq, 1 /* is_ccs */);
+  uint8_t seq64be[8];
+  memset(seq64be, 0, sizeof(seq64be));
+  seq64be[6] = (uint8_t)(priority >> 8);
+  seq64be[7] = (uint8_t)priority;
+
+  pitem *item = pitem_new(seq64be, frag);
+  if (item == NULL) {
+    dtls1_hm_fragment_free(frag);
+    return 0;
+  }
+
+  pqueue_insert(ssl->d1->sent_messages, item);
+  return 1;
+}
+
+int dtls1_buffer_message(SSL *s) {
   /* this function is called immediately after a message has
    * been serialized */
   assert(s->init_off == 0);
 
-  frag = dtls1_hm_fragment_new(s->init_num, 0);
+  hm_fragment *frag = dtls1_hm_fragment_new(s->init_num, 0);
   if (!frag) {
     return 0;
   }
 
   memcpy(frag->fragment, s->init_buf->data, s->init_num);
 
-  if (is_ccs) {
-    assert(s->d1->w_msg_hdr.msg_len + DTLS1_CCS_HEADER_LENGTH ==
-           (unsigned int)s->init_num);
-  } else {
-    assert(s->d1->w_msg_hdr.msg_len + DTLS1_HM_HEADER_LENGTH ==
-           (unsigned int)s->init_num);
-  }
+  assert(s->d1->w_msg_hdr.msg_len + DTLS1_HM_HEADER_LENGTH ==
+         (unsigned int)s->init_num);
 
   frag->msg_header.msg_len = s->d1->w_msg_hdr.msg_len;
   frag->msg_header.seq = s->d1->w_msg_hdr.seq;
   frag->msg_header.type = s->d1->w_msg_hdr.type;
   frag->msg_header.frag_off = 0;
   frag->msg_header.frag_len = s->d1->w_msg_hdr.msg_len;
-  frag->msg_header.is_ccs = is_ccs;
+  frag->msg_header.is_ccs = 0;
   frag->msg_header.epoch = s->d1->w_epoch;
 
+  uint16_t priority = dtls1_get_queue_priority(frag->msg_header.seq,
+                                               0 /* handshake */);
+  uint8_t seq64be[8];
   memset(seq64be, 0, sizeof(seq64be));
-  seq64be[6] = (uint8_t)(
-      dtls1_get_queue_priority(frag->msg_header.seq, frag->msg_header.is_ccs) >>
-      8);
-  seq64be[7] = (uint8_t)(
-      dtls1_get_queue_priority(frag->msg_header.seq, frag->msg_header.is_ccs));
+  seq64be[6] = (uint8_t)(priority >> 8);
+  seq64be[7] = (uint8_t)priority;
 
-  item = pitem_new(seq64be, frag);
+  pitem *item = pitem_new(seq64be, frag);
   if (item == NULL) {
     dtls1_hm_fragment_free(frag);
     return 0;
@@ -812,6 +820,17 @@
   return 1;
 }
 
+int dtls1_send_change_cipher_spec(SSL *s, int a, int b) {
+  if (s->state == a) {
+    /* Buffer the message to handle retransmits. */
+    s->d1->handshake_write_seq = s->d1->next_handshake_write_seq;
+    dtls1_buffer_change_cipher_spec(s, s->d1->handshake_write_seq);
+    s->state = b;
+  }
+
+  return dtls1_write_change_cipher_spec(s, dtls1_use_current_epoch);
+}
+
 /* call this function when the buffered messages are no longer needed */
 void dtls1_clear_record_buffer(SSL *s) {
   pitem *item;
diff --git a/ssl/d1_lib.c b/ssl/d1_lib.c
index cb95585..787ad9a 100644
--- a/ssl/d1_lib.c
+++ b/ssl/d1_lib.c
@@ -321,7 +321,7 @@
   s->init_off = 0;
 
   /* Buffer the message to handle re-xmits */
-  dtls1_buffer_message(s, 0);
+  dtls1_buffer_message(s);
 
   /* Add the new message to the handshake hash. Serialize the message
    * header as if it were a single fragment. */
@@ -336,5 +336,5 @@
 }
 
 int dtls1_handshake_write(SSL *s) {
-  return dtls1_do_write(s, SSL3_RT_HANDSHAKE, dtls1_use_current_epoch);
+  return dtls1_do_handshake_write(s, dtls1_use_current_epoch);
 }
diff --git a/ssl/internal.h b/ssl/internal.h
index 842cf3f..cdf8592 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -912,7 +912,9 @@
   /* records being received in the current epoch */
   DTLS1_BITMAP bitmap;
 
-  /* handshake message numbers */
+  /* handshake message numbers.
+   * TODO(davidben): It doesn't make much sense to store both of these. Only
+   * store one. */
   uint16_t handshake_write_seq;
   uint16_t next_handshake_write_seq;
 
@@ -1075,7 +1077,7 @@
 int ssl3_set_handshake_header(SSL *s, int htype, unsigned long len);
 int ssl3_handshake_write(SSL *s);
 
-int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch);
+int dtls1_do_handshake_write(SSL *s, enum dtls1_use_epoch_t use_epoch);
 int dtls1_read_app_data(SSL *ssl, uint8_t *buf, int len, int peek);
 void dtls1_read_close_notify(SSL *ssl);
 int dtls1_read_bytes(SSL *s, int type, uint8_t *buf, int len, int peek);
@@ -1090,8 +1092,7 @@
 int dtls1_send_change_cipher_spec(SSL *s, int a, int b);
 int dtls1_send_finished(SSL *s, int a, int b, const char *sender, int slen);
 int dtls1_read_failed(SSL *s, int code);
-int dtls1_buffer_message(SSL *s, int ccs);
-int dtls1_get_queue_priority(unsigned short seq, int is_ccs);
+int dtls1_buffer_message(SSL *s);
 int dtls1_retransmit_buffered_messages(SSL *s);
 void dtls1_clear_record_buffer(SSL *s);
 void dtls1_get_message_header(uint8_t *data, struct hm_header_st *msg_hdr);