Simplify DTLS epoch rewind.

SSL_AEAD_CTX ownership is currently too confusing. Instead, rely on the lack of
renego, so the previous epoch always uses the NULL cipher. (Were we to support
DTLS renego, we could keep track of s->d1->last_aead_write_ctx like
s->d1->last_write_sequence, but it isn't worth it.)

Buffered messages also tracked an old s->session, but this is unnecessary. The
s->session NULL check in tls1_enc dates to the OpenSSL initial commit and is
redundant with the aead NULL check.

Change-Id: I9a510468d95934c65bca4979094551c7536980ae
Reviewed-on: https://boringssl-review.googlesource.com/3234
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/include/openssl/dtls1.h b/include/openssl/dtls1.h
index b9395a3..c31ddfc 100644
--- a/include/openssl/dtls1.h
+++ b/include/openssl/dtls1.h
@@ -93,12 +93,6 @@
   uint8_t max_seq_num[8];
 } DTLS1_BITMAP;
 
-struct dtls1_retransmit_state {
-  SSL_AEAD_CTX *aead_write_ctx;
-  SSL_SESSION *session;
-  uint16_t epoch;
-};
-
 struct hm_header_st {
   uint8_t type;
   unsigned long msg_len;
@@ -106,7 +100,9 @@
   unsigned long frag_off;
   unsigned long frag_len;
   unsigned int is_ccs;
-  struct dtls1_retransmit_state saved_retransmit_state;
+  /* epoch, for buffered outgoing messages, is the epoch the message was
+   * originally sent in. */
+  uint16_t epoch;
 };
 
 struct ccs_header_st {
diff --git a/ssl/d1_both.c b/ssl/d1_both.c
index 25ad08b..72dadc6 100644
--- a/ssl/d1_both.c
+++ b/ssl/d1_both.c
@@ -221,16 +221,6 @@
 }
 
 void dtls1_hm_fragment_free(hm_fragment *frag) {
-  if (frag->msg_header.is_ccs) {
-    /* TODO(davidben): Simplify aead_write_ctx ownership, probably by just
-     * forbidding DTLS renego. */
-    SSL_AEAD_CTX *aead_write_ctx =
-        frag->msg_header.saved_retransmit_state.aead_write_ctx;
-    if (aead_write_ctx) {
-      EVP_AEAD_CTX_cleanup(&aead_write_ctx->ctx);
-      OPENSSL_free(aead_write_ctx);
-    }
-  }
   if (frag->fragment) {
     OPENSSL_free(frag->fragment);
   }
@@ -995,11 +985,7 @@
   frag->msg_header.frag_off = 0;
   frag->msg_header.frag_len = s->d1->w_msg_hdr.msg_len;
   frag->msg_header.is_ccs = is_ccs;
-
-  /* save current state*/
-  frag->msg_header.saved_retransmit_state.aead_write_ctx = s->aead_write_ctx;
-  frag->msg_header.saved_retransmit_state.session = s->session;
-  frag->msg_header.saved_retransmit_state.epoch = s->d1->w_epoch;
+  frag->msg_header.epoch = s->d1->w_epoch;
 
   memset(seq64be, 0, sizeof(seq64be));
   seq64be[6] = (uint8_t)(
@@ -1026,7 +1012,6 @@
   hm_fragment *frag;
   unsigned long header_length;
   uint8_t seq64be[8];
-  struct dtls1_retransmit_state saved_state;
   uint8_t save_write_sequence[8];
 
   /* assert(s->init_num == 0);
@@ -1061,32 +1046,40 @@
                            frag->msg_header.msg_len, frag->msg_header.seq,
                            0, frag->msg_header.frag_len);
 
-  /* save current state */
-  saved_state.aead_write_ctx = s->aead_write_ctx;
-  saved_state.session = s->session;
-  saved_state.epoch = s->d1->w_epoch;
+  /* Save current state. */
+  SSL_AEAD_CTX *aead_write_ctx = s->aead_write_ctx;
+  uint16_t epoch = s->d1->w_epoch;
 
-  /* restore state in which the message was originally sent */
-  s->aead_write_ctx = frag->msg_header.saved_retransmit_state.aead_write_ctx;
-  s->session = frag->msg_header.saved_retransmit_state.session;
-  s->d1->w_epoch = frag->msg_header.saved_retransmit_state.epoch;
-
-  if (frag->msg_header.saved_retransmit_state.epoch == saved_state.epoch - 1) {
+  /* DTLS renegotiation is unsupported, so only epochs 0 (NULL cipher) and 1
+   * (negotiated cipher) exist. */
+  assert(epoch == 0 || epoch == 1);
+  assert(frag->msg_header.epoch <= epoch);
+  const int fragment_from_previous_epoch = (epoch == 1 &&
+                                            frag->msg_header.epoch == 0);
+  if (fragment_from_previous_epoch) {
+    /* Rewind to the previous epoch.
+     *
+     * TODO(davidben): Instead of swapping out connection-global state, this
+     * logic should pass a "use previous epoch" parameter down to lower-level
+     * functions. */
+    s->d1->w_epoch = frag->msg_header.epoch;
+    s->aead_write_ctx = NULL;
     memcpy(save_write_sequence, s->s3->write_sequence,
            sizeof(s->s3->write_sequence));
     memcpy(s->s3->write_sequence, s->d1->last_write_sequence,
            sizeof(s->s3->write_sequence));
+  } else {
+    /* Otherwise the messages must be from the same epoch. */
+    assert(frag->msg_header.epoch == epoch);
   }
 
   ret = dtls1_do_write(s, frag->msg_header.is_ccs ? SSL3_RT_CHANGE_CIPHER_SPEC
                                                   : SSL3_RT_HANDSHAKE);
 
-  /* restore current state */
-  s->aead_write_ctx = saved_state.aead_write_ctx;
-  s->session = saved_state.session;
-  s->d1->w_epoch = saved_state.epoch;
-
-  if (frag->msg_header.saved_retransmit_state.epoch == saved_state.epoch - 1) {
+  if (fragment_from_previous_epoch) {
+    /* Restore the current epoch. */
+    s->aead_write_ctx = aead_write_ctx;
+    s->d1->w_epoch = epoch;
     memcpy(s->d1->last_write_sequence, s->s3->write_sequence,
            sizeof(s->s3->write_sequence));
     memcpy(s->s3->write_sequence, save_write_sequence,
diff --git a/ssl/t1_enc.c b/ssl/t1_enc.c
index 014bc88..fd0223d 100644
--- a/ssl/t1_enc.c
+++ b/ssl/t1_enc.c
@@ -340,14 +340,12 @@
     }
     aead_ctx = s->aead_read_ctx;
   } else {
-    /* When updating the cipher state for DTLS, we do not wish to overwrite the
-     * old ones because DTLS stores pointers to them in order to implement
-     * retransmission. See dtls1_hm_fragment_free.
-     *
-     * TODO(davidben): Simplify aead_write_ctx ownership, probably by just
-     * forbidding DTLS renego. */
-    if (SSL_IS_DTLS(s)) {
-      s->aead_write_ctx = NULL;
+    if (SSL_IS_DTLS(s) && s->aead_write_ctx != NULL) {
+      /* DTLS renegotiation is unsupported, so a CCS can only switch away from
+       * the NULL cipher. This simplifies renegotiation. */
+      OPENSSL_PUT_ERROR(SSL, tls1_change_cipher_state_aead,
+                        ERR_R_INTERNAL_ERROR);
+      return 0;
     }
     if (!tls1_aead_ctx_init(&s->aead_write_ctx)) {
       return 0;
@@ -578,7 +576,7 @@
     aead = s->aead_read_ctx;
   }
 
-  if (s->session == NULL || aead == NULL) {
+  if (aead == NULL) {
     /* Handle the initial NULL cipher. */
     memmove(rec->data, rec->input, rec->length);
     rec->input = rec->data;
diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go
index f84530f..aa8b189 100644
--- a/ssl/test/runner/dtls.go
+++ b/ssl/test/runner/dtls.go
@@ -171,7 +171,7 @@
 		return nil
 	}
 
-	var fragments []byte
+	var fragments [][]byte
 	fragments, c.pendingFragments = c.pendingFragments, fragments
 
 	if c.config.Bugs.ReorderHandshakeFragments {