Pull the DTLS reassembly bitmap into its own abstraction We'll use the same bitmap for tracking which parts of a message have been ACKed. To that end, I've gone ahead and implemented a NextUnmarkedRange. This does make the hm_fragment structure a little larger, but we indirect them to the heap, so I think this is fine. In the steady state, they do not contribute to per-connection memory use. Bug: 42290594 Change-Id: I326520454ee4d6832248b50ea3f5205f41d74cbe Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72269 Auto-Submit: David Benjamin <davidben@google.com> Reviewed-by: Nick Harper <nharper@chromium.org> Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc index e0c2614..dd7d146 100644 --- a/ssl/d1_both.cc +++ b/ssl/d1_both.cc
@@ -117,6 +117,8 @@ #include <limits.h> #include <string.h> +#include <algorithm> + #include <openssl/err.h> #include <openssl/evp.h> #include <openssl/mem.h> @@ -140,13 +142,127 @@ // the underlying BIO supplies one. static const unsigned int kDefaultMTU = 1500 - 28; +// BitRange returns a |uint8_t| with bits |start|, inclusive, to |end|, +// exclusive, set. +static uint8_t BitRange(size_t start, size_t end) { + assert(start <= end && end <= 8); + return static_cast<uint8_t>(~((1u << start) - 1) & ((1u << end) - 1)); +} + +// FirstUnmarkedRangeInByte returns the first unmarked range in bits |b|. +static DTLSMessageBitmap::Range FirstUnmarkedRangeInByte(uint8_t b) { + size_t start, end; + for (start = 0; start < 8; start++) { + if ((b & (1u << start)) == 0) { + break; + } + } + for (end = start; end < 8; end++) { + if ((b & (1u << end)) != 0) { + break; + } + } + return DTLSMessageBitmap::Range{start, end}; +} + +bool DTLSMessageBitmap::Init(size_t num_bits) { + if (num_bits + 7 < num_bits) { + OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW); + return false; + } + size_t num_bytes = (num_bits + 7) / 8; + size_t bits_rounded = num_bytes * 8; + if (!bytes_.Init(num_bytes)) { + return false; + } + MarkRange(num_bits, bits_rounded); + return true; +} + +void DTLSMessageBitmap::MarkRange(size_t start, size_t end) { + // Clamp everything within range. + start = std::min(start, bytes_.size() << 3); + end = std::min(end, bytes_.size() << 3); + assert(start <= end); + if (start >= end) { + return; + } + + if ((start >> 3) == (end >> 3)) { + bytes_[start >> 3] |= BitRange(start & 7, end & 7); + } else { + bytes_[start >> 3] |= BitRange(start & 7, 8); + for (size_t i = (start >> 3) + 1; i < (end >> 3); i++) { + bytes_[i] = 0xff; + } + if ((end & 7) != 0) { + bytes_[end >> 3] |= BitRange(0, end & 7); + } + } + + // Release the buffer if we've marked everything. + auto iter = std::find_if(bytes_.begin(), bytes_.end(), + [](uint8_t b) { return b != 0xff; }); + if (iter == bytes_.end()) { + assert(NextUnmarkedRange(0).empty()); + bytes_.Reset(); + } +} + +DTLSMessageBitmap::Range DTLSMessageBitmap::NextUnmarkedRange( + size_t start) const { + size_t idx = start >> 3; + if (idx >= bytes_.size()) { + return Range{0, 0}; + } + + // Look at the bits from |start| up to a byte boundary. + uint8_t byte = bytes_[idx] | BitRange(0, start & 7); + if (byte == 0xff) { + // Nothing unmarked at this byte. Keep searching for an unmarked bit. + for (idx = idx + 1; idx < bytes_.size(); idx++) { + if (bytes_[idx] != 0xff) { + byte = bytes_[idx]; + break; + } + } + if (idx >= bytes_.size()) { + return Range{0, 0}; + } + } + + Range range = FirstUnmarkedRangeInByte(byte); + assert(!range.empty()); + bool should_extend = range.end == 8; + range.start += idx << 3; + range.end += idx << 3; + if (!should_extend) { + // The range did not end at a byte boundary. We're done. + return range; + } + + // Collect all fully unmarked bytes. + for (idx = idx + 1; idx < bytes_.size(); idx++) { + if (bytes_[idx] != 0) { + break; + } + } + range.end = idx << 3; + + // Add any bits from the remaining byte, if any. + if (idx < bytes_.size()) { + Range extra = FirstUnmarkedRangeInByte(bytes_[idx]); + if (extra.start == 0) { + range.end += extra.end; + } + } + + return range; +} // Receiving handshake messages. -hm_fragment::~hm_fragment() { - OPENSSL_free(data); - OPENSSL_free(reassembly); -} +hm_fragment::~hm_fragment() { OPENSSL_free(data); } static UniquePtr<hm_fragment> dtls1_hm_fragment_new( const struct hm_header_st *msg_hdr) { @@ -176,81 +292,19 @@ return nullptr; } - // If the handshake message is empty, |frag->reassembly| is NULL. - if (msg_hdr->msg_len > 0) { - // Initialize reassembly bitmask. - if (msg_hdr->msg_len + 7 < msg_hdr->msg_len) { - OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW); - return nullptr; - } - size_t bitmask_len = (msg_hdr->msg_len + 7) / 8; - frag->reassembly = (uint8_t *)OPENSSL_zalloc(bitmask_len); - if (frag->reassembly == NULL) { - return nullptr; - } + if (!frag->reassembly.Init(msg_hdr->msg_len)) { + return nullptr; } return frag; } -// bit_range returns a |uint8_t| with bits |start|, inclusive, to |end|, -// exclusive, set. -static uint8_t bit_range(size_t start, size_t end) { - return (uint8_t)(~((1u << start) - 1) & ((1u << end) - 1)); -} - -// dtls1_hm_fragment_mark marks bytes |start|, inclusive, to |end|, exclusive, -// as received in |frag|. If |frag| becomes complete, it clears -// |frag->reassembly|. The range must be within the bounds of |frag|'s message -// and |frag->reassembly| must not be NULL. -static void dtls1_hm_fragment_mark(hm_fragment *frag, size_t start, - size_t end) { - size_t msg_len = frag->msg_len; - - if (frag->reassembly == NULL || start > end || end > msg_len) { - assert(0); - return; - } - // A zero-length message will never have a pending reassembly. - assert(msg_len > 0); - - if (start == end) { - return; - } - - if ((start >> 3) == (end >> 3)) { - frag->reassembly[start >> 3] |= bit_range(start & 7, end & 7); - } else { - frag->reassembly[start >> 3] |= bit_range(start & 7, 8); - for (size_t i = (start >> 3) + 1; i < (end >> 3); i++) { - frag->reassembly[i] = 0xff; - } - if ((end & 7) != 0) { - frag->reassembly[end >> 3] |= bit_range(0, end & 7); - } - } - - // Check if the fragment is complete. - for (size_t i = 0; i < (msg_len >> 3); i++) { - if (frag->reassembly[i] != 0xff) { - return; - } - } - if ((msg_len & 7) != 0 && - frag->reassembly[msg_len >> 3] != bit_range(0, msg_len & 7)) { - return; - } - - OPENSSL_free(frag->reassembly); - frag->reassembly = NULL; -} - // dtls1_is_current_message_complete returns whether the current handshake // message is complete. static bool dtls1_is_current_message_complete(const SSL *ssl) { size_t idx = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT; hm_fragment *frag = ssl->d1->incoming_messages[idx].get(); - return frag != NULL && frag->reassembly == NULL; + return frag != nullptr && frag->reassembly.IsComplete(); } // dtls1_get_incoming_message returns the incoming message corresponding to @@ -306,8 +360,7 @@ const size_t frag_off = msg_hdr.frag_off; const size_t frag_len = msg_hdr.frag_len; const size_t msg_len = msg_hdr.msg_len; - if (frag_off > msg_len || frag_off + frag_len < frag_off || - frag_off + frag_len > msg_len || + if (frag_off > msg_len || frag_len > msg_len - frag_off || msg_len > ssl_max_handshake_message_len(ssl)) { OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE); *out_alert = SSL_AD_ILLEGAL_PARAMETER; @@ -336,12 +389,12 @@ } hm_fragment *frag = dtls1_get_incoming_message(ssl, out_alert, &msg_hdr); - if (frag == NULL) { + if (frag == nullptr) { return false; } assert(frag->msg_len == msg_len); - if (frag->reassembly == NULL) { + if (frag->reassembly.IsComplete()) { // The message is already assembled. continue; } @@ -350,7 +403,7 @@ // Copy the body into the fragment. OPENSSL_memcpy(frag->data + DTLS1_HM_HEADER_LENGTH + frag_off, CBS_data(&body), CBS_len(&body)); - dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len); + frag->reassembly.MarkRange(frag_off, frag_off + frag_len); } return true;
diff --git a/ssl/internal.h b/ssl/internal.h index 4a1f84c..29a674c 100644 --- a/ssl/internal.h +++ b/ssl/internal.h
@@ -3257,9 +3257,43 @@ #define DTLS1_HM_HEADER_LENGTH 12 -#define DTLS1_CCS_HEADER_LENGTH 1 +// A DTLSMessageBitmap maintains a list of bits which may be marked to indicate +// a portion of a message was received or ACKed. +class DTLSMessageBitmap { + public: + // A Range represents a range of bits from |start|, inclusive, to |end|, + // exclusive. + struct Range { + size_t start = 0; + size_t end = 0; -#define DTLS1_AL_HEADER_LENGTH 2 + bool empty() const { return start == end; } + bool operator==(const Range &r) const { + return start == r.start && end == r.end; + } + bool operator!=(const Range &r) const { return !(*this == r); } + }; + + // Init initializes the structure with |num_bits| unmarked bits, from zero + // to |num_bits - 1|. + bool Init(size_t num_bits); + + // MarkRange marks the bits from |start|, inclusive, to |end|, exclusive. + void MarkRange(size_t start, size_t end); + + // NextUnmarkedRange returns the next range of unmarked bits, starting from + // |start|, inclusive. If all bits after |start| are marked, it returns an + // empty range. + Range NextUnmarkedRange(size_t start) const; + + // IsComplete returns whether every bit in the bitmask has been marked. + bool IsComplete() const { return bytes_.empty(); } + + private: + // bytes_ contains the unmarked bits. We maintain an invariant: if |bytes_| is + // not empty, some bit is unset. + Array<uint8_t> bytes_; +}; struct hm_header_st { uint8_t type; @@ -3288,9 +3322,8 @@ // data is a pointer to the message, including message header. It has length // |DTLS1_HM_HEADER_LENGTH| + |msg_len|. uint8_t *data = nullptr; - // reassembly is a bitmask of |msg_len| bits corresponding to which parts of - // the message have been received. It is NULL if the message is complete. - uint8_t *reassembly = nullptr; + // reassembly tracks which parts of the message have been received. + DTLSMessageBitmap reassembly; }; struct OPENSSL_timeval {
diff --git a/ssl/ssl_internal_test.cc b/ssl/ssl_internal_test.cc index f5982a2..2cca65a 100644 --- a/ssl/ssl_internal_test.cc +++ b/ssl/ssl_internal_test.cc
@@ -482,6 +482,127 @@ EXPECT_EQ(reconstruct_seqnum(0x8002, 0xffff, 0x10000), 0x8002u); } +TEST(DTLSMessageBitmapTest, Basic) { + // expect_bitmap checks that |b|'s unmarked bits are those listed in |ranges|. + // Each element of |ranges| must be non-empty and non-overlapping, and + // |ranges| must be sorted. + auto expect_bitmap = [](const DTLSMessageBitmap &b, + const std::vector<DTLSMessageBitmap::Range> &ranges) { + EXPECT_EQ(ranges.empty(), b.IsComplete()); + size_t start = 0; + for (const auto &r : ranges) { + for (; start < r.start; start++) { + SCOPED_TRACE(start); + EXPECT_EQ(b.NextUnmarkedRange(start), r); + } + for (; start < r.end; start++) { + SCOPED_TRACE(start); + EXPECT_EQ(b.NextUnmarkedRange(start), + (DTLSMessageBitmap::Range{start, r.end})); + } + } + EXPECT_TRUE(b.NextUnmarkedRange(start).empty()); + EXPECT_TRUE(b.NextUnmarkedRange(start + 1).empty()); + EXPECT_TRUE(b.NextUnmarkedRange(start + 42).empty()); + + // This is implied from the previous checks, but NextUnmarkedRange should + // work as an iterator to reconstruct the ranges. + std::vector<DTLSMessageBitmap::Range> got_ranges; + for (auto r = b.NextUnmarkedRange(0); !r.empty(); + r = b.NextUnmarkedRange(r.end)) { + got_ranges.push_back(r); + } + EXPECT_EQ(ranges, got_ranges); + }; + + // Initially, the bitmap is empty (fully marked). + DTLSMessageBitmap bitmap; + expect_bitmap(bitmap, {}); + + // It can also be initialized to the empty message and marked. + ASSERT_TRUE(bitmap.Init(0)); + expect_bitmap(bitmap, {}); + bitmap.MarkRange(0, 0); + expect_bitmap(bitmap, {}); + + // Track 100 bits and mark byte by byte. + ASSERT_TRUE(bitmap.Init(100)); + expect_bitmap(bitmap, {{0, 100}}); + for (size_t i = 0; i < 100; i++) { + SCOPED_TRACE(i); + bitmap.MarkRange(i, i + 1); + if (i < 99) { + expect_bitmap(bitmap, {{i + 1, 100}}); + } else { + expect_bitmap(bitmap, {}); + } + } + + // Do the same but in reverse. + ASSERT_TRUE(bitmap.Init(100)); + expect_bitmap(bitmap, {{0, 100}}); + for (size_t i = 100; i > 0; i--) { + SCOPED_TRACE(i); + bitmap.MarkRange(i - 1, i); + if (i > 1) { + expect_bitmap(bitmap, {{0, i - 1}}); + } else { + expect_bitmap(bitmap, {}); + } + } + + // Overlapping ranges are fine. + ASSERT_TRUE(bitmap.Init(100)); + expect_bitmap(bitmap, {{0, 100}}); + for (size_t i = 0; i < 100; i++) { + SCOPED_TRACE(i); + bitmap.MarkRange(i / 2, i + 1); + if (i < 99) { + expect_bitmap(bitmap, {{i + 1, 100}}); + } else { + expect_bitmap(bitmap, {}); + } + } + + // Mark the middle chunk of every power of 3. + ASSERT_TRUE(bitmap.Init(100)); + bitmap.MarkRange(1, 2); + bitmap.MarkRange(3, 6); + bitmap.MarkRange(9, 18); + bitmap.MarkRange(27, 54); + bitmap.MarkRange(81, 162); + expect_bitmap(bitmap, {{0, 1}, {2, 3}, {6, 9}, {18, 27}, {54, 81}}); + + // Mark most of the chunk shifted down a bit, so it both overlaps the previous + // and also leaves some of the right chunks unmarked. + bitmap.MarkRange(6 - 2, 9 - 2); + bitmap.MarkRange(18 - 4, 27 - 4); + bitmap.MarkRange(54 - 8, 81 - 8); + expect_bitmap(bitmap, + {{0, 1}, {2, 3}, {9 - 2, 9}, {27 - 4, 27}, {81 - 8, 81}}); + + // Re-mark things that have already been marked. + bitmap.MarkRange(1, 2); + bitmap.MarkRange(3, 6); + bitmap.MarkRange(9, 18); + bitmap.MarkRange(27, 54); + bitmap.MarkRange(81, 162); + expect_bitmap(bitmap, + {{0, 1}, {2, 3}, {9 - 2, 9}, {27 - 4, 27}, {81 - 8, 81}}); + + // Moves should work. + DTLSMessageBitmap bitmap2 = std::move(bitmap); + expect_bitmap(bitmap, {}); + expect_bitmap(bitmap2, + {{0, 1}, {2, 3}, {9 - 2, 9}, {27 - 4, 27}, {81 - 8, 81}}); + + // Mark everything in two large ranges. + bitmap2.MarkRange(27 - 2, 100); + expect_bitmap(bitmap2, {{0, 1}, {2, 3}, {9 - 2, 9}, {27 - 4, 27 - 2}}); + bitmap2.MarkRange(0, 50); + expect_bitmap(bitmap2, {}); +} + } // namespace BSSL_NAMESPACE_END #endif // !BORINGSSL_SHARED_LIBRARY