Pull the DTLS reassembly bitmap into its own abstraction

We'll use the same bitmap for tracking which parts of a message have
been ACKed. To that end, I've gone ahead and implemented a
NextUnmarkedRange.

This does make the hm_fragment structure a little larger, but we
indirect them to the heap, so I think this is fine. In the steady state,
they do not contribute to per-connection memory use.

Bug: 42290594
Change-Id: I326520454ee4d6832248b50ea3f5205f41d74cbe
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72269
Auto-Submit: David Benjamin <davidben@google.com>
Reviewed-by: Nick Harper <nharper@chromium.org>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index e0c2614..dd7d146 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -117,6 +117,8 @@
 #include <limits.h>
 #include <string.h>
 
+#include <algorithm>
+
 #include <openssl/err.h>
 #include <openssl/evp.h>
 #include <openssl/mem.h>
@@ -140,13 +142,127 @@
 // the underlying BIO supplies one.
 static const unsigned int kDefaultMTU = 1500 - 28;
 
+// BitRange returns a |uint8_t| with bits |start|, inclusive, to |end|,
+// exclusive, set.
+static uint8_t BitRange(size_t start, size_t end) {
+  assert(start <= end && end <= 8);
+  return static_cast<uint8_t>(~((1u << start) - 1) & ((1u << end) - 1));
+}
+
+// FirstUnmarkedRangeInByte returns the first unmarked range in bits |b|.
+static DTLSMessageBitmap::Range FirstUnmarkedRangeInByte(uint8_t b) {
+  size_t start, end;
+  for (start = 0; start < 8; start++) {
+    if ((b & (1u << start)) == 0) {
+      break;
+    }
+  }
+  for (end = start; end < 8; end++) {
+    if ((b & (1u << end)) != 0) {
+      break;
+    }
+  }
+  return DTLSMessageBitmap::Range{start, end};
+}
+
+bool DTLSMessageBitmap::Init(size_t num_bits) {
+  if (num_bits + 7 < num_bits) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
+    return false;
+  }
+  size_t num_bytes = (num_bits + 7) / 8;
+  size_t bits_rounded = num_bytes * 8;
+  if (!bytes_.Init(num_bytes)) {
+    return false;
+  }
+  MarkRange(num_bits, bits_rounded);
+  return true;
+}
+
+void DTLSMessageBitmap::MarkRange(size_t start, size_t end) {
+  // Clamp everything within range.
+  start = std::min(start, bytes_.size() << 3);
+  end = std::min(end, bytes_.size() << 3);
+  assert(start <= end);
+  if (start >= end) {
+    return;
+  }
+
+  if ((start >> 3) == (end >> 3)) {
+    bytes_[start >> 3] |= BitRange(start & 7, end & 7);
+  } else {
+    bytes_[start >> 3] |= BitRange(start & 7, 8);
+    for (size_t i = (start >> 3) + 1; i < (end >> 3); i++) {
+      bytes_[i] = 0xff;
+    }
+    if ((end & 7) != 0) {
+      bytes_[end >> 3] |= BitRange(0, end & 7);
+    }
+  }
+
+  // Release the buffer if we've marked everything.
+  auto iter = std::find_if(bytes_.begin(), bytes_.end(),
+                           [](uint8_t b) { return b != 0xff; });
+  if (iter == bytes_.end()) {
+    assert(NextUnmarkedRange(0).empty());
+    bytes_.Reset();
+  }
+}
+
+DTLSMessageBitmap::Range DTLSMessageBitmap::NextUnmarkedRange(
+    size_t start) const {
+  size_t idx = start >> 3;
+  if (idx >= bytes_.size()) {
+    return Range{0, 0};
+  }
+
+  // Look at the bits from |start| up to a byte boundary.
+  uint8_t byte = bytes_[idx] | BitRange(0, start & 7);
+  if (byte == 0xff) {
+    // Nothing unmarked at this byte. Keep searching for an unmarked bit.
+    for (idx = idx + 1; idx < bytes_.size(); idx++) {
+      if (bytes_[idx] != 0xff) {
+        byte = bytes_[idx];
+        break;
+      }
+    }
+    if (idx >= bytes_.size()) {
+      return Range{0, 0};
+    }
+  }
+
+  Range range = FirstUnmarkedRangeInByte(byte);
+  assert(!range.empty());
+  bool should_extend = range.end == 8;
+  range.start += idx << 3;
+  range.end += idx << 3;
+  if (!should_extend) {
+    // The range did not end at a byte boundary. We're done.
+    return range;
+  }
+
+  // Collect all fully unmarked bytes.
+  for (idx = idx + 1; idx < bytes_.size(); idx++) {
+    if (bytes_[idx] != 0) {
+      break;
+    }
+  }
+  range.end = idx << 3;
+
+  // Add any bits from the remaining byte, if any.
+  if (idx < bytes_.size()) {
+    Range extra = FirstUnmarkedRangeInByte(bytes_[idx]);
+    if (extra.start == 0) {
+      range.end += extra.end;
+    }
+  }
+
+  return range;
+}
 
 // Receiving handshake messages.
 
