Maintain the sequence number as a uint64_t.

We spend a lot of effort implementing a big-endian sequence number
update, etc., when the sequence number is just a 64-bit counter. (Or
48-bit counter in DTLS because we currently retain the epoch
separately. We can probably tidy that a bit too, but I'll leave that
for later. Right now the DTLS record layer state is a bit entwined
with the TLS one.)

Just store it as uint64_t. This should also simplify
https://boringssl-review.googlesource.com/c/boringssl/+/54325 a little.

Change-Id: I95233f924a660bc523b21496fdc9211055b75073
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/54505
Reviewed-by: Bob Beck <bbe@google.com>
Reviewed-by: Adam Langley <agl@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc
index a28dcdc..677418b 100644
--- a/ssl/dtls_method.cc
+++ b/ssl/dtls_method.cc
@@ -90,7 +90,7 @@
 
   ssl->d1->r_epoch++;
   OPENSSL_memset(&ssl->d1->bitmap, 0, sizeof(ssl->d1->bitmap));
-  OPENSSL_memset(ssl->s3->read_sequence, 0, sizeof(ssl->s3->read_sequence));
+  ssl->s3->read_sequence = 0;
 
   ssl->s3->aead_read_ctx = std::move(aead_ctx);
   ssl->s3->read_level = level;
