Remove expect and received flight hooks.

Instead, the DTLS driver can detect these states implicitly based on
when we write flights and when the handshake completes. When we flush a
new flight, the peer has enough information to send their reply, so we
start a timer. When we begin assembling a new flight, we must have
received the final message in the peer's flight. (If there are
asynchronous events between, we may stop the timer later, but we may
freely stop the timer anytime before we next try to read something.)

The only place this fails is if we were the last to write a flight,
we'll have a stray timer. Clear it in a handshake completion hook.

Change-Id: I973c592ee5721192949a45c259b93192fa309edb
Reviewed-on: https://boringssl-review.googlesource.com/18864
Reviewed-by: Steven Valdez <svaldez@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index 3ac21c3..2fa0183 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -540,6 +540,7 @@
   ssl->d1->outgoing_messages_len = 0;
   ssl->d1->outgoing_written = 0;
   ssl->d1->outgoing_offset = 0;
+  ssl->d1->outgoing_messages_complete = false;
 }
 
 int dtls1_init_message(SSL *ssl, CBB *cbb, CBB *body, uint8_t type) {
@@ -577,6 +578,13 @@
  * it takes ownership of |data| and releases it with |OPENSSL_free| when
  * done. */
 static int add_outgoing(SSL *ssl, int is_ccs, uint8_t *data, size_t len) {
+  if (ssl->d1->outgoing_messages_complete) {
+    /* If we've begun writing a new flight, we received the peer flight. Discard
+     * the timer and the our flight. */
+    dtls1_stop_timer(ssl);
+    dtls_clear_outgoing_messages(ssl);
+  }
+
   static_assert(SSL_MAX_HANDSHAKE_FLIGHT <
                     (1 << 8 * sizeof(ssl->d1->outgoing_messages_len)),
                 "outgoing_messages_len is too small");
@@ -795,7 +803,7 @@
   return 1;
 }
 
-int dtls1_flush_flight(SSL *ssl) {
+static int send_flight(SSL *ssl) {
   dtls1_update_mtu(ssl);
 
   int ret = -1;
@@ -837,6 +845,13 @@
   return ret;
 }
 
+int dtls1_flush_flight(SSL *ssl) {
+  ssl->d1->outgoing_messages_complete = true;
+  /* Start the retransmission timer for the next flight (if any). */
+  dtls1_start_timer(ssl);
+  return send_flight(ssl);
+}
+
 int dtls1_retransmit_outgoing_messages(SSL *ssl) {
   /* Rewind to the start of the flight and write it again.
    *
@@ -845,7 +860,7 @@
   ssl->d1->outgoing_written = 0;
   ssl->d1->outgoing_offset = 0;
 
-  return dtls1_flush_flight(ssl);
+  return send_flight(ssl);
 }
 
 unsigned int dtls1_min_mtu(void) {
diff --git a/ssl/d1_lib.cc b/ssl/d1_lib.cc
index 8ef1aa2..30110b4 100644
--- a/ssl/d1_lib.cc
+++ b/ssl/d1_lib.cc
@@ -150,22 +150,17 @@
   return 1;
 }
 
-void dtls1_double_timeout(SSL *ssl) {
+static void dtls1_double_timeout(SSL *ssl) {
   ssl->d1->timeout_duration_ms *= 2;
   if (ssl->d1->timeout_duration_ms > 60000) {
     ssl->d1->timeout_duration_ms = 60000;
   }
-  dtls1_start_timer(ssl);
 }
 
 void dtls1_stop_timer(SSL *ssl) {
-  /* Reset everything */
   ssl->d1->num_timeouts = 0;
   OPENSSL_memset(&ssl->d1->next_timeout, 0, sizeof(ssl->d1->next_timeout));
   ssl->d1->timeout_duration_ms = ssl->initial_timeout_duration_ms;
-
-  /* Clear retransmission buffer */
-  dtls_clear_outgoing_messages(ssl);
 }
 
 int dtls1_check_timeout_num(SSL *ssl) {
@@ -183,10 +178,10 @@
   if (ssl->d1->num_timeouts > DTLS1_MAX_TIMEOUTS) {
     /* fail the connection, enough alerts have been sent */
     OPENSSL_PUT_ERROR(SSL, SSL_R_READ_TIMEOUT_EXPIRED);
-    return -1;
+    return 0;
   }
 
-  return 0;
+  return 1;
 }
 
 }  // namespace bssl
@@ -202,7 +197,7 @@
     return 0;
   }
 
-  /* If no timeout is set, just return NULL */
+  /* If no timeout is set, just return 0. */
   if (ssl->d1->next_timeout.tv_sec == 0 && ssl->d1->next_timeout.tv_usec == 0) {
     return 0;
   }
@@ -210,7 +205,7 @@
   struct OPENSSL_timeval timenow;
   ssl_get_current_time(ssl, &timenow);
 
-  /* If timer already expired, set remaining time to 0 */
+  /* If timer already expired, set remaining time to 0. */
   if (ssl->d1->next_timeout.tv_sec < timenow.tv_sec ||
       (ssl->d1->next_timeout.tv_sec == timenow.tv_sec &&
        ssl->d1->next_timeout.tv_usec <= timenow.tv_usec)) {
@@ -218,7 +213,7 @@
     return 1;
   }
 
