Schedule ACKs when we receive a partial flight

This implements a very simple ACK policy: for every ACKable record (i.e.
contains no far-future fragments that we ignore), add it to a queue of
record numbers to ACK and set an ACK timer at 1/4 of the current
retransmission timeout.

RFC 9147 doesn't say a whole lot, but this is arguably slightly
different from the recommended policy. RFC 9147 has some text with
implies you're only meant to ACK the current flight and not arbitrarily
old message fragments. However, tracking that in the fully general case with
post-handshake messages is unclear. There's no harm in ACKing those
packets, so start with this. See discussion in
https://mailarchive.ietf.org/arch/msg/tls/kjJnquJOVaWxu5hUCmNzB35eqY0/

Something kind of fun is that this provision in the spec happens for
free:

> ACKs SHOULD NOT be sent for these flights unless the responding flight
> cannot be generated immediately. All other flights MUST be ACKed. In
> this case, implementations MAY send explicit ACKs for the complete
> received flight even though it will eventually also be implicitly
> acknowledged through the responding flight. A notable example for this
> is the case of client authentication in constrained environments,
> where generating the CertificateVerify message can take considerable
> time on the client.

If we generate the next flight before the ACK timer, it will be canceled
and we don't ACK. If the next flight is async and takes too long, the
ACK timer will win and we tell the peer not to retransmit.

Bug: 42290594
Change-Id: I7974499f82ce2b2c7da91f02ca65886b0f82896c
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72952
Reviewed-by: Nick Harper <nharper@chromium.org>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index a736954..e5cc7ac 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -445,6 +445,22 @@
 
   if (!skipped_fragments) {
     ssl->d1->records_to_ack.PushBack(record_number);
+
+    if (ssl_has_final_version(ssl) &&
+        ssl_protocol_version(ssl) >= TLS1_3_VERSION &&
+        !ssl->d1->ack_timer.IsSet()) {
+      // Schedule sending an ACK. The delay serves several purposes:
+      // - If there are more records to come, we send only one ACK.
+      // - If there are more records to come and the flight is now complete, we
+      //   will send the reply (which implicitly ACKs the previous flight) and
+      //   cancel the timer.
+      // - If there are more records to come, the flight is now complete, but
+      //   generating the response is delayed (e.g. a slow, async private key),
+      //   the timer will fire and we send an ACK anyway.
+      OPENSSL_timeval now = ssl_ctx_get_current_time(ssl->ctx.get());
+      ssl->d1->ack_timer.StartMicroseconds(
+          now, uint64_t{ssl->d1->timeout_duration_ms} * 1000 / 4);
+    }
   }
 
   return true;
@@ -998,10 +1014,8 @@
     // to ACK previous records. This clears the ACK buffer slightly earlier than
     // the specification suggests. See the discussion in
     // https://mailarchive.ietf.org/arch/msg/tls/kjJnquJOVaWxu5hUCmNzB35eqY0/
-    //
-    // TODO(crbug.com/42290594): When we introduce the ACK timer, this should
-    // also stop the ACK timer.
     ssl->d1->records_to_ack.Clear();
+    ssl->d1->ack_timer.Stop();
   }
   // Start the retransmission timer for the next flight (if any).
   dtls1_start_timer(ssl);
@@ -1010,6 +1024,7 @@
 
 int dtls1_send_ack(SSL *ssl) {
   assert(ssl_protocol_version(ssl) >= TLS1_3_VERSION);
+  ssl->d1->ack_timer.Stop();
   if (ssl->d1->records_to_ack.empty()) {
     return 1;
   }
diff --git a/ssl/d1_lib.cc b/ssl/d1_lib.cc
index c73a102..2c7a1cd 100644
--- a/ssl/d1_lib.cc
+++ b/ssl/d1_lib.cc
@@ -247,6 +247,8 @@
   OPENSSL_timeval now = ssl_ctx_get_current_time(ssl->ctx.get());
   uint64_t remaining_usec =
       ssl->d1->retransmit_timer.MicrosecondsRemaining(now);
+  remaining_usec =
+      std::min(remaining_usec, ssl->d1->ack_timer.MicrosecondsRemaining(now));
   if (remaining_usec == DTLSTimer::kNever) {
     return 0;  // No timeout is set.
   }
@@ -274,17 +276,34 @@
     return -1;
   }
 
