Store DTLS epoch state separately

This wraps DTLS per-epoch state into DTLSReadEpoch and DTLSWriteEpoch
structures. For now, we only keep one DTLSReadEpoch though we will need
to keep two in DTLS 1.3. In preparation for that, I've reworked the DTLS
record parser to resolve the DTLSReadEpoch at the same time, even though
there's currently only one of them.

On the write side, this removes the special-cased initial write epoch
and just stores an array of the last few epochs.

Some things this does not yet do, but that we ideally would do as
follow-ups:

1. Move RecordNumberEncrypter out of SSLAEADContext, now that we have a
   DTLS-specific struct for it.

2. Pass just a byte secret into set_read_state / set_write_state and let
   the transport-aware code construct the SSLAEADContext.

3. Don't construct the SSLAEADContext at all for QUIC.

4. Don't save the read and write traffic secrets in QUIC at all.

5. KeyUpdate should be part of the SSL_PROTOCOL_METHOD interface to
   accomodate DTLS driving KeyUpdate by ACK.

Update-Note: As part of rearranging the record parser, when the DTLS 1.2
implementation encounters a DTLS 1.3 record, it will now discard just
that record and continue parsing records out of the packet, rather than
discarding the whole packet. This isn't expected to make any difference.

Bug: 371998381
Change-Id: Ie2ae657d41e33152208a001df177630398798394
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72128
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Nick Harper <nharper@chromium.org>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index 80da717..f9f3a90 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -315,7 +315,14 @@
     }
 
     // The encrypted epoch in DTLS has only one handshake message.
