Use DTLSRecordNumber in DTLSWriteEpoch

This saves 8 bytes per write epoch and is a bit tidier. We could do
something similar with DTLSReadEpoch if we hid it in DTLSReplayBitmap
but that was a bit messy.

Bug: 42290594
Change-Id: I3de19ef90d59566c303bf3b2ff85e76bafc790d4
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72272
Reviewed-by: Nick Harper <nharper@chromium.org>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index 4bd30b6..1aac3a1 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -581,7 +581,7 @@
         // TODO(crbug.com/42290594): Epoch 1 (0-RTT) should be retained until
         // epoch 3 (app data) is available.
         for (const auto &msg : ssl->d1->outgoing_messages) {
-          if (msg.epoch == write_epoch->epoch) {
+          if (msg.epoch == write_epoch->epoch()) {
             return false;
           }
         }
@@ -640,7 +640,7 @@
 
   DTLS_OUTGOING_MESSAGE msg;
   msg.data = std::move(data);
-  msg.epoch = ssl->d1->write_epoch.epoch;
+  msg.epoch = ssl->d1->write_epoch.epoch();
   msg.is_ccs = is_ccs;
   if (!ssl->d1->outgoing_messages.TryPushBack(std::move(msg))) {
     assert(false);
diff --git a/ssl/d1_pkt.cc b/ssl/d1_pkt.cc
index 18cd5ee..a26fc70 100644
--- a/ssl/d1_pkt.cc
+++ b/ssl/d1_pkt.cc
@@ -239,7 +239,7 @@
 
   // TODO(crbug.com/42290594): Use the 0-RTT epoch if writing 0-RTT.
   int ret = dtls1_write_record(ssl, SSL3_RT_APPLICATION_DATA, in,
-                               ssl->d1->write_epoch.epoch);
+                               ssl->d1->write_epoch.epoch());
   if (ret <= 0) {
     return ret;
   }
@@ -282,7 +282,7 @@
 
 int dtls1_dispatch_alert(SSL *ssl) {
   int ret = dtls1_write_record(ssl, SSL3_RT_ALERT, ssl->s3->send_alert,
-                               ssl->d1->write_epoch.epoch);
+                               ssl->d1->write_epoch.epoch());
   if (ret <= 0) {
     return ret;
   }
diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc
index 27993da..9fc854e 100644
--- a/ssl/dtls_method.cc
+++ b/ssl/dtls_method.cc
@@ -118,14 +118,15 @@
   DTLSWriteEpoch new_epoch;
   if (ssl_protocol_version(ssl) > TLS1_2_VERSION) {
     // TODO(crbug.com/42290594): See above.
-    new_epoch.epoch = level;
+    new_epoch.next_record = DTLSRecordNumber(level, 0);
     new_epoch.rn_encrypter =
         RecordNumberEncrypter::Create(aead_ctx->cipher(), traffic_secret);
     if (new_epoch.rn_encrypter == nullptr) {
       return false;
     }
   } else {
-    new_epoch.epoch = ssl->d1->write_epoch.epoch + 1;
+    new_epoch.next_record =
+        DTLSRecordNumber(ssl->d1->write_epoch.epoch() + 1, 0);
   }
   new_epoch.aead = std::move(aead_ctx);
 
diff --git a/ssl/dtls_record.cc b/ssl/dtls_record.cc
index 5ab89bb..586b852 100644
--- a/ssl/dtls_record.cc
+++ b/ssl/dtls_record.cc
@@ -443,11 +443,11 @@
 }
 
 static DTLSWriteEpoch *get_write_epoch(const SSL *ssl, uint16_t epoch) {
-  if (ssl->d1->write_epoch.epoch == epoch) {
+  if (ssl->d1->write_epoch.epoch() == epoch) {
     return &ssl->d1->write_epoch;
   }
   for (const auto &e : ssl->d1->extra_write_epochs) {
-    if (e->epoch == epoch) {
+    if (e->epoch() == epoch) {
       return e.get();
     }
   }
@@ -510,7 +510,8 @@
   const size_t record_header_len = dtls_record_header_write_len(ssl, epoch);
 
   // Ensure the sequence number update does not overflow.
-  if (write_epoch->next_seq + 1 > DTLSRecordNumber::kMaxSequence) {
+  DTLSRecordNumber record_number = write_epoch->next_record;
+  if (!record_number.HasNext()) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
     return false;
   }
@@ -535,7 +536,6 @@
   }
 
   uint16_t record_version = dtls_record_version(ssl);
-  DTLSRecordNumber record_number(epoch, write_epoch->next_seq);
   if (dtls13_header) {
     // The first byte of the DTLS 1.3 record header has the following format:
     // 0 1 2 3 4 5 6 7
@@ -554,19 +554,15 @@
     // We always use a two-byte sequence number. A one-byte sequence number
     // would require coordinating with the application on ACK feedback to know
     // that the peer is not too far behind.
-    out[1] = write_epoch->next_seq >> 8;
-    out[2] = write_epoch->next_seq & 0xff;
+    CRYPTO_store_u16_be(out + 1, write_epoch->next_record.sequence());
     // TODO(crbug.com/42290594): When we know the record is last in the packet,
     // omit the length.
-    out[3] = ciphertext_len >> 8;
-    out[4] = ciphertext_len & 0xff;
+    CRYPTO_store_u16_be(out + 3, ciphertext_len);
   } else {
     out[0] = type;
-    out[1] = record_version >> 8;
-    out[2] = record_version & 0xff;
-    CRYPTO_store_u64_be(&out[3], record_number.combined());
-    out[11] = ciphertext_len >> 8;
-    out[12] = ciphertext_len & 0xff;
+    CRYPTO_store_u16_be(out + 1, record_version);
+    CRYPTO_store_u64_be(out + 3, record_number.combined());
+    CRYPTO_store_u16_be(out + 11, ciphertext_len);
   }
   Span<const uint8_t> header = MakeConstSpan(out, record_header_len);
 
@@ -595,7 +591,7 @@
   }
 
   *out_number = record_number;
-  write_epoch->next_seq++;
+  write_epoch->next_record = record_number.Next();
   *out_len = record_header_len + ciphertext_len;
   ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_HEADER, header);
   return true;