-  // If no timer is expired, don't do anything.
-  OPENSSL_timeval now = ssl_ctx_get_current_time(ssl->ctx.get());
-  if (!ssl->d1->retransmit_timer.IsExpired(now)) {
+  if (!ssl->d1->ack_timer.IsSet() && !ssl->d1->retransmit_timer.IsSet()) {
+    // No timers are running. Don't bother querying the clock.
     return 0;
   }
 
-  if (!dtls1_check_timeout_num(ssl)) {
-    return -1;
+  OPENSSL_timeval now = ssl_ctx_get_current_time(ssl->ctx.get());
+  bool any_timer_expired = false;
+  if (ssl->d1->ack_timer.IsExpired(now)) {
+    any_timer_expired = true;
+    int ret = dtls1_send_ack(ssl);
+    if (ret <= 0) {
+      return ret;
+    }
   }
 
-  dtls1_double_timeout(ssl);
-  dtls1_start_timer(ssl);
-  return dtls1_retransmit_outgoing_messages(ssl);
+  if (ssl->d1->retransmit_timer.IsExpired(now)) {
+    any_timer_expired = true;
+    if (!dtls1_check_timeout_num(ssl)) {
+      return -1;
+    }
+
+    dtls1_double_timeout(ssl);
+    dtls1_start_timer(ssl);
+    int ret = dtls1_retransmit_outgoing_messages(ssl);
+    if (ret <= 0) {
+      return ret;
+    }
+  }
+
+  return any_timer_expired ? 1 : 0;
 }
diff --git a/ssl/internal.h b/ssl/internal.h
index 8dcc097..a8ceb16 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -3613,6 +3613,9 @@
   // not hear from the peer.
   DTLSTimer retransmit_timer;
 
+  // ack_timer tracks when to send an ACK.
+  DTLSTimer ack_timer;
+
   // timeout_duration_ms is the timeout duration in milliseconds.
   uint32_t timeout_duration_ms = 0;
 };
diff --git a/ssl/test/handshake_util.cc b/ssl/test/handshake_util.cc
index 65219d8..d5e43e7 100644
--- a/ssl/test/handshake_util.cc
+++ b/ssl/test/handshake_util.cc
@@ -96,6 +96,29 @@
       test_state->early_callback_ready = true;
       return true;
     case SSL_ERROR_WANT_PRIVATE_KEY_OPERATION:
+      if (config->private_key_delay_ms != 0 &&
+          test_state->private_key_retries == 0) {
+        // The first time around, simulate the private key operation taking a
+        // long time to run.
+        if (test_state->packeted_bio == nullptr) {
+          fprintf(stderr, "-private-key-delay-ms requires DTLS.\n");
+          return false;
+        }
+        timeval *clock = PacketedBioGetClock(test_state->packeted_bio);
+        clock->tv_sec += config->private_key_delay_ms / 1000;
+        clock->tv_usec += config->private_key_delay_ms * 1000;
+        if (clock->tv_usec >= 1000000) {
+          clock->tv_usec -= 1000000;
+          clock->tv_sec++;
+        }
+        AsyncBioEnforceWriteQuota(test_state->async_bio, false);
+        int timeout_ret = DTLSv1_handle_timeout(ssl);
+        AsyncBioEnforceWriteQuota(test_state->async_bio, true);
+        if (timeout_ret < 0) {
+          fprintf(stderr, "Error retransmitting.\n");
+          return false;
+        }
+      }
       test_state->private_key_retries++;
       return true;
     case SSL_ERROR_WANT_CERTIFICATE_VERIFY:
diff --git a/ssl/test/packeted_bio.cc b/ssl/test/packeted_bio.cc
index f5e2a54..9406550 100644
--- a/ssl/test/packeted_bio.cc
+++ b/ssl/test/packeted_bio.cc
@@ -328,3 +328,11 @@
   OPENSSL_memset(&data->timeout, 0, sizeof(data->timeout));
   return true;
 }
+
+timeval *PacketedBioGetClock(BIO *bio) {
+  PacketedBio *data = GetData(bio);
+  if (data == nullptr) {
+    return nullptr;
+  }
+  return data->clock;
+}
diff --git a/ssl/test/packeted_bio.h b/ssl/test/packeted_bio.h
index b99e971..9b2e50e 100644
--- a/ssl/test/packeted_bio.h
+++ b/ssl/test/packeted_bio.h
@@ -47,5 +47,8 @@
 // pending timeout. Otherwise, it returns false.
 bool PacketedBioAdvanceClock(BIO *bio);
 
+// PacketedBioAdvanceClock return's |bio|'s clock.
+timeval *PacketedBioGetClock(BIO *bio);
+
 
 #endif  // HEADER_PACKETED_BIO
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 9dddd13..9f90722 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -1223,11 +1223,14 @@
 	// from the client certificate are sent over these keys.
 	c.useOutTrafficSecret(uint16(encryptionApplication), c.wireVersion, hs.suite, serverTrafficSecret)
 
