Pass a dtls1_use_epoch enum down to dtls1_seal_record.

This is considerably less scary than swapping out connection state. It also
fixes a minor bug where, if dtls1_do_write had an alert to dispatch and we
happened to retry during a rexmit, it would use the wrong epoch.

BUG=468889

Change-Id: I754b0d46bfd02f797f4c3f7cfde28d3e5f30c52b
Reviewed-on: https://boringssl-review.googlesource.com/4793
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/d1_both.c b/ssl/d1_both.c
index f6442fd..3eb26c5 100644
--- a/ssl/d1_both.c
+++ b/ssl/d1_both.c
@@ -259,7 +259,7 @@
 
 /* send s->init_buf in records of type 'type' (SSL3_RT_HANDSHAKE or
  * SSL3_RT_CHANGE_CIPHER_SPEC) */
-int dtls1_do_write(SSL *s, int type) {
+int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch) {
   int ret;
   int curr_mtu;
   unsigned int len, frag_off;
@@ -350,7 +350,8 @@
       len = s->init_num;
     }
 
-    ret = dtls1_write_bytes(s, type, &s->init_buf->data[s->init_off], len);
+    ret = dtls1_write_bytes(s, type, &s->init_buf->data[s->init_off], len,
+                            use_epoch);
     if (ret < 0) {
       return -1;
     }
@@ -684,7 +685,7 @@
   }
 
   /* SSL3_ST_CW_CHANGE_B */
-  return dtls1_do_write(s, SSL3_RT_CHANGE_CIPHER_SPEC);
+  return dtls1_do_write(s, SSL3_RT_CHANGE_CIPHER_SPEC, dtls1_use_current_epoch);
 }
 
 int dtls1_read_failed(SSL *s, int code) {
@@ -724,7 +725,6 @@
   int ret;
   /* XDTLS: for now assuming that read/writes are blocking */
   unsigned long header_length;
-  uint8_t save_write_sequence[8];
 
   /* assert(s->init_num == 0);
      assert(s->init_off == 0); */
@@ -743,45 +743,18 @@
                            frag->msg_header.msg_len, frag->msg_header.seq,
                            0, frag->msg_header.frag_len);
 
-  /* Save current state. */
-  SSL_AEAD_CTX *aead_write_ctx = s->aead_write_ctx;
-  uint16_t epoch = s->d1->w_epoch;
-
   /* DTLS renegotiation is unsupported, so only epochs 0 (NULL cipher) and 1
    * (negotiated cipher) exist. */
-  assert(epoch == 0 || epoch == 1);
-  assert(frag->msg_header.epoch <= epoch);
-  const int fragment_from_previous_epoch = (epoch == 1 &&
-                                            frag->msg_header.epoch == 0);
-  if (fragment_from_previous_epoch) {
-    /* Rewind to the previous epoch.
-     *
-     * TODO(davidben): Instead of swapping out connection-global state, this
-     * logic should pass a "use previous epoch" parameter down to lower-level
-     * functions. */
-    s->d1->w_epoch = frag->msg_header.epoch;
-    s->aead_write_ctx = NULL;
-    memcpy(save_write_sequence, s->s3->write_sequence,
-           sizeof(s->s3->write_sequence));
-    memcpy(s->s3->write_sequence, s->d1->last_write_sequence,
-           sizeof(s->s3->write_sequence));
-  } else {
-    /* Otherwise the messages must be from the same epoch. */
-    assert(frag->msg_header.epoch == epoch);
+  assert(s->d1->w_epoch == 0 || s->d1->w_epoch == 1);
+  assert(frag->msg_header.epoch <= s->d1->w_epoch);
+  enum dtls1_use_epoch_t use_epoch = dtls1_use_current_epoch;
+  if (s->d1->w_epoch == 1 && frag->msg_header.epoch == 0) {
+    use_epoch = dtls1_use_previous_epoch;
   }
 
   ret = dtls1_do_write(s, frag->msg_header.is_ccs ? SSL3_RT_CHANGE_CIPHER_SPEC
-                                                  : SSL3_RT_HANDSHAKE);
-
-  if (fragment_from_previous_epoch) {
-    /* Restore the current epoch. */
-    s->aead_write_ctx = aead_write_ctx;
-    s->d1->w_epoch = epoch;
-    memcpy(s->d1->last_write_sequence, s->s3->write_sequence,
-           sizeof(s->s3->write_sequence));
-    memcpy(s->s3->write_sequence, save_write_sequence,
-           sizeof(s->s3->write_sequence));
-  }
+                                                  : SSL3_RT_HANDSHAKE,
+                       use_epoch);
 
   (void)BIO_flush(SSL_get_wbio(s));
   return ret;
