Add tests for reconstruct_seqnum.

Bug: 715
Change-Id: Ibb8ae0c152477eb5aa035582fac06368ef3c7c1e
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/70347
Auto-Submit: Nick Harper <nharper@chromium.org>
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/dtls_record.cc b/ssl/dtls_record.cc
index 1d32630..783e950 100644
--- a/ssl/dtls_record.cc
+++ b/ssl/dtls_record.cc
@@ -171,24 +171,25 @@
   return epoch;
 }
 
-// reconstruct_seqnum returns the smallest sequence number that hasn't been seen
-// in |bitmap| and is still within |bitmap|'s window to handle as a reordered
-// record.
-//
-// Section 4.2.2 of RFC 9147 describes an algorithm for reconstructing sequence
-// numbers, which is implemented here. This algorithm finds the sequence number
-// that is numerically closest to one plus the largest sequence number seen in
-// this epoch.
-static uint64_t reconstruct_seqnum(uint16_t wire_seq, uint64_t seq_mask,
-                                   DTLS1_BITMAP *bitmap) {
-  uint64_t max_seqnum_plus_one = bitmap->max_seq_num + 1;
+uint64_t reconstruct_seqnum(uint16_t wire_seq, uint64_t seq_mask,
+                            uint64_t max_valid_seqnum) {
+  uint64_t max_seqnum_plus_one = max_valid_seqnum + 1;
   uint64_t diff = (wire_seq - max_seqnum_plus_one) & seq_mask;
   uint64_t step = seq_mask + 1;
   uint64_t seqnum = max_seqnum_plus_one + diff;
-  // diff is always non-negative, so seqnum is >= max_seqnum_plus_one. If the
-  // diff is larger than half the step size, then the numerically closest
-  // sequence number is less than max_seqnum_plus_one instead of greater.
-  if (diff > step / 2) {
+  // seqnum is computed as the addition of 3 non-negative values
+  // (max_valid_seqnum, 1, and diff). The values 1 and diff are small (relative
+  // to the size of a uint64_t), while max_valid_seqnum can span the range of
+  // all uint64_t values. If seqnum is less than max_valid_seqnum, then the
+  // addition overflowed.
+  bool overflowed = seqnum < max_valid_seqnum;
+  // If the diff is larger than half the step size, then the closest seqnum
+  // to max_seqnum_plus_one (in Z_{2^64}) is seqnum minus step instead of
+  // seqnum.
+  bool closer_is_less = diff > step / 2;
+  // Subtracting step from seqnum will cause underflow if seqnum is too small.
+  bool would_underflow = seqnum < step;
+  if (overflowed || (closer_is_less && !would_underflow)) {
     seqnum -= step;
   }
   return seqnum;
@@ -216,7 +217,8 @@
       // The record header was incomplete or malformed.
       return false;
     }
-    *out_sequence = reconstruct_seqnum(seq, 0xffff, &ssl->d1->bitmap);
+    *out_sequence =
+        reconstruct_seqnum(seq, 0xffff, ssl->d1->bitmap.max_seq_num);
   } else {
     // 8-bit sequence number.
     uint8_t seq;
@@ -224,7 +226,7 @@
       // The record header was incomplete or malformed.
       return false;
     }
-    *out_sequence = reconstruct_seqnum(seq, 0xff, &ssl->d1->bitmap);
+    *out_sequence = reconstruct_seqnum(seq, 0xff, ssl->d1->bitmap.max_seq_num);
   }
   *out_header_len = packet_size - CBS_len(in);
   if ((type & 0x04) == 0x04) {
diff --git a/ssl/internal.h b/ssl/internal.h
index 8cf0339..febb676 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -965,6 +965,17 @@
   uint64_t max_seq_num = 0;
 };
 
+// reconstruct_seqnum takes the low order bits of a record sequence number from
+// the wire and reconstructs the full sequence number. It does so using the
+// algorithm described in section 4.2.2 of RFC 9147, where |wire_seq| is the
+// low bits of the sequence number as seen on the wire, |seq_mask| is a bitmask
+// of 8 or 16 1 bits corresponding to the length of the sequence number on the
+// wire, and |max_valid_seqnum| is the largest sequence number of a record
+// successfully deprotected in this epoch. This function returns the sequence
+// number that is numerically closest to one plus |max_valid_seqnum| that when
+// bitwise and-ed with |seq_mask| equals |wire_seq|.
+OPENSSL_EXPORT uint64_t reconstruct_seqnum(uint16_t wire_seq, uint64_t seq_mask,
+                                           uint64_t max_valid_seqnum);
 
 // Record layer.
 
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index ba1a22b..1b71c97 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -631,6 +631,127 @@
   }
 }
 