-hm_fragment::~hm_fragment() {
-  OPENSSL_free(data);
-  OPENSSL_free(reassembly);
-}
+hm_fragment::~hm_fragment() { OPENSSL_free(data); }
 
 static UniquePtr<hm_fragment> dtls1_hm_fragment_new(
     const struct hm_header_st *msg_hdr) {
@@ -176,81 +292,19 @@
     return nullptr;
   }
 
-  // If the handshake message is empty, |frag->reassembly| is NULL.
-  if (msg_hdr->msg_len > 0) {
-    // Initialize reassembly bitmask.
-    if (msg_hdr->msg_len + 7 < msg_hdr->msg_len) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
-      return nullptr;
-    }
-    size_t bitmask_len = (msg_hdr->msg_len + 7) / 8;
-    frag->reassembly = (uint8_t *)OPENSSL_zalloc(bitmask_len);
-    if (frag->reassembly == NULL) {
-      return nullptr;
-    }
+  if (!frag->reassembly.Init(msg_hdr->msg_len)) {
+    return nullptr;
   }
 
   return frag;
 }
 
-// bit_range returns a |uint8_t| with bits |start|, inclusive, to |end|,
-// exclusive, set.
-static uint8_t bit_range(size_t start, size_t end) {
-  return (uint8_t)(~((1u << start) - 1) & ((1u << end) - 1));
-}
-
-// dtls1_hm_fragment_mark marks bytes |start|, inclusive, to |end|, exclusive,
-// as received in |frag|. If |frag| becomes complete, it clears
-// |frag->reassembly|. The range must be within the bounds of |frag|'s message
-// and |frag->reassembly| must not be NULL.
-static void dtls1_hm_fragment_mark(hm_fragment *frag, size_t start,
-                                   size_t end) {
-  size_t msg_len = frag->msg_len;
-
-  if (frag->reassembly == NULL || start > end || end > msg_len) {
-    assert(0);
-    return;
-  }
-  // A zero-length message will never have a pending reassembly.
-  assert(msg_len > 0);
-
-  if (start == end) {
-    return;
-  }
-
-  if ((start >> 3) == (end >> 3)) {
-    frag->reassembly[start >> 3] |= bit_range(start & 7, end & 7);
-  } else {
-    frag->reassembly[start >> 3] |= bit_range(start & 7, 8);
-    for (size_t i = (start >> 3) + 1; i < (end >> 3); i++) {
-      frag->reassembly[i] = 0xff;
-    }
-    if ((end & 7) != 0) {
-      frag->reassembly[end >> 3] |= bit_range(0, end & 7);
-    }
-  }
-
-  // Check if the fragment is complete.
-  for (size_t i = 0; i < (msg_len >> 3); i++) {
-    if (frag->reassembly[i] != 0xff) {
-      return;
-    }
-  }
-  if ((msg_len & 7) != 0 &&
-      frag->reassembly[msg_len >> 3] != bit_range(0, msg_len & 7)) {
-    return;
-  }
-
-  OPENSSL_free(frag->reassembly);
-  frag->reassembly = NULL;
-}
-
 // dtls1_is_current_message_complete returns whether the current handshake
 // message is complete.
 static bool dtls1_is_current_message_complete(const SSL *ssl) {
   size_t idx = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
   hm_fragment *frag = ssl->d1->incoming_messages[idx].get();
-  return frag != NULL && frag->reassembly == NULL;
+  return frag != nullptr && frag->reassembly.IsComplete();
 }
 
 // dtls1_get_incoming_message returns the incoming message corresponding to
@@ -306,8 +360,7 @@
     const size_t frag_off = msg_hdr.frag_off;
     const size_t frag_len = msg_hdr.frag_len;
     const size_t msg_len = msg_hdr.msg_len;
-    if (frag_off > msg_len || frag_off + frag_len < frag_off ||
-        frag_off + frag_len > msg_len ||
+    if (frag_off > msg_len || frag_len > msg_len - frag_off ||
         msg_len > ssl_max_handshake_message_len(ssl)) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
       *out_alert = SSL_AD_ILLEGAL_PARAMETER;
@@ -336,12 +389,12 @@
     }
 
     hm_fragment *frag = dtls1_get_incoming_message(ssl, out_alert, &msg_hdr);