@@ -103,9 +103,8 @@
                                   Span<const uint8_t> secret_for_quic) {
   assert(secret_for_quic.empty());  // QUIC does not use DTLS.
   ssl->d1->w_epoch++;
-  OPENSSL_memcpy(ssl->d1->last_write_sequence, ssl->s3->write_sequence,
-                 sizeof(ssl->s3->write_sequence));
-  OPENSSL_memset(ssl->s3->write_sequence, 0, sizeof(ssl->s3->write_sequence));
+  ssl->d1->last_write_sequence = ssl->s3->write_sequence;
+  ssl->s3->write_sequence = 0;
 
   ssl->d1->last_aead_write_ctx = std::move(ssl->s3->aead_write_ctx);
   ssl->s3->aead_write_ctx = std::move(aead_ctx);
diff --git a/ssl/dtls_record.cc b/ssl/dtls_record.cc
index 992fb52..eb3df69 100644
--- a/ssl/dtls_record.cc
+++ b/ssl/dtls_record.cc
@@ -123,52 +123,37 @@
 
 BSSL_NAMESPACE_BEGIN
 
-// to_u64_be treats |in| as a 8-byte big-endian integer and returns the value as
-// a |uint64_t|.
-static uint64_t to_u64_be(const uint8_t in[8]) {
-  uint64_t ret = 0;
-  unsigned i;
-  for (i = 0; i < 8; i++) {
-    ret <<= 8;
-    ret |= in[i];
-  }
-  return ret;
-}
-
 // dtls1_bitmap_should_discard returns one if |seq_num| has been seen in
 // |bitmap| or is stale. Otherwise it returns zero.
 static bool dtls1_bitmap_should_discard(DTLS1_BITMAP *bitmap,
-                                        const uint8_t seq_num[8]) {
+                                        uint64_t seq_num) {
   const unsigned kWindowSize = sizeof(bitmap->map) * 8;
 
-  uint64_t seq_num_u = to_u64_be(seq_num);
-  if (seq_num_u > bitmap->max_seq_num) {
+  if (seq_num > bitmap->max_seq_num) {
     return false;
   }
-  uint64_t idx = bitmap->max_seq_num - seq_num_u;
+  uint64_t idx = bitmap->max_seq_num - seq_num;
   return idx >= kWindowSize || (bitmap->map & (((uint64_t)1) << idx));
 }
 
 // dtls1_bitmap_record updates |bitmap| to record receipt of sequence number
 // |seq_num|. It slides the window forward if needed. It is an error to call
 // this function on a stale sequence number.
-static void dtls1_bitmap_record(DTLS1_BITMAP *bitmap,
-                                const uint8_t seq_num[8]) {
+static void dtls1_bitmap_record(DTLS1_BITMAP *bitmap, uint64_t seq_num) {
   const unsigned kWindowSize = sizeof(bitmap->map) * 8;
 
-  uint64_t seq_num_u = to_u64_be(seq_num);
   // Shift the window if necessary.
-  if (seq_num_u > bitmap->max_seq_num) {
-    uint64_t shift = seq_num_u - bitmap->max_seq_num;
+  if (seq_num > bitmap->max_seq_num) {
+    uint64_t shift = seq_num - bitmap->max_seq_num;
     if (shift >= kWindowSize) {
       bitmap->map = 0;
     } else {
       bitmap->map <<= shift;
     }
-    bitmap->max_seq_num = seq_num_u;
+    bitmap->max_seq_num = seq_num;
   }
 
-  uint64_t idx = bitmap->max_seq_num - seq_num_u;
+  uint64_t idx = bitmap->max_seq_num - seq_num;
   if (idx < kWindowSize) {
     bitmap->map |= ((uint64_t)1) << idx;
   }
@@ -192,11 +177,11 @@
   // Decode the record.
   uint8_t type;
   uint16_t version;
-  uint8_t sequence[8];
+  uint8_t sequence_bytes[8];
   CBS body;
   if (!CBS_get_u8(&cbs, &type) ||
       !CBS_get_u16(&cbs, &version) ||
-      !CBS_copy_bytes(&cbs, sequence, 8) ||
+      !CBS_copy_bytes(&cbs, sequence_bytes, sizeof(sequence_bytes)) ||
       !CBS_get_u16_length_prefixed(&cbs, &body) ||
       CBS_len(&body) > SSL3_RT_MAX_ENCRYPTED_LENGTH) {
     // The record header was incomplete or malformed. Drop the entire packet.
@@ -222,7 +207,8 @@
   Span<const uint8_t> header = in.subspan(0, DTLS1_RT_HEADER_LENGTH);
   ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HEADER, header);
 
-  uint16_t epoch = (((uint16_t)sequence[0]) << 8) | sequence[1];
+  uint64_t sequence = CRYPTO_load_u64_be(sequence_bytes);
+  uint16_t epoch = static_cast<uint16_t>(sequence >> 48);
   if (epoch != ssl->d1->r_epoch ||
       dtls1_bitmap_should_discard(&ssl->d1->bitmap, sequence)) {
     // Drop this record. It's from the wrong epoch or is a replay. Note that if
@@ -304,12 +290,12 @@
   // Determine the parameters for the current epoch.
   uint16_t epoch = ssl->d1->w_epoch;
   SSLAEADContext *aead = ssl->s3->aead_write_ctx.get();
-  uint8_t *seq = ssl->s3->write_sequence;
+  uint64_t *seq = &ssl->s3->write_sequence;
   if (use_epoch == dtls1_use_previous_epoch) {
     assert(ssl->d1->w_epoch >= 1);
     epoch = ssl->d1->w_epoch - 1;
     aead = ssl->d1->last_aead_write_ctx.get();
-    seq = ssl->d1->last_write_sequence;
+    seq = &ssl->d1->last_write_sequence;
   }
 
   if (max_out < DTLS1_RT_HEADER_LENGTH) {
@@ -323,9 +309,15 @@
   out[1] = record_version >> 8;
   out[2] = record_version & 0xff;
 
-  out[3] = epoch >> 8;
-  out[4] = epoch & 0xff;
-  OPENSSL_memcpy(&out[5], &seq[2], 6);
+  // Ensure the sequence number update does not overflow.
+  const uint64_t kMaxSequenceNumber = (uint64_t{1} << 48) - 1;
+  if (*seq + 1 > kMaxSequenceNumber) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
+    return false;
+  }
+
+  uint64_t seq_with_epoch = (uint64_t{epoch} << 48) | *seq;
+  CRYPTO_store_u64_be(&out[3], seq_with_epoch);
 
   size_t ciphertext_len;
   if (!aead->CiphertextLen(&ciphertext_len, in_len, 0)) {
@@ -339,12 +331,12 @@
   size_t len_copy;
   if (!aead->Seal(out + DTLS1_RT_HEADER_LENGTH, &len_copy,
                   max_out - DTLS1_RT_HEADER_LENGTH, type, record_version,
-                  &out[3] /* seq */, header, in, in_len) ||
-      !ssl_record_sequence_update(&seq[2], 6)) {
+                  seq_with_epoch, header, in, in_len)) {
     return false;
   }
   assert(ciphertext_len == len_copy);
 
+  (*seq)++;
   *out_len = DTLS1_RT_HEADER_LENGTH + ciphertext_len;
   ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_HEADER, header);
   return true;
diff --git a/ssl/handoff.cc b/ssl/handoff.cc
index 736ab47..dd24811 100644
--- a/ssl/handoff.cc
+++ b/ssl/handoff.cc
@@ -17,6 +17,7 @@
 #include <openssl/bytestring.h>
 #include <openssl/err.h>
 
+#include "../crypto/internal.h"
 #include "internal.h"
 
 
@@ -338,14 +339,16 @@
   } else {
     session = s3->session_reused ? ssl->session.get() : hs->new_session.get();
   }
+  uint8_t read_sequence[8], write_sequence[8];
+  CRYPTO_store_u64_be(read_sequence, s3->read_sequence);
+  CRYPTO_store_u64_be(write_sequence, s3->write_sequence);
   static const uint8_t kUnusedChannelID[64] = {0};
   if (!CBB_add_asn1(out, &seq, CBS_ASN1_SEQUENCE) ||
       !CBB_add_asn1_uint64(&seq, kHandbackVersion) ||
       !CBB_add_asn1_uint64(&seq, type) ||
-      !CBB_add_asn1_octet_string(&seq, s3->read_sequence,
-                                 sizeof(s3->read_sequence)) ||
-      !CBB_add_asn1_octet_string(&seq, s3->write_sequence,
-                                 sizeof(s3->write_sequence)) ||
+      !CBB_add_asn1_octet_string(&seq, read_sequence, sizeof(read_sequence)) ||
+      !CBB_add_asn1_octet_string(&seq, write_sequence,
+                                 sizeof(write_sequence)) ||
       !CBB_add_asn1_octet_string(&seq, s3->server_random,
                                  sizeof(s3->server_random)) ||
       !CBB_add_asn1_octet_string(&seq, s3->client_random,
@@ -366,7 +369,7 @@
                                  sizeof(kUnusedChannelID)) ||
       // These two fields were historically |token_binding_negotiated| and
       // |negotiated_token_binding_param|.
-      !CBB_add_asn1_bool(&seq, 0) ||
+      !CBB_add_asn1_bool(&seq, 0) ||  //
       !CBB_add_asn1_uint64(&seq, 0) ||
       !CBB_add_asn1_bool(&seq, s3->hs->next_proto_neg_seen) ||
       !CBB_add_asn1_bool(&seq, s3->hs->cert_request) ||
@@ -694,11 +697,13 @@
       }
       break;
   }
-  if (!CopyExact({s3->read_sequence, sizeof(s3->read_sequence)}, &read_seq) ||
-      !CopyExact({s3->write_sequence, sizeof(s3->write_sequence)},
-                 &write_seq)) {
+  uint8_t read_sequence[8], write_sequence[8];
+  if (!CopyExact(read_sequence, &read_seq) ||
+      !CopyExact(write_sequence, &write_seq)) {
     return false;
   }
+  s3->read_sequence = CRYPTO_load_u64_be(read_sequence);
+  s3->write_sequence = CRYPTO_load_u64_be(write_sequence);
   if (type == handback_after_ecdhe &&
       (hs->key_shares[0] = SSLKeyShare::Create(&key_share)) == nullptr) {
     return false;
diff --git a/ssl/internal.h b/ssl/internal.h
index 71ff0ff..8ef1509 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -829,15 +829,14 @@
   // to the plaintext in |in| and returns true.  Otherwise, it returns
   // false. The output will always be |ExplicitNonceLen| bytes ahead of |in|.
   bool Open(Span<uint8_t> *out, uint8_t type, uint16_t record_version,
-            const uint8_t seqnum[8], Span<const uint8_t> header,
-            Span<uint8_t> in);
+            uint64_t seqnum, Span<const uint8_t> header, Span<uint8_t> in);
 
   // Seal encrypts and authenticates |in_len| bytes from |in| and writes the
   // result to |out|. It returns true on success and false on error.
   //
   // If |in| and |out| alias then |out| + |ExplicitNonceLen| must be == |in|.
   bool Seal(uint8_t *out, size_t *out_len, size_t max_out, uint8_t type,
-            uint16_t record_version, const uint8_t seqnum[8],
+            uint16_t record_version, uint64_t seqnum,
             Span<const uint8_t> header, const uint8_t *in, size_t in_len);
 
   // SealScatter encrypts and authenticates |in_len| bytes from |in| and splits
@@ -856,10 +855,9 @@
   // If |in| and |out| alias then |out| must be == |in|. Other arguments may not
   // alias anything.
   bool SealScatter(uint8_t *out_prefix, uint8_t *out, uint8_t *out_suffix,
-                   uint8_t type, uint16_t record_version,
-                   const uint8_t seqnum[8], Span<const uint8_t> header,
-                   const uint8_t *in, size_t in_len, const uint8_t *extra_in,
-                   size_t extra_in_len);
+                   uint8_t type, uint16_t record_version, uint64_t seqnum,
+                   Span<const uint8_t> header, const uint8_t *in, size_t in_len,
+                   const uint8_t *extra_in, size_t extra_in_len);
 
   bool GetIV(const uint8_t **out_iv, size_t *out_iv_len) const;
 
@@ -868,8 +866,7 @@
   // necessary.
   Span<const uint8_t> GetAdditionalData(uint8_t storage[13], uint8_t type,
                                         uint16_t record_version,
-                                        const uint8_t seqnum[8],
-                                        size_t plaintext_len,
+                                        uint64_t seqnum, size_t plaintext_len,
                                         Span<const uint8_t> header);
 
   const SSL_CIPHER *cipher_;
@@ -916,10 +913,6 @@
 
 // Record layer.
 
-// ssl_record_sequence_update increments the sequence number in |seq|. It
-// returns true on success and false on wraparound.
-bool ssl_record_sequence_update(uint8_t *seq, size_t seq_len);
-
 // ssl_record_prefix_len returns the length of the prefix before the ciphertext
 // of a record for |ssl|.
 //
@@ -2644,8 +2637,8 @@
   SSL3_STATE();
   ~SSL3_STATE();
 
-  uint8_t read_sequence[8] = {0};
-  uint8_t write_sequence[8] = {0};
+  uint64_t read_sequence = 0;
+  uint64_t write_sequence = 0;
 
   uint8_t server_random[SSL3_RANDOM_SIZE] = {0};
   uint8_t client_random[SSL3_RANDOM_SIZE] = {0};
@@ -2935,7 +2928,7 @@
   uint16_t handshake_read_seq = 0;
 
   // save last sequence number for retransmissions
-  uint8_t last_write_sequence[8] = {0};
+  uint64_t last_write_sequence = 0;
   UniquePtr<SSLAEADContext> last_aead_write_ctx;
 
   // incoming_messages is a ring buffer of incoming handshake messages that have
diff --git a/ssl/ssl_aead_ctx.cc b/ssl/ssl_aead_ctx.cc
index 0bad266..27f0084 100644
--- a/ssl/ssl_aead_ctx.cc
+++ b/ssl/ssl_aead_ctx.cc
@@ -220,13 +220,13 @@
 }
 
 Span<const uint8_t> SSLAEADContext::GetAdditionalData(
-    uint8_t storage[13], uint8_t type, uint16_t record_version,
-    const uint8_t seqnum[8], size_t plaintext_len, Span<const uint8_t> header) {
+    uint8_t storage[13], uint8_t type, uint16_t record_version, uint64_t seqnum,
+    size_t plaintext_len, Span<const uint8_t> header) {
   if (ad_is_header_) {
     return header;
   }
 
-  OPENSSL_memcpy(storage, seqnum, 8);
+  CRYPTO_store_u64_be(storage, seqnum);
   size_t len = 8;
   storage[len++] = type;
   storage[len++] = static_cast<uint8_t>((record_version >> 8));
@@ -239,7 +239,7 @@
 }
 
 bool SSLAEADContext::Open(Span<uint8_t> *out, uint8_t type,
-                          uint16_t record_version, const uint8_t seqnum[8],
+                          uint16_t record_version, uint64_t seqnum,
                           Span<const uint8_t> header, Span<uint8_t> in) {
   if (is_null_cipher() || FUZZER_MODE) {
     // Handle the initial NULL cipher.
@@ -288,7 +288,7 @@
     in = in.subspan(variable_nonce_len_);
   } else {
     assert(variable_nonce_len_ == 8);
-    OPENSSL_memcpy(nonce + nonce_len, seqnum, variable_nonce_len_);
+    CRYPTO_store_u64_be(nonce + nonce_len, seqnum);
   }
   nonce_len += variable_nonce_len_;
 
@@ -313,8 +313,7 @@
 
 bool SSLAEADContext::SealScatter(uint8_t *out_prefix, uint8_t *out,
                                  uint8_t *out_suffix, uint8_t type,
-                                 uint16_t record_version,
-                                 const uint8_t seqnum[8],
+                                 uint16_t record_version, uint64_t seqnum,
                                  Span<const uint8_t> header, const uint8_t *in,
                                  size_t in_len, const uint8_t *extra_in,
                                  size_t extra_in_len) {
@@ -365,7 +364,7 @@
     // When sending we use the sequence number as the variable part of the
     // nonce.
     assert(variable_nonce_len_ == 8);
-    OPENSSL_memcpy(nonce + nonce_len, seqnum, variable_nonce_len_);
+    CRYPTO_store_u64_be(nonce + nonce_len, seqnum);
   }
   nonce_len += variable_nonce_len_;
 
@@ -398,7 +397,7 @@
 
 bool SSLAEADContext::Seal(uint8_t *out, size_t *out_len, size_t max_out_len,
                           uint8_t type, uint16_t record_version,
-                          const uint8_t seqnum[8], Span<const uint8_t> header,
+                          uint64_t seqnum, Span<const uint8_t> header,
                           const uint8_t *in, size_t in_len) {
   const size_t prefix_len = ExplicitNonceLen();
   size_t suffix_len;
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index 1b2e9f4..4d56d37 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -2823,20 +2823,19 @@
 }
 
 uint64_t SSL_get_read_sequence(const SSL *ssl) {
-  // TODO(davidben): Internally represent sequence numbers as uint64_t.
   if (SSL_is_dtls(ssl)) {
     // max_seq_num already includes the epoch.
     assert(ssl->d1->r_epoch == (ssl->d1->bitmap.max_seq_num >> 48));
     return ssl->d1->bitmap.max_seq_num;
   }
-  return CRYPTO_load_u64_be(ssl->s3->read_sequence);
+  return ssl->s3->read_sequence;
 }
 
 uint64_t SSL_get_write_sequence(const SSL *ssl) {
-  uint64_t ret = CRYPTO_load_u64_be(ssl->s3->write_sequence);
+  uint64_t ret = ssl->s3->write_sequence;
   if (SSL_is_dtls(ssl)) {
     assert((ret >> 48) == 0);
-    ret |= ((uint64_t)ssl->d1->w_epoch) << 48;
+    ret |= uint64_t{ssl->d1->w_epoch} << 48;
   }
   return ret;
 }
diff --git a/ssl/tls_method.cc b/ssl/tls_method.cc
index 326cbe7..5fcf684 100644
--- a/ssl/tls_method.cc
+++ b/ssl/tls_method.cc
@@ -108,7 +108,7 @@
     }
   }
 
-  OPENSSL_memset(ssl->s3->read_sequence, 0, sizeof(ssl->s3->read_sequence));
+  ssl->s3->read_sequence = 0;
   ssl->s3->aead_read_ctx = std::move(aead_ctx);
   ssl->s3->read_level = level;
   return true;
@@ -137,7 +137,7 @@
     }
   }
 
