Pull the DTLS reassembly bitmap into its own abstraction
We'll use the same bitmap for tracking which parts of a message have
been ACKed. To that end, I've gone ahead and implemented a
NextUnmarkedRange.
This does make the hm_fragment structure a little larger, but we
indirect them to the heap, so I think this is fine. In the steady state,
they do not contribute to per-connection memory use.
Bug: 42290594
Change-Id: I326520454ee4d6832248b50ea3f5205f41d74cbe
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72269
Auto-Submit: David Benjamin <davidben@google.com>
Reviewed-by: Nick Harper <nharper@chromium.org>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index e0c2614..dd7d146 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -117,6 +117,8 @@
#include <limits.h>
#include <string.h>
+#include <algorithm>
+
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/mem.h>
@@ -140,13 +142,127 @@
// the underlying BIO supplies one.
static const unsigned int kDefaultMTU = 1500 - 28;
+// BitRange returns a |uint8_t| with bits |start|, inclusive, to |end|,
+// exclusive, set.
+static uint8_t BitRange(size_t start, size_t end) {
+ assert(start <= end && end <= 8);
+ return static_cast<uint8_t>(~((1u << start) - 1) & ((1u << end) - 1));
+}
+
+// FirstUnmarkedRangeInByte returns the first unmarked range in bits |b|.
+static DTLSMessageBitmap::Range FirstUnmarkedRangeInByte(uint8_t b) {
+ size_t start, end;
+ for (start = 0; start < 8; start++) {
+ if ((b & (1u << start)) == 0) {
+ break;
+ }
+ }
+ for (end = start; end < 8; end++) {
+ if ((b & (1u << end)) != 0) {
+ break;
+ }
+ }
+ return DTLSMessageBitmap::Range{start, end};
+}
+
+bool DTLSMessageBitmap::Init(size_t num_bits) {
+ if (num_bits + 7 < num_bits) {
+ OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
+ return false;
+ }
+ size_t num_bytes = (num_bits + 7) / 8;
+ size_t bits_rounded = num_bytes * 8;
+ if (!bytes_.Init(num_bytes)) {
+ return false;
+ }
+ MarkRange(num_bits, bits_rounded);
+ return true;
+}
+
+void DTLSMessageBitmap::MarkRange(size_t start, size_t end) {
+ // Clamp everything within range.
+ start = std::min(start, bytes_.size() << 3);
+ end = std::min(end, bytes_.size() << 3);
+ assert(start <= end);
+ if (start >= end) {
+ return;
+ }
+
+ if ((start >> 3) == (end >> 3)) {
+ bytes_[start >> 3] |= BitRange(start & 7, end & 7);
+ } else {
+ bytes_[start >> 3] |= BitRange(start & 7, 8);
+ for (size_t i = (start >> 3) + 1; i < (end >> 3); i++) {
+ bytes_[i] = 0xff;
+ }
+ if ((end & 7) != 0) {
+ bytes_[end >> 3] |= BitRange(0, end & 7);
+ }
+ }
+
+ // Release the buffer if we've marked everything.
+ auto iter = std::find_if(bytes_.begin(), bytes_.end(),
+ [](uint8_t b) { return b != 0xff; });
+ if (iter == bytes_.end()) {
+ assert(NextUnmarkedRange(0).empty());
+ bytes_.Reset();
+ }
+}
+
+DTLSMessageBitmap::Range DTLSMessageBitmap::NextUnmarkedRange(
+ size_t start) const {
+ size_t idx = start >> 3;
+ if (idx >= bytes_.size()) {
+ return Range{0, 0};
+ }
+
+ // Look at the bits from |start| up to a byte boundary.
+ uint8_t byte = bytes_[idx] | BitRange(0, start & 7);
+ if (byte == 0xff) {
+ // Nothing unmarked at this byte. Keep searching for an unmarked bit.
+ for (idx = idx + 1; idx < bytes_.size(); idx++) {
+ if (bytes_[idx] != 0xff) {
+ byte = bytes_[idx];
+ break;
+ }
+ }
+ if (idx >= bytes_.size()) {
+ return Range{0, 0};
+ }
+ }
+
+ Range range = FirstUnmarkedRangeInByte(byte);
+ assert(!range.empty());
+ bool should_extend = range.end == 8;
+ range.start += idx << 3;
+ range.end += idx << 3;
+ if (!should_extend) {
+ // The range did not end at a byte boundary. We're done.
+ return range;
+ }
+
+ // Collect all fully unmarked bytes.
+ for (idx = idx + 1; idx < bytes_.size(); idx++) {
+ if (bytes_[idx] != 0) {
+ break;
+ }
+ }
+ range.end = idx << 3;
+
+ // Add any bits from the remaining byte, if any.
+ if (idx < bytes_.size()) {
+ Range extra = FirstUnmarkedRangeInByte(bytes_[idx]);
+ if (extra.start == 0) {
+ range.end += extra.end;
+ }
+ }
+
+ return range;
+}
// Receiving handshake messages.
-hm_fragment::~hm_fragment() {
- OPENSSL_free(data);
- OPENSSL_free(reassembly);
-}
+hm_fragment::~hm_fragment() { OPENSSL_free(data); }
static UniquePtr<hm_fragment> dtls1_hm_fragment_new(
const struct hm_header_st *msg_hdr) {
@@ -176,81 +292,19 @@
return nullptr;
}
- // If the handshake message is empty, |frag->reassembly| is NULL.
- if (msg_hdr->msg_len > 0) {
- // Initialize reassembly bitmask.
- if (msg_hdr->msg_len + 7 < msg_hdr->msg_len) {
- OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
- return nullptr;
- }
- size_t bitmask_len = (msg_hdr->msg_len + 7) / 8;
- frag->reassembly = (uint8_t *)OPENSSL_zalloc(bitmask_len);
- if (frag->reassembly == NULL) {
- return nullptr;
- }
+ if (!frag->reassembly.Init(msg_hdr->msg_len)) {
+ return nullptr;
}
return frag;
}
-// bit_range returns a |uint8_t| with bits |start|, inclusive, to |end|,
-// exclusive, set.
-static uint8_t bit_range(size_t start, size_t end) {
- return (uint8_t)(~((1u << start) - 1) & ((1u << end) - 1));
-}
-
-// dtls1_hm_fragment_mark marks bytes |start|, inclusive, to |end|, exclusive,
-// as received in |frag|. If |frag| becomes complete, it clears
-// |frag->reassembly|. The range must be within the bounds of |frag|'s message
-// and |frag->reassembly| must not be NULL.
-static void dtls1_hm_fragment_mark(hm_fragment *frag, size_t start,
- size_t end) {
- size_t msg_len = frag->msg_len;
-
- if (frag->reassembly == NULL || start > end || end > msg_len) {
- assert(0);
- return;
- }
- // A zero-length message will never have a pending reassembly.
- assert(msg_len > 0);
-
- if (start == end) {
- return;
- }
-
- if ((start >> 3) == (end >> 3)) {
- frag->reassembly[start >> 3] |= bit_range(start & 7, end & 7);
- } else {
- frag->reassembly[start >> 3] |= bit_range(start & 7, 8);
- for (size_t i = (start >> 3) + 1; i < (end >> 3); i++) {
- frag->reassembly[i] = 0xff;
- }
- if ((end & 7) != 0) {
- frag->reassembly[end >> 3] |= bit_range(0, end & 7);
- }
- }
-
- // Check if the fragment is complete.
- for (size_t i = 0; i < (msg_len >> 3); i++) {
- if (frag->reassembly[i] != 0xff) {
- return;
- }
- }
- if ((msg_len & 7) != 0 &&
- frag->reassembly[msg_len >> 3] != bit_range(0, msg_len & 7)) {
- return;
- }
-
- OPENSSL_free(frag->reassembly);
- frag->reassembly = NULL;
-}
-
// dtls1_is_current_message_complete returns whether the current handshake
// message is complete.
static bool dtls1_is_current_message_complete(const SSL *ssl) {
size_t idx = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
hm_fragment *frag = ssl->d1->incoming_messages[idx].get();
- return frag != NULL && frag->reassembly == NULL;
+ return frag != nullptr && frag->reassembly.IsComplete();
}
// dtls1_get_incoming_message returns the incoming message corresponding to
@@ -306,8 +360,7 @@
const size_t frag_off = msg_hdr.frag_off;
const size_t frag_len = msg_hdr.frag_len;
const size_t msg_len = msg_hdr.msg_len;
- if (frag_off > msg_len || frag_off + frag_len < frag_off ||
- frag_off + frag_len > msg_len ||
+ if (frag_off > msg_len || frag_len > msg_len - frag_off ||
msg_len > ssl_max_handshake_message_len(ssl)) {
OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
*out_alert = SSL_AD_ILLEGAL_PARAMETER;
@@ -336,12 +389,12 @@
}
hm_fragment *frag = dtls1_get_incoming_message(ssl, out_alert, &msg_hdr);
- if (frag == NULL) {
+ if (frag == nullptr) {
return false;
}
assert(frag->msg_len == msg_len);
- if (frag->reassembly == NULL) {
+ if (frag->reassembly.IsComplete()) {
// The message is already assembled.
continue;
}
@@ -350,7 +403,7 @@
// Copy the body into the fragment.
OPENSSL_memcpy(frag->data + DTLS1_HM_HEADER_LENGTH + frag_off,
CBS_data(&body), CBS_len(&body));
- dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len);
+ frag->reassembly.MarkRange(frag_off, frag_off + frag_len);
}
return true;
diff --git a/ssl/internal.h b/ssl/internal.h
index 4a1f84c..29a674c 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -3257,9 +3257,43 @@
#define DTLS1_HM_HEADER_LENGTH 12
-#define DTLS1_CCS_HEADER_LENGTH 1
+// A DTLSMessageBitmap maintains a list of bits which may be marked to indicate
+// a portion of a message was received or ACKed.
+class DTLSMessageBitmap {
+ public:
+ // A Range represents a range of bits from |start|, inclusive, to |end|,
+ // exclusive.
+ struct Range {
+ size_t start = 0;
+ size_t end = 0;
-#define DTLS1_AL_HEADER_LENGTH 2
+ bool empty() const { return start == end; }
+ bool operator==(const Range &r) const {
+ return start == r.start && end == r.end;
+ }
+ bool operator!=(const Range &r) const { return !(*this == r); }
+ };
+
+ // Init initializes the structure with |num_bits| unmarked bits, from zero
+ // to |num_bits - 1|.
+ bool Init(size_t num_bits);
+
+ // MarkRange marks the bits from |start|, inclusive, to |end|, exclusive.
+ void MarkRange(size_t start, size_t end);
+
+ // NextUnmarkedRange returns the next range of unmarked bits, starting from
+ // |start|, inclusive. If all bits after |start| are marked, it returns an
+ // empty range.
+ Range NextUnmarkedRange(size_t start) const;
+
+ // IsComplete returns whether every bit in the bitmask has been marked.
+ bool IsComplete() const { return bytes_.empty(); }
+
+ private:
+ // bytes_ contains the unmarked bits. We maintain an invariant: if |bytes_| is
+ // not empty, some bit is unset.
+ Array<uint8_t> bytes_;
+};
struct hm_header_st {
uint8_t type;
@@ -3288,9 +3322,8 @@
// data is a pointer to the message, including message header. It has length
// |DTLS1_HM_HEADER_LENGTH| + |msg_len|.
uint8_t *data = nullptr;
- // reassembly is a bitmask of |msg_len| bits corresponding to which parts of
- // the message have been received. It is NULL if the message is complete.
- uint8_t *reassembly = nullptr;
+ // reassembly tracks which parts of the message have been received.
+ DTLSMessageBitmap reassembly;
};
struct OPENSSL_timeval {
diff --git a/ssl/ssl_internal_test.cc b/ssl/ssl_internal_test.cc
index f5982a2..2cca65a 100644
--- a/ssl/ssl_internal_test.cc
+++ b/ssl/ssl_internal_test.cc
@@ -482,6 +482,127 @@
EXPECT_EQ(reconstruct_seqnum(0x8002, 0xffff, 0x10000), 0x8002u);
}
+TEST(DTLSMessageBitmapTest, Basic) {
+ // expect_bitmap checks that |b|'s unmarked bits are those listed in |ranges|.
+ // Each element of |ranges| must be non-empty and non-overlapping, and
+ // |ranges| must be sorted.
+ auto expect_bitmap = [](const DTLSMessageBitmap &b,
+ const std::vector<DTLSMessageBitmap::Range> &ranges) {
+ EXPECT_EQ(ranges.empty(), b.IsComplete());
+ size_t start = 0;
+ for (const auto &r : ranges) {
+ for (; start < r.start; start++) {
+ SCOPED_TRACE(start);
+ EXPECT_EQ(b.NextUnmarkedRange(start), r);
+ }
+ for (; start < r.end; start++) {
+ SCOPED_TRACE(start);
+ EXPECT_EQ(b.NextUnmarkedRange(start),
+ (DTLSMessageBitmap::Range{start, r.end}));
+ }
+ }
+ EXPECT_TRUE(b.NextUnmarkedRange(start).empty());
+ EXPECT_TRUE(b.NextUnmarkedRange(start + 1).empty());
+ EXPECT_TRUE(b.NextUnmarkedRange(start + 42).empty());
+
+ // This is implied from the previous checks, but NextUnmarkedRange should
+ // work as an iterator to reconstruct the ranges.
+ std::vector<DTLSMessageBitmap::Range> got_ranges;
+ for (auto r = b.NextUnmarkedRange(0); !r.empty();
+ r = b.NextUnmarkedRange(r.end)) {
+ got_ranges.push_back(r);
+ }
+ EXPECT_EQ(ranges, got_ranges);
+ };
+
+ // Initially, the bitmap is empty (fully marked).
+ DTLSMessageBitmap bitmap;
+ expect_bitmap(bitmap, {});
+
+ // It can also be initialized to the empty message and marked.
+ ASSERT_TRUE(bitmap.Init(0));
+ expect_bitmap(bitmap, {});
+ bitmap.MarkRange(0, 0);
+ expect_bitmap(bitmap, {});
+
+ // Track 100 bits and mark byte by byte.
+ ASSERT_TRUE(bitmap.Init(100));
+ expect_bitmap(bitmap, {{0, 100}});
+ for (size_t i = 0; i < 100; i++) {
+ SCOPED_TRACE(i);
+ bitmap.MarkRange(i, i + 1);
+ if (i < 99) {
+ expect_bitmap(bitmap, {{i + 1, 100}});
+ } else {
+ expect_bitmap(bitmap, {});
+ }
+ }
+
+ // Do the same but in reverse.
+ ASSERT_TRUE(bitmap.Init(100));
+ expect_bitmap(bitmap, {{0, 100}});
+ for (size_t i = 100; i > 0; i--) {
+ SCOPED_TRACE(i);
+ bitmap.MarkRange(i - 1, i);
+ if (i > 1) {
+ expect_bitmap(bitmap, {{0, i - 1}});
+ } else {
+ expect_bitmap(bitmap, {});
+ }
+ }
+
+ // Overlapping ranges are fine.
+ ASSERT_TRUE(bitmap.Init(100));
+ expect_bitmap(bitmap, {{0, 100}});
+ for (size_t i = 0; i < 100; i++) {
+ SCOPED_TRACE(i);
+ bitmap.MarkRange(i / 2, i + 1);
+ if (i < 99) {
+ expect_bitmap(bitmap, {{i + 1, 100}});
+ } else {
+ expect_bitmap(bitmap, {});
+ }
+ }
+
+ // Mark the middle chunk of every power of 3.
+ ASSERT_TRUE(bitmap.Init(100));
+ bitmap.MarkRange(1, 2);
+ bitmap.MarkRange(3, 6);
+ bitmap.MarkRange(9, 18);
+ bitmap.MarkRange(27, 54);
+ bitmap.MarkRange(81, 162);
+ expect_bitmap(bitmap, {{0, 1}, {2, 3}, {6, 9}, {18, 27}, {54, 81}});
+
+ // Mark most of the chunk shifted down a bit, so it both overlaps the previous
+ // and also leaves some of the right chunks unmarked.
+ bitmap.MarkRange(6 - 2, 9 - 2);
+ bitmap.MarkRange(18 - 4, 27 - 4);
+ bitmap.MarkRange(54 - 8, 81 - 8);
+ expect_bitmap(bitmap,
+ {{0, 1}, {2, 3}, {9 - 2, 9}, {27 - 4, 27}, {81 - 8, 81}});
+
+ // Re-mark things that have already been marked.
+ bitmap.MarkRange(1, 2);
+ bitmap.MarkRange(3, 6);
+ bitmap.MarkRange(9, 18);
+ bitmap.MarkRange(27, 54);
+ bitmap.MarkRange(81, 162);
+ expect_bitmap(bitmap,
+ {{0, 1}, {2, 3}, {9 - 2, 9}, {27 - 4, 27}, {81 - 8, 81}});
+
+ // Moves should work.
+ DTLSMessageBitmap bitmap2 = std::move(bitmap);
+ expect_bitmap(bitmap, {});
+ expect_bitmap(bitmap2,
+ {{0, 1}, {2, 3}, {9 - 2, 9}, {27 - 4, 27}, {81 - 8, 81}});
+
+ // Mark everything in two large ranges.
+ bitmap2.MarkRange(27 - 2, 100);
+ expect_bitmap(bitmap2, {{0, 1}, {2, 3}, {9 - 2, 9}, {27 - 4, 27 - 2}});
+ bitmap2.MarkRange(0, 50);
+ expect_bitmap(bitmap2, {});
+}
+
} // namespace
BSSL_NAMESPACE_END
#endif // !BORINGSSL_SHARED_LIBRARY