-    if (frag == NULL) {
+    if (frag == nullptr) {
       return false;
     }
     assert(frag->msg_len == msg_len);
 
-    if (frag->reassembly == NULL) {
+    if (frag->reassembly.IsComplete()) {
       // The message is already assembled.
       continue;
     }
@@ -350,7 +403,7 @@
     // Copy the body into the fragment.
     OPENSSL_memcpy(frag->data + DTLS1_HM_HEADER_LENGTH + frag_off,
                    CBS_data(&body), CBS_len(&body));
-    dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len);
+    frag->reassembly.MarkRange(frag_off, frag_off + frag_len);
   }
 
   return true;
diff --git a/ssl/internal.h b/ssl/internal.h
index 4a1f84c..29a674c 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -3257,9 +3257,43 @@
 
 #define DTLS1_HM_HEADER_LENGTH 12
 
-#define DTLS1_CCS_HEADER_LENGTH 1
+// A DTLSMessageBitmap maintains a list of bits which may be marked to indicate
+// a portion of a message was received or ACKed.
+class DTLSMessageBitmap {
+ public:
+  // A Range represents a range of bits from |start|, inclusive, to |end|,
+  // exclusive.
+  struct Range {
+    size_t start = 0;
+    size_t end = 0;
 
-#define DTLS1_AL_HEADER_LENGTH 2
+    bool empty() const { return start == end; }
+    bool operator==(const Range &r) const {
+      return start == r.start && end == r.end;
+    }
+    bool operator!=(const Range &r) const { return !(*this == r); }
+  };
+
+  // Init initializes the structure with |num_bits| unmarked bits, from zero
+  // to |num_bits - 1|.
+  bool Init(size_t num_bits);
+
+  // MarkRange marks the bits from |start|, inclusive, to |end|, exclusive.
+  void MarkRange(size_t start, size_t end);
+
+  // NextUnmarkedRange returns the next range of unmarked bits, starting from
+  // |start|, inclusive. If all bits after |start| are marked, it returns an
+  // empty range.
+  Range NextUnmarkedRange(size_t start) const;
+
+  // IsComplete returns whether every bit in the bitmask has been marked.
+  bool IsComplete() const { return bytes_.empty(); }
+
+ private:
+  // bytes_ contains the unmarked bits. We maintain an invariant: if |bytes_| is
+  // not empty, some bit is unset.
+  Array<uint8_t> bytes_;
+};
 
 struct hm_header_st {
   uint8_t type;
@@ -3288,9 +3322,8 @@
   // data is a pointer to the message, including message header. It has length
   // |DTLS1_HM_HEADER_LENGTH| + |msg_len|.
   uint8_t *data = nullptr;
-  // reassembly is a bitmask of |msg_len| bits corresponding to which parts of
-  // the message have been received. It is NULL if the message is complete.
-  uint8_t *reassembly = nullptr;
+  // reassembly tracks which parts of the message have been received.
+  DTLSMessageBitmap reassembly;
 };
 
 struct OPENSSL_timeval {
diff --git a/ssl/ssl_internal_test.cc b/ssl/ssl_internal_test.cc
index f5982a2..2cca65a 100644
--- a/ssl/ssl_internal_test.cc
+++ b/ssl/ssl_internal_test.cc
@@ -482,6 +482,127 @@
   EXPECT_EQ(reconstruct_seqnum(0x8002, 0xffff, 0x10000), 0x8002u);
 }
 