+TEST(ReconstructSeqnumTest, Increment) {
+  // Test simple cases from the beginning of an epoch with both 8- and 16-bit
+  // wire sequence numbers.
+  EXPECT_EQ(reconstruct_seqnum(0, 0xff, 0), 0u);
+  EXPECT_EQ(reconstruct_seqnum(1, 0xff, 0), 1u);
+  EXPECT_EQ(reconstruct_seqnum(2, 0xff, 0), 2u);
+  EXPECT_EQ(reconstruct_seqnum(0, 0xffff, 0), 0u);
+  EXPECT_EQ(reconstruct_seqnum(1, 0xffff, 0), 1u);
+  EXPECT_EQ(reconstruct_seqnum(2, 0xffff, 0), 2u);
+
+  // When the max seen sequence number is 0, the numerically closest
+  // reconstructed sequence number could be negative. Sequence numbers are
+  // non-negative, so reconstruct_seqnum should instead return the closest
+  // non-negative number instead of returning a number congruent to that
+  // closest negative number mod 2^64.
+  EXPECT_EQ(reconstruct_seqnum(0xff, 0xff, 0), 0xffu);
+  EXPECT_EQ(reconstruct_seqnum(0xfe, 0xff, 0), 0xfeu);
+  EXPECT_EQ(reconstruct_seqnum(0xffff, 0xffff, 0), 0xffffu);
+  EXPECT_EQ(reconstruct_seqnum(0xfffe, 0xffff, 0), 0xfffeu);
+
+  // When the wire sequence number is less than the corresponding low bytes of
+  // the max seen sequence number, check that the next larger sequence number
+  // is reconstructed as its numerically closer than the corresponding sequence
+  // number that would keep the high order bits the same.
+  EXPECT_EQ(reconstruct_seqnum(0, 0xff, 0xff), 0x100u);
+  EXPECT_EQ(reconstruct_seqnum(1, 0xff, 0xff), 0x101u);
+  EXPECT_EQ(reconstruct_seqnum(2, 0xff, 0xff), 0x102u);
+  EXPECT_EQ(reconstruct_seqnum(0, 0xffff, 0xffff), 0x10000u);
+  EXPECT_EQ(reconstruct_seqnum(1, 0xffff, 0xffff), 0x10001u);
+  EXPECT_EQ(reconstruct_seqnum(2, 0xffff, 0xffff), 0x10002u);
+
+  // Test cases when the wire sequence number is close to the largest magnitude
+  // that can be represented in 8 or 16 bits.
+  EXPECT_EQ(reconstruct_seqnum(0xff, 0xff, 0x2f0), 0x2ffu);
+  EXPECT_EQ(reconstruct_seqnum(0xfe, 0xff, 0x2f0), 0x2feu);
+  EXPECT_EQ(reconstruct_seqnum(0xffff, 0xffff, 0x2f000), 0x2ffffu);
+  EXPECT_EQ(reconstruct_seqnum(0xfffe, 0xffff, 0x2f000), 0x2fffeu);
+
+  // Test that reconstruct_seqnum can return
+  // std::numeric_limits<uint64_t>::max().
+  EXPECT_EQ(reconstruct_seqnum(0xff, 0xff, 0xffffffffffffffff),
+            std::numeric_limits<uint64_t>::max());
+  EXPECT_EQ(reconstruct_seqnum(0xff, 0xff, 0xfffffffffffffffe),
+            std::numeric_limits<uint64_t>::max());
+  EXPECT_EQ(reconstruct_seqnum(0xffff, 0xffff, 0xffffffffffffffff),
+            std::numeric_limits<uint64_t>::max());
+  EXPECT_EQ(reconstruct_seqnum(0xffff, 0xffff, 0xfffffffffffffffe),
+            std::numeric_limits<uint64_t>::max());
+}
+
+TEST(ReconstructSeqnumTest, Decrement) {
+  // Test that the sequence number 0 can be reconstructed when the max
+  // seen sequence number is greater than 0.
+  EXPECT_EQ(reconstruct_seqnum(0, 0xff, 0x10), 0u);
+  EXPECT_EQ(reconstruct_seqnum(0, 0xffff, 0x1000), 0u);
+
+  // Test cases where the reconstructed sequence number is less than the max
+  // seen sequence number.
+  EXPECT_EQ(reconstruct_seqnum(0, 0xff, 0x210), 0x200u);
+  EXPECT_EQ(reconstruct_seqnum(2, 0xff, 0x210), 0x202u);
+  EXPECT_EQ(reconstruct_seqnum(0, 0xffff, 0x43210), 0x40000u);
+  EXPECT_EQ(reconstruct_seqnum(2, 0xffff, 0x43210), 0x40002u);
+
+  // Test when the wire sequence number is greater than the low bits of the
+  // max seen sequence number.
+  EXPECT_EQ(reconstruct_seqnum(0xff, 0xff, 0x200), 0x1ffu);
+  EXPECT_EQ(reconstruct_seqnum(0xfe, 0xff, 0x200), 0x1feu);
+  EXPECT_EQ(reconstruct_seqnum(0xffff, 0xffff, 0x20000), 0x1ffffu);
+  EXPECT_EQ(reconstruct_seqnum(0xfffe, 0xffff, 0x20000), 0x1fffeu);
+
+  // Test when the max seen sequence number is close to the uint64_t max value.
+  // In some cases, the closest numerical value in the integers will overflow
+  // a uint64_t. Instead of returning the closest value in Z_{2^64},
+  // reconstruct_seqnum should return the closest integer less than 2^64, even
+  // if there is a closer value greater than 2^64.
+  EXPECT_EQ(reconstruct_seqnum(0, 0xff, 0xffffffffffffffff),
+            0xffffffffffffff00u);
+  EXPECT_EQ(reconstruct_seqnum(0, 0xff, 0xfffffffffffffffe),
+            0xffffffffffffff00u);
+  EXPECT_EQ(reconstruct_seqnum(1, 0xff, 0xffffffffffffffff),
+            0xffffffffffffff01u);
+  EXPECT_EQ(reconstruct_seqnum(1, 0xff, 0xfffffffffffffffe),
+            0xffffffffffffff01u);
+  EXPECT_EQ(reconstruct_seqnum(0xfe, 0xff, 0xffffffffffffffff),
+            0xfffffffffffffffeu);
+  EXPECT_EQ(reconstruct_seqnum(0xfd, 0xff, 0xfffffffffffffffe),
+            0xfffffffffffffffdu);
+  EXPECT_EQ(reconstruct_seqnum(0, 0xffff, 0xffffffffffffffff),
+            0xffffffffffff0000u);
+  EXPECT_EQ(reconstruct_seqnum(0, 0xffff, 0xfffffffffffffffe),
+            0xffffffffffff0000u);
+  EXPECT_EQ(reconstruct_seqnum(1, 0xffff, 0xffffffffffffffff),
+            0xffffffffffff0001u);
+  EXPECT_EQ(reconstruct_seqnum(1, 0xffff, 0xfffffffffffffffe),
+            0xffffffffffff0001u);
+  EXPECT_EQ(reconstruct_seqnum(0xfffe, 0xffff, 0xffffffffffffffff),
+            0xfffffffffffffffeu);
+  EXPECT_EQ(reconstruct_seqnum(0xfffd, 0xffff, 0xfffffffffffffffe),
+            0xfffffffffffffffdu);
+}
+
+TEST(ReconstructSeqnumTest, Halfway) {
+  // Test wire sequence numbers that are close to halfway away from the max
+  // seen sequence number. The algorithm specifies that the output should be
+  // numerically closest to 1 plus the max seen (0x100 in the following test
+  // cases). With a max seen of 0x100 and a wire sequence of 0x81, the two
+  // closest values to 1+0x100 are 0x81 and 0x181, which are both the same
+  // amount away. The algorithm doesn't specify what to do on this edge case;
+  // our implementation chooses the larger value (0x181), on the assumption that
+  // it's more likely to be a new or larger sequence number rather than a replay
+  // or an out-of-order packet.
+  EXPECT_EQ(reconstruct_seqnum(0x80, 0xff, 0x100), 0x180u);
+  EXPECT_EQ(reconstruct_seqnum(0x81, 0xff, 0x100), 0x181u);
+  EXPECT_EQ(reconstruct_seqnum(0x82, 0xff, 0x100), 0x82u);
+
+  // Repeat these tests with 16-bit wire sequence numbers.
+  EXPECT_EQ(reconstruct_seqnum(0x8000, 0xffff, 0x10000), 0x18000u);
+  EXPECT_EQ(reconstruct_seqnum(0x8001, 0xffff, 0x10000), 0x18001u);
+  EXPECT_EQ(reconstruct_seqnum(0x8002, 0xffff, 0x10000), 0x8002u);
+}
+
 TEST(SSLTest, CipherRules) {
   for (const CipherTest &t : kCipherTests) {
     SCOPED_TRACE(t.rule);