Introduce a DTLSTimer abstraction

We'll need to carry a couple of timers with DTLS 1.3. Abstract this into
a DTLSTimer.

Bug: 42290594
Change-Id: I4d57dfae9c5984cb10f5db251642af5aaec9a495
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72951
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Nick Harper <nharper@chromium.org>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 94a70ae..48ccbbd 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -613,7 +613,7 @@
 // This duration overrides the default of 400 milliseconds, which is
 // recommendation of RFC 9147 for real-time protocols.
 OPENSSL_EXPORT void DTLSv1_set_initial_timeout_duration(SSL *ssl,
-                                                        unsigned duration_ms);
+                                                        uint32_t duration_ms);
 
 // DTLSv1_get_timeout queries the running DTLS timers. If there are any in
 // progress, it sets |*out| to the time remaining until the first timer expires
diff --git a/ssl/d1_lib.cc b/ssl/d1_lib.cc
index 55f46b1..c73a102 100644
--- a/ssl/d1_lib.cc
+++ b/ssl/d1_lib.cc
@@ -121,39 +121,79 @@
   ssl->d1 = NULL;
 }
 
+void DTLSTimer::StartMicroseconds(OPENSSL_timeval now, uint64_t microseconds) {
+  uint64_t seconds = microseconds / 1000000;
+  microseconds %= 1000000;
+
+  now.tv_usec += microseconds;
+  if (now.tv_usec >= 1000000) {
+    now.tv_usec -= 1000000;
+    seconds++;
+  }
+
+  if (now.tv_sec > UINT64_MAX - seconds) {
+    Stop();
+    return;
+  }
+  now.tv_sec += seconds;
+  expire_time_ = now;
+}
+
+void DTLSTimer::Stop() { expire_time_ = {0, 0}; }
+
+bool DTLSTimer::IsExpired(OPENSSL_timeval now) const {
+  return MicrosecondsRemaining(now) == 0;
+}
+
+bool DTLSTimer::IsSet() const {
+  return expire_time_.tv_sec != 0 || expire_time_.tv_usec != 0;
+}
+
+uint64_t DTLSTimer::MicrosecondsRemaining(OPENSSL_timeval now) const {
+  if (!IsSet()) {
+    return kNever;
+  }
+
+  if (now.tv_sec > expire_time_.tv_sec ||
+      (now.tv_sec == expire_time_.tv_sec &&
+       now.tv_usec >= expire_time_.tv_usec)) {
+    return 0;
+  }
+
+  uint64_t sec = expire_time_.tv_sec - now.tv_sec;
+  uint32_t usec;
+  if (expire_time_.tv_usec >= now.tv_usec) {
+    usec = expire_time_.tv_usec - now.tv_usec;
+  } else {
+    sec--;
+    usec = expire_time_.tv_usec + 1000000 - now.tv_usec;
+  }
+
+  // If remaining time is less than 15 ms, return 0 to prevent issues because of
+  // small divergences with socket timeouts.
+  if (sec == 0 && usec < 15000) {
+    return 0;
+  }
+
+  if (sec > UINT64_MAX / 1000000) {
+    return kNever;
+  }
+  sec *= 1000000;
+  if (sec > UINT64_MAX - usec) {
+    return kNever;
+  }
+  return sec + usec;
+}
+
 void dtls1_start_timer(SSL *ssl) {
   // If timer is not set, initialize duration.
-  if (ssl->d1->next_timeout.tv_sec == 0 && ssl->d1->next_timeout.tv_usec == 0) {
+  if (!ssl->d1->retransmit_timer.IsSet()) {
     ssl->d1->timeout_duration_ms = ssl->initial_timeout_duration_ms;
   }
 
-  // Set timeout to current time
-  ssl->d1->next_timeout = ssl_ctx_get_current_time(ssl->ctx.get());
-
-  // Add duration to current time
-  ssl->d1->next_timeout.tv_sec += ssl->d1->timeout_duration_ms / 1000;
-  ssl->d1->next_timeout.tv_usec += (ssl->d1->timeout_duration_ms % 1000) * 1000;
-  if (ssl->d1->next_timeout.tv_usec >= 1000000) {
-    ssl->d1->next_timeout.tv_sec++;
-    ssl->d1->next_timeout.tv_usec -= 1000000;
-  }
-}
-
-bool dtls1_is_timer_expired(SSL *ssl) {
-  struct timeval timeleft;
-
-  // Get time left until timeout, return false if no timer running
-  if (!DTLSv1_get_timeout(ssl, &timeleft)) {
-    return false;
-  }
-
-  // Return false if timer is not expired yet
-  if (timeleft.tv_sec > 0 || timeleft.tv_usec > 0) {
-    return false;
-  }
-
-  // Timer expired, so return true
-  return true;
+  OPENSSL_timeval now = ssl_ctx_get_current_time(ssl->ctx.get());
+  ssl->d1->retransmit_timer.StartMicroseconds(
+      now, uint64_t{ssl->d1->timeout_duration_ms} * 1000);
 }
 
 static void dtls1_double_timeout(SSL *ssl) {
@@ -165,7 +205,7 @@
 
 void dtls1_stop_timer(SSL *ssl) {
   ssl->d1->num_timeouts = 0;
-  ssl->d1->next_timeout = {0, 0};
+  ssl->d1->retransmit_timer.Stop();
   ssl->d1->timeout_duration_ms = ssl->initial_timeout_duration_ms;
 }
 
@@ -195,7 +235,7 @@
 
 using namespace bssl;
 
-void DTLSv1_set_initial_timeout_duration(SSL *ssl, unsigned int duration_ms) {
+void DTLSv1_set_initial_timeout_duration(SSL *ssl, uint32_t duration_ms) {
   ssl->initial_timeout_duration_ms = duration_ms;
 }
 
@@ -204,46 +244,25 @@
     return 0;
   }
 
-  // If no timeout is set, just return 0.
-  if (ssl->d1->next_timeout.tv_sec == 0 && ssl->d1->next_timeout.tv_usec == 0) {
-    return 0;
+  OPENSSL_timeval now = ssl_ctx_get_current_time(ssl->ctx.get());
+  uint64_t remaining_usec =
+      ssl->d1->retransmit_timer.MicrosecondsRemaining(now);
+  if (remaining_usec == DTLSTimer::kNever) {
+    return 0;  // No timeout is set.
   }
 
-  OPENSSL_timeval timenow = ssl_ctx_get_current_time(ssl->ctx.get());
+  uint64_t remaining_sec = remaining_usec / 1000000;
+  remaining_usec %= 1000000;
 
-  // If timer already expired, set remaining time to 0.
-  if (ssl->d1->next_timeout.tv_sec < timenow.tv_sec ||
-      (ssl->d1->next_timeout.tv_sec == timenow.tv_sec &&
-       ssl->d1->next_timeout.tv_usec <= timenow.tv_usec)) {
-    OPENSSL_memset(out, 0, sizeof(*out));
-    return 1;
-  }
-
-  // Calculate time left until timer expires.
-  OPENSSL_timeval ret = ssl->d1->next_timeout;
-  ret.tv_sec -= timenow.tv_sec;
-  if (ret.tv_usec >= timenow.tv_usec) {
-    ret.tv_usec -= timenow.tv_usec;
+  // |timeval| uses |time_t|, which may be 32-bit.
+  const auto kTvSecMax = std::numeric_limits<decltype(out->tv_sec)>::max();
+  if (remaining_sec > static_cast<uint64_t>(kTvSecMax)) {
+    out->tv_sec = kTvSecMax;  // Saturate the output.
+    out->tv_usec = 999999;
   } else {
-    ret.tv_usec = 1000000 + ret.tv_usec - timenow.tv_usec;
-    ret.tv_sec--;
+    out->tv_sec = static_cast<decltype(out->tv_sec)>(remaining_sec);
   }
-
-  // If remaining time is less than 15 ms, set it to 0 to prevent issues
-  // because of small divergences with socket timeouts.
-  if (ret.tv_sec == 0 && ret.tv_usec < 15000) {
-    ret = {0, 0};
-  }
-
-  // Clamp the result in case of overflow.
-  if (ret.tv_sec > INT_MAX) {
-    assert(0);
-    out->tv_sec = INT_MAX;
-  } else {
-    out->tv_sec = ret.tv_sec;
-  }
-
-  out->tv_usec = ret.tv_usec;
+  out->tv_usec = remaining_usec;
   return 1;
 }
 
@@ -256,7 +275,8 @@
   }
 
   // If no timer is expired, don't do anything.
-  if (!dtls1_is_timer_expired(ssl)) {
+  OPENSSL_timeval now = ssl_ctx_get_current_time(ssl->ctx.get());
+  if (!ssl->d1->retransmit_timer.IsExpired(now)) {
     return 0;
   }
 
diff --git a/ssl/internal.h b/ssl/internal.h
index f559b54..8dcc097 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -3463,6 +3463,41 @@
   uint32_t tv_usec;
 };
 
+struct DTLSTimer {
+ public:
+  static constexpr uint64_t kNever = UINT64_MAX;
+
+  // StartMicroseconds schedules the timer to expire the specified number of
+  // microseconds from |now|.
+  void StartMicroseconds(OPENSSL_timeval now, uint64_t microseconds);
+
+  // Stop disables the timer.
+  void Stop();
+
+  // IsExpired returns true if the timer was set and is expired at time |now|.
+  bool IsExpired(OPENSSL_timeval now) const;
+
+  // IsSet returns true if the timer is scheduled or expired, and false if it is
+  // stopped.
+  bool IsSet() const;
+
+  // MicrosecondsRemaining returns the time remaining, in microseconds, at
+  // |now|, or |kNever| if the timer is unset.
+  uint64_t MicrosecondsRemaining(OPENSSL_timeval now) const;
+
+ private:
+  // expire_time_ is the time when the timer expires, or zero if the timer is
+  // unset.
+  //
+  // TODO(crbug.com/366284846): This is an extremely inconvenient time
+  // representation. Switch libssl to something like a 64-bit count of
+  // microseconds. While it's decidedly past 1970 now, zero is a less obviously
+  // sound distinguished value for the monotonic clock, so maybe we should use a
+  // different distinguished time, like |INT64_MAX| in the microseconds
+  // representation.
+  OPENSSL_timeval expire_time_ = {0, 0};
+};
+
 // DTLS_MAX_EXTRA_WRITE_EPOCHS is the maximum number of additional write epochs
 // that DTLS may need to retain.
 //
@@ -3574,12 +3609,12 @@
   // the last time it was reset.
   unsigned num_timeouts = 0;
 
-  // Indicates when the last handshake msg or heartbeat sent will
-  // timeout.
-  struct OPENSSL_timeval next_timeout = {0, 0};
+  // retransmit_timer tracks when to schedule the next DTLS retransmit if we do
+  // not hear from the peer.
+  DTLSTimer retransmit_timer;
 
   // timeout_duration_ms is the timeout duration in milliseconds.
-  unsigned timeout_duration_ms = 0;
+  uint32_t timeout_duration_ms = 0;
 };
 
 // An ALPSConfig is a pair of ALPN protocol and settings value to use with ALPS.
@@ -3944,7 +3979,6 @@
 
 void dtls1_start_timer(SSL *ssl);
 void dtls1_stop_timer(SSL *ssl);
-bool dtls1_is_timer_expired(SSL *ssl);
 unsigned int dtls1_min_mtu(void);
 
 bool dtls1_new(SSL *ssl);
@@ -4440,7 +4474,7 @@
   // initial_timeout_duration_ms is the default DTLS timeout duration in
   // milliseconds. It's used to initialize the timer any time it's restarted. We
   // default to RFC 9147's recommendation for real-time applications, 400ms.
-  unsigned initial_timeout_duration_ms = 400;
+  uint32_t initial_timeout_duration_ms = 400;
 
   // session is the configured session to be offered by the client. This session
   // is immutable.