Stage new DTLS 1.3 read epochs until the first record comes in

In DTLS 1.2, we install the new read epoch as soon as we change keys,
and immediately stop accepting data from the previous epoch. This is
fine because there is only one key change, and we don't need to receive
data from the previous epoch at that point. (Even if we used previous
flights to trigger retransmits---which we probably should but don't---I
believe all DTLS 1.2 key changes are such that the previous flight will
contain *some* message from the epoch we expect and we can key on that.)

In DTLS 1.3, things are different. The peer's KeyUpdates do not apply
until we send an ACK, but that ACK might be lost. So that means, at
least for KeyUpdate, we must stage the new epoch and only apply it
later.

During the handshake, this comes up in two cases:

- A server might send ServerHello..Finished but some of ServerHello (or
  the whole flight) is lost. The server is now expecting
  Certificate..Finished (epoch 2), but the client cannot transition to
  epoch 2 without the ServerHello. The client will then try to induce a
  retransmit by either sending an ACK or resending the ClientHello. The
  client must parse epoch 0 to be able to catch this. This is nice
  but not critical, because the server also has a retransmit timer.

- A client might send Certificate..Finished but it was all lost. The
  client is now done with the handshake (epoch 3), but the server
  cannot complete. The server will then retransmit its
  ServerHello..Finished flight or send an ACK. The ACK will actually
  come at epoch 3, so we don't need epoch 2 in that case. But the
  retransmit is purely in epochs 0 and 2, so the client must parse
  epoch 2 to catch this. This is also not critical because the
  client is expected to keep its retransmit timer running after the
  handshake (still TODO).

To support all this, new DTLS 1.3 read epochs are staged until we get a
record.

This means that record processing must now account for receiving records
at more epochs, including some invalid cases. Specifically:

- We might receive new handshake messages at the current epoch, even
  though the current epoch is being closed. This is an error.

- We might receive application data at epoch 2 even though we've
  finished the handshake and expect it at epoch 3. This is an error.

I've added checks for both of these, with tests. This also resolves a
few TODOs. The app data check will also help with 0-RTT, where epochs 1
and 2 will flow concurrently. (We'll still need an inverse check that
epoch 1 never carries handshake data.) It also once again changes the
broken SSL_get_read_sequence API from one kind of broken to another kind
of broken, so I've updated the comment to explain the current state.

As part of this, stop tracking read_level and write_level outside of
QUIC. It's only used in QUIC and the semantics are a little weird when
we defer activating new read epochs.

