Store an Array in hm_fragment
Also rename it to DTLSIncomingMessage to mirror DTLSOutgoingMessage.
(Also renamed to be C++ style like other, slightly more C++-y,
internals.)
It is slightly less compact in memory to store the size as a size_t
rather than a uint32_t that can be packed in there, but this object is
largely discarded after the handshake, and this way we get more bounds
checking.
I haven't actually replaced the memcpy with a span-based copy since we
don't have one yet, but this opens the door for that.
Bug: 374890768
Change-Id: I0f69fc4ca64bebed95d90979ec99fe1ea57c66fb
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72273
Reviewed-by: Nick Harper <nharper@chromium.org>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index 1aac3a1..e0faa6d 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -262,27 +262,22 @@
// Receiving handshake messages.
-hm_fragment::~hm_fragment() { OPENSSL_free(data); }
-
-static UniquePtr<hm_fragment> dtls1_hm_fragment_new(
+static UniquePtr<DTLSIncomingMessage> dtls_new_incoming_message(
const struct hm_header_st *msg_hdr) {
ScopedCBB cbb;
- UniquePtr<hm_fragment> frag = MakeUnique<hm_fragment>();
+ UniquePtr<DTLSIncomingMessage> frag = MakeUnique<DTLSIncomingMessage>();
if (!frag) {
return nullptr;
}
frag->type = msg_hdr->type;
frag->seq = msg_hdr->seq;
- frag->msg_len = msg_hdr->msg_len;
// Allocate space for the reassembled message and fill in the header.
- frag->data =
- (uint8_t *)OPENSSL_malloc(DTLS1_HM_HEADER_LENGTH + msg_hdr->msg_len);
- if (frag->data == NULL) {
+ if (!frag->data.InitForOverwrite(DTLS1_HM_HEADER_LENGTH + msg_hdr->msg_len)) {
return nullptr;
}
- if (!CBB_init_fixed(cbb.get(), frag->data, DTLS1_HM_HEADER_LENGTH) ||
+ if (!CBB_init_fixed(cbb.get(), frag->data.data(), DTLS1_HM_HEADER_LENGTH) ||
!CBB_add_u8(cbb.get(), msg_hdr->type) ||
!CBB_add_u24(cbb.get(), msg_hdr->msg_len) ||
!CBB_add_u16(cbb.get(), msg_hdr->seq) ||
@@ -303,7 +298,7 @@
// message is complete.
static bool dtls1_is_current_message_complete(const SSL *ssl) {
size_t idx = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
- hm_fragment *frag = ssl->d1->incoming_messages[idx].get();
+ DTLSIncomingMessage *frag = ssl->d1->incoming_messages[idx].get();
return frag != nullptr && frag->reassembly.IsComplete();
}
@@ -311,7 +306,7 @@
// |msg_hdr|. If none exists, it creates a new one and inserts it in the
// queue. Otherwise, it checks |msg_hdr| is consistent with the existing one. It
// returns NULL on failure. The caller does not take ownership of the result.
-static hm_fragment *dtls1_get_incoming_message(
+static DTLSIncomingMessage *dtls1_get_incoming_message(
SSL *ssl, uint8_t *out_alert, const struct hm_header_st *msg_hdr) {
if (msg_hdr->seq < ssl->d1->handshake_read_seq ||
msg_hdr->seq - ssl->d1->handshake_read_seq >= SSL_MAX_HANDSHAKE_FLIGHT) {
@@ -320,13 +315,13 @@
}
size_t idx = msg_hdr->seq % SSL_MAX_HANDSHAKE_FLIGHT;
- hm_fragment *frag = ssl->d1->incoming_messages[idx].get();
+ DTLSIncomingMessage *frag = ssl->d1->incoming_messages[idx].get();
if (frag != NULL) {
assert(frag->seq == msg_hdr->seq);
// The new fragment must be compatible with the previous fragments from this
// message.
if (frag->type != msg_hdr->type ||
- frag->msg_len != msg_hdr->msg_len) {
+ frag->msg_len() != msg_hdr->msg_len) {
OPENSSL_PUT_ERROR(SSL, SSL_R_FRAGMENT_MISMATCH);
*out_alert = SSL_AD_ILLEGAL_PARAMETER;
return NULL;
@@ -335,7 +330,7 @@
}
// This is the first fragment from this message.
- ssl->d1->incoming_messages[idx] = dtls1_hm_fragment_new(msg_hdr);
+ ssl->d1->incoming_messages[idx] = dtls_new_incoming_message(msg_hdr);
if (!ssl->d1->incoming_messages[idx]) {
*out_alert = SSL_AD_INTERNAL_ERROR;
return NULL;
@@ -388,11 +383,12 @@
continue;
}
- hm_fragment *frag = dtls1_get_incoming_message(ssl, out_alert, &msg_hdr);
+ DTLSIncomingMessage *frag =
+ dtls1_get_incoming_message(ssl, out_alert, &msg_hdr);
if (frag == nullptr) {
return false;
}
- assert(frag->msg_len == msg_len);
+ assert(frag->msg_len() == msg_len);
if (frag->reassembly.IsComplete()) {
// The message is already assembled.
@@ -401,8 +397,8 @@
assert(msg_len > 0);
// Copy the body into the fragment.
- OPENSSL_memcpy(frag->data + DTLS1_HM_HEADER_LENGTH + frag_off,
- CBS_data(&body), CBS_len(&body));
+ Span<uint8_t> dest = frag->msg().subspan(frag_off, CBS_len(&body));
+ OPENSSL_memcpy(dest.data(), CBS_data(&body), CBS_len(&body));
frag->reassembly.MarkRange(frag_off, frag_off + frag_len);
}
@@ -484,10 +480,10 @@
}
size_t idx = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
- hm_fragment *frag = ssl->d1->incoming_messages[idx].get();
+ const DTLSIncomingMessage *frag = ssl->d1->incoming_messages[idx].get();
out->type = frag->type;
- CBS_init(&out->body, frag->data + DTLS1_HM_HEADER_LENGTH, frag->msg_len);
- CBS_init(&out->raw, frag->data, DTLS1_HM_HEADER_LENGTH + frag->msg_len);
+ out->raw = CBS(frag->data);
+ out->body = CBS(frag->msg());
out->is_v2_hello = false;
if (!ssl->s3->has_message) {
ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE, out->raw);
@@ -638,7 +634,7 @@
ssl->d1->handshake_write_seq++;
}
- DTLS_OUTGOING_MESSAGE msg;
+ DTLSOutgoingMessage msg;
msg.data = std::move(data);
msg.epoch = ssl->d1->write_epoch.epoch();
msg.is_ccs = is_ccs;
@@ -697,7 +693,7 @@
// |*out_len| to the number of bytes written.
static enum seal_result_t seal_next_message(SSL *ssl, uint8_t *out,
size_t *out_len, size_t max_out,
- const DTLS_OUTGOING_MESSAGE *msg) {
+ const DTLSOutgoingMessage *msg) {
assert(ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size());
assert(msg == &ssl->d1->outgoing_messages[ssl->d1->outgoing_written]);
@@ -793,7 +789,7 @@
assert(ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size());
for (; ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size();
ssl->d1->outgoing_written++) {
- const DTLS_OUTGOING_MESSAGE *msg =
+ const DTLSOutgoingMessage *msg =
&ssl->d1->outgoing_messages[ssl->d1->outgoing_written];
size_t len;
enum seal_result_t ret = seal_next_message(ssl, out, &len, max_out, msg);
diff --git a/ssl/internal.h b/ssl/internal.h
index 10cea64..af988de 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1547,7 +1547,7 @@
// tls_flush_pending_hs_data flushes any handshake plaintext data.
bool tls_flush_pending_hs_data(SSL *ssl);
-struct DTLS_OUTGOING_MESSAGE {
+struct DTLSOutgoingMessage {
Array<uint8_t> data;
uint16_t epoch = 0;
bool is_ccs = false;
@@ -3343,25 +3343,24 @@
uint32_t frag_len;
};
-// An hm_fragment is an incoming DTLS message, possibly not yet assembled.
-struct hm_fragment {
+// An DTLSIncomingMessage is an incoming DTLS message, possibly not yet
+// assembled.
+struct DTLSIncomingMessage {
static constexpr bool kAllowUniquePtr = true;
- hm_fragment() {}
- hm_fragment(const hm_fragment &) = delete;
- hm_fragment &operator=(const hm_fragment &) = delete;
-
- ~hm_fragment();
+ Span<uint8_t> msg() { return MakeSpan(data).subspan(DTLS1_HM_HEADER_LENGTH); }
+ Span<const uint8_t> msg() const {
+ return MakeSpan(data).subspan(DTLS1_HM_HEADER_LENGTH);
+ }
+ size_t msg_len() const { return msg().size(); }
// type is the type of the message.
uint8_t type = 0;
// seq is the sequence number of this message.
uint16_t seq = 0;
- // msg_len is the length of the message body.
- uint32_t msg_len = 0;
- // data is a pointer to the message, including message header. It has length
- // |DTLS1_HM_HEADER_LENGTH| + |msg_len|.
- uint8_t *data = nullptr;
+ // data contains the message, including the message header of length
+ // |DTLS1_HM_HEADER_LENGTH|.
+ Array<uint8_t> data;
// reassembly tracks which parts of the message have been received.
DTLSMessageBitmap reassembly;
};
@@ -3426,11 +3425,11 @@
// yet to be processed. The front of the ring buffer is message number
// |handshake_read_seq|, at position |handshake_read_seq| %
// |SSL_MAX_HANDSHAKE_FLIGHT|.
- UniquePtr<hm_fragment> incoming_messages[SSL_MAX_HANDSHAKE_FLIGHT];
+ UniquePtr<DTLSIncomingMessage> incoming_messages[SSL_MAX_HANDSHAKE_FLIGHT];
// outgoing_messages is the queue of outgoing messages from the last handshake
// flight.
- InplaceVector<DTLS_OUTGOING_MESSAGE, SSL_MAX_HANDSHAKE_FLIGHT>
+ InplaceVector<DTLSOutgoingMessage, SSL_MAX_HANDSHAKE_FLIGHT>
outgoing_messages;
// outgoing_written is the number of outgoing messages that have been