Store state for each DTLS epoch.

The SSLAEADContext class contains most of the state needed for each
epoch, but instead of trying to shoehorn more state into it, this change
creates a new struct to contain the needed state, including the next
write sequence number for the epoch.

Change-Id: I5c259275fc90920a5c1a4dec87ab83a80f62b47a
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/71347
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
Auto-Submit: Nick Harper <nharper@chromium.org>
diff --git a/ssl/d1_lib.cc b/ssl/d1_lib.cc
index f6a4de8..49aa50a 100644
--- a/ssl/d1_lib.cc
+++ b/ssl/d1_lib.cc
@@ -95,8 +95,14 @@
     return false;
   }
 
-  d1->initial_aead_write_ctx = SSLAEADContext::CreateNullCipher(true);
-  if (!d1->initial_aead_write_ctx) {
+  d1->initial_epoch_state = MakeUnique<DTLSEpochState>();
+  if (!d1->initial_epoch_state) {
+    tls_free(ssl);
+    return false;
+  }
+  d1->initial_epoch_state->aead_write_ctx =
+      SSLAEADContext::CreateNullCipher(true);
+  if (!d1->initial_epoch_state->aead_write_ctx) {
     tls_free(ssl);
     return false;
   }
diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc
index 6108f5c..501dd97 100644
--- a/ssl/dtls_method.cc
+++ b/ssl/dtls_method.cc
@@ -112,13 +112,13 @@
                                   Span<const uint8_t> secret_for_quic) {
   assert(secret_for_quic.empty());  // QUIC does not use DTLS.
   ssl->d1->w_epoch++;
-  ssl->d1->last_write_sequence = ssl->s3->write_sequence;
   ssl->s3->write_sequence = 0;
 
   if (ssl_protocol_version(ssl) > TLS1_2_VERSION) {
     ssl->d1->w_epoch = level;
   }
-  ssl->d1->last_aead_write_ctx = std::move(ssl->s3->aead_write_ctx);
+  ssl->d1->last_epoch_state.aead_write_ctx = std::move(ssl->s3->aead_write_ctx);
+  ssl->d1->last_epoch_state.write_sequence = ssl->s3->write_sequence;
   ssl->s3->aead_write_ctx = std::move(aead_ctx);
   ssl->s3->write_level = level;
   return true;
diff --git a/ssl/dtls_record.cc b/ssl/dtls_record.cc
index a83d6b1..c07636e 100644
--- a/ssl/dtls_record.cc
+++ b/ssl/dtls_record.cc
@@ -417,12 +417,12 @@
 
 static SSLAEADContext *get_write_aead(const SSL *ssl, uint16_t epoch) {
   if (epoch == 0) {
-    return ssl->d1->initial_aead_write_ctx.get();
+    return ssl->d1->initial_epoch_state->aead_write_ctx.get();
   }
 
   if (epoch < ssl->d1->w_epoch) {
     BSSL_CHECK(epoch + 1 == ssl->d1->w_epoch);
-    return ssl->d1->last_aead_write_ctx.get();
+    return ssl->d1->last_epoch_state.aead_write_ctx.get();
   }
 
   BSSL_CHECK(epoch == ssl->d1->w_epoch);
@@ -477,11 +477,11 @@
   // Determine the parameters for the current epoch.
   SSLAEADContext *aead = get_write_aead(ssl, epoch);
   uint64_t *seq = &ssl->s3->write_sequence;
-  if (epoch < ssl->d1->w_epoch) {
-    seq = &ssl->d1->last_write_sequence;
+  if (epoch == 0) {
+    seq = &ssl->d1->initial_epoch_state->write_sequence;
+  } else if (epoch < ssl->d1->w_epoch) {
+    seq = &ssl->d1->last_epoch_state.write_sequence;
   }
-  // TODO(crbug.com/boringssl/715): If epoch is initial or handshake, the value
-  // of seq is probably wrong for a retransmission.
 
   const size_t record_header_len = dtls_record_header_write_len(ssl, epoch);
 
diff --git a/ssl/internal.h b/ssl/internal.h
index e651828..67ed227 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -3072,6 +3072,14 @@
   uint32_t tv_usec;
 };
 
+// A DTLSEpochState object contains state about a DTLS epoch.
+struct DTLSEpochState {
+  static constexpr bool kAllowUniquePtr = true;
+
+  UniquePtr<SSLAEADContext> aead_write_ctx;
+  uint64_t write_sequence;
+};
+
 struct DTLS1_STATE {
   static constexpr bool kAllowUniquePtr = true;
 
@@ -3103,14 +3111,12 @@
   uint16_t handshake_write_seq = 0;
   uint16_t handshake_read_seq = 0;
 
-  // save last sequence number for retransmissions
-  uint64_t last_write_sequence = 0;
-  UniquePtr<SSLAEADContext> last_aead_write_ctx;
-
+  // state from the last epoch
+  DTLSEpochState last_epoch_state;
 
   // In DTLS 1.3, this contains the write AEAD for the initial encryption level.
   // TODO(crbug.com/boringssl/715): Drop this when it is no longer needed.
-  UniquePtr<SSLAEADContext> initial_aead_write_ctx;
+  UniquePtr<DTLSEpochState> initial_epoch_state;
 
   // incoming_messages is a ring buffer of incoming handshake messages that have
   // yet to be processed. The front of the ring buffer is message number