(I'm not sure whether this behavior is really necessary for epochs
before epoch 3. The only case it accomodates is the client using epoch 0
to induce a retransmit when ServerHello is lost. The impact it has on
SSL_get_read_sequence immediately after the handshake is kind of weird,
so maybe we want to special case it a bit. I can't think of any case
where we'd want to defer installing epoch 3.)

Bug: 42290594
Change-Id: Id222de2b77ff8e68e64f301f32d8c2b95ad3e7a7
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72450
Reviewed-by: Nick Harper <nharper@chromium.org>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index 0b3711f..0eb3970 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -350,9 +350,8 @@
 }
 
 bool dtls1_process_handshake_fragments(SSL *ssl, uint8_t *out_alert,
-                                       Span<uint8_t> record) {
-  CBS cbs;
-  CBS_init(&cbs, record.data(), record.size());
+                                       Span<const uint8_t> record) {
+  CBS cbs = record;
   while (CBS_len(&cbs) > 0) {
     // Read a handshake fragment.
     struct hm_header_st msg_hdr;
@@ -373,10 +372,28 @@
       return false;
     }
 
-    if (msg_hdr.seq < ssl->d1->handshake_read_seq ||
-        msg_hdr.seq >
-            (unsigned)ssl->d1->handshake_read_seq + SSL_MAX_HANDSHAKE_FLIGHT) {
-      // Ignore fragments from the past, or ones too far in the future.
+    if (msg_hdr.seq < ssl->d1->handshake_read_seq) {
+      // Ignore fragments from the past. This is a retransmit of data we already
+      // received.
+      //
+      // TODO(crbug.com/42290594): Use this to drive retransmits.
+      continue;
+    }
+
+    if (ssl->d1->next_read_epoch != nullptr) {
+      // Any any time, we only expect new messages in one epoch. If
+      // |next_read_epoch| is set, we've started a new epoch but haven't
+      // received records in it yet. (Once a record is received in the new
+      // epoch, |next_read_epoch| becomes the current read epoch.) This new
+      // fragment is in the old epoch, but we expect handshake messages to be in
+      // the next epoch, so this is an error.
+      OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESS_HANDSHAKE_DATA);
+      *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
+      return false;
+    }
+
+    if (msg_hdr.seq - ssl->d1->handshake_read_seq > SSL_MAX_HANDSHAKE_FLIGHT) {
+      // Ignore fragments too far in the future.
       continue;
     }
 
@@ -415,19 +432,8 @@
 
   switch (type) {
     case SSL3_RT_APPLICATION_DATA:
-      // Unencrypted application data records are always illegal.
-      //
-      // TODO(crbug.com/42290594): Revisit both of these checks for DTLS 1.3.
-      // Many more epochs cannot have application data, and there is a key
-      // change immediately before the first application data record.
-      if (ssl->d1->read_epoch.epoch == 0) {
-        OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
-        *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
-        return ssl_open_record_error;
-      }
-
-      // Out-of-order application data may be received between ChangeCipherSpec
-      // and finished. Discard it.
+      // In DTLS 1.2, out-of-order application data may be received between
+      // ChangeCipherSpec and Finished. Discard it.
       return ssl_open_record_discard;
 
     case SSL3_RT_CHANGE_CIPHER_SPEC:
diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc
index 9fc854e..ac32346 100644
--- a/ssl/dtls_method.cc
+++ b/ssl/dtls_method.cc
@@ -88,27 +88,30 @@
   }
 
   DTLSReadEpoch new_epoch;
+  new_epoch.aead = std::move(aead_ctx);
   if (ssl_protocol_version(ssl) > TLS1_2_VERSION) {
     // TODO(crbug.com/42290594): Handle the additional epochs used for key
     // update.
-    // TODO(crbug.com/42290594): If we want to gracefully handle packet
-    // reordering around KeyUpdate (i.e. accept records from both epochs), we'll
-    // need a separate bitmap for each epoch.
     new_epoch.epoch = level;
     new_epoch.rn_encrypter =
-        RecordNumberEncrypter::Create(aead_ctx->cipher(), traffic_secret);
+        RecordNumberEncrypter::Create(new_epoch.aead->cipher(), traffic_secret);
     if (new_epoch.rn_encrypter == nullptr) {
       return false;
     }
+
+    // In DTLS 1.3, new read epochs are not applied immediately. In principle,
+    // we could do the same in DTLS 1.2, but we would ignore every record from
+    // the previous epoch anyway.
+    assert(ssl->d1->next_read_epoch == nullptr);
+    ssl->d1->next_read_epoch = MakeUnique<DTLSReadEpoch>(std::move(new_epoch));
+    if (ssl->d1->next_read_epoch == nullptr) {
+      return false;
+    }
   } else {
     new_epoch.epoch = ssl->d1->read_epoch.epoch + 1;
+    ssl->d1->read_epoch = std::move(new_epoch);
+    ssl->d1->has_change_cipher_spec = false;
   }
-  new_epoch.bitmap = DTLSReplayBitmap();
-  new_epoch.aead = std::move(aead_ctx);
-
-  ssl->d1->read_epoch = std::move(new_epoch);
-  ssl->s3->read_level = level;
-  ssl->d1->has_change_cipher_spec = false;
   return true;
 }
 
@@ -137,7 +140,6 @@
 
   ssl->d1->write_epoch = std::move(new_epoch);
   ssl->d1->extra_write_epochs.PushBack(std::move(current));
-  ssl->s3->write_level = level;
   dtls_clear_unused_write_epochs(ssl);
   return true;
 }
diff --git a/ssl/dtls_record.cc b/ssl/dtls_record.cc
index 8a17ee2..3d381fb 100644
--- a/ssl/dtls_record.cc
+++ b/ssl/dtls_record.cc
@@ -265,15 +265,19 @@
 
   // Look up the corresponding epoch. This header form only matches encrypted
   // DTLS 1.3 epochs.