-    if (ssl->d1->r_epoch == 1 && msg_hdr.seq != ssl->d1->handshake_read_seq) {
+    //
+    // TODO(crbug.com/42290594): This check doesn't make any sense in DTLS 1.3,
+    // but is currently a no-op because epoch 1 is 0-RTT. Revisit this and
+    // figure out if we need to change anything. See
+    // https://boringssl-review.googlesource.com/c/boringssl/+/8988 for when
+    // this check was added.
+    if (ssl->d1->read_epoch.epoch == 1 &&
+        msg_hdr.seq != ssl->d1->handshake_read_seq) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
       *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
       return false;
@@ -361,7 +368,11 @@
   switch (type) {
     case SSL3_RT_APPLICATION_DATA:
       // Unencrypted application data records are always illegal.
-      if (ssl->s3->aead_read_ctx->is_null_cipher()) {
+      //
+      // TODO(crbug.com/42290594): Revisit both of these checks for DTLS 1.3.
+      // Many more epochs cannot have application data, and there is a key
+      // change immediately before the first application data record.
+      if (ssl->d1->read_epoch.epoch == 0) {
         OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
         *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
         return ssl_open_record_error;
@@ -374,7 +385,7 @@
     case SSL3_RT_CHANGE_CIPHER_SPEC:
       // We do not support renegotiation, so encrypted ChangeCipherSpec records
       // are illegal.
-      if (!ssl->s3->aead_read_ctx->is_null_cipher()) {
+      if (ssl->d1->read_epoch.epoch != 0) {
         OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
         *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
         return ssl_open_record_error;
@@ -387,6 +398,7 @@
       }
 
       // Flag the ChangeCipherSpec for later.
+      // TODO(crbug.com/42290594): Should we reject this in DTLS 1.3?
       ssl->d1->has_change_cipher_spec = true;
       ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_CHANGE_CIPHER_SPEC,
                           record);
@@ -497,6 +509,26 @@
   ssl->d1->outgoing_offset = 0;
   ssl->d1->outgoing_messages_complete = false;
   ssl->d1->flight_has_reply = false;
+  dtls_clear_unused_write_epochs(ssl);
+}
+
+void dtls_clear_unused_write_epochs(SSL *ssl) {
+  ssl->d1->extra_write_epochs.EraseIf(
+      [ssl](const UniquePtr<DTLSWriteEpoch> &write_epoch) -> bool {
+        // Non-current epochs may be discarded once there are no outgoing
+        // messages that reference them.
+        //
+        // TODO(crbug.com/42290594): If |msg| has been fully ACKed, its epoch
+        // may be discarded.
+        // TODO(crbug.com/42290594): Epoch 1 (0-RTT) should be retained until
+        // epoch 3 (app data) is available.
+        for (const auto &msg : ssl->d1->outgoing_messages) {
+          if (msg.epoch == write_epoch->epoch) {
+            return false;
+          }
+        }
+        return true;
+      });
 }
 
 bool dtls1_init_message(const SSL *ssl, CBB *cbb, CBB *body, uint8_t type) {
@@ -550,7 +582,7 @@
 
   DTLS_OUTGOING_MESSAGE msg;
   msg.data = std::move(data);
-  msg.epoch = ssl->d1->w_epoch;
+  msg.epoch = ssl->d1->write_epoch.epoch;
   msg.is_ccs = is_ccs;
   if (!ssl->d1->outgoing_messages.TryPushBack(std::move(msg))) {
     assert(false);
diff --git a/ssl/d1_lib.cc b/ssl/d1_lib.cc
index 4aa71ac..3204b69 100644
--- a/ssl/d1_lib.cc
+++ b/ssl/d1_lib.cc
@@ -85,23 +85,23 @@
 
 DTLS1_STATE::~DTLS1_STATE() {}
 
+bool DTLS1_STATE::Init() {
+  // Set up the initial epochs.
+  read_epoch.aead = SSLAEADContext::CreateNullCipher();
+  write_epoch.aead = SSLAEADContext::CreateNullCipher();
+  if (read_epoch.aead == nullptr || write_epoch.aead == nullptr) {
+    return false;
+  }
+
+  return true;
+}
+
 bool dtls1_new(SSL *ssl) {
   if (!tls_new(ssl)) {
     return false;
   }
   UniquePtr<DTLS1_STATE> d1 = MakeUnique<DTLS1_STATE>();
-  if (!d1) {
-    tls_free(ssl);
-    return false;
-  }
-
-  d1->initial_epoch_state = MakeUnique<DTLSEpochState>();
-  if (!d1->initial_epoch_state) {
-    tls_free(ssl);
-    return false;
-  }
-  d1->initial_epoch_state->aead_write_ctx = SSLAEADContext::CreateNullCipher();
-  if (!d1->initial_epoch_state->aead_write_ctx) {
+  if (!d1 || !d1->Init()) {
     tls_free(ssl);
     return false;
   }
diff --git a/ssl/d1_pkt.cc b/ssl/d1_pkt.cc
index 13da69a..1dc5775 100644
--- a/ssl/d1_pkt.cc
+++ b/ssl/d1_pkt.cc
@@ -215,8 +215,9 @@
     return 1;
   }
 
+  // TODO(crbug.com/42290594): Use the 0-RTT epoch if writing 0-RTT.
   int ret = dtls1_write_record(ssl, SSL3_RT_APPLICATION_DATA, in,
-                               ssl->d1->w_epoch);
+                               ssl->d1->write_epoch.epoch);
   if (ret <= 0) {
     return ret;
   }
@@ -224,11 +225,6 @@
   return 1;
 }
 
-static size_t dtls_seal_align_prefix_len(const SSL *ssl, uint16_t epoch) {
-  return dtls_record_header_write_len(ssl, epoch) +
-         ssl->s3->aead_write_ctx->ExplicitNonceLen();
-}
-
 int dtls1_write_record(SSL *ssl, int type, Span<const uint8_t> in,
                        uint16_t epoch) {
   SSLBuffer *buf = &ssl->s3->write_buffer;
@@ -244,7 +240,7 @@
   }
 
   size_t ciphertext_len;
-  if (!buf->EnsureCap(dtls_seal_align_prefix_len(ssl, epoch),
+  if (!buf->EnsureCap(dtls_seal_prefix_len(ssl, epoch),
                       in.size() + SSL_max_seal_overhead(ssl)) ||
       !dtls_seal_record(ssl, buf->remaining().data(), &ciphertext_len,
                         buf->remaining().size(), type, in.data(), in.size(),
@@ -263,7 +259,7 @@
 
 int dtls1_dispatch_alert(SSL *ssl) {
   int ret = dtls1_write_record(ssl, SSL3_RT_ALERT, ssl->s3->send_alert,
-                               ssl->d1->w_epoch);
+                               ssl->d1->write_epoch.epoch);
   if (ret <= 0) {
     return ret;
   }
diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc
index 501dd97..dff3d5c 100644
--- a/ssl/dtls_method.cc
+++ b/ssl/dtls_method.cc
@@ -79,8 +79,7 @@
 
 static bool dtls1_set_read_state(SSL *ssl, ssl_encryption_level_t level,
                                  UniquePtr<SSLAEADContext> aead_ctx,
-                                 Span<const uint8_t> secret_for_quic) {
-  assert(secret_for_quic.empty());  // QUIC does not use DTLS.
+                                 Span<const uint8_t> traffic_secret) {
   // Cipher changes are forbidden if the current epoch has leftover data.
   if (dtls_has_unprocessed_handshake_data(ssl)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESS_HANDSHAKE_DATA);
@@ -88,20 +87,21 @@
     return false;
   }
 
+  DTLSReadEpoch new_epoch;
   if (ssl_protocol_version(ssl) > TLS1_2_VERSION) {
     // TODO(crbug.com/boringssl/715): Handle the additional epochs used for key
     // update.
     // TODO(crbug.com/boringssl/715): If we want to gracefully handle packet
     // reordering around KeyUpdate (i.e. accept records from both epochs), we'll
     // need a separate bitmap for each epoch.
-    ssl->d1->r_epoch = level;
+    new_epoch.epoch = level;
   } else {
-    ssl->d1->r_epoch++;
+    new_epoch.epoch = ssl->d1->read_epoch.epoch + 1;
   }
-  ssl->d1->bitmap = DTLS1_BITMAP();
-  ssl->s3->read_sequence = 0;
+  new_epoch.bitmap = DTLSReplayBitmap();
+  new_epoch.aead = std::move(aead_ctx);
 
-  ssl->s3->aead_read_ctx = std::move(aead_ctx);
+  ssl->d1->read_epoch = std::move(new_epoch);
   ssl->s3->read_level = level;
   ssl->d1->has_change_cipher_spec = false;
   return true;
@@ -109,18 +109,25 @@
 
 static bool dtls1_set_write_state(SSL *ssl, ssl_encryption_level_t level,
                                   UniquePtr<SSLAEADContext> aead_ctx,
-                                  Span<const uint8_t> secret_for_quic) {
-  assert(secret_for_quic.empty());  // QUIC does not use DTLS.
-  ssl->d1->w_epoch++;
-  ssl->s3->write_sequence = 0;
-
+                                  Span<const uint8_t> traffic_secret) {
+  DTLSWriteEpoch new_epoch;
   if (ssl_protocol_version(ssl) > TLS1_2_VERSION) {
-    ssl->d1->w_epoch = level;
+    // TODO(crbug.com/boringssl/715): See above.
+    new_epoch.epoch = level;
+  } else {
+    new_epoch.epoch = ssl->d1->write_epoch.epoch + 1;
   }
-  ssl->d1->last_epoch_state.aead_write_ctx = std::move(ssl->s3->aead_write_ctx);
-  ssl->d1->last_epoch_state.write_sequence = ssl->s3->write_sequence;
-  ssl->s3->aead_write_ctx = std::move(aead_ctx);
+  new_epoch.aead = std::move(aead_ctx);
+
+  auto current = MakeUnique<DTLSWriteEpoch>(std::move(ssl->d1->write_epoch));
+  if (current == nullptr) {
+    return false;
+  }
+
+  ssl->d1->write_epoch = std::move(new_epoch);
+  ssl->d1->extra_write_epochs.PushBack(std::move(current));
   ssl->s3->write_level = level;
+  dtls_clear_unused_write_epochs(ssl);
   return true;
 }
 
diff --git a/ssl/dtls_record.cc b/ssl/dtls_record.cc
index 479ecd4..161530f 100644
--- a/ssl/dtls_record.cc
+++ b/ssl/dtls_record.cc
@@ -123,39 +123,33 @@
 
 BSSL_NAMESPACE_BEGIN
 
-// dtls1_bitmap_should_discard returns one if |seq_num| has been seen in
-// |bitmap| or is stale. Otherwise it returns zero.
-static bool dtls1_bitmap_should_discard(DTLS1_BITMAP *bitmap,
-                                        uint64_t seq_num) {
-  const size_t kWindowSize = bitmap->map.size();
+bool DTLSReplayBitmap::ShouldDiscard(uint64_t seq_num) const {
+  const size_t kWindowSize = map_.size();
 
-  if (seq_num > bitmap->max_seq_num) {
+  if (seq_num > max_seq_num_) {
     return false;
   }
-  uint64_t idx = bitmap->max_seq_num - seq_num;
-  return idx >= kWindowSize || bitmap->map[idx];
+  uint64_t idx = max_seq_num_ - seq_num;
+  return idx >= kWindowSize || map_[idx];
 }
 
-// dtls1_bitmap_record updates |bitmap| to record receipt of sequence number
-// |seq_num|. It slides the window forward if needed. It is an error to call
-// this function on a stale sequence number.
-static void dtls1_bitmap_record(DTLS1_BITMAP *bitmap, uint64_t seq_num) {
-  const size_t kWindowSize = bitmap->map.size();
+void DTLSReplayBitmap::Record(uint64_t seq_num) {
+  const size_t kWindowSize = map_.size();
 
   // Shift the window if necessary.
-  if (seq_num > bitmap->max_seq_num) {
-    uint64_t shift = seq_num - bitmap->max_seq_num;
+  if (seq_num > max_seq_num_) {
+    uint64_t shift = seq_num - max_seq_num_;
     if (shift >= kWindowSize) {
-      bitmap->map.reset();
+      map_.reset();
     } else {
-      bitmap->map <<= shift;
+      map_ <<= shift;
     }
-    bitmap->max_seq_num = seq_num;
+    max_seq_num_ = seq_num;
   }
 
-  uint64_t idx = bitmap->max_seq_num - seq_num;
+  uint64_t idx = max_seq_num_ - seq_num;
   if (idx < kWindowSize) {
-    bitmap->map[idx] = true;
+    map_[idx] = true;
   }
 }
 
@@ -207,109 +201,138 @@
   return seqnum;
 }
 
-static bool parse_dtls13_record_header(SSL *ssl, CBS *in, Span<uint8_t> packet,
-                                       uint8_t type, CBS *out_body,
-                                       uint64_t *out_sequence,
-                                       uint16_t *out_epoch,
-                                       size_t *out_header_len) {
-  // TODO(crbug.com/boringssl/715): Decrypt the sequence number before
-  // decoding it.
-  if ((type & 0x10) == 0x10) {
+static Span<uint8_t> cbs_to_writable_bytes(CBS cbs) {
+  return MakeSpan(const_cast<uint8_t *>(CBS_data(&cbs)), CBS_len(&cbs));
+}
+
+struct ParsedDTLSRecord {
+  // read_epoch will be null if the record is for an unrecognized epoch. In that
+  // case, sequence may be zero if we are unable to decrypt the sequence number.
+  DTLSReadEpoch *read_epoch = nullptr;
+  // sequence includes the epoch for DTLS 1.2 records and does not include it
+  // for DTLS 1.3.
+  uint64_t sequence = 0;
+  CBS header, body;
+  uint8_t type = 0;
+  uint16_t version = 0;
+};
+
+static bool use_dtls13_record_header(const SSL *ssl, uint16_t epoch) {
+  // Plaintext records in DTLS 1.3 also use the DTLSPlaintext structure for
+  // backwards compatibility.
+  return ssl->s3->version != 0 && ssl_protocol_version(ssl) > TLS1_2_VERSION &&
+         epoch > 0;
+}
+
+static bool parse_dtls13_record(SSL *ssl, CBS *in, ParsedDTLSRecord *out) {
+  if (out->type & 0x10) {
     // Connection ID bit set, which we didn't negotiate.
     return false;
   }
 
-  // TODO(crbug.com/boringssl/715): Add a runner test that performs many
-  // key updates to verify epoch reconstruction works for epochs larger than
-  // 3.
-  *out_epoch = reconstruct_epoch(type, ssl->d1->r_epoch);
-  size_t seqlen = 1;
-  if ((type & 0x08) == 0x08) {
-    // If this bit is set, the sequence number is 16 bits long, otherwise it is
-    // 8 bits. The seqlen variable tracks the length of the sequence number in
-    // bytes.
-    seqlen = 2;
-  }
-  if (!CBS_skip(in, seqlen)) {
-    // The record header was incomplete or malformed.
+  // TODO(crbug.com/42290594): Add a runner test that performs many
+  // key updates to verify epoch reconstruction works for epochs larger than 3.
+  uint16_t epoch = reconstruct_epoch(out->type, ssl->d1->read_epoch.epoch);
+  size_t seq_len = (out->type & 0x08) ? 2 : 1;
+  CBS seq_bytes;
+  if (!CBS_get_bytes(in, &seq_bytes, seq_len)) {
     return false;
   }
-  *out_header_len = packet.size() - CBS_len(in);
-  if ((type & 0x04) == 0x04) {
-    *out_header_len += 2;
+  if (out->type & 0x04) {
     // 16-bit length present
-    if (!CBS_get_u16_length_prefixed(in, out_body)) {
-      // The record header was incomplete or malformed.
+    if (!CBS_get_u16_length_prefixed(in, &out->body)) {
       return false;
     }
   } else {
     // No length present - the remaining contents are the whole packet.
     // CBS_get_bytes is used here to advance |in| to the end so that future
     // code that computes the number of consumed bytes functions correctly.
-    if (!CBS_get_bytes(in, out_body, CBS_len(in))) {
-      return false;
-    }
+    BSSL_CHECK(CBS_get_bytes(in, &out->body, CBS_len(in)));
   }
 
-  // Decrypt and reconstruct the sequence number:
-  uint8_t mask[AES_BLOCK_SIZE];
-  SSLAEADContext *aead = ssl->s3->aead_read_ctx.get();
-  if (!aead->GenerateRecordNumberMask(mask, *out_body)) {
-    // GenerateRecordNumberMask most likely failed because the record body was
-    // not long enough.
-    return false;
+  // Look up the corresponding epoch. This header form only matches encrypted
+  // DTLS 1.3 epochs.
+  // TODO(crbug.com/42290594): DTLS 1.3 will require that we track multiple
+  // epochs.
+  if (epoch == ssl->d1->read_epoch.epoch &&
+      use_dtls13_record_header(ssl, epoch)) {
+    out->read_epoch = &ssl->d1->read_epoch;
+
+    // Decrypt and reconstruct the sequence number:
+    uint8_t mask[AES_BLOCK_SIZE];
+    if (!out->read_epoch->aead->GenerateRecordNumberMask(mask, out->body)) {
+      // GenerateRecordNumberMask most likely failed because the record body was
+      // not long enough.
+      return false;
+    }
+    // Apply the mask to the sequence number in-place. The header (with the
+    // decrypted sequence number bytes) is used as the additional data for the
+    // AEAD function.
+    auto writable_seq = cbs_to_writable_bytes(seq_bytes);
+    uint64_t seq = 0;
+    for (size_t i = 0; i < writable_seq.size(); i++) {
+      writable_seq[i] ^= mask[i];
+      seq = (seq << 8) | writable_seq[i];
+    }
+    out->sequence = reconstruct_seqnum(seq, (1 << (seq_len * 8)) - 1,
+                                       out->read_epoch->bitmap.max_seq_num());
   }
-  // Apply the mask to the sequence number as it exists in the header. The
-  // header (with the decrypted sequence number bytes) is used as the
-  // additional data for the AEAD function. Since we don't support Connection
-  // ID, the sequence number starts immediately after the type byte.
-  uint64_t seq = 0;
-  for (size_t i = 0; i < seqlen; i++) {
-    packet[i + 1] ^= mask[i];
-    seq = (seq << 8) | packet[i + 1];
-  }
-  *out_sequence = reconstruct_seqnum(seq, (1 << (seqlen * 8)) - 1,
-                                     ssl->d1->bitmap.max_seq_num);
+
   return true;
 }
 
-static bool parse_dtls_plaintext_record_header(
-    SSL *ssl, CBS *in, size_t packet_size, uint8_t type, CBS *out_body,
-    uint64_t *out_sequence, uint16_t *out_epoch, size_t *out_header_len,
-    uint16_t *out_version) {
-  SSLAEADContext *aead = ssl->s3->aead_read_ctx.get();
-  uint8_t sequence_bytes[8];
-  if (!CBS_get_u16(in, out_version) ||
-      !CBS_copy_bytes(in, sequence_bytes, sizeof(sequence_bytes))) {
-    return false;
-  }
-  *out_header_len = packet_size - CBS_len(in) + 2;
-  if (!CBS_get_u16_length_prefixed(in, out_body) ||
-      CBS_len(out_body) > SSL3_RT_MAX_ENCRYPTED_LENGTH) {
+static bool parse_dtls12_record(SSL *ssl, CBS *in, ParsedDTLSRecord *out) {
+  if (!CBS_get_u16(in, &out->version) ||  //
+      !CBS_get_u64(in, &out->sequence) ||
+      !CBS_get_u16_length_prefixed(in, &out->body)) {
     return false;
   }
 
+  uint16_t epoch = static_cast<uint16_t>(out->sequence >> 48);
   bool version_ok;
-  if (aead->is_null_cipher()) {
+  if (epoch == 0) {
     // Only check the first byte. Enforcing beyond that can prevent decoding
     // version negotiation failure alerts.
-    version_ok = (*out_version >> 8) == DTLS1_VERSION_MAJOR;
+    version_ok = (out->version >> 8) == DTLS1_VERSION_MAJOR;
   } else {
-    version_ok = *out_version == dtls_record_version(ssl);
+    version_ok = out->version == dtls_record_version(ssl);
   }
-
   if (!version_ok) {
     return false;
   }
 
-  *out_sequence = CRYPTO_load_u64_be(sequence_bytes);
-  *out_epoch = static_cast<uint16_t>(*out_sequence >> 48);
+  // Look up the corresponding epoch. In DTLS 1.2, we only need to consider one
+  // epoch.
+  if (epoch == ssl->d1->read_epoch.epoch &&
+      !use_dtls13_record_header(ssl, epoch)) {
+    out->read_epoch = &ssl->d1->read_epoch;
+  }
 
-  // Discard the packet if we're expecting an encrypted DTLS 1.3 record but we
-  // get the old record header format.
-  if (!aead->is_null_cipher() && ssl_protocol_version(ssl) >= TLS1_3_VERSION) {
+  return true;
+}
+
+static bool parse_dtls_record(SSL *ssl, CBS *cbs, ParsedDTLSRecord *out) {
+  CBS copy = *cbs;
+  if (!CBS_get_u8(cbs, &out->type)) {
     return false;
   }
+
+  bool ok;
+  if ((out->type & 0xe0) == 0x20) {
+    ok = parse_dtls13_record(ssl, cbs, out);
+  } else {
+    ok = parse_dtls12_record(ssl, cbs, out);
+  }
+  if (!ok) {
+    return false;
+  }
+
+  if (CBS_len(&out->body) > SSL3_RT_MAX_ENCRYPTED_LENGTH) {
+    return false;
+  }
+
+  size_t header_len = CBS_data(&out->body) - CBS_data(&copy);
+  BSSL_CHECK(CBS_get_bytes(&copy, &out->header, header_len));
   return true;
 }
 
@@ -326,47 +349,20 @@
     return ssl_open_record_partial;
   }
 
-  CBS cbs = CBS(in);
-
-  uint8_t type;
-  size_t record_header_len;
-  if (!CBS_get_u8(&cbs, &type)) {
-    // The record header was incomplete or malformed. Drop the entire packet.
-    *out_consumed = in.size();
-    return ssl_open_record_discard;
-  }
-  SSLAEADContext *aead = ssl->s3->aead_read_ctx.get();
-  uint64_t sequence;
-  uint16_t epoch;
-  uint16_t version = 0;
-  CBS body;
-  bool valid_record_header;
-  // Decode the record header. If the 3 high bits of the type are 001, then the
-  // record header is the DTLS 1.3 format. The DTLS 1.3 format should only be
-  // used for encrypted records with DTLS 1.3. Plaintext records or DTLS 1.2
-  // records use the old record header format.
-  if ((type & 0xe0) == 0x20 && !aead->is_null_cipher() &&
-      ssl_protocol_version(ssl) >= TLS1_3_VERSION) {
-    valid_record_header = parse_dtls13_record_header(
-        ssl, &cbs, in, type, &body, &sequence, &epoch, &record_header_len);
-  } else {
-    valid_record_header = parse_dtls_plaintext_record_header(
-        ssl, &cbs, in.size(), type, &body, &sequence, &epoch,
-        &record_header_len, &version);
-  }
-  if (!valid_record_header) {
+  CBS cbs(in);
+  ParsedDTLSRecord record;
+  if (!parse_dtls_record(ssl, &cbs, &record)) {
     // The record header was incomplete or malformed. Drop the entire packet.
     *out_consumed = in.size();
     return ssl_open_record_discard;
   }
 
-  Span<const uint8_t> header = in.subspan(0, record_header_len);
-  ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HEADER, header);
+  ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HEADER, record.header);
 
-  if (epoch != ssl->d1->r_epoch ||
-      dtls1_bitmap_should_discard(&ssl->d1->bitmap, sequence)) {
-    // Drop this record. It's from the wrong epoch or is a replay. Note that if
-    // |epoch| is the next epoch, the record could be buffered for later. For
+  if (record.read_epoch == nullptr ||
+      record.read_epoch->bitmap.ShouldDiscard(record.sequence)) {
+    // Drop this record. It's from an unknown epoch or is a replay. Note that if
+    // the record is from next epoch, it could be buffered for later. For
     // simplicity, drop it and expect retransmit to handle it later; DTLS must
     // handle packet loss anyway.
     *out_consumed = in.size() - CBS_len(&cbs);
@@ -374,9 +370,9 @@
   }
 
   // discard the body in-place.
-  if (!aead->Open(
-          out, type, version, sequence, header,
-          MakeSpan(const_cast<uint8_t *>(CBS_data(&body)), CBS_len(&body)))) {
+  if (!record.read_epoch->aead->Open(out, record.type, record.version,
+                                     record.sequence, record.header,
+                                     cbs_to_writable_bytes(record.body))) {
     // Bad packets are silently dropped in DTLS. See section 4.2.1 of RFC 6347.
     // Clear the error queue of any errors decryption may have added. Drop the
     // entire packet as it must not have come from the peer.
@@ -390,8 +386,8 @@
   *out_consumed = in.size() - CBS_len(&cbs);
 
   // DTLS 1.3 hides the record type inside the encrypted data.
-  bool has_padding =
-      !aead->is_null_cipher() && ssl_protocol_version(ssl) >= TLS1_3_VERSION;
+  bool has_padding = !record.read_epoch->aead->is_null_cipher() &&
+                     ssl_protocol_version(ssl) >= TLS1_3_VERSION;
   // Check the plaintext length.
   size_t plaintext_limit = SSL3_RT_MAX_PLAIN_LENGTH + (has_padding ? 1 : 0);
   if (out->size() > plaintext_limit) {
@@ -407,45 +403,36 @@
         *out_alert = SSL_AD_DECRYPT_ERROR;
         return ssl_open_record_error;
       }
-      type = out->back();
+      record.type = out->back();
       *out = out->subspan(0, out->size() - 1);
-    } while (type == 0);
+    } while (record.type == 0);
   }
 
-  dtls1_bitmap_record(&ssl->d1->bitmap, sequence);
+  record.read_epoch->bitmap.Record(record.sequence);
 
   // TODO(davidben): Limit the number of empty records as in TLS? This is only
   // useful if we also limit discarded packets.
 
-  if (type == SSL3_RT_ALERT) {
+  if (record.type == SSL3_RT_ALERT) {
     return ssl_process_alert(ssl, out_alert, *out);
   }
 
   ssl->s3->warning_alert_count = 0;
 
-  *out_type = type;
+  *out_type = record.type;
   return ssl_open_record_success;
 }
 
-static SSLAEADContext *get_write_aead(const SSL *ssl, uint16_t epoch) {
-  if (epoch == 0) {
-    return ssl->d1->initial_epoch_state->aead_write_ctx.get();
+static DTLSWriteEpoch *get_write_epoch(const SSL *ssl, uint16_t epoch) {
+  if (ssl->d1->write_epoch.epoch == epoch) {
+    return &ssl->d1->write_epoch;
   }
-
-  if (epoch < ssl->d1->w_epoch) {
-    BSSL_CHECK(epoch + 1 == ssl->d1->w_epoch);
-    return ssl->d1->last_epoch_state.aead_write_ctx.get();
+  for (const auto &e : ssl->d1->extra_write_epochs) {
+    if (e->epoch == epoch) {
+      return e.get();
+    }
   }
-
-  BSSL_CHECK(epoch == ssl->d1->w_epoch);
-  return ssl->s3->aead_write_ctx.get();
-}
-
-static bool use_dtls13_record_header(const SSL *ssl, uint16_t epoch) {
-  // Plaintext records in DTLS 1.3 also use the DTLSPlaintext structure for
-  // backwards compatibility.
-  return ssl->s3->version != 0 && ssl_protocol_version(ssl) > TLS1_2_VERSION &&
-         epoch > 0;
+  return nullptr;
 }
 
 size_t dtls_record_header_write_len(const SSL *ssl, uint16_t epoch) {
@@ -462,8 +449,12 @@
 
 size_t dtls_max_seal_overhead(const SSL *ssl,
                               uint16_t epoch) {
+  DTLSWriteEpoch *write_epoch = get_write_epoch(ssl, epoch);
+  if (write_epoch == nullptr) {
+    return 0;
+  }
   size_t ret = dtls_record_header_write_len(ssl, epoch) +
-               get_write_aead(ssl, epoch)->MaxOverhead();
+               write_epoch->aead->MaxOverhead();
   if (use_dtls13_record_header(ssl, epoch)) {
     // Add 1 byte for the encrypted record type.
     ret++;
@@ -472,8 +463,12 @@
 }
 
 size_t dtls_seal_prefix_len(const SSL *ssl, uint16_t epoch) {
+  DTLSWriteEpoch *write_epoch = get_write_epoch(ssl, epoch);
+  if (write_epoch == nullptr) {
+    return 0;
+  }
   return dtls_record_header_write_len(ssl, epoch) +
-         get_write_aead(ssl, epoch)->ExplicitNonceLen();
+         write_epoch->aead->ExplicitNonceLen();
 }
 
 bool dtls_seal_record(SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out,
@@ -487,26 +482,21 @@
   }
 
   // Determine the parameters for the current epoch.
-  SSLAEADContext *aead = get_write_aead(ssl, epoch);
-  uint64_t *seq = &ssl->s3->write_sequence;
-  if (epoch == 0) {
-    seq = &ssl->d1->initial_epoch_state->write_sequence;
-  } else if (epoch < ssl->d1->w_epoch) {
-    seq = &ssl->d1->last_epoch_state.write_sequence;
+  DTLSWriteEpoch *write_epoch = get_write_epoch(ssl, epoch);
+  if (write_epoch == nullptr) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+    return false;
   }
 
   const size_t record_header_len = dtls_record_header_write_len(ssl, epoch);
 
   // Ensure the sequence number update does not overflow.
   const uint64_t kMaxSequenceNumber = (uint64_t{1} << 48) - 1;
-  if (*seq + 1 > kMaxSequenceNumber) {
+  if (write_epoch->next_seq + 1 > kMaxSequenceNumber) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
     return false;
   }
 
-  uint16_t record_version = dtls_record_version(ssl);
-  uint64_t seq_with_epoch = (uint64_t{epoch} << 48) | *seq;
-
   bool dtls13_header = use_dtls13_record_header(ssl, epoch);
   uint8_t *extra_in = NULL;
   size_t extra_in_len = 0;
@@ -516,7 +506,8 @@
   }
 
   size_t ciphertext_len;
-  if (!aead->CiphertextLen(&ciphertext_len, in_len, extra_in_len)) {
+  if (!write_epoch->aead->CiphertextLen(&ciphertext_len, in_len,
+                                        extra_in_len)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_RECORD_TOO_LARGE);
     return false;
   }
@@ -525,6 +516,8 @@
     return false;
   }
 
+  uint16_t record_version = dtls_record_version(ssl);
+  uint64_t aead_seq;
   if (dtls13_header) {
     // The first byte of the DTLS 1.3 record header has the following format:
     // 0 1 2 3 4 5 6 7
@@ -543,29 +536,31 @@
     // We always use a two-byte sequence number. A one-byte sequence number
     // would require coordinating with the application on ACK feedback to know
     // that the peer is not too far behind.
-    out[1] = *seq >> 8;
-    out[2] = *seq & 0xff;
+    out[1] = write_epoch->next_seq >> 8;
+    out[2] = write_epoch->next_seq & 0xff;
     // TODO(crbug.com/42290594): When we know the record is last in the packet,
     // omit the length.
     out[3] = ciphertext_len >> 8;
     out[4] = ciphertext_len & 0xff;
     // DTLS 1.3 uses the sequence number without the epoch for the AEAD.
-    seq_with_epoch = *seq;
+    aead_seq = write_epoch->next_seq;
   } else {
     out[0] = type;
     out[1] = record_version >> 8;
     out[2] = record_version & 0xff;
-    CRYPTO_store_u64_be(&out[3], seq_with_epoch);
+    // DTLS 1.2 uses the sequence number with the epoch for the AEAD.
+    aead_seq = (uint64_t{epoch} << 48) | write_epoch->next_seq;
+    CRYPTO_store_u64_be(&out[3], aead_seq);
     out[11] = ciphertext_len >> 8;
     out[12] = ciphertext_len & 0xff;
   }
   Span<const uint8_t> header = MakeConstSpan(out, record_header_len);
 
 
-  if (!aead->SealScatter(out + record_header_len, out + prefix,
-                         out + prefix + in_len, type, record_version,
-                         seq_with_epoch, header, in, in_len, extra_in,
-                         extra_in_len)) {
+  if (!write_epoch->aead->SealScatter(out + record_header_len, out + prefix,
+                                      out + prefix + in_len, type,
+                                      record_version, aead_seq, header, in,
+                                      in_len, extra_in, extra_in_len)) {
     return false;
   }
 
@@ -581,14 +576,14 @@
     // cipher suites have no requirements on the mask size. We only need the
     // first two bytes from the mask.
     uint8_t mask[AES_BLOCK_SIZE];
-    if (!aead->GenerateRecordNumberMask(mask, sample)) {
+    if (!write_epoch->aead->GenerateRecordNumberMask(mask, sample)) {
       return false;
     }
     out[1] ^= mask[0];
     out[2] ^= mask[1];
   }
 
-  (*seq)++;
+  write_epoch->next_seq++;
   *out_len = record_header_len + ciphertext_len;
   ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_HEADER, header);
   return true;
diff --git a/ssl/internal.h b/ssl/internal.h
index b3b1f67..174b7e4 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -559,23 +559,29 @@
   T *end() { return data() + size_; }
   const T *end() const { return data() + size_; }
 
-  void clear() {
-    cxx17_destroy_n(data(), size_);
-    size_ = 0;
+  void clear() { Shrink(0); }
+
+  // Shrink resizes the vector to |new_size|, which must not be larger than the
+  // current size. Unlike |Resize|, this can be called when |T| is not
+  // default-constructible.
+  void Shrink(size_t new_size) {
+    BSSL_CHECK(new_size <= size_);
+    cxx17_destroy_n(data() + new_size, size_ - new_size);
+    size_ = static_cast<PackedSize<N>>(new_size);
   }
 
   // TryResize resizes the vector to |new_size| and returns true, or returns
   // false if |new_size| is too large. Any newly-added elements are
   // value-initialized.
   bool TryResize(size_t new_size) {
+    if (new_size <= size_) {
+      Shrink(new_size);
+      return true;
+    }
     if (new_size > capacity()) {
       return false;
     }
-    if (new_size < size_) {
-      cxx17_destroy_n(data() + new_size, size_ - new_size);
-    } else {
-      cxx17_uninitialized_value_construct_n(data() + size_, new_size - size_);
-    }
+    cxx17_uninitialized_value_construct_n(data() + size_, new_size - size_);
     size_ = static_cast<PackedSize<N>>(new_size);
     return true;
   }
@@ -584,14 +590,14 @@
   // default-initialized, so POD types may contain uninitialized values that the
   // caller is responsible for filling in.
   bool TryResizeMaybeUninit(size_t new_size) {
+    if (new_size <= size_) {
+      Shrink(new_size);
+      return true;
+    }
     if (new_size > capacity()) {
       return false;
     }
-    if (new_size < size_) {
-      cxx17_destroy_n(data() + new_size, size_ - new_size);
-    } else {
-      cxx17_uninitialized_default_construct_n(data() + size_, new_size - size_);
-    }
+    cxx17_uninitialized_default_construct_n(data() + size_, new_size - size_);
     size_ = static_cast<PackedSize<N>>(new_size);
     return true;
   }
@@ -633,6 +639,27 @@
     return *ret;
   }
 
+  template <typename Pred>
+  void EraseIf(Pred pred) {
+    // See if anything needs to be erased at all. This avoids a self-move.
+    auto iter = std::find_if(begin(), end(), pred);
+    if (iter == end()) {
+      return;
+    }
+
+    // Elements before the first to be erased may be left as-is.
+    size_t new_size = iter - begin();
+    // Swap all subsequent elements in if they are to be kept.
+    for (size_t i = new_size + 1; i < size(); i++) {
+      if (!pred((*this)[i])) {
+        (*this)[new_size] = std::move((*this)[i]);
+        new_size++;
+      }
+    }
+
+    Shrink(new_size);
+  }
+
  private:
   alignas(T) char storage_[sizeof(T[N])];
   PackedSize<N> size_ = 0;
@@ -1157,6 +1184,7 @@
   // records.
   InplaceVector<uint8_t, 12> fixed_nonce_;
   uint8_t variable_nonce_len_ = 0;
+  // TODO(crbug.com/42290594): Move this into DTLSReadEpoch and DTLSWriteEpoch.
   UniquePtr<RecordNumberEncrypter> rn_encrypter_;
   // variable_nonce_included_in_record_ is true if the variable nonce
   // for a record is included as a prefix before the ciphertext.
@@ -1217,15 +1245,28 @@
 
 // DTLS replay bitmap.
 
-// DTLS1_BITMAP maintains a sliding window of 64 sequence numbers to detect
-// replayed packets. It should be initialized by zeroing every field.
-struct DTLS1_BITMAP {
+// DTLSReplayBitmap maintains a sliding window of sequence numbers to detect
+// replayed packets.
+class DTLSReplayBitmap {
+ public:
+  // ShouldDiscard returns true if |seq_num| has been seen in
+  // |bitmap| or is stale. Otherwise it returns false.
+  bool ShouldDiscard(uint64_t seqnum) const;
+
+  // Record updates the bitmap to record receipt of sequence number
+  // |seq_num|. It slides the window forward if needed. It is an error to call
+  // this function on a stale sequence number.
+  void Record(uint64_t seqnum);
+
+  uint64_t max_seq_num() const { return max_seq_num_; }
+
+ private:
   // map is a bitset of sequence numbers that have been seen. Bit i corresponds
-  // to |max_seq_num - i|.
-  std::bitset<256> map;
-  // max_seq_num is the largest sequence number seen so far as a 64-bit
+  // to |max_seq_num_ - i|.
+  std::bitset<256> map_;
+  // max_seq_num_ is the largest sequence number seen so far as a 64-bit
   // integer.
-  uint64_t max_seq_num = 0;
+  uint64_t max_seq_num_ = 0;
 };
 
 // reconstruct_seqnum takes the low order bits of a record sequence number from
@@ -1240,8 +1281,25 @@
 OPENSSL_EXPORT uint64_t reconstruct_seqnum(uint16_t wire_seq, uint64_t seq_mask,
                                            uint64_t max_valid_seqnum);
 
+
 // Record layer.
 
+struct DTLSReadEpoch {
+  static constexpr bool kAllowUniquePtr = true;
+
+  uint16_t epoch = 0;
+  UniquePtr<SSLAEADContext> aead;
+  DTLSReplayBitmap bitmap;
+};
+
+struct DTLSWriteEpoch {
+  static constexpr bool kAllowUniquePtr = true;
+
+  uint16_t epoch = 0;
+  UniquePtr<SSLAEADContext> aead;
+  uint64_t next_seq = 0;
+};
+
 // ssl_record_prefix_len returns the length of the prefix before the ciphertext
 // of a record for |ssl|.
 //
@@ -1496,6 +1554,10 @@
 // dtls_clear_outgoing_messages releases all buffered outgoing messages.
 void dtls_clear_outgoing_messages(SSL *ssl);
 
+// dtls_clear_unused_write_epochs releases any write epochs that are no longer
+// needed.
+void dtls_clear_unused_write_epochs(SSL *ssl);
+
 
 // Callbacks.
 
@@ -2844,19 +2906,25 @@
   // on_handshake_complete is called when the handshake is complete.
   void (*on_handshake_complete)(SSL *ssl);
   // set_read_state sets |ssl|'s read cipher state and level to |aead_ctx| and
-  // |level|. In QUIC, |aead_ctx| is a placeholder object and |secret_for_quic|
-  // is the original secret. This function returns true on success and false on
-  // error.
+  // |level|. In QUIC, |aead_ctx| is a placeholder object. In TLS 1.3,
+  // |traffic_secret| is the original traffic secret. This function returns true
+  // on success and false on error.
+  //
+  // TODO(crbug.com/371998381): Take the traffic secrets as input and let the
+  // function create the SSLAEADContext.
   bool (*set_read_state)(SSL *ssl, ssl_encryption_level_t level,
                          UniquePtr<SSLAEADContext> aead_ctx,
-                         Span<const uint8_t> secret_for_quic);
+                         Span<const uint8_t> traffic_secret);
   // set_write_state sets |ssl|'s write cipher state and level to |aead_ctx| and
-  // |level|. In QUIC, |aead_ctx| is a placeholder object and |secret_for_quic|
-  // is the original secret. This function returns true on success and false on
-  // error.
+  // |level|. In QUIC, |aead_ctx| is a placeholder object In TLS 1.3,
+  // |traffic_secret| is the original traffic secret. This function returns true
+  // on success and false on error.
+  //
+  // TODO(crbug.com/371998381): Take the traffic secrets as input and let the
+  // function create the SSLAEADContext.
   bool (*set_write_state)(SSL *ssl, ssl_encryption_level_t level,
                           UniquePtr<SSLAEADContext> aead_ctx,
-                          Span<const uint8_t> secret_for_quic);
+                          Span<const uint8_t> traffic_secret);
 };
 
 // The following wrappers call |open_*| but handle |read_shutdown| correctly.
@@ -3268,13 +3336,17 @@
   uint32_t tv_usec;
 };
 
-// A DTLSEpochState object contains state about a DTLS epoch.
-struct DTLSEpochState {
-  static constexpr bool kAllowUniquePtr = true;
-
-  UniquePtr<SSLAEADContext> aead_write_ctx;
-  uint64_t write_sequence = 0;
-};
+// DTLS_MAX_EXTRA_WRITE_EPOCHS is the maximum number of additional write epochs
+// that DTLS may need to retain.
+//
+// The maximum is, as a DTLS 1.3 server, immediately after sending Finished. At
+// this point, the current epoch is the application write keys (epoch 3), but we
+// may have ServerHello (epoch 0) and EncryptedExtensions (epoch 1) to
+// retransmit. KeyUpdate does not increase this count. If the server were to
+// initiate KeyUpdate from this state, it would not apply the new epoch until
+// the client's ACKs have caught up. At that point, epochs 0 and 1 can be
+// discarded.
+#define DTLS_MAX_EXTRA_WRITE_EPOCHS 2
 
 struct DTLS1_STATE {
   static constexpr bool kAllowUniquePtr = true;
@@ -3282,6 +3354,8 @@
   DTLS1_STATE();
   ~DTLS1_STATE();
 
+  bool Init();
+
   // has_change_cipher_spec is true if we have received a ChangeCipherSpec from
   // the peer in this epoch.
   bool has_change_cipher_spec : 1;
@@ -3296,23 +3370,22 @@
   // peer sent the final flight.
   bool flight_has_reply : 1;
 
-  // The current data and handshake epoch.  This is initially undefined, and
-  // starts at zero once the initial handshake is completed.
-  uint16_t r_epoch = 0;
-  uint16_t w_epoch = 0;
-
-  // records being received in the current epoch
-  DTLS1_BITMAP bitmap;
-
   uint16_t handshake_write_seq = 0;
   uint16_t handshake_read_seq = 0;
 
-  // state from the last epoch
-  DTLSEpochState last_epoch_state;
+  // read_epoch is the current DTLS read epoch.
+  // TODO(crbug.com/42290594): DTLS 1.3 will require that we also store the next
+  // epoch, and switch over on the first record from the new epoch.
+  DTLSReadEpoch read_epoch;
 
-  // In DTLS 1.3, this contains the write AEAD for the initial encryption level.
-  // TODO(crbug.com/boringssl/715): Drop this when it is no longer needed.
-  UniquePtr<DTLSEpochState> initial_epoch_state;
+  // write_epoch is the current DTLS write epoch. Non-retransmit records will
+  // generally use this epoch.
+  // TODO(crbug.com/42290594): 0-RTT will be the exception, when implemented.
+  DTLSWriteEpoch write_epoch;
+
+  // extra_write_epochs is the collection available write epochs.
+  InplaceVector<UniquePtr<DTLSWriteEpoch>, DTLS_MAX_EXTRA_WRITE_EPOCHS>
+      extra_write_epochs;
 
   // incoming_messages is a ring buffer of incoming handshake messages that have
   // yet to be processed. The front of the ring buffer is message number
diff --git a/ssl/s3_lib.cc b/ssl/s3_lib.cc
index 14d9b64..faa4d61 100644
--- a/ssl/s3_lib.cc
+++ b/ssl/s3_lib.cc
@@ -187,10 +187,20 @@
     return false;
   }
 
-  s3->aead_read_ctx = SSLAEADContext::CreateNullCipher();
-  s3->aead_write_ctx = SSLAEADContext::CreateNullCipher();
+  // TODO(crbug.com/368805255): Fields that aren't used in DTLS should not be
+  // allocated at all.
+  // TODO(crbug.com/371998381): Don't create these in QUIC either, once the
+  // placeholder QUIC ones for subsequent epochs are removed.
+  if (!SSL_is_dtls(ssl)) {
+    s3->aead_read_ctx = SSLAEADContext::CreateNullCipher();
+    s3->aead_write_ctx = SSLAEADContext::CreateNullCipher();
+    if (!s3->aead_read_ctx || !s3->aead_write_ctx) {
+      return false;
+    }
+  }
+
   s3->hs = ssl_handshake_new(ssl);
-  if (!s3->aead_read_ctx || !s3->aead_write_ctx || !s3->hs) {
+  if (!s3->hs) {
     return false;
   }
 
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index f0b3872..cbb5cbf 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -2933,6 +2933,13 @@
 
 int SSL_get_ivs(const SSL *ssl, const uint8_t **out_read_iv,
                 const uint8_t **out_write_iv, size_t *out_iv_len) {
+  // No cipher suites maintain stateful internal IVs in DTLS. It would not be
+  // compatible with reordering.
+  if (SSL_is_dtls(ssl)) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+    return 0;
+  }
+
   size_t write_iv_len;
   if (!ssl->s3->aead_read_ctx->GetIV(out_read_iv, out_iv_len) ||
       !ssl->s3->aead_write_ctx->GetIV(out_write_iv, &write_iv_len) ||
@@ -2956,19 +2963,25 @@
     // max_seq_num already includes the epoch. However, the current epoch may
     // be one ahead of the highest record received, immediately after a key
     // change.
-    assert(ssl->d1->r_epoch >= ssl->d1->bitmap.max_seq_num >> 48);
-    return ssl->d1->bitmap.max_seq_num;
+    const DTLSReadEpoch *read_epoch = &ssl->d1->read_epoch;
+    assert(read_epoch->epoch >= read_epoch->bitmap.max_seq_num() >> 48);
+    return read_epoch->bitmap.max_seq_num();
   }
   return ssl->s3->read_sequence;
 }
 
 uint64_t SSL_get_write_sequence(const SSL *ssl) {
-  uint64_t ret = ssl->s3->write_sequence;
   if (SSL_is_dtls(ssl)) {
-    assert((ret >> 48) == 0);
-    ret |= uint64_t{ssl->d1->w_epoch} << 48;
+    const DTLSWriteEpoch *write_epoch = &ssl->d1->write_epoch;
+    uint64_t ret = write_epoch->next_seq;
+    if (SSL_is_dtls(ssl)) {
+      assert((ret >> 48) == 0);
+      ret |= uint64_t{write_epoch->epoch} << 48;
+    }
+    return ret;
   }
-  return ret;
+
+  return ssl->s3->write_sequence;
 }
 
 uint16_t SSL_get_peer_signature_algorithm(const SSL *ssl) {
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 963f2ee..1e09a76 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -802,6 +802,81 @@
   EXPECT_FALSE(vec_of_vecs5.TryPushBack(v));
 }
 
+TEST(InplaceVectorTest, EraseIf) {
+  // Test that EraseIf never causes a self-move, and also correctly works with
+  // a move-only type that cannot be default-constructed.
+  class NoSelfMove {
+   public:
+    explicit NoSelfMove(int v) : v_(std::make_unique<int>(v)) {}
+    NoSelfMove(NoSelfMove &&other) { *this = std::move(other); }
+    NoSelfMove &operator=(NoSelfMove &&other) {
+      BSSL_CHECK(this != &other);
+      v_ = std::move(other.v_);
+      return *this;
+    }
+
+    int value() const { return *v_; }
+
+   private:
+    std::unique_ptr<int> v_;
+  };
+
+  InplaceVector<NoSelfMove, 8> vec;
+  auto reset = [&] {
+    vec.clear();
+    for (int i = 0; i < 8; i++) {
+      vec.PushBack(NoSelfMove(i));
+    }
+  };
+  auto expect = [&](const std::vector<int> &expected) {
+    ASSERT_EQ(vec.size(), expected.size());
+    for (size_t i = 0; i < vec.size(); i++) {
+      SCOPED_TRACE(i);
+      EXPECT_EQ(vec[i].value(), expected[i]);
+    }
+  };
+
+  reset();
+  vec.EraseIf([](const auto &) { return false; });
+  expect({0, 1, 2, 3, 4, 5, 6, 7});
+
+  reset();
+  vec.EraseIf([](const auto &) { return true; });
+  expect({});
+
+  reset();
+  vec.EraseIf([](const auto &v) { return v.value() < 4; });
+  expect({4, 5, 6, 7});
+
+  reset();
+  vec.EraseIf([](const auto &v) { return v.value() >= 4; });
+  expect({0, 1, 2, 3});
+
+  reset();
+  vec.EraseIf([](const auto &v) { return v.value() % 2 == 0; });
+  expect({1, 3, 5, 7});
+
+  reset();
+  vec.EraseIf([](const auto &v) { return v.value() % 2 == 1; });
+  expect({0, 2, 4, 6});
+
+  reset();
+  vec.EraseIf([](const auto &v) { return 2 <= v.value() && v.value() <= 5; });
+  expect({0, 1, 6, 7});
+
+  reset();
+  vec.EraseIf([](const auto &v) { return v.value() == 0; });
+  expect({1, 2, 3, 4, 5, 6, 7});
+
+  reset();
+  vec.EraseIf([](const auto &v) { return v.value() == 4; });
+  expect({0, 1, 2, 3, 5, 6, 7});
+
+  reset();
+  vec.EraseIf([](const auto &v) { return v.value() == 7; });
+  expect({0, 1, 2, 3, 4, 5, 6});
+}
+
 TEST(InplaceVectorDeathTest, BoundsChecks) {
   InplaceVector<int, 4> vec;
   // The vector is currently empty.
diff --git a/ssl/t1_enc.cc b/ssl/t1_enc.cc
index 360c855..a13629e 100644
--- a/ssl/t1_enc.cc
+++ b/ssl/t1_enc.cc
@@ -251,12 +251,12 @@
   if (direction == evp_aead_open) {
     return ssl->method->set_read_state(ssl, ssl_encryption_application,
                                        std::move(aead_ctx),
-                                       /*secret_for_quic=*/{});
+                                       /*traffic_secret=*/{});
   }
 
   return ssl->method->set_write_state(ssl, ssl_encryption_application,
                                       std::move(aead_ctx),
-                                      /*secret_for_quic=*/{});
+                                      /*traffic_secret=*/{});
 }
 
 bool tls1_change_cipher_state(SSL_HANDSHAKE *hs,
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index b747215..24da90a 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -86,7 +86,7 @@
       if (!null_ctx ||
           !ssl->method->set_write_state(ssl, ssl_encryption_initial,
                                         std::move(null_ctx),
-                                        /*secret_for_quic=*/{})) {
+                                        /*traffic_secret=*/{})) {
         return false;
       }
     } else {
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index 11ea7b5..ee3b635 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -186,12 +186,10 @@
   const EVP_MD *digest = ssl_session_get_digest(session);
   bool is_dtls = SSL_is_dtls(ssl);
   UniquePtr<SSLAEADContext> traffic_aead;
-  Span<const uint8_t> secret_for_quic;
   if (ssl->quic_method != nullptr) {
     // Install a placeholder SSLAEADContext so that SSL accessors work. The
     // encryption itself will be handled by the SSL_QUIC_METHOD.
     traffic_aead = SSLAEADContext::CreatePlaceholderForQUIC(session->cipher);
-    secret_for_quic = traffic_secret;
   } else {
     // Look up cipher suite properties.
     const EVP_AEAD *aead;
@@ -237,13 +235,13 @@
 
   if (direction == evp_aead_open) {
     if (!ssl->method->set_read_state(ssl, level, std::move(traffic_aead),
-                                     secret_for_quic)) {
+                                     traffic_secret)) {
       return false;
     }
     ssl->s3->read_traffic_secret.CopyFrom(traffic_secret);
   } else {
     if (!ssl->method->set_write_state(ssl, level, std::move(traffic_aead),
-                                      secret_for_quic)) {
+                                      traffic_secret)) {
       return false;
     }
     ssl->s3->write_traffic_secret.CopyFrom(traffic_secret);
diff --git a/ssl/tls_method.cc b/ssl/tls_method.cc
index 5fcf684..0a3fc4e 100644
--- a/ssl/tls_method.cc
+++ b/ssl/tls_method.cc
@@ -84,7 +84,7 @@
 
 static bool tls_set_read_state(SSL *ssl, ssl_encryption_level_t level,
                                UniquePtr<SSLAEADContext> aead_ctx,
-                               Span<const uint8_t> secret_for_quic) {
+                               Span<const uint8_t> traffic_secret) {
   // Cipher changes are forbidden if the current epoch has leftover data.
   if (tls_has_unprocessed_handshake_data(ssl)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESS_HANDSHAKE_DATA);
@@ -95,8 +95,8 @@
   if (ssl->quic_method != nullptr) {
     if ((ssl->s3->hs == nullptr || !ssl->s3->hs->hints_requested) &&
         !ssl->quic_method->set_read_secret(ssl, level, aead_ctx->cipher(),
-                                           secret_for_quic.data(),
-                                           secret_for_quic.size())) {
+                                           traffic_secret.data(),
+                                           traffic_secret.size())) {
       return false;
     }
 
@@ -116,7 +116,7 @@
 
 static bool tls_set_write_state(SSL *ssl, ssl_encryption_level_t level,
                                 UniquePtr<SSLAEADContext> aead_ctx,
-                                Span<const uint8_t> secret_for_quic) {
+                                Span<const uint8_t> traffic_secret) {
   if (!tls_flush_pending_hs_data(ssl)) {
     return false;
   }
@@ -124,8 +124,8 @@
   if (ssl->quic_method != nullptr) {
     if ((ssl->s3->hs == nullptr || !ssl->s3->hs->hints_requested) &&
         !ssl->quic_method->set_write_secret(ssl, level, aead_ctx->cipher(),
-                                            secret_for_quic.data(),
-                                            secret_for_quic.size())) {
+                                            traffic_secret.data(),
+                                            traffic_secret.size())) {
       return false;
     }
 
diff --git a/ssl/tls_record.cc b/ssl/tls_record.cc
index 685b78a..d982683 100644
--- a/ssl/tls_record.cc
+++ b/ssl/tls_record.cc
@@ -593,7 +593,8 @@
 
 size_t SSL_max_seal_overhead(const SSL *ssl) {
   if (SSL_is_dtls(ssl)) {
-    return dtls_max_seal_overhead(ssl, ssl->d1->w_epoch);
+    // TODO(crbug.com/42290594): Use the 0-RTT epoch if writing 0-RTT.
+    return dtls_max_seal_overhead(ssl, ssl->d1->write_epoch.epoch);
   }
 
   size_t ret = SSL3_RT_HEADER_LENGTH;