-  /* Calculate time left until timer expires */
+  /* Calculate time left until timer expires. */
   struct OPENSSL_timeval ret;
   OPENSSL_memcpy(&ret, &ssl->d1->next_timeout, sizeof(ret));
   ret.tv_sec -= timenow.tv_sec;
@@ -251,20 +246,20 @@
   ssl_reset_error_state(ssl);
 
   if (!SSL_is_dtls(ssl)) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
     return -1;
   }
 
-  /* if no timer is expired, don't do anything */
+  /* If no timer is expired, don't do anything. */
   if (!dtls1_is_timer_expired(ssl)) {
     return 0;
   }
 
-  dtls1_double_timeout(ssl);
-
-  if (dtls1_check_timeout_num(ssl) < 0) {
+  if (!dtls1_check_timeout_num(ssl)) {
     return -1;
   }
 
+  dtls1_double_timeout(ssl);
   dtls1_start_timer(ssl);
   return dtls1_retransmit_outgoing_messages(ssl);
 }
diff --git a/ssl/d1_pkt.cc b/ssl/d1_pkt.cc
index a9f2d7c..3841232 100644
--- a/ssl/d1_pkt.cc
+++ b/ssl/d1_pkt.cc
@@ -239,7 +239,7 @@
         /* Retransmit our last flight of messages. If the peer sends the second
          * Finished, they may not have received ours. Only do this for the
          * first fragment, in case the Finished was fragmented. */
-        if (dtls1_check_timeout_num(ssl) < 0) {
+        if (!dtls1_check_timeout_num(ssl)) {
           return -1;
         }
 
diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc
index 59c771b..947cfce 100644
--- a/ssl/dtls_method.cc
+++ b/ssl/dtls_method.cc
@@ -72,9 +72,12 @@
   return cipher->algorithm_enc != SSL_eNULL;
 }
 
-static void dtls1_expect_flight(SSL *ssl) { dtls1_start_timer(ssl); }
-
-static void dtls1_received_flight(SSL *ssl) { dtls1_stop_timer(ssl); }
+static void dtls1_on_handshake_complete(SSL *ssl) {
+  /* If we wrote the last flight, we'll have a timer left over without waiting
+   * for a read. Stop the timer but leave the flight around for post-handshake
+   * transmission logic. */
+  dtls1_stop_timer(ssl);
+}
 
 static int dtls1_set_read_state(SSL *ssl, UniquePtr<SSLAEADContext> aead_ctx) {
   /* Cipher changes are illegal when there are buffered incoming messages. */
@@ -124,8 +127,7 @@
     dtls1_add_change_cipher_spec,
     dtls1_add_alert,
     dtls1_flush_flight,
-    dtls1_expect_flight,
-    dtls1_received_flight,
+    dtls1_on_handshake_complete,
     dtls1_set_read_state,
     dtls1_set_write_state,
 };
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index 2a3e627..946316d 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -238,7 +238,6 @@
           goto end;
         }
         if (ssl->d1->send_cookie) {
-          ssl->method->received_flight(ssl);
           hs->state = SSL3_ST_CW_CLNT_HELLO_A;
         } else {
           hs->state = SSL3_ST_CR_SRVR_HELLO_A;
@@ -333,7 +332,6 @@
         if (ret <= 0) {
           goto end;
         }
-        ssl->method->received_flight(ssl);
         hs->state = SSL3_ST_CW_CERT_A;
         break;
 
@@ -460,7 +458,6 @@
         if (ret <= 0) {
           goto end;
         }
-        ssl->method->received_flight(ssl);
 
         if (ssl->session != NULL) {
           hs->state = SSL3_ST_CW_CHANGE;
@@ -475,9 +472,6 @@
           goto end;
         }
         hs->state = hs->next_state;
-        if (hs->state != SSL3_ST_FINISH_CLIENT_HANDSHAKE) {
-          ssl->method->expect_flight(ssl);
-        }
         break;
 
       case SSL_ST_TLS13: {
@@ -497,6 +491,7 @@
       }
 
       case SSL3_ST_FINISH_CLIENT_HANDSHAKE:
+        ssl->method->on_handshake_complete(ssl);
         ssl->method->release_current_message(ssl, 1 /* free_buffer */);
 
         SSL_SESSION_free(ssl->s3->established_session);
diff --git a/ssl/handshake_server.cc b/ssl/handshake_server.cc
index fb1c4d8..47fdc61 100644
--- a/ssl/handshake_server.cc
+++ b/ssl/handshake_server.cc
@@ -233,7 +233,6 @@
         if (ret <= 0) {
           goto end;
         }
-        ssl->method->received_flight(ssl);
         hs->state = SSL3_ST_SW_SRVR_HELLO_A;
         break;
 
@@ -362,7 +361,6 @@
           goto end;
         }
 