-  // TODO(crbug.com/42290594): DTLS 1.3 will require that we track multiple
-  // epochs.
-  if (epoch == ssl->d1->read_epoch.epoch &&
-      use_dtls13_record_header(ssl, epoch)) {
-    out->read_epoch = &ssl->d1->read_epoch;
+  DTLSReadEpoch *read_epoch = nullptr;
+  if (epoch == ssl->d1->read_epoch.epoch) {
+    read_epoch = &ssl->d1->read_epoch;
+  } else if (ssl->d1->next_read_epoch != nullptr &&
+             epoch == ssl->d1->next_read_epoch->epoch) {
+    read_epoch = ssl->d1->next_read_epoch.get();
+  }
+  if (read_epoch != nullptr && use_dtls13_record_header(ssl, epoch)) {
+    out->read_epoch = read_epoch;
 
     // Decrypt and reconstruct the sequence number:
     uint8_t mask[2];
-    if (!out->read_epoch->rn_encrypter->GenerateMask(mask, out->body)) {
+    if (!read_epoch->rn_encrypter->GenerateMask(mask, out->body)) {
       // GenerateMask most likely failed because the record body was not long
       // enough.
       return false;
@@ -287,8 +291,8 @@
       writable_seq[i] ^= mask[i];
       seq = (seq << 8) | writable_seq[i];
     }
-    uint64_t full_seq = reconstruct_seqnum(
-        seq, (1 << (seq_len * 8)) - 1, out->read_epoch->bitmap.max_seq_num());
+    uint64_t full_seq = reconstruct_seqnum(seq, (1 << (seq_len * 8)) - 1,
+                                           read_epoch->bitmap.max_seq_num());
     out->number = DTLSRecordNumber(epoch, full_seq);
   }
 
@@ -428,6 +432,22 @@
 
   record.read_epoch->bitmap.Record(record.number.sequence());
 
+  // Once we receive a record from the next epoch, it becomes the current epoch.
+  if (record.read_epoch == ssl->d1->next_read_epoch.get()) {
+    ssl->d1->read_epoch = std::move(*ssl->d1->next_read_epoch);
+    ssl->d1->next_read_epoch = nullptr;
+  }
+
+  // We do not retain previous epochs, so it is guaranteed records come in at
+  // the "current" epoch. (But the current epoch may be one behind the
+  // handshake.)
+  //
+  // TODO(crbug.com/374890768): In DTLS 1.3, where rekeys may occur
+  // mid-connection, retaining previous epochs would make us more robust to
+  // packet reordering. If we do this, we'll need to take care to not
+  // accidentally accept data at the wrong epoch.
+  assert(record.number.epoch() == ssl->d1->read_epoch.epoch);
+
   // TODO(davidben): Limit the number of empty records as in TLS? This is only
   // useful if we also limit discarded packets.
 
@@ -435,6 +455,25 @@
     return ssl_process_alert(ssl, out_alert, *out);
   }
 
+  // Reject application data in epochs that do not allow it.
+  if (record.type == SSL3_RT_APPLICATION_DATA) {
+    bool app_data_allowed;
+    if (ssl->s3->version != 0 && ssl_protocol_version(ssl) >= TLS1_3_VERSION) {
+      // Application data is allowed in 0-RTT (epoch 1) and after the handshake
+      // (3 and up).
+      app_data_allowed =
+          record.number.epoch() == 1 || record.number.epoch() >= 3;
+    } else {
+      // Application data is allowed starting epoch 1.
+      app_data_allowed = record.number.epoch() >= 1;
+    }
+    if (!app_data_allowed) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
+      *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
+      return ssl_open_record_error;
+    }
+  }
+
   ssl->s3->warning_alert_count = 0;
 
   *out_type = record.type;
diff --git a/ssl/internal.h b/ssl/internal.h
index 15634d6..636decf 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -3169,8 +3169,8 @@
   // needs re-doing when in SSL_accept or SSL_connect
   int rwstate = SSL_ERROR_NONE;
 
