Maintain the sequence number as a uint64_t.
We spend a lot of effort implementing a big-endian sequence number
update, etc., when the sequence number is just a 64-bit counter. (Or
48-bit counter in DTLS because we currently retain the epoch
separately. We can probably tidy that a bit too, but I'll leave that
for later. Right now the DTLS record layer state is a bit entwined
with the TLS one.)
Just store it as uint64_t. This should also simplify
https://boringssl-review.googlesource.com/c/boringssl/+/54325 a little.
Change-Id: I95233f924a660bc523b21496fdc9211055b75073
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/54505
Reviewed-by: Bob Beck <bbe@google.com>
Reviewed-by: Adam Langley <agl@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc
index a28dcdc..677418b 100644
--- a/ssl/dtls_method.cc
+++ b/ssl/dtls_method.cc
@@ -90,7 +90,7 @@
ssl->d1->r_epoch++;
OPENSSL_memset(&ssl->d1->bitmap, 0, sizeof(ssl->d1->bitmap));
- OPENSSL_memset(ssl->s3->read_sequence, 0, sizeof(ssl->s3->read_sequence));
+ ssl->s3->read_sequence = 0;
ssl->s3->aead_read_ctx = std::move(aead_ctx);
ssl->s3->read_level = level;
@@ -103,9 +103,8 @@
Span<const uint8_t> secret_for_quic) {
assert(secret_for_quic.empty()); // QUIC does not use DTLS.
ssl->d1->w_epoch++;
- OPENSSL_memcpy(ssl->d1->last_write_sequence, ssl->s3->write_sequence,
- sizeof(ssl->s3->write_sequence));
- OPENSSL_memset(ssl->s3->write_sequence, 0, sizeof(ssl->s3->write_sequence));
+ ssl->d1->last_write_sequence = ssl->s3->write_sequence;
+ ssl->s3->write_sequence = 0;
ssl->d1->last_aead_write_ctx = std::move(ssl->s3->aead_write_ctx);
ssl->s3->aead_write_ctx = std::move(aead_ctx);
diff --git a/ssl/dtls_record.cc b/ssl/dtls_record.cc
index 992fb52..eb3df69 100644
--- a/ssl/dtls_record.cc
+++ b/ssl/dtls_record.cc
@@ -123,52 +123,37 @@
BSSL_NAMESPACE_BEGIN
-// to_u64_be treats |in| as a 8-byte big-endian integer and returns the value as
-// a |uint64_t|.
-static uint64_t to_u64_be(const uint8_t in[8]) {
- uint64_t ret = 0;
- unsigned i;
- for (i = 0; i < 8; i++) {
- ret <<= 8;
- ret |= in[i];
- }
- return ret;
-}
-
// dtls1_bitmap_should_discard returns one if |seq_num| has been seen in
// |bitmap| or is stale. Otherwise it returns zero.
static bool dtls1_bitmap_should_discard(DTLS1_BITMAP *bitmap,
- const uint8_t seq_num[8]) {
+ uint64_t seq_num) {
const unsigned kWindowSize = sizeof(bitmap->map) * 8;
- uint64_t seq_num_u = to_u64_be(seq_num);
- if (seq_num_u > bitmap->max_seq_num) {
+ if (seq_num > bitmap->max_seq_num) {
return false;
}
- uint64_t idx = bitmap->max_seq_num - seq_num_u;
+ uint64_t idx = bitmap->max_seq_num - seq_num;
return idx >= kWindowSize || (bitmap->map & (((uint64_t)1) << idx));
}
// dtls1_bitmap_record updates |bitmap| to record receipt of sequence number
// |seq_num|. It slides the window forward if needed. It is an error to call
// this function on a stale sequence number.
-static void dtls1_bitmap_record(DTLS1_BITMAP *bitmap,
- const uint8_t seq_num[8]) {
+static void dtls1_bitmap_record(DTLS1_BITMAP *bitmap, uint64_t seq_num) {
const unsigned kWindowSize = sizeof(bitmap->map) * 8;
- uint64_t seq_num_u = to_u64_be(seq_num);
// Shift the window if necessary.
- if (seq_num_u > bitmap->max_seq_num) {
- uint64_t shift = seq_num_u - bitmap->max_seq_num;
+ if (seq_num > bitmap->max_seq_num) {
+ uint64_t shift = seq_num - bitmap->max_seq_num;
if (shift >= kWindowSize) {
bitmap->map = 0;
} else {
bitmap->map <<= shift;
}
- bitmap->max_seq_num = seq_num_u;
+ bitmap->max_seq_num = seq_num;
}
- uint64_t idx = bitmap->max_seq_num - seq_num_u;
+ uint64_t idx = bitmap->max_seq_num - seq_num;
if (idx < kWindowSize) {
bitmap->map |= ((uint64_t)1) << idx;
}
@@ -192,11 +177,11 @@
// Decode the record.
uint8_t type;
uint16_t version;
- uint8_t sequence[8];
+ uint8_t sequence_bytes[8];
CBS body;
if (!CBS_get_u8(&cbs, &type) ||
!CBS_get_u16(&cbs, &version) ||
- !CBS_copy_bytes(&cbs, sequence, 8) ||
+ !CBS_copy_bytes(&cbs, sequence_bytes, sizeof(sequence_bytes)) ||
!CBS_get_u16_length_prefixed(&cbs, &body) ||
CBS_len(&body) > SSL3_RT_MAX_ENCRYPTED_LENGTH) {
// The record header was incomplete or malformed. Drop the entire packet.
@@ -222,7 +207,8 @@
Span<const uint8_t> header = in.subspan(0, DTLS1_RT_HEADER_LENGTH);
ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HEADER, header);
- uint16_t epoch = (((uint16_t)sequence[0]) << 8) | sequence[1];
+ uint64_t sequence = CRYPTO_load_u64_be(sequence_bytes);
+ uint16_t epoch = static_cast<uint16_t>(sequence >> 48);
if (epoch != ssl->d1->r_epoch ||
dtls1_bitmap_should_discard(&ssl->d1->bitmap, sequence)) {
// Drop this record. It's from the wrong epoch or is a replay. Note that if
@@ -304,12 +290,12 @@
// Determine the parameters for the current epoch.
uint16_t epoch = ssl->d1->w_epoch;
SSLAEADContext *aead = ssl->s3->aead_write_ctx.get();
- uint8_t *seq = ssl->s3->write_sequence;
+ uint64_t *seq = &ssl->s3->write_sequence;
if (use_epoch == dtls1_use_previous_epoch) {
assert(ssl->d1->w_epoch >= 1);
epoch = ssl->d1->w_epoch - 1;
aead = ssl->d1->last_aead_write_ctx.get();
- seq = ssl->d1->last_write_sequence;
+ seq = &ssl->d1->last_write_sequence;
}
if (max_out < DTLS1_RT_HEADER_LENGTH) {
@@ -323,9 +309,15 @@
out[1] = record_version >> 8;
out[2] = record_version & 0xff;
- out[3] = epoch >> 8;
- out[4] = epoch & 0xff;
- OPENSSL_memcpy(&out[5], &seq[2], 6);
+ // Ensure the sequence number update does not overflow.
+ const uint64_t kMaxSequenceNumber = (uint64_t{1} << 48) - 1;
+ if (*seq + 1 > kMaxSequenceNumber) {
+ OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
+ return false;
+ }
+
+ uint64_t seq_with_epoch = (uint64_t{epoch} << 48) | *seq;
+ CRYPTO_store_u64_be(&out[3], seq_with_epoch);
size_t ciphertext_len;
if (!aead->CiphertextLen(&ciphertext_len, in_len, 0)) {
@@ -339,12 +331,12 @@
size_t len_copy;
if (!aead->Seal(out + DTLS1_RT_HEADER_LENGTH, &len_copy,
max_out - DTLS1_RT_HEADER_LENGTH, type, record_version,
- &out[3] /* seq */, header, in, in_len) ||
- !ssl_record_sequence_update(&seq[2], 6)) {
+ seq_with_epoch, header, in, in_len)) {
return false;
}
assert(ciphertext_len == len_copy);
+ (*seq)++;
*out_len = DTLS1_RT_HEADER_LENGTH + ciphertext_len;
ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_HEADER, header);
return true;
diff --git a/ssl/handoff.cc b/ssl/handoff.cc
index 736ab47..dd24811 100644
--- a/ssl/handoff.cc
+++ b/ssl/handoff.cc
@@ -17,6 +17,7 @@
#include <openssl/bytestring.h>
#include <openssl/err.h>
+#include "../crypto/internal.h"
#include "internal.h"
@@ -338,14 +339,16 @@
} else {
session = s3->session_reused ? ssl->session.get() : hs->new_session.get();
}
+ uint8_t read_sequence[8], write_sequence[8];
+ CRYPTO_store_u64_be(read_sequence, s3->read_sequence);
+ CRYPTO_store_u64_be(write_sequence, s3->write_sequence);
static const uint8_t kUnusedChannelID[64] = {0};
if (!CBB_add_asn1(out, &seq, CBS_ASN1_SEQUENCE) ||
!CBB_add_asn1_uint64(&seq, kHandbackVersion) ||
!CBB_add_asn1_uint64(&seq, type) ||
- !CBB_add_asn1_octet_string(&seq, s3->read_sequence,
- sizeof(s3->read_sequence)) ||
- !CBB_add_asn1_octet_string(&seq, s3->write_sequence,
- sizeof(s3->write_sequence)) ||
+ !CBB_add_asn1_octet_string(&seq, read_sequence, sizeof(read_sequence)) ||
+ !CBB_add_asn1_octet_string(&seq, write_sequence,
+ sizeof(write_sequence)) ||
!CBB_add_asn1_octet_string(&seq, s3->server_random,
sizeof(s3->server_random)) ||
!CBB_add_asn1_octet_string(&seq, s3->client_random,
@@ -366,7 +369,7 @@
sizeof(kUnusedChannelID)) ||
// These two fields were historically |token_binding_negotiated| and
// |negotiated_token_binding_param|.
- !CBB_add_asn1_bool(&seq, 0) ||
+ !CBB_add_asn1_bool(&seq, 0) || //
!CBB_add_asn1_uint64(&seq, 0) ||
!CBB_add_asn1_bool(&seq, s3->hs->next_proto_neg_seen) ||
!CBB_add_asn1_bool(&seq, s3->hs->cert_request) ||
@@ -694,11 +697,13 @@
}
break;
}
- if (!CopyExact({s3->read_sequence, sizeof(s3->read_sequence)}, &read_seq) ||
- !CopyExact({s3->write_sequence, sizeof(s3->write_sequence)},
- &write_seq)) {
+ uint8_t read_sequence[8], write_sequence[8];
+ if (!CopyExact(read_sequence, &read_seq) ||
+ !CopyExact(write_sequence, &write_seq)) {
return false;
}
+ s3->read_sequence = CRYPTO_load_u64_be(read_sequence);
+ s3->write_sequence = CRYPTO_load_u64_be(write_sequence);
if (type == handback_after_ecdhe &&
(hs->key_shares[0] = SSLKeyShare::Create(&key_share)) == nullptr) {
return false;
diff --git a/ssl/internal.h b/ssl/internal.h
index 71ff0ff..8ef1509 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -829,15 +829,14 @@
// to the plaintext in |in| and returns true. Otherwise, it returns
// false. The output will always be |ExplicitNonceLen| bytes ahead of |in|.
bool Open(Span<uint8_t> *out, uint8_t type, uint16_t record_version,
- const uint8_t seqnum[8], Span<const uint8_t> header,
- Span<uint8_t> in);
+ uint64_t seqnum, Span<const uint8_t> header, Span<uint8_t> in);
// Seal encrypts and authenticates |in_len| bytes from |in| and writes the
// result to |out|. It returns true on success and false on error.
//
// If |in| and |out| alias then |out| + |ExplicitNonceLen| must be == |in|.
bool Seal(uint8_t *out, size_t *out_len, size_t max_out, uint8_t type,
- uint16_t record_version, const uint8_t seqnum[8],
+ uint16_t record_version, uint64_t seqnum,
Span<const uint8_t> header, const uint8_t *in, size_t in_len);
// SealScatter encrypts and authenticates |in_len| bytes from |in| and splits
@@ -856,10 +855,9 @@
// If |in| and |out| alias then |out| must be == |in|. Other arguments may not
// alias anything.
bool SealScatter(uint8_t *out_prefix, uint8_t *out, uint8_t *out_suffix,
- uint8_t type, uint16_t record_version,
- const uint8_t seqnum[8], Span<const uint8_t> header,
- const uint8_t *in, size_t in_len, const uint8_t *extra_in,
- size_t extra_in_len);
+ uint8_t type, uint16_t record_version, uint64_t seqnum,
+ Span<const uint8_t> header, const uint8_t *in, size_t in_len,
+ const uint8_t *extra_in, size_t extra_in_len);
bool GetIV(const uint8_t **out_iv, size_t *out_iv_len) const;
@@ -868,8 +866,7 @@
// necessary.
Span<const uint8_t> GetAdditionalData(uint8_t storage[13], uint8_t type,
uint16_t record_version,
- const uint8_t seqnum[8],
- size_t plaintext_len,
+ uint64_t seqnum, size_t plaintext_len,
Span<const uint8_t> header);
const SSL_CIPHER *cipher_;
@@ -916,10 +913,6 @@
// Record layer.
-// ssl_record_sequence_update increments the sequence number in |seq|. It
-// returns true on success and false on wraparound.
-bool ssl_record_sequence_update(uint8_t *seq, size_t seq_len);
-
// ssl_record_prefix_len returns the length of the prefix before the ciphertext
// of a record for |ssl|.
//
@@ -2644,8 +2637,8 @@
SSL3_STATE();
~SSL3_STATE();
- uint8_t read_sequence[8] = {0};
- uint8_t write_sequence[8] = {0};
+ uint64_t read_sequence = 0;
+ uint64_t write_sequence = 0;
uint8_t server_random[SSL3_RANDOM_SIZE] = {0};
uint8_t client_random[SSL3_RANDOM_SIZE] = {0};
@@ -2935,7 +2928,7 @@
uint16_t handshake_read_seq = 0;
// save last sequence number for retransmissions
- uint8_t last_write_sequence[8] = {0};
+ uint64_t last_write_sequence = 0;
UniquePtr<SSLAEADContext> last_aead_write_ctx;
// incoming_messages is a ring buffer of incoming handshake messages that have
diff --git a/ssl/ssl_aead_ctx.cc b/ssl/ssl_aead_ctx.cc
index 0bad266..27f0084 100644
--- a/ssl/ssl_aead_ctx.cc
+++ b/ssl/ssl_aead_ctx.cc
@@ -220,13 +220,13 @@
}
Span<const uint8_t> SSLAEADContext::GetAdditionalData(
- uint8_t storage[13], uint8_t type, uint16_t record_version,
- const uint8_t seqnum[8], size_t plaintext_len, Span<const uint8_t> header) {
+ uint8_t storage[13], uint8_t type, uint16_t record_version, uint64_t seqnum,
+ size_t plaintext_len, Span<const uint8_t> header) {
if (ad_is_header_) {
return header;
}
- OPENSSL_memcpy(storage, seqnum, 8);
+ CRYPTO_store_u64_be(storage, seqnum);
size_t len = 8;
storage[len++] = type;
storage[len++] = static_cast<uint8_t>((record_version >> 8));
@@ -239,7 +239,7 @@
}
bool SSLAEADContext::Open(Span<uint8_t> *out, uint8_t type,
- uint16_t record_version, const uint8_t seqnum[8],
+ uint16_t record_version, uint64_t seqnum,
Span<const uint8_t> header, Span<uint8_t> in) {
if (is_null_cipher() || FUZZER_MODE) {
// Handle the initial NULL cipher.
@@ -288,7 +288,7 @@
in = in.subspan(variable_nonce_len_);
} else {
assert(variable_nonce_len_ == 8);
- OPENSSL_memcpy(nonce + nonce_len, seqnum, variable_nonce_len_);
+ CRYPTO_store_u64_be(nonce + nonce_len, seqnum);
}
nonce_len += variable_nonce_len_;
@@ -313,8 +313,7 @@
bool SSLAEADContext::SealScatter(uint8_t *out_prefix, uint8_t *out,
uint8_t *out_suffix, uint8_t type,
- uint16_t record_version,
- const uint8_t seqnum[8],
+ uint16_t record_version, uint64_t seqnum,
Span<const uint8_t> header, const uint8_t *in,
size_t in_len, const uint8_t *extra_in,
size_t extra_in_len) {
@@ -365,7 +364,7 @@
// When sending we use the sequence number as the variable part of the
// nonce.
assert(variable_nonce_len_ == 8);
- OPENSSL_memcpy(nonce + nonce_len, seqnum, variable_nonce_len_);
+ CRYPTO_store_u64_be(nonce + nonce_len, seqnum);
}
nonce_len += variable_nonce_len_;
@@ -398,7 +397,7 @@
bool SSLAEADContext::Seal(uint8_t *out, size_t *out_len, size_t max_out_len,
uint8_t type, uint16_t record_version,
- const uint8_t seqnum[8], Span<const uint8_t> header,
+ uint64_t seqnum, Span<const uint8_t> header,
const uint8_t *in, size_t in_len) {
const size_t prefix_len = ExplicitNonceLen();
size_t suffix_len;
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index 1b2e9f4..4d56d37 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -2823,20 +2823,19 @@
}
uint64_t SSL_get_read_sequence(const SSL *ssl) {
- // TODO(davidben): Internally represent sequence numbers as uint64_t.
if (SSL_is_dtls(ssl)) {
// max_seq_num already includes the epoch.
assert(ssl->d1->r_epoch == (ssl->d1->bitmap.max_seq_num >> 48));
return ssl->d1->bitmap.max_seq_num;
}
- return CRYPTO_load_u64_be(ssl->s3->read_sequence);
+ return ssl->s3->read_sequence;
}
uint64_t SSL_get_write_sequence(const SSL *ssl) {
- uint64_t ret = CRYPTO_load_u64_be(ssl->s3->write_sequence);
+ uint64_t ret = ssl->s3->write_sequence;
if (SSL_is_dtls(ssl)) {
assert((ret >> 48) == 0);
- ret |= ((uint64_t)ssl->d1->w_epoch) << 48;
+ ret |= uint64_t{ssl->d1->w_epoch} << 48;
}
return ret;
}
diff --git a/ssl/tls_method.cc b/ssl/tls_method.cc
index 326cbe7..5fcf684 100644
--- a/ssl/tls_method.cc
+++ b/ssl/tls_method.cc
@@ -108,7 +108,7 @@
}
}
- OPENSSL_memset(ssl->s3->read_sequence, 0, sizeof(ssl->s3->read_sequence));
+ ssl->s3->read_sequence = 0;
ssl->s3->aead_read_ctx = std::move(aead_ctx);
ssl->s3->read_level = level;
return true;
@@ -137,7 +137,7 @@
}
}
- OPENSSL_memset(ssl->s3->write_sequence, 0, sizeof(ssl->s3->write_sequence));
+ ssl->s3->write_sequence = 0;
ssl->s3->aead_write_ctx = std::move(aead_ctx);
ssl->s3->write_level = level;
return true;
diff --git a/ssl/tls_record.cc b/ssl/tls_record.cc
index acff1ad..88b8ef9 100644
--- a/ssl/tls_record.cc
+++ b/ssl/tls_record.cc
@@ -151,17 +151,6 @@
#endif
}
-bool ssl_record_sequence_update(uint8_t *seq, size_t seq_len) {
- for (size_t i = seq_len - 1; i < seq_len; i--) {
- ++seq[i];
- if (seq[i] != 0) {
- return true;
- }
- }
- OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
- return false;
-}
-
size_t ssl_record_prefix_len(const SSL *ssl) {
size_t header_len;
if (SSL_is_dtls(ssl)) {
@@ -286,6 +275,13 @@
return skip_early_data(ssl, out_alert, *out_consumed);
}
+ // Ensure the sequence number update does not overflow.
+ if (ssl->s3->read_sequence + 1 == 0) {
+ OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
+ *out_alert = SSL_AD_INTERNAL_ERROR;
+ return ssl_open_record_error;
+ }
+
// Decrypt the body in-place.
if (!ssl->s3->aead_read_ctx->Open(
out, type, version, ssl->s3->read_sequence, header,
@@ -301,11 +297,7 @@
}
ssl->s3->skip_early_data = false;
-
- if (!ssl_record_sequence_update(ssl->s3->read_sequence, 8)) {
- *out_alert = SSL_AD_INTERNAL_ERROR;
- return ssl_open_record_error;
- }
+ ssl->s3->read_sequence++;
// TLS 1.3 hides the record type inside the encrypted data.
bool has_padding =
@@ -411,13 +403,19 @@
out_prefix[4] = ciphertext_len & 0xff;
Span<const uint8_t> header = MakeSpan(out_prefix, SSL3_RT_HEADER_LENGTH);
- if (!aead->SealScatter(out_prefix + SSL3_RT_HEADER_LENGTH, out, out_suffix,
- out_prefix[0], record_version, ssl->s3->write_sequence,
- header, in, in_len, extra_in, extra_in_len) ||
- !ssl_record_sequence_update(ssl->s3->write_sequence, 8)) {
+ // Ensure the sequence number update does not overflow.
+ if (ssl->s3->write_sequence + 1 == 0) {
+ OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
return false;
}
+ if (!aead->SealScatter(out_prefix + SSL3_RT_HEADER_LENGTH, out, out_suffix,
+ out_prefix[0], record_version, ssl->s3->write_sequence,
+ header, in, in_len, extra_in, extra_in_len)) {
+ return false;
+ }
+
+ ssl->s3->write_sequence++;
ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_HEADER, header);
return true;
}