diff --git a/ssl/d1_lib.c b/ssl/d1_lib.c
index e53156f..473588d 100644
--- a/ssl/d1_lib.c
+++ b/ssl/d1_lib.c
@@ -338,5 +338,5 @@
 }
 
 int dtls1_handshake_write(SSL *s) {
-  return dtls1_do_write(s, SSL3_RT_HANDSHAKE);
+  return dtls1_do_write(s, SSL3_RT_HANDSHAKE, dtls1_use_current_epoch);
 }
diff --git a/ssl/d1_pkt.c b/ssl/d1_pkt.c
index e80e773..b6570ee 100644
--- a/ssl/d1_pkt.c
+++ b/ssl/d1_pkt.c
@@ -185,7 +185,7 @@
 static void dtls1_record_bitmap_update(SSL *s, DTLS1_BITMAP *bitmap);
 static int dtls1_process_record(SSL *s);
 static int do_dtls1_write(SSL *s, int type, const uint8_t *buf,
-                          unsigned int len);
+                          unsigned int len, enum dtls1_use_epoch_t use_epoch);
 
 static int dtls1_process_record(SSL *s) {
   int al;
@@ -695,18 +695,19 @@
     return -1;
   }
 
-  i = dtls1_write_bytes(s, type, buf_, len);
+  i = dtls1_write_bytes(s, type, buf_, len, dtls1_use_current_epoch);
   return i;
 }
 
 /* Call this to write data in records of type 'type' It will return <= 0 if not
  * all data has been sent or non-blocking IO. */