-  enum ssl_encryption_level_t read_level = ssl_encryption_initial;
-  enum ssl_encryption_level_t write_level = ssl_encryption_initial;
+  enum ssl_encryption_level_t quic_read_level = ssl_encryption_initial;
+  enum ssl_encryption_level_t quic_write_level = ssl_encryption_initial;
 
   // version is the protocol version, or zero if the version has not yet been
   // set. In clients offering 0-RTT, this version will initially be set to the
@@ -3464,10 +3464,12 @@
   uint16_t handshake_read_seq = 0;
 
   // read_epoch is the current DTLS read epoch.
-  // TODO(crbug.com/42290594): DTLS 1.3 will require that we also store the next
-  // epoch, and switch over on the first record from the new epoch.
   DTLSReadEpoch read_epoch;
 
+  // next_read_epoch is the next DTLS read epoch in DTLS 1.3. It will become
+  // current once a record is received from it.
+  UniquePtr<DTLSReadEpoch> next_read_epoch;
+
   // write_epoch is the current DTLS write epoch. Non-retransmit records will
   // generally use this epoch.
   // TODO(crbug.com/42290594): 0-RTT will be the exception, when implemented.
@@ -3856,7 +3858,7 @@
 void dtls1_free(SSL *ssl);
 
 bool dtls1_process_handshake_fragments(SSL *ssl, uint8_t *out_alert,
-                                       Span<uint8_t> record);
+                                       Span<const uint8_t> record);
 bool dtls1_get_message(const SSL *ssl, SSLMessage *out);
 ssl_open_record_t dtls1_open_handshake(SSL *ssl, size_t *out_consumed,
                                        uint8_t *out_alert, Span<uint8_t> in);
diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc
index 8110725..4ece96d 100644
--- a/ssl/s3_both.cc
+++ b/ssl/s3_both.cc
@@ -252,7 +252,7 @@
                     pending_hs_data->length);
   if (ssl->quic_method) {
     if ((ssl->s3->hs == nullptr || !ssl->s3->hs->hints_requested) &&
-        !ssl->quic_method->add_handshake_data(ssl, ssl->s3->write_level,
+        !ssl->quic_method->add_handshake_data(ssl, ssl->s3->quic_write_level,
                                               data.data(), data.size())) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
       return false;
diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc
index a34c5a6..83ec13d 100644
--- a/ssl/s3_pkt.cc
+++ b/ssl/s3_pkt.cc
@@ -456,7 +456,7 @@
 
 int tls_dispatch_alert(SSL *ssl) {
   if (ssl->quic_method) {
-    if (!ssl->quic_method->send_alert(ssl, ssl->s3->write_level,
+    if (!ssl->quic_method->send_alert(ssl, ssl->s3->quic_write_level,
                                       ssl->s3->send_alert[1])) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
       return 0;
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index e74f886..f52030b 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -820,11 +820,13 @@
 }
 
 enum ssl_encryption_level_t SSL_quic_read_level(const SSL *ssl) {
-  return ssl->s3->read_level;
+  assert(ssl->quic_method != nullptr);
+  return ssl->s3->quic_read_level;
 }
 
 enum ssl_encryption_level_t SSL_quic_write_level(const SSL *ssl) {
-  return ssl->s3->write_level;
+  assert(ssl->quic_method != nullptr);
+  return ssl->s3->quic_write_level;
 }
 
 int SSL_provide_quic_data(SSL *ssl, enum ssl_encryption_level_t level,
@@ -834,7 +836,7 @@
     return 0;
   }
 
-  if (level != ssl->s3->read_level) {
+  if (level != ssl->s3->quic_read_level) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_ENCRYPTION_LEVEL_RECEIVED);
     return 0;
   }
@@ -2952,21 +2954,17 @@
 
 uint64_t SSL_get_read_sequence(const SSL *ssl) {
   if (SSL_is_dtls(ssl)) {
-    // TODO(crbug.com/42290608): This API needs to reworked. Right at an epoch
-    // transition, it is possible that |read_epoch| has not received any
-    // records. We will then return that sequence 0 is the highest received, but
-    // this is not quite right.
+    // TODO(crbug.com/42290608): This API needs to reworked.
     //
-    // This is mostly moot in DTLS 1.2 because, after the handshake, we will
-    // never be in this state. In DTLS 1.3, there is a key transition
-    // immediately after the handshake, and in the steady state with KeyUpdate.
-    // While not yet implemented, DTLS 1.3 will handle key changes by having two
-    // epochs live at once (current and optional next), only cycling forward
-    // when we receive a record at the new epoch.
+    // In DTLS 1.2, right at an epoch transition, |read_epoch| may not have
+    // received any records. We will then return that sequence 0 is the highest
+    // received, but it's really -1, which is not representable. This is mostly
+    // moot because, after the handshake, we will never be in the state.
     //
-    // When we implement this, this sequence 0 edge case will be gone, but
-    // replaced with a different issue: our record layer APIs have no way to
-    // report transition state. We'll likely need a new API for DTLS offload.
+    // In DTLS 1.3, epochs do not transition until the first record comes in.
+    // This avoids the DTLS 1.2 problem but introduces a different problem:
+    // during a KeyUpdate (which may occur in the steady state), both epochs are
+    // live. We'll likely need a new API for DTLS offload.
     const DTLSReadEpoch *read_epoch = &ssl->d1->read_epoch;
     return DTLSRecordNumber(read_epoch->epoch, read_epoch->bitmap.max_seq_num())
         .combined();
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index e31c532..a4fedd8 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -2810,16 +2810,13 @@
   if (is_dtls()) {
     if (version() == DTLS1_3_EXPERIMENTAL_VERSION) {
       // Both client and server must be at epoch 3 (application data).
-      EXPECT_EQ(EpochFromSequence(client_read_seq), 3);
       EXPECT_EQ(EpochFromSequence(client_write_seq), 3);
-      EXPECT_EQ(EpochFromSequence(server_read_seq), 3);
       EXPECT_EQ(EpochFromSequence(server_write_seq), 3);
 
-      // TODO(crbug.com/42290608): The next record to be written should exceed
-      // the largest received, but they'll actually be equal because the
-      // |SSL_get_read_sequence| API cannot represent DTLS key transitions.
-      EXPECT_GE(client_write_seq, server_read_seq);
-      EXPECT_GE(server_write_seq, client_read_seq);
+      // TODO(crbug.com/42290608): Ideally we would check the read sequence
+      // numbers and compare them against each other, but
+      // |SSL_get_read_sequence| is ill-defined right after DTLS 1.3's key
+      // change. See that function for details.
     } else {
       // Both client and server must be at epoch 1.
       EXPECT_EQ(EpochFromSequence(client_read_seq), 1);
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index b5bf783..64a02e5 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -1285,6 +1285,14 @@
 	// immediately after ChangeCipherSpec.
 	AlertAfterChangeCipherSpec alert
 
+	// AppDataBeforeTLS13KeyChange, if not nil, causes application data to
+	// be sent immediately before the final key change in (D)TLS 1.3.
+	AppDataBeforeTLS13KeyChange []byte
+
+	// UnencryptedEncryptedExtensions, if true, causes the server to send
+	// EncryptedExtensions unencrypted, delaying the first key change.
+	UnencryptedEncryptedExtensions bool
+
 	// TimeoutSchedule is the schedule of packet drops and simulated
 	// timeouts for before each handshake leg from the peer.
 	TimeoutSchedule []time.Duration
diff --git a/ssl/test/runner/fuzzer_mode.json b/ssl/test/runner/fuzzer_mode.json
index 6967ae7..f895546 100644
--- a/ssl/test/runner/fuzzer_mode.json
+++ b/ssl/test/runner/fuzzer_mode.json
@@ -50,6 +50,9 @@
     "*EarlyDataRejected*": "Trial decryption does not work with the NULL cipher.",
     "ALPS-EarlyData-Mismatch-*": "Trial decryption does not work with the NULL cipher.",
 
+    "UnencryptedEncryptedExtensions": "The NULL cipher will not notice that the peer didn't change keys.",
+    "AppDataBeforeTLS13KeyChange*": "The NULL cipher will not notice that the peer didn't change keys.",
+
     "Renegotiate-Client-BadExt*": "Fuzzer mode does not check renegotiation_info.",
 
     "CBCRecordSplitting*": "Fuzzer mode does not implement record-splitting.",
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index b31d636..5281c7a 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -1507,6 +1507,10 @@
 	}
 	c.flushHandshake()
 
+	if data := c.config.Bugs.AppDataBeforeTLS13KeyChange; data != nil {
+		c.writeRecord(recordTypeApplicationData, data)
+	}
+
 	// Switch to application data keys.
 	c.useOutTrafficSecret(uint16(encryptionApplication), c.wireVersion, hs.suite, clientTrafficSecret)
 	c.resumptionSecret = hs.finishedHash.deriveSecret(resumptionLabel)
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 2b910aa..1eaef77 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -1061,6 +1061,10 @@
 		c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
 	}
 
+	if config.Bugs.UnencryptedEncryptedExtensions {
+		c.writeRecord(recordTypeHandshake, encryptedExtensions.marshal())
+	}
+
 	// Switch to handshake traffic keys.
 	serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel)
 	c.useOutTrafficSecret(uint16(encryptionHandshake), c.wireVersion, hs.suite, serverHandshakeTrafficSecret)
@@ -1072,7 +1076,7 @@
 	if config.Bugs.PartialEncryptedExtensionsWithServerHello {
 		// The first byte has already been sent.
 		c.writeRecord(recordTypeHandshake, encryptedExtensions.marshal()[1:])
-	} else {
+	} else if !config.Bugs.UnencryptedEncryptedExtensions {
 		c.writeRecord(recordTypeHandshake, encryptedExtensions.marshal())
 	}
 
@@ -1243,6 +1247,10 @@
 	serverTrafficSecret := hs.finishedHash.deriveSecret(serverApplicationTrafficLabel)
 	c.exporterSecret = hs.finishedHash.deriveSecret(exporterLabel)
 
+	if data := c.config.Bugs.AppDataBeforeTLS13KeyChange; data != nil {
+		c.writeRecord(recordTypeApplicationData, data)
+	}
+
 	// Switch to application data keys on write. In particular, any alerts
 	// from the client certificate are sent over these keys.
 	c.useOutTrafficSecret(uint16(encryptionApplication), c.wireVersion, hs.suite, serverTrafficSecret)
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 8c745e3..41a75b1 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -2507,6 +2507,99 @@
 			expectedError: ":UNEXPECTED_RECORD:",
 		},
 		{
+			name: "AppDataBeforeTLS13KeyChange",
+			config: Config{
+				MinVersion: VersionTLS13,
+				MaxVersion: VersionTLS13,
+				Bugs: ProtocolBugs{
+					AppDataBeforeTLS13KeyChange: []byte("TEST MESSAGE"),
+				},
+			},
+			// The shim should fail to decrypt this record.
+			shouldFail:         true,
+			expectedError:      ":BAD_DECRYPT:",
+			expectedLocalError: "remote error: bad record MAC",
+		},
+		{
+			name: "AppDataBeforeTLS13KeyChange-Empty",
+			config: Config{
+				MinVersion: VersionTLS13,
+				MaxVersion: VersionTLS13,
+				Bugs: ProtocolBugs{
+					AppDataBeforeTLS13KeyChange: []byte{},
+				},
+			},
+			// The shim should fail to decrypt this record.
+			shouldFail:         true,
+			expectedError:      ":BAD_DECRYPT:",
+			expectedLocalError: "remote error: bad record MAC",
+		},
+		{
+			protocol: dtls,
+			name:     "AppDataBeforeTLS13KeyChange-DTLS",
+			config: Config{
+				MinVersion: VersionTLS13,
+				MaxVersion: VersionTLS13,
+				Bugs: ProtocolBugs{
+					AppDataBeforeTLS13KeyChange: []byte("TEST MESSAGE"),
+				},
+			},
+			// The shim will decrypt the record, because it has not
+			// yet applied the key change, but it should know to
+			// reject the record.
+			shouldFail:         true,
+			expectedError:      ":UNEXPECTED_RECORD:",
+			expectedLocalError: "remote error: unexpected message",
+		},
+		{
+			protocol: dtls,
+			name:     "AppDataBeforeTLS13KeyChange-DTLS-Empty",
+			config: Config{
+				MinVersion: VersionTLS13,
+				MaxVersion: VersionTLS13,
+				Bugs: ProtocolBugs{
+					AppDataBeforeTLS13KeyChange: []byte{},
+				},
+			},
+			// The shim will decrypt the record, because it has not
+			// yet applied the key change, but it should know to
+			// reject the record.
+			shouldFail:         true,
+			expectedError:      ":UNEXPECTED_RECORD:",
+			expectedLocalError: "remote error: unexpected message",
+		},
+		{
+			name: "UnencryptedEncryptedExtensions",
+			config: Config{
+				MinVersion: VersionTLS13,
+				MaxVersion: VersionTLS13,
+				Bugs: ProtocolBugs{
+					UnencryptedEncryptedExtensions: true,
+				},
+			},
+			// The shim should fail to decrypt this record.
+			shouldFail:         true,
+			expectedError:      ":DECRYPTION_FAILED_OR_BAD_RECORD_MAC:",
+			expectedLocalError: "remote error: bad record MAC",
+		},
+		{
+			protocol: dtls,
+			name:     "UnencryptedEncryptedExtensions-DTLS",
+			config: Config{
+				MinVersion: VersionTLS13,
+				MaxVersion: VersionTLS13,
+				Bugs: ProtocolBugs{
+					UnencryptedEncryptedExtensions: true,
+				},
+			},
+			// The shim will decrypt the record, because it has not
+			// yet applied the key change, but it should know to
+			// reject new handshake data on the previous epoch.
+			shouldFail:         true,
+			expectedError:      ":EXCESS_HANDSHAKE_DATA:",
+			expectedLocalError: "remote error: unexpected message",
+		},
+		{
 			name: "AppDataAfterChangeCipherSpec",
 			config: Config{
 				MaxVersion: VersionTLS12,
@@ -3314,13 +3407,9 @@
 					SendExtraFinished: true,
 				},
 			},
-			// TODO(crbug.com/42290594): When not reordered or packed, the extra
-			// Finished in epoch 2 does not arrive until after we've switched to
-			// epoch 3, so the record is simply dropped right now. When we defer
-			// epoch changes to the first record, this will change and we'll
-			// notice this, if no epoch 3 records arrive in the meantime. In
-			// general, a DTLS implementation may or may not notice invalid
-			// messages across key changes.
+			shouldFail:         true,
+			expectedError:      ":EXCESS_HANDSHAKE_DATA:",
+			expectedLocalError: "remote error: unexpected message",
 		},
 		{
 			protocol: dtls,
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index 7082d16..4972fbe 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -97,9 +97,10 @@
         return false;
       }
     }
+  } else {
+    assert(ssl->s3->quic_write_level == level);
   }
 
-  assert(ssl->s3->write_level == level);
   return true;
 }
 