-        ssl->method->received_flight(ssl);
         if (ssl->session != NULL) {
           hs->state = SSL_ST_OK;
         } else {
@@ -400,9 +398,6 @@
         }
 
         hs->state = hs->next_state;
-        if (hs->state != SSL_ST_OK) {
-          ssl->method->expect_flight(ssl);
-        }
         break;
 
       case SSL_ST_TLS13: {
@@ -422,6 +417,7 @@
       }
 
       case SSL_ST_OK:
+        ssl->method->on_handshake_complete(ssl);
         ssl->method->release_current_message(ssl, 1 /* free_buffer */);
 
         /* If we aren't retaining peer certificates then we can discard it
diff --git a/ssl/internal.h b/ssl/internal.h
index 2dbf063..2d3557c 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1800,6 +1800,11 @@
    * the peer in this epoch. */
   bool has_change_cipher_spec:1;
 
+  /* outgoing_messages_complete is true if |outgoing_messages| has been
+   * completed by an attempt to flush it. Future calls to |add_message| and
+   * |add_change_cipher_spec| will start a new flight. */
+  bool outgoing_messages_complete:1;
+
   uint8_t cookie[DTLS1_COOKIE_LENGTH];
   size_t cookie_len;
 
@@ -2202,7 +2207,6 @@
 void dtls1_start_timer(SSL *ssl);
 void dtls1_stop_timer(SSL *ssl);
 int dtls1_is_timer_expired(SSL *ssl);
-void dtls1_double_timeout(SSL *ssl);
 unsigned int dtls1_min_mtu(void);
 
 int dtls1_new(SSL *ssl);
@@ -2398,12 +2402,8 @@
   /* flush_flight flushes the pending flight to the transport. It returns one on
    * success and <= 0 on error. */
   int (*flush_flight)(SSL *ssl);
-  /* expect_flight is called when the handshake expects a flight of messages from
-   * the peer. */
-  void (*expect_flight)(SSL *ssl);
-  /* received_flight is called when the handshake has received a flight of
-   * messages from the peer. */
-  void (*received_flight)(SSL *ssl);
+  /* on_handshake_complete is called when the handshake is complete. */
+  void (*on_handshake_complete)(SSL *ssl);
   /* set_read_state sets |ssl|'s read cipher state to |aead_ctx|. It returns
    * one on success and zero if changing the read state is forbidden at this
    * point. */
diff --git a/ssl/tls13_both.cc b/ssl/tls13_both.cc
index a5b9c53..1c2e7f7 100644
--- a/ssl/tls13_both.cc
+++ b/ssl/tls13_both.cc
@@ -55,7 +55,6 @@
         if (hs->wait != ssl_hs_flush_and_read_message) {
           break;
         }
-        ssl->method->expect_flight(ssl);
         hs->wait = ssl_hs_read_message;
         SSL_FALLTHROUGH;
       }
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index 4cc7e60..2940265 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -530,7 +530,6 @@
     return ssl_hs_error;
   }
 
-  ssl->method->received_flight(ssl);
   hs->tls13_state = state_send_end_of_early_data;
   return ssl_hs_ok;
 }
diff --git a/ssl/tls13_server.cc b/ssl/tls13_server.cc
index 933affa..03f8bdd 100644
--- a/ssl/tls13_server.cc
+++ b/ssl/tls13_server.cc
@@ -452,8 +452,6 @@
     ssl->s3->skip_early_data = 1;
   }
 
-  ssl->method->received_flight(ssl);
-
   /* Resolve ECDHE and incorporate it into the secret. */
   int need_retry;
   if (!resolve_ecdhe_secret(hs, &need_retry, &client_hello)) {
@@ -519,7 +517,6 @@
     return ssl_hs_error;
   }
 
-  ssl->method->received_flight(ssl);
   hs->tls13_state = state_send_server_hello;
   return ssl_hs_ok;
 }
@@ -803,8 +800,6 @@
     return ssl_hs_error;
   }
 
-  ssl->method->received_flight(ssl);
-
   if (!ssl->early_data_accepted) {
     if (!ssl_hash_current_message(hs) ||
         !tls13_derive_resumption_secret(hs)) {
diff --git a/ssl/tls_method.cc b/ssl/tls_method.cc
index 4751e2e..02f5c07 100644
--- a/ssl/tls_method.cc
+++ b/ssl/tls_method.cc
@@ -69,9 +69,7 @@
 
 static int ssl3_supports_cipher(const SSL_CIPHER *cipher) { return 1; }
 
-static void ssl3_expect_flight(SSL *ssl) {}
-
-static void ssl3_received_flight(SSL *ssl) {}
+static void ssl3_on_handshake_complete(SSL *ssl) {}
 
 static int ssl3_set_read_state(SSL *ssl, UniquePtr<SSLAEADContext> aead_ctx) {
   if (ssl->s3->rrec.length != 0) {
@@ -115,8 +113,7 @@
     ssl3_add_change_cipher_spec,
     ssl3_add_alert,
     ssl3_flush_flight,
-    ssl3_expect_flight,
-    ssl3_received_flight,
+    ssl3_on_handshake_complete,
     ssl3_set_read_state,
     ssl3_set_write_state,
 };