diff --git a/ssl/internal.h b/ssl/internal.h
index d9b6519..10cea64 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1231,7 +1231,6 @@
 
 // Record layer.
 
-// TODO(davidben): Use this type more extensively in the epoch state.
 class DTLSRecordNumber {
  public:
   static constexpr uint64_t kMaxSequence = (uint64_t{1} << 48) - 1;
@@ -1250,6 +1249,13 @@
   uint16_t epoch() const { return combined_ >> 48; }
   uint64_t sequence() const { return combined_ & kMaxSequence; }
 
+  bool HasNext() const { return sequence() < kMaxSequence; }
+  DTLSRecordNumber Next() const {
+    BSSL_CHECK(HasNext());
+    // This will not overflow into the epoch.
+    return DTLSRecordNumber::FromCombined(combined_ + 1);
+  }
+
  private:
   explicit DTLSRecordNumber(uint64_t combined) : combined_(combined) {}
 
@@ -1275,6 +1281,8 @@
 struct DTLSReadEpoch {
   static constexpr bool kAllowUniquePtr = true;
 
+  // TODO(davidben): This could be made slightly more compact if |bitmap| stored
+  // a DTLSRecordNumber.
   uint16_t epoch = 0;
   UniquePtr<SSLAEADContext> aead;
   UniquePtr<RecordNumberEncrypter> rn_encrypter;
@@ -1284,10 +1292,11 @@
 struct DTLSWriteEpoch {
   static constexpr bool kAllowUniquePtr = true;
 
-  uint16_t epoch = 0;
+  uint16_t epoch() const { return next_record.epoch(); }
+
+  DTLSRecordNumber next_record;
   UniquePtr<SSLAEADContext> aead;
   UniquePtr<RecordNumberEncrypter> rn_encrypter;
-  uint64_t next_seq = 0;
 };
 
 // ssl_record_prefix_len returns the length of the prefix before the ciphertext
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index a5023c8..e74f886 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -2976,9 +2976,7 @@
 
 uint64_t SSL_get_write_sequence(const SSL *ssl) {
   if (SSL_is_dtls(ssl)) {
-    const DTLSWriteEpoch *write_epoch = &ssl->d1->write_epoch;
-    return DTLSRecordNumber(write_epoch->epoch, write_epoch->next_seq)
-        .combined();
+    return ssl->d1->write_epoch.next_record.combined();
   }
 
   return ssl->s3->write_sequence;
diff --git a/ssl/tls_record.cc b/ssl/tls_record.cc
index d982683..bdd6519 100644
--- a/ssl/tls_record.cc
+++ b/ssl/tls_record.cc
@@ -594,7 +594,7 @@
 size_t SSL_max_seal_overhead(const SSL *ssl) {
   if (SSL_is_dtls(ssl)) {
     // TODO(crbug.com/42290594): Use the 0-RTT epoch if writing 0-RTT.
-    return dtls_max_seal_overhead(ssl, ssl->d1->write_epoch.epoch);
+    return dtls_max_seal_overhead(ssl, ssl->d1->write_epoch.epoch());
   }
 
   size_t ret = SSL3_RT_HEADER_LENGTH;