-	if err := c.flushHandshake(); err != nil {
-		return err
-	}
-
-	if encryptedExtensions.extensions.hasEarlyData {
+	// In TLS, we need to consume EndOfEarlyData, and also test early data that
+	// was only partially written while reading the ServerHello. Both of these
+	// require flushing ServerHello first. Neither of these apply to DTLS, where
+	// we need to flush after installing handshake keys.
+	if encryptedExtensions.extensions.hasEarlyData && !c.isDTLS {
+		if err := c.flushHandshake(); err != nil {
+			return err
+		}
 		for _, expectedMsg := range config.Bugs.ExpectLateEarlyData {
 			if err := c.readRecord(recordTypeApplicationData); err != nil {
 				return err
@@ -1251,6 +1254,12 @@
 		return err
 	}
 
+	// DTLS testing requires this flush occur after installing handshake keys,
+	// so that we can process ACKs.
+	if err := c.flushHandshake(); err != nil {
+		return err
+	}
+
 	// If we sent an ALPS extension, the client must respond with a single EncryptedExtensions.
 	if encryptedExtensions.extensions.hasApplicationSettings || encryptedExtensions.extensions.hasApplicationSettingsOld {
 		clientEncryptedExtensions, err := readHandshakeType[clientEncryptedExtensionsMsg](c)
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 6ee603f..29c1231 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -11686,15 +11686,24 @@
 						DefaultCurves: []CurveID{}, // Force HelloRetryRequest.
 						Bugs: ProtocolBugs{
 							WriteFlightDTLS: func(c *DTLSController, prev, received, next []DTLSMessage, records []DTLSRecordNumberInfo) {
+								if len(received) == 0 && next[0].Type == typeClientHello {
+									// Send the initial ClientHello as-is.
+									c.WriteFlight(next)
+									return
+								}
+
 								// Send a portion of the first message. The rest was lost.
 								msg := next[0]
 								split := len(msg.Data) / 2
 								c.WriteFragments([]DTLSFragment{msg.Fragment(0, split)})
-								// This is enough to ACK the previous flight. The shim
-								// should stop retransmitting and even stop the timer.
-								//
-								// TODO(crbug.com/42290594): The shim should send a partial
-								// ACK to request that we retransmit.
+								// After waiting the current timeout, the shim should ACK
+								// the partial flight.
+								c.ExpectNextTimeout(useTimeouts[0] / 4)
+								c.AdvanceClock(useTimeouts[0] / 4)
+								c.ReadACK(c.InEpoch())
+								// The partial flight is enough to ACK the previous flight.
+								// The shim should stop retransmitting and even stop the
+								// retransmit timer.
 								c.ExpectNoNextTimeout()
 								for _, t := range useTimeouts {
 									c.AdvanceClock(t)
@@ -11723,7 +11732,9 @@
 							WriteFlightDTLS: func(c *DTLSController, prev, received, next []DTLSMessage, records []DTLSRecordNumberInfo) {
 								msg := next[0]
 								if msg.Type != typeServerHello {
-									// Leave the NewSessionTicket flight alone.
+									// TODO(crbug.com/42290594): Do not manipulate NewSessionTicket
+									// flights for now. The shim actually does now ACK those on a
+									// timer, but we'll need to test those more explicitly.
 									c.WriteFlight(next)
 									return
 								}
@@ -11749,15 +11760,16 @@
 								c.AdvanceClock(useTimeouts[1])
 								c.ReadRetransmit()
 
-								// Send EncryptedExtensions. The shim now knows the version
-								// and implicitly ACKs everything.
+								// Send EncryptedExtensions. The shim now knows the version.
 								c.WriteFragments([]DTLSFragment{next[1].Fragment(0, len(next[1].Data))})
 
+								// The shim should ACK the partial flight. The shim hasn't
+								// gotten to epoch 3 yet, so the ACK will come in epoch 2.
+								c.AdvanceClock(useTimeouts[2] / 4)
+								c.ReadACK(uint16(encryptionHandshake))
+
 								// This is enough to ACK the previous flight. The shim
 								// should stop retransmitting and even stop the timer.
-								//
-								// TODO(crbug.com/42290594): The shim should send a partial
-								// ACK to request that we retransmit.
 								c.ExpectNoNextTimeout()
 								for _, t := range useTimeouts[2:] {
 									c.AdvanceClock(t)
@@ -12318,6 +12330,97 @@
 					shouldFail:    true,
 					expectedError: ":READ_TIMEOUT_EXPIRED:",
 				})
+
+				// If generating the reply to a flight takes time (generating a
+				// CertificateVerify for a client certificate), the shim should
+				// send an ACK.
+				testCases = append(testCases, testCase{
+					protocol: dtls,
+					name:     "DTLS-Retransmit-SlowReplyGeneration" + suffix,
+					config: Config{
+						MaxVersion: vers.version,
+						ClientAuth: RequireAnyClientCert,
+						Bugs: ProtocolBugs{
+							WriteFlightDTLS: func(c *DTLSController, prev, received, next []DTLSMessage, records []DTLSRecordNumberInfo) {
+								c.WriteFlight(next)
+								if next[0].Type == typeServerHello {
+									// The shim will reply with Certificate..Finished, but
+									// take time to do so. In that time, it should schedule
+									// an ACK so the runner knows not to retransmit.
+									c.ReadACK(c.InEpoch())
+								}
+							},
+						},
+					},
+					shimCertificate: &rsaCertificate,
+					// Simulate it taking time to generate the reply.
+					flags: slices.Concat(flags, []string{"-private-key-delay-ms", strconv.Itoa(int(useTimeouts[0].Milliseconds()))}),
+				})
+
+				// BoringSSL's ACK policy may schedule both retransmit and ACK
+				// timers in parallel.
+				//
+				// TODO(crbug.com/42290594): This is only possible during the
+				// handshake because we're willing to ACK old flights without
+				// trying to distinguish these cases. However, post-handshake
+				// messages will exercise this, so that may be a better version
+				// of this test. In-handshake, it's kind of a waste to ACK this,
+				// so maybe we should stop.
+				testCases = append(testCases, testCase{
+					protocol: dtls,
+					name:     "DTLS-Retransmit-BothTimers" + suffix,
+					config: Config{
+						MaxVersion: vers.version,
+						Bugs: ProtocolBugs{
+							// Arrange for there to be two server flights.
+							SendHelloRetryRequestCookie: []byte("cookie"),
+							WriteFlightDTLS: func(c *DTLSController, prev, received, next []DTLSMessage, records []DTLSRecordNumberInfo) {
+								if next[0].Sequence == 0 || next[0].Type != typeServerHello {
+									// Send the first flight (HelloRetryRequest) as-is,
+									// as well as any post-handshake flights.
+									c.WriteFlight(next)
+									return
+								}
+
+								// The shim just send the ClientHello2 and is
+								// waiting for ServerHello..Finished. If it hears
+								// nothing, it will retransmit ClientHello2 on the
+								// assumption the packet was lost.
+								c.ExpectNextTimeout(useTimeouts[0])
+
+								// Retransmit a portion of HelloRetryRequest.
+								c.WriteFragments([]DTLSFragment{prev[0].Fragment(0, 1)})
+
+								// The shim does not actually need to ACK this,
+								// but BoringSSL does. Now both timers are active.
+								// Fire the first...
+								c.ExpectNextTimeout(useTimeouts[0] / 4)
+								c.AdvanceClock(useTimeouts[0] / 4)
+								c.ReadACK(0)
+
+								// ...followed by the second.
+								c.ExpectNextTimeout(3 * useTimeouts[0] / 4)
+								c.AdvanceClock(3 * useTimeouts[0] / 4)
+								c.ReadRetransmit()
+
+								// The shim is now set for the next retransmit.
+								c.ExpectNextTimeout(useTimeouts[1])
+
+								// Start the ACK timer again.
+								c.WriteFragments([]DTLSFragment{prev[0].Fragment(0, 1)})
+								c.ExpectNextTimeout(useTimeouts[1] / 4)
+
+								// Expire both timers at once.
+								c.AdvanceClock(useTimeouts[1])
+								c.ReadACK(0)
+								c.ReadRetransmit()
+
+								c.WriteFlight(next)
+							},
+						},
+					},
+					flags: flags,
+				})
 			}
 		}
 	}
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index 4c3c1cc..6583b74 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -514,6 +514,7 @@
                        &TestConfig::signed_cert_timestamps),
             Base64Flag("-signed-cert-timestamps",
                        &CredentialConfig::signed_cert_timestamps)),
+        IntFlag("-private-key-delay-ms", &TestConfig::private_key_delay_ms),
     };
     std::sort(ret.begin(), ret.end(), FlagNameComparator{});
     return ret;
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index 093d241..b4b290d 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -221,6 +221,7 @@
   bool no_check_ecdsa_curve = false;
   int expect_selected_credential = -1;
   std::vector<CredentialConfig> credentials;
+  int private_key_delay_ms = 0;
 
   std::vector<const char*> handshaker_args;