Use DTLSRecordNumber in DTLSWriteEpoch This saves 8 bytes per write epoch and is a bit tidier. We could do something similar with DTLSReadEpoch if we hid it in DTLSReplayBitmap but that was a bit messy. Bug: 42290594 Change-Id: I3de19ef90d59566c303bf3b2ff85e76bafc790d4 Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72272 Reviewed-by: Nick Harper <nharper@chromium.org> Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc index 4bd30b6..1aac3a1 100644 --- a/ssl/d1_both.cc +++ b/ssl/d1_both.cc
@@ -581,7 +581,7 @@ // TODO(crbug.com/42290594): Epoch 1 (0-RTT) should be retained until // epoch 3 (app data) is available. for (const auto &msg : ssl->d1->outgoing_messages) { - if (msg.epoch == write_epoch->epoch) { + if (msg.epoch == write_epoch->epoch()) { return false; } } @@ -640,7 +640,7 @@ DTLS_OUTGOING_MESSAGE msg; msg.data = std::move(data); - msg.epoch = ssl->d1->write_epoch.epoch; + msg.epoch = ssl->d1->write_epoch.epoch(); msg.is_ccs = is_ccs; if (!ssl->d1->outgoing_messages.TryPushBack(std::move(msg))) { assert(false);
diff --git a/ssl/d1_pkt.cc b/ssl/d1_pkt.cc index 18cd5ee..a26fc70 100644 --- a/ssl/d1_pkt.cc +++ b/ssl/d1_pkt.cc
@@ -239,7 +239,7 @@ // TODO(crbug.com/42290594): Use the 0-RTT epoch if writing 0-RTT. int ret = dtls1_write_record(ssl, SSL3_RT_APPLICATION_DATA, in, - ssl->d1->write_epoch.epoch); + ssl->d1->write_epoch.epoch()); if (ret <= 0) { return ret; } @@ -282,7 +282,7 @@ int dtls1_dispatch_alert(SSL *ssl) { int ret = dtls1_write_record(ssl, SSL3_RT_ALERT, ssl->s3->send_alert, - ssl->d1->write_epoch.epoch); + ssl->d1->write_epoch.epoch()); if (ret <= 0) { return ret; }
diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc index 27993da..9fc854e 100644 --- a/ssl/dtls_method.cc +++ b/ssl/dtls_method.cc
@@ -118,14 +118,15 @@ DTLSWriteEpoch new_epoch; if (ssl_protocol_version(ssl) > TLS1_2_VERSION) { // TODO(crbug.com/42290594): See above. - new_epoch.epoch = level; + new_epoch.next_record = DTLSRecordNumber(level, 0); new_epoch.rn_encrypter = RecordNumberEncrypter::Create(aead_ctx->cipher(), traffic_secret); if (new_epoch.rn_encrypter == nullptr) { return false; } } else { - new_epoch.epoch = ssl->d1->write_epoch.epoch + 1; + new_epoch.next_record = + DTLSRecordNumber(ssl->d1->write_epoch.epoch() + 1, 0); } new_epoch.aead = std::move(aead_ctx);
diff --git a/ssl/dtls_record.cc b/ssl/dtls_record.cc index 5ab89bb..586b852 100644 --- a/ssl/dtls_record.cc +++ b/ssl/dtls_record.cc
@@ -443,11 +443,11 @@ } static DTLSWriteEpoch *get_write_epoch(const SSL *ssl, uint16_t epoch) { - if (ssl->d1->write_epoch.epoch == epoch) { + if (ssl->d1->write_epoch.epoch() == epoch) { return &ssl->d1->write_epoch; } for (const auto &e : ssl->d1->extra_write_epochs) { - if (e->epoch == epoch) { + if (e->epoch() == epoch) { return e.get(); } } @@ -510,7 +510,8 @@ const size_t record_header_len = dtls_record_header_write_len(ssl, epoch); // Ensure the sequence number update does not overflow. - if (write_epoch->next_seq + 1 > DTLSRecordNumber::kMaxSequence) { + DTLSRecordNumber record_number = write_epoch->next_record; + if (!record_number.HasNext()) { OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW); return false; } @@ -535,7 +536,6 @@ } uint16_t record_version = dtls_record_version(ssl); - DTLSRecordNumber record_number(epoch, write_epoch->next_seq); if (dtls13_header) { // The first byte of the DTLS 1.3 record header has the following format: // 0 1 2 3 4 5 6 7 @@ -554,19 +554,15 @@ // We always use a two-byte sequence number. A one-byte sequence number // would require coordinating with the application on ACK feedback to know // that the peer is not too far behind. - out[1] = write_epoch->next_seq >> 8; - out[2] = write_epoch->next_seq & 0xff; + CRYPTO_store_u16_be(out + 1, write_epoch->next_record.sequence()); // TODO(crbug.com/42290594): When we know the record is last in the packet, // omit the length. - out[3] = ciphertext_len >> 8; - out[4] = ciphertext_len & 0xff; + CRYPTO_store_u16_be(out + 3, ciphertext_len); } else { out[0] = type; - out[1] = record_version >> 8; - out[2] = record_version & 0xff; - CRYPTO_store_u64_be(&out[3], record_number.combined()); - out[11] = ciphertext_len >> 8; - out[12] = ciphertext_len & 0xff; + CRYPTO_store_u16_be(out + 1, record_version); + CRYPTO_store_u64_be(out + 3, record_number.combined()); + CRYPTO_store_u16_be(out + 11, ciphertext_len); } Span<const uint8_t> header = MakeConstSpan(out, record_header_len); @@ -595,7 +591,7 @@ } *out_number = record_number; - write_epoch->next_seq++; + write_epoch->next_record = record_number.Next(); *out_len = record_header_len + ciphertext_len; ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_HEADER, header); return true;
diff --git a/ssl/internal.h b/ssl/internal.h index d9b6519..10cea64 100644 --- a/ssl/internal.h +++ b/ssl/internal.h
@@ -1231,7 +1231,6 @@ // Record layer. -// TODO(davidben): Use this type more extensively in the epoch state. class DTLSRecordNumber { public: static constexpr uint64_t kMaxSequence = (uint64_t{1} << 48) - 1; @@ -1250,6 +1249,13 @@ uint16_t epoch() const { return combined_ >> 48; } uint64_t sequence() const { return combined_ & kMaxSequence; } + bool HasNext() const { return sequence() < kMaxSequence; } + DTLSRecordNumber Next() const { + BSSL_CHECK(HasNext()); + // This will not overflow into the epoch. + return DTLSRecordNumber::FromCombined(combined_ + 1); + } + private: explicit DTLSRecordNumber(uint64_t combined) : combined_(combined) {} @@ -1275,6 +1281,8 @@ struct DTLSReadEpoch { static constexpr bool kAllowUniquePtr = true; + // TODO(davidben): This could be made slightly more compact if |bitmap| stored + // a DTLSRecordNumber. uint16_t epoch = 0; UniquePtr<SSLAEADContext> aead; UniquePtr<RecordNumberEncrypter> rn_encrypter; @@ -1284,10 +1292,11 @@ struct DTLSWriteEpoch { static constexpr bool kAllowUniquePtr = true; - uint16_t epoch = 0; + uint16_t epoch() const { return next_record.epoch(); } + + DTLSRecordNumber next_record; UniquePtr<SSLAEADContext> aead; UniquePtr<RecordNumberEncrypter> rn_encrypter; - uint64_t next_seq = 0; }; // ssl_record_prefix_len returns the length of the prefix before the ciphertext
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc index a5023c8..e74f886 100644 --- a/ssl/ssl_lib.cc +++ b/ssl/ssl_lib.cc
@@ -2976,9 +2976,7 @@ uint64_t SSL_get_write_sequence(const SSL *ssl) { if (SSL_is_dtls(ssl)) { - const DTLSWriteEpoch *write_epoch = &ssl->d1->write_epoch; - return DTLSRecordNumber(write_epoch->epoch, write_epoch->next_seq) - .combined(); + return ssl->d1->write_epoch.next_record.combined(); } return ssl->s3->write_sequence;
diff --git a/ssl/tls_record.cc b/ssl/tls_record.cc index d982683..bdd6519 100644 --- a/ssl/tls_record.cc +++ b/ssl/tls_record.cc
@@ -594,7 +594,7 @@ size_t SSL_max_seal_overhead(const SSL *ssl) { if (SSL_is_dtls(ssl)) { // TODO(crbug.com/42290594): Use the 0-RTT epoch if writing 0-RTT. - return dtls_max_seal_overhead(ssl, ssl->d1->write_epoch.epoch); + return dtls_max_seal_overhead(ssl, ssl->d1->write_epoch.epoch()); } size_t ret = SSL3_RT_HEADER_LENGTH;