-  OPENSSL_memset(ssl->s3->write_sequence, 0, sizeof(ssl->s3->write_sequence));
+  ssl->s3->write_sequence = 0;
   ssl->s3->aead_write_ctx = std::move(aead_ctx);
   ssl->s3->write_level = level;
   return true;
diff --git a/ssl/tls_record.cc b/ssl/tls_record.cc
index acff1ad..88b8ef9 100644
--- a/ssl/tls_record.cc
+++ b/ssl/tls_record.cc
@@ -151,17 +151,6 @@
 #endif
 }
 
-bool ssl_record_sequence_update(uint8_t *seq, size_t seq_len) {
-  for (size_t i = seq_len - 1; i < seq_len; i--) {
-    ++seq[i];
-    if (seq[i] != 0) {
-      return true;
-    }
-  }
-  OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
-  return false;
-}
-
 size_t ssl_record_prefix_len(const SSL *ssl) {
   size_t header_len;
   if (SSL_is_dtls(ssl)) {
@@ -286,6 +275,13 @@
     return skip_early_data(ssl, out_alert, *out_consumed);
   }
 
+  // Ensure the sequence number update does not overflow.
+  if (ssl->s3->read_sequence + 1 == 0) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
+    *out_alert = SSL_AD_INTERNAL_ERROR;
+    return ssl_open_record_error;
+  }
+
   // Decrypt the body in-place.
   if (!ssl->s3->aead_read_ctx->Open(
           out, type, version, ssl->s3->read_sequence, header,
@@ -301,11 +297,7 @@
   }
 
   ssl->s3->skip_early_data = false;
-
-  if (!ssl_record_sequence_update(ssl->s3->read_sequence, 8)) {
-    *out_alert = SSL_AD_INTERNAL_ERROR;
-    return ssl_open_record_error;
-  }
+  ssl->s3->read_sequence++;
 
   // TLS 1.3 hides the record type inside the encrypted data.
   bool has_padding =
@@ -411,13 +403,19 @@
   out_prefix[4] = ciphertext_len & 0xff;
   Span<const uint8_t> header = MakeSpan(out_prefix, SSL3_RT_HEADER_LENGTH);
 
-  if (!aead->SealScatter(out_prefix + SSL3_RT_HEADER_LENGTH, out, out_suffix,
-                         out_prefix[0], record_version, ssl->s3->write_sequence,
-                         header, in, in_len, extra_in, extra_in_len) ||
-      !ssl_record_sequence_update(ssl->s3->write_sequence, 8)) {
+  // Ensure the sequence number update does not overflow.
+  if (ssl->s3->write_sequence + 1 == 0) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
     return false;
   }
 
+  if (!aead->SealScatter(out_prefix + SSL3_RT_HEADER_LENGTH, out, out_suffix,
+                         out_prefix[0], record_version, ssl->s3->write_sequence,
+                         header, in, in_len, extra_in, extra_in_len)) {
+    return false;
+  }
+
+  ssl->s3->write_sequence++;
   ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_HEADER, header);
   return true;
 }