@@ -347,9 +348,6 @@
 }
 
 static enum ssl_hs_wait_t do_send_second_client_hello(SSL_HANDSHAKE *hs) {
-  // Any 0-RTT keys must have been discarded.
-  assert(hs->ssl->s3->write_level == ssl_encryption_initial);
-
   // Build the second ClientHelloInner, if applicable. The second ClientHello
   // uses an empty string for |enc|.
   if (hs->ssl->s3->ech_status == ssl_ech_accepted &&
diff --git a/ssl/tls_method.cc b/ssl/tls_method.cc
index 0a3fc4e..2de7ad1 100644
--- a/ssl/tls_method.cc
+++ b/ssl/tls_method.cc
@@ -106,11 +106,11 @@
     if (level == ssl_encryption_early_data) {
       return true;
     }
+    ssl->s3->quic_read_level = level;
   }
 
   ssl->s3->read_sequence = 0;
   ssl->s3->aead_read_ctx = std::move(aead_ctx);
-  ssl->s3->read_level = level;
   return true;
 }
 
@@ -135,11 +135,11 @@
     if (level == ssl_encryption_early_data) {
       return true;
     }
+    ssl->s3->quic_write_level = level;
   }
 
   ssl->s3->write_sequence = 0;
   ssl->s3->aead_write_ctx = std::move(aead_ctx);
-  ssl->s3->write_level = level;
   return true;
 }