+TEST(DTLSMessageBitmapTest, Basic) {
+  // expect_bitmap checks that |b|'s unmarked bits are those listed in |ranges|.
+  // Each element of |ranges| must be non-empty and non-overlapping, and
+  // |ranges| must be sorted.
+  auto expect_bitmap = [](const DTLSMessageBitmap &b,
+                          const std::vector<DTLSMessageBitmap::Range> &ranges) {
+    EXPECT_EQ(ranges.empty(), b.IsComplete());
+    size_t start = 0;
+    for (const auto &r : ranges) {
+      for (; start < r.start; start++) {
+        SCOPED_TRACE(start);
+        EXPECT_EQ(b.NextUnmarkedRange(start), r);
+      }
+      for (; start < r.end; start++) {
+        SCOPED_TRACE(start);
+        EXPECT_EQ(b.NextUnmarkedRange(start),
+                  (DTLSMessageBitmap::Range{start, r.end}));
+      }
+    }
+    EXPECT_TRUE(b.NextUnmarkedRange(start).empty());
+    EXPECT_TRUE(b.NextUnmarkedRange(start + 1).empty());
+    EXPECT_TRUE(b.NextUnmarkedRange(start + 42).empty());
+
+    // This is implied from the previous checks, but NextUnmarkedRange should
+    // work as an iterator to reconstruct the ranges.
+    std::vector<DTLSMessageBitmap::Range> got_ranges;
+    for (auto r = b.NextUnmarkedRange(0); !r.empty();
+         r = b.NextUnmarkedRange(r.end)) {
+      got_ranges.push_back(r);
+    }
+    EXPECT_EQ(ranges, got_ranges);
+  };
+
+  // Initially, the bitmap is empty (fully marked).
+  DTLSMessageBitmap bitmap;
+  expect_bitmap(bitmap, {});
+
+  // It can also be initialized to the empty message and marked.
+  ASSERT_TRUE(bitmap.Init(0));
+  expect_bitmap(bitmap, {});
+  bitmap.MarkRange(0, 0);
+  expect_bitmap(bitmap, {});
+
+  // Track 100 bits and mark byte by byte.
+  ASSERT_TRUE(bitmap.Init(100));
+  expect_bitmap(bitmap, {{0, 100}});
+  for (size_t i = 0; i < 100; i++) {
+    SCOPED_TRACE(i);
+    bitmap.MarkRange(i, i + 1);
+    if (i < 99) {
+      expect_bitmap(bitmap, {{i + 1, 100}});
+    } else {
+      expect_bitmap(bitmap, {});
+    }
+  }
+
+  // Do the same but in reverse.
+  ASSERT_TRUE(bitmap.Init(100));
+  expect_bitmap(bitmap, {{0, 100}});
+  for (size_t i = 100; i > 0; i--) {
+    SCOPED_TRACE(i);
+    bitmap.MarkRange(i - 1, i);
+    if (i > 1) {
+      expect_bitmap(bitmap, {{0, i - 1}});
+    } else {
+      expect_bitmap(bitmap, {});
+    }
+  }
+
+  // Overlapping ranges are fine.
+  ASSERT_TRUE(bitmap.Init(100));
+  expect_bitmap(bitmap, {{0, 100}});
+  for (size_t i = 0; i < 100; i++) {
+    SCOPED_TRACE(i);
+    bitmap.MarkRange(i / 2, i + 1);
+    if (i < 99) {
+      expect_bitmap(bitmap, {{i + 1, 100}});
+    } else {
+      expect_bitmap(bitmap, {});
+    }
+  }
+
+  // Mark the middle chunk of every power of 3.
+  ASSERT_TRUE(bitmap.Init(100));
+  bitmap.MarkRange(1, 2);
+  bitmap.MarkRange(3, 6);
+  bitmap.MarkRange(9, 18);
+  bitmap.MarkRange(27, 54);
+  bitmap.MarkRange(81, 162);
+  expect_bitmap(bitmap, {{0, 1}, {2, 3}, {6, 9}, {18, 27}, {54, 81}});
+
+  // Mark most of the chunk shifted down a bit, so it both overlaps the previous
+  // and also leaves some of the right chunks unmarked.
+  bitmap.MarkRange(6 - 2, 9 - 2);
+  bitmap.MarkRange(18 - 4, 27 - 4);
+  bitmap.MarkRange(54 - 8, 81 - 8);
+  expect_bitmap(bitmap,
+                {{0, 1}, {2, 3}, {9 - 2, 9}, {27 - 4, 27}, {81 - 8, 81}});
+
+  // Re-mark things that have already been marked.
+  bitmap.MarkRange(1, 2);
+  bitmap.MarkRange(3, 6);
+  bitmap.MarkRange(9, 18);
+  bitmap.MarkRange(27, 54);
+  bitmap.MarkRange(81, 162);
+  expect_bitmap(bitmap,
+                {{0, 1}, {2, 3}, {9 - 2, 9}, {27 - 4, 27}, {81 - 8, 81}});
+
+  // Moves should work.
+  DTLSMessageBitmap bitmap2 = std::move(bitmap);
+  expect_bitmap(bitmap, {});
+  expect_bitmap(bitmap2,
+                {{0, 1}, {2, 3}, {9 - 2, 9}, {27 - 4, 27}, {81 - 8, 81}});
+
+  // Mark everything in two large ranges.
+  bitmap2.MarkRange(27 - 2, 100);
+  expect_bitmap(bitmap2, {{0, 1}, {2, 3}, {9 - 2, 9}, {27 - 4, 27 - 2}});
+  bitmap2.MarkRange(0, 50);
+  expect_bitmap(bitmap2, {});
+}
+
 }  // namespace
 BSSL_NAMESPACE_END
 #endif  // !BORINGSSL_SHARED_LIBRARY