Test DTLSv1_get_timeout behavior

The next CL will rework it a bit, so add some tests for it.

Bug: 42290594
Change-Id: Ib2dc3068c446a22d27f87c238275ca740932b3ac
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72950
Reviewed-by: Nick Harper <nharper@chromium.org>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index 40e6230..885e5fe 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -845,7 +845,11 @@
 
   if (config->is_dtls) {
     bssl::UniquePtr<BIO> packeted = PacketedBioCreate(
-        GetClock(), [ssl_raw = ssl.get()](uint32_t mtu) -> bool {
+        GetClock(),
+        [ssl_raw = ssl.get()](timeval *out) -> bool {
+          return DTLSv1_get_timeout(ssl_raw, out);
+        },
+        [ssl_raw = ssl.get()](uint32_t mtu) -> bool {
           return SSL_set_mtu(ssl_raw, mtu);
         });
     if (!packeted) {
diff --git a/ssl/test/packeted_bio.cc b/ssl/test/packeted_bio.cc
index c7cccd7..f5e2a54 100644
--- a/ssl/test/packeted_bio.cc
+++ b/ssl/test/packeted_bio.cc
@@ -15,6 +15,7 @@
 #include "packeted_bio.h"
 
 #include <assert.h>
+#include <inttypes.h>
 #include <limits.h>
 #include <stdio.h>
 #include <string.h>
@@ -36,10 +37,15 @@
 constexpr uint8_t kOpcodeTimeout = 'T';
 constexpr uint8_t kOpcodeTimeoutAck = 't';
 constexpr uint8_t kOpcodeMTU = 'M';
+constexpr uint8_t kOpcodeExpectNextTimeout = 'E';
 
 struct PacketedBio {
-  PacketedBio(timeval *clock_arg, std::function<bool(uint32_t)> set_mtu_arg)
-      : clock(clock_arg), set_mtu(std::move(set_mtu_arg)) {
+  PacketedBio(timeval *clock_arg,
+              std::function<bool(timeval *)> get_timeout_arg,
+              std::function<bool(uint32_t)> set_mtu_arg)
+      : clock(clock_arg),
+        get_timeout(std::move(get_timeout_arg)),
+        set_mtu(std::move(set_mtu_arg)) {
     OPENSSL_memset(&timeout, 0, sizeof(timeout));
   }
 
@@ -49,6 +55,7 @@
 
   timeval timeout;
   timeval *clock;
+  std::function<bool(timeval *)> get_timeout;
   std::function<bool(uint32_t)> set_mtu;
 };
 
@@ -173,6 +180,50 @@
       continue;
     }
 
+    if (opcode == kOpcodeExpectNextTimeout) {
+      uint8_t buf[8];
+      ret = ReadAll(bio->next_bio, buf, sizeof(buf));
+      if (ret <= 0) {
+        BIO_copy_next_retry(bio);
+        return ret;
+      }
+      uint64_t expected = CRYPTO_load_u64_be(buf);
+      timeval timeout;
+      bool has_timeout = data->get_timeout(&timeout);
+      if (expected == UINT64_MAX) {
+        if (has_timeout) {
+          fprintf(stderr,
+                  "Expected no timeout, but got %" PRIu64 ".%06" PRIu64 "s.\n",
+                  static_cast<uint64_t>(timeout.tv_sec),
+                  static_cast<uint64_t>(timeout.tv_usec));
+          return -1;
+        }
+      } else {
+        expected /= 1000;  // Convert nanoseconds to microseconds.
+        uint64_t expected_sec = expected / 1000000;
+        uint64_t expected_usec = expected % 1000000;
+        if (!has_timeout) {
+          fprintf(stderr,
+                  "Expected timeout of %" PRIu64 ".%06" PRIu64
+                  "s, but got none.\n",
+                  expected_sec, expected_usec);
+          return -1;
+        }
+        if (static_cast<uint64_t>(timeout.tv_sec) != expected_sec ||
+            static_cast<uint64_t>(timeout.tv_usec) != expected_usec) {
+          fprintf(stderr,
+                  "Expected timeout of %" PRIu64 ".%06" PRIu64
+                  "s, but got %" PRIu64 ".%06" PRIu64 "s.\n",
+                  expected_sec, expected_usec,
+                  static_cast<uint64_t>(timeout.tv_sec),
+                  static_cast<uint64_t>(timeout.tv_usec));
+          return -1;
+        }
+      }
+      // Continue reading.
+      continue;
+    }
+
     if (opcode != kOpcodePacket) {
       fprintf(stderr, "Unknown opcode, %u\n", opcode);
       return -1;
@@ -249,13 +300,14 @@
 
 }  // namespace
 
-bssl::UniquePtr<BIO> PacketedBioCreate(timeval *clock,
-                                       std::function<bool(uint32_t)> set_mtu) {
+bssl::UniquePtr<BIO> PacketedBioCreate(
+    timeval *clock, std::function<bool(timeval *)> get_timeout,
+    std::function<bool(uint32_t)> set_mtu) {
   bssl::UniquePtr<BIO> bio(BIO_new(&g_packeted_bio_method));
   if (!bio) {
     return nullptr;
   }
-  bio->ptr = new PacketedBio(clock, std::move(set_mtu));
+  bio->ptr = new PacketedBio(clock, std::move(get_timeout), std::move(set_mtu));
   return bio;
 }
 
diff --git a/ssl/test/packeted_bio.h b/ssl/test/packeted_bio.h
index a064b10..b99e971 100644
--- a/ssl/test/packeted_bio.h
+++ b/ssl/test/packeted_bio.h
@@ -31,13 +31,17 @@
 
 // PacketedBioCreate creates a filter BIO which implements a reliable in-order
 // blocking datagram socket. It uses the value of |*clock| as the clock.
-// |set_mtu| will be called when the runner asks to change the MTU.
+// |get_timeout| should output what the |SSL| object believes is the next
+// timeout, or return false if there is none. It will be compared against
+// assertions from the runner. |set_mtu| will be called when the runner asks to
+// change the MTU.
 //
 // During a |BIO_read|, the peer may signal the filter BIO to simulate a
 // timeout. The operation will fail immediately. The caller must then call
 // |PacketedBioAdvanceClock| before retrying |BIO_read|.
-bssl::UniquePtr<BIO> PacketedBioCreate(timeval *clock,
-                                       std::function<bool(uint32_t)> set_mtu);
+bssl::UniquePtr<BIO> PacketedBioCreate(
+    timeval *clock, std::function<bool(timeval *)> get_timeout,
+    std::function<bool(uint32_t)> set_mtu);
 
 // PacketedBioAdvanceClock advances |bio|'s clock and returns true if there is a
 // pending timeout. Otherwise, it returns false.
diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go
index 266b9c3..2f3f842 100644
--- a/ssl/test/runner/dtls.go
+++ b/ssl/test/runner/dtls.go
@@ -1360,3 +1360,27 @@
 		return
 	}
 }
+
+// ExpectNextTimeout indicates the shim's next timeout should be d from now.
+func (c *DTLSController) ExpectNextTimeout(d time.Duration) {
+	if c.err != nil {
+		return
+	}
+	if err := c.conn.dtlsFlushPacket(); err != nil {
+		c.err = err
+		return
+	}
+	c.err = c.conn.config.Bugs.PacketAdaptor.ExpectNextTimeout(d)
+}
+
+// ExpectNoNext indicates the shim should not have a next timeout.
+func (c *DTLSController) ExpectNoNextTimeout() {
+	if c.err != nil {
+		return
+	}
+	if err := c.conn.dtlsFlushPacket(); err != nil {
+		c.err = err
+		return
+	}
+	c.err = c.conn.config.Bugs.PacketAdaptor.ExpectNoNextTimeout()
+}
diff --git a/ssl/test/runner/packet_adapter.go b/ssl/test/runner/packet_adapter.go
index bf1a0bb..9cbca37 100644
--- a/ssl/test/runner/packet_adapter.go
+++ b/ssl/test/runner/packet_adapter.go
@@ -8,6 +8,7 @@
 	"encoding/binary"
 	"fmt"
 	"io"
+	"math"
 	"net"
 	"slices"
 	"time"
@@ -30,6 +31,11 @@
 // opcodeMTU updates the shim's MTU, encoded as a 32-bit number of bytes.
 const opcodeMTU = byte('M')
 
+// opcodeExpectNextTimeout indicates that the shim should report a specified timeout
+// to the calling application. The timeout is encoded as in opcodeTimeout, but
+// MaxUint64 indicates there should be no timeout.
+const opcodeExpectNextTimeout = byte('E')
+
 type packetAdaptor struct {
 	net.Conn
 	debug *recordingConn
@@ -145,6 +151,24 @@
 	return err
 }
 
+// ExpectNextTimeout indicates the peer's next timeout should be d from now.
+func (p *packetAdaptor) ExpectNextTimeout(d time.Duration) error {
+	payload := make([]byte, 1+8)
+	payload[0] = opcodeExpectNextTimeout
+	binary.BigEndian.PutUint64(payload[1:], uint64(d.Nanoseconds()))
+	_, err := p.Conn.Write(payload)
+	return err
+}
+
+// ExpectNoNext indicates the peer should not have a next timeout.
+func (p *packetAdaptor) ExpectNoNextTimeout() error {
+	payload := make([]byte, 1+8)
+	payload[0] = opcodeExpectNextTimeout
+	binary.BigEndian.PutUint64(payload[1:], math.MaxUint64)
+	_, err := p.Conn.Write(payload)
+	return err
+}
+
 type replayAdaptor struct {
 	net.Conn
 	prevWrite []byte
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 2f4a87c..d1a782b 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -11591,9 +11591,11 @@
 					// Exercise every timeout but the last one (which would fail the
 					// connection).
 					for _, t := range useTimeouts[:len(useTimeouts)-1] {
+						c.ExpectNextTimeout(t)
 						c.AdvanceClock(t)
 						c.ReadRetransmit()
 					}
+					c.ExpectNextTimeout(useTimeouts[len(useTimeouts)-1])
 				}
 				// Finally release the whole flight to the shim.
 				c.WriteFlight(next)
@@ -11692,6 +11694,7 @@
 								//
 								// TODO(crbug.com/42290594): The shim should send a partial
 								// ACK to request that we retransmit.
+								c.ExpectNoNextTimeout()
 								for _, t := range useTimeouts {
 									c.AdvanceClock(t)
 								}
@@ -11717,13 +11720,19 @@
 						MaxVersion: vers.version,
 						Bugs: ProtocolBugs{
 							WriteFlightDTLS: func(c *DTLSController, prev, received, next []DTLSMessage, records []DTLSRecordNumberInfo) {
-								// Send a portion of the ServerHello. The rest was lost.
 								msg := next[0]
+								if msg.Type != typeServerHello {
+									// Leave the NewSessionTicket flight alone.
+									c.WriteFlight(next)
+									return
+								}
+								// Send a portion of the ServerHello. The rest was lost.
 								split := len(msg.Data) / 2
 								c.WriteFragments([]DTLSFragment{msg.Fragment(0, split)})
 
 								// The shim did not know this was DTLS 1.3, so it still
 								// retransmits ClientHello.
+								c.ExpectNextTimeout(useTimeouts[0])
 								c.AdvanceClock(useTimeouts[0])
 								c.ReadRetransmit()
 
@@ -11735,6 +11744,7 @@
 								// packet as EncryptedExtensions, which will trigger the case
 								// below.
 								c.WriteFragments([]DTLSFragment{msg.Fragment(split, len(msg.Data)-split)})
+								c.ExpectNextTimeout(useTimeouts[1])
 								c.AdvanceClock(useTimeouts[1])
 								c.ReadRetransmit()
 
@@ -11747,6 +11757,7 @@
 								//
 								// TODO(crbug.com/42290594): The shim should send a partial
 								// ACK to request that we retransmit.
+								c.ExpectNoNextTimeout()
 								for _, t := range useTimeouts[2:] {
 									c.AdvanceClock(t)
 								}
@@ -11775,9 +11786,11 @@
 					Bugs: ProtocolBugs{
 						WriteFlightDTLS: func(c *DTLSController, prev, received, next []DTLSMessage, records []DTLSRecordNumberInfo) {
 							for _, t := range useTimeouts[:len(useTimeouts)-1] {
+								c.ExpectNextTimeout(t)
 								c.AdvanceClock(t)
 								c.ReadRetransmit()
 							}
+							c.ExpectNextTimeout(useTimeouts[len(useTimeouts)-1])
 							c.AdvanceClock(useTimeouts[len(useTimeouts)-1])
 							// The shim should give up at this point.
 						},
@@ -11798,8 +11811,11 @@
 					MaxVersion: vers.version,
 					Bugs: ProtocolBugs{
 						WriteFlightDTLS: func(c *DTLSController, prev, received, next []DTLSMessage, records []DTLSRecordNumberInfo) {
-							c.AdvanceClock(useTimeouts[0] - 10*time.Millisecond)
-							c.ReadRetransmit()
+							if len(received) > 0 {
+								c.ExpectNextTimeout(useTimeouts[0])
+								c.AdvanceClock(useTimeouts[0] - 10*time.Millisecond)
+								c.ReadRetransmit()
+							}
 							c.WriteFlight(next)
 						},
 					},
@@ -11856,9 +11872,11 @@
 							MaxHandshakeRecordLength: 512,
 							WriteFlightDTLS: func(c *DTLSController, prev, received, next []DTLSMessage, records []DTLSRecordNumberInfo) {
 								if len(received) > 0 {
+									c.ExpectNextTimeout(useTimeouts[0])
 									c.WriteACK(c.OutEpoch(), records)
 									// After ACKing everything, the shim should stop the timer
 									// and wait for the next flight.
+									c.ExpectNoNextTimeout()
 									for _, t := range useTimeouts {
 										c.AdvanceClock(t)
 									}