-int dtls1_write_bytes(SSL *s, int type, const void *buf, int len) {
+int dtls1_write_bytes(SSL *s, int type, const void *buf, int len,
+                      enum dtls1_use_epoch_t use_epoch) {
   int i;
 
   assert(len <= SSL3_RT_MAX_PLAIN_LENGTH);
   s->rwstate = SSL_NOTHING;
-  i = do_dtls1_write(s, type, buf, len);
+  i = do_dtls1_write(s, type, buf, len, use_epoch);
   return i;
 }
 
@@ -716,27 +717,40 @@
  * number. */
 static int dtls1_seal_record(SSL *s, uint8_t *out, size_t *out_len,
                              size_t max_out, uint8_t type, const uint8_t *in,
-                             size_t in_len) {
+                             size_t in_len, enum dtls1_use_epoch_t use_epoch) {
   if (max_out < DTLS1_RT_HEADER_LENGTH) {
     OPENSSL_PUT_ERROR(SSL, dtls1_seal_record, SSL_R_BUFFER_TOO_SMALL);
     return 0;
   }
 
+  /* Determine the parameters for the current epoch. */
+  uint16_t epoch = s->d1->w_epoch;
+  SSL_AEAD_CTX *aead = s->aead_write_ctx;
+  uint8_t *seq = s->s3->write_sequence;
+  if (use_epoch == dtls1_use_previous_epoch) {
+    /* DTLS renegotiation is unsupported, so only epochs 0 (NULL cipher) and 1
+     * (negotiated cipher) exist. */
+    assert(s->d1->w_epoch == 1);
+    epoch = s->d1->w_epoch - 1;
+    aead = NULL;
+    seq = s->d1->last_write_sequence;
+  }
+
   out[0] = type;
 
   uint16_t wire_version = s->s3->have_version ? s->version : DTLS1_VERSION;
   out[1] = wire_version >> 8;
   out[2] = wire_version & 0xff;
 
-  out[3] = s->d1->w_epoch >> 8;
-  out[4] = s->d1->w_epoch & 0xff;
-  memcpy(&out[5], &s->s3->write_sequence[2], 6);
+  out[3] = epoch >> 8;
+  out[4] = epoch & 0xff;
+  memcpy(&out[5], &seq[2], 6);
 
   size_t ciphertext_len;
-  if (!SSL_AEAD_CTX_seal(s->aead_write_ctx, out + DTLS1_RT_HEADER_LENGTH,
-                         &ciphertext_len, max_out - DTLS1_RT_HEADER_LENGTH,
-                         type, wire_version, &out[3] /* seq */, in, in_len) ||
-      !ssl3_record_sequence_update(&s->s3->write_sequence[2], 6)) {
+  if (!SSL_AEAD_CTX_seal(aead, out + DTLS1_RT_HEADER_LENGTH, &ciphertext_len,
+                         max_out - DTLS1_RT_HEADER_LENGTH, type, wire_version,
+                         &out[3] /* seq */, in, in_len) ||
+      !ssl3_record_sequence_update(&seq[2], 6)) {
     return 0;
   }
 
@@ -758,7 +772,7 @@
 }
 
 static int do_dtls1_write(SSL *s, int type, const uint8_t *buf,
-                          unsigned int len) {
+                          unsigned int len, enum dtls1_use_epoch_t use_epoch) {
   SSL3_BUFFER *wb = &s->s3->wbuf;
 
   /* ssl3_write_pending drops the write if |BIO_write| fails in DTLS, so there
@@ -790,7 +804,8 @@
   size_t max_out = wb->len - wb->offset;
 
   size_t ciphertext_len;
-  if (!dtls1_seal_record(s, out, &ciphertext_len, max_out, type, buf, len)) {
+  if (!dtls1_seal_record(s, out, &ciphertext_len, max_out, type, buf, len,
+                         use_epoch)) {
     return -1;
   }
 
@@ -863,7 +878,8 @@
   *ptr++ = s->s3->send_alert[0];
   *ptr++ = s->s3->send_alert[1];
 
-  i = do_dtls1_write(s, SSL3_RT_ALERT, &buf[0], sizeof(buf));
+  i = do_dtls1_write(s, SSL3_RT_ALERT, &buf[0], sizeof(buf),
+                     dtls1_use_current_epoch);
   if (i <= 0) {
     s->s3->alert_dispatch = 1;
   } else {
diff --git a/ssl/internal.h b/ssl/internal.h
index c4e931d..08a74dc 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -940,7 +940,12 @@
 int ssl3_set_handshake_header(SSL *s, int htype, unsigned long len);
 int ssl3_handshake_write(SSL *s);
 
-int dtls1_do_write(SSL *s, int type);
+enum dtls1_use_epoch_t {
+  dtls1_use_previous_epoch,
+  dtls1_use_current_epoch,
+};
+
+int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch);
 int ssl3_read_n(SSL *s, int n, int extend);
 int dtls1_read_bytes(SSL *s, int type, uint8_t *buf, int len, int peek);
 int ssl3_write_pending(SSL *s, int type, const uint8_t *buf, unsigned int len);
@@ -949,7 +954,8 @@
                               unsigned long frag_len);
 
 int dtls1_write_app_data_bytes(SSL *s, int type, const void *buf, int len);
-int dtls1_write_bytes(SSL *s, int type, const void *buf, int len);
+int dtls1_write_bytes(SSL *s, int type, const void *buf, int len,
+                      enum dtls1_use_epoch_t use_epoch);
 
 int dtls1_send_change_cipher_spec(SSL *s, int a, int b);
 int dtls1_send_finished(SSL *s, int a, int b, const char *sender, int slen);