Replace reuse_message with an explicit next_message call.

This means that ssl_get_message (soon to be replaced with a BIO-less
version) is idempotent which avoids the SSL3_ST_SR_KEY_EXCH_B
contortion. It also eases converting the TLS 1.2 state machine. See
https://docs.google.com/a/google.com/document/d/11n7LHsT3GwE34LAJIe3EFs4165TI4UR_3CqiM9LJVpI/edit?usp=sharing
for details.

Bug: 128
Change-Id: Iddd4f951389e8766da07a9de595b552e75f8acf0
Reviewed-on: https://boringssl-review.googlesource.com/18805
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/include/openssl/ssl3.h b/include/openssl/ssl3.h
index 39cd07b..1eea758 100644
--- a/include/openssl/ssl3.h
+++ b/include/openssl/ssl3.h
@@ -348,7 +348,6 @@
 /* read from client */
 #define SSL3_ST_SR_CERT_A (0x180 | SSL_ST_ACCEPT)
 #define SSL3_ST_SR_KEY_EXCH_A (0x190 | SSL_ST_ACCEPT)
-#define SSL3_ST_SR_KEY_EXCH_B (0x191 | SSL_ST_ACCEPT)
 #define SSL3_ST_SR_CERT_VRFY_A (0x1A0 | SSL_ST_ACCEPT)
 #define SSL3_ST_SR_CHANGE (0x1B0 | SSL_ST_ACCEPT)
 #define SSL3_ST_SR_NEXT_PROTO_A (0x210 | SSL_ST_ACCEPT)
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index 5fc93cb..70af6c1 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -420,14 +420,6 @@
 }
 
 int dtls1_get_message(SSL *ssl) {
-  if (ssl->s3->tmp.reuse_message) {
-    /* There must be a current message. */
-    assert(ssl->init_msg != NULL);
-    ssl->s3->tmp.reuse_message = 0;
-  } else {
-    dtls1_release_current_message(ssl);
-  }
-
   /* Process handshake records until the current message is ready. */
   while (!dtls1_is_current_message_complete(ssl)) {
     int ret = dtls1_process_handshake_record(ssl);
@@ -463,11 +455,8 @@
   CBS_init(out, frag->data, DTLS1_HM_HEADER_LENGTH + frag->msg_len);
 }
 
-void dtls1_release_current_message(SSL *ssl) {
-  if (ssl->init_msg == NULL) {
-    return;
-  }
-
+void dtls1_next_message(SSL *ssl) {
+  assert(ssl->init_msg != NULL);
   assert(dtls1_is_current_message_complete(ssl));
   size_t index = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
   dtls1_hm_fragment_free(ssl->d1->incoming_messages[index]);
diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc
index 1d089e8..d17afa6 100644
--- a/ssl/dtls_method.cc
+++ b/ssl/dtls_method.cc
@@ -73,7 +73,6 @@
 }
 
 static void dtls1_on_handshake_complete(SSL *ssl) {
-  dtls1_release_current_message(ssl);
   /* If we wrote the last flight, we'll have a timer left over without waiting
    * for a read. Stop the timer but leave the flight around for post-handshake
    * transmission logic. */
@@ -115,7 +114,7 @@
     dtls1_free,
     dtls1_get_message,
     dtls1_get_current_message,
-    dtls1_release_current_message,
+    dtls1_next_message,
     dtls1_read_app_data,
     dtls1_read_change_cipher_spec,
     dtls1_read_close_notify,
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index 385f726..c43bda3 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -776,7 +776,6 @@
 
   if (ssl->s3->tmp.message_type != DTLS1_MT_HELLO_VERIFY_REQUEST) {
     ssl->d1->send_cookie = false;
-    ssl->s3->tmp.reuse_message = 1;
     return 1;
   }
 
@@ -794,6 +793,7 @@
   ssl->d1->cookie_len = CBS_len(&cookie);
 
   ssl->d1->send_cookie = true;
+  ssl->method->next_message(ssl);
   return 1;
 }
 
@@ -1060,6 +1060,7 @@
     return -1;
   }
 
+  ssl->method->next_message(ssl);
   return 1;
 }
 
@@ -1130,6 +1131,7 @@
     }
   }
 
+  ssl->method->next_message(ssl);
   return 1;
 }
 
@@ -1143,7 +1145,6 @@
   if (ssl->s3->tmp.message_type != SSL3_MT_CERTIFICATE_STATUS) {
     /* A server may send status_request in ServerHello and then change
      * its mind about sending CertificateStatus. */
-    ssl->s3->tmp.reuse_message = 1;
     return 1;
   }
 
@@ -1171,6 +1172,7 @@
     return -1;
   }
 
+  ssl->method->next_message(ssl);
   return 1;
 }
 
@@ -1189,7 +1191,6 @@
       return -1;
     }
 
-    ssl->s3->tmp.reuse_message = 1;
     return 1;
   }
 
@@ -1359,6 +1360,8 @@
       return -1;
     }
   }
+
+  ssl->method->next_message(ssl);
   return 1;
 }
 
@@ -1370,7 +1373,6 @@
   }
 
   if (ssl->s3->tmp.message_type == SSL3_MT_SERVER_HELLO_DONE) {
-    ssl->s3->tmp.reuse_message = 1;
     /* If we get here we don't need the handshake buffer as we won't be doing
      * client auth. */
     hs->transcript.FreeBuffer();
@@ -1426,6 +1428,7 @@
   hs->cert_request = 1;
   hs->ca_names = std::move(ca_names);
   ssl->ctx->x509_method->hs_flush_cached_ca_names(hs);
+  ssl->method->next_message(ssl);
   return 1;
 }
 
@@ -1448,6 +1451,7 @@
     return -1;
   }
 
+  ssl->method->next_message(ssl);
   return 1;
 }
 
@@ -1818,6 +1822,7 @@
      * negotiating the extension. The value of |ticket_expected| is checked in
      * |ssl_update_cache| so is cleared here to avoid an unnecessary update. */
     hs->ticket_expected = 0;
+    ssl->method->next_message(ssl);
     return 1;
   }
 
@@ -1861,6 +1866,7 @@
     ssl->session = renewed_session.release();
   }
 
+  ssl->method->next_message(ssl);
   return 1;
 }
 
diff --git a/ssl/handshake_server.cc b/ssl/handshake_server.cc
index 907943f..1889177 100644
--- a/ssl/handshake_server.cc
+++ b/ssl/handshake_server.cc
@@ -304,7 +304,6 @@
         break;
 
       case SSL3_ST_SR_KEY_EXCH_A:
-      case SSL3_ST_SR_KEY_EXCH_B:
         ret = ssl3_get_client_key_exchange(hs);
         if (ret <= 0) {
           goto end;
@@ -925,6 +924,7 @@
     hs->transcript.FreeBuffer();
   }
 
+  ssl->method->next_message(ssl);
   return 1;
 }
 
@@ -1195,7 +1195,6 @@
       /* OpenSSL returns X509_V_OK when no certificates are received. This is
        * classed by them as a bug, but it's assumed by at least NGINX. */
       hs->new_session->verify_result = X509_V_OK;
-      ssl->s3->tmp.reuse_message = 1;
       return 1;
     }
 
@@ -1253,14 +1252,12 @@
     /* OpenSSL returns X509_V_OK when no certificates are received. This is
      * classed by them as a bug, but it's assumed by at least NGINX. */
     hs->new_session->verify_result = X509_V_OK;
-    return 1;
-  }
-
-  /* The hash will have been filled in. */
-  if (ssl->retain_only_sha256_of_client_certs) {
+  } else if (ssl->retain_only_sha256_of_client_certs) {
+    /* The hash will have been filled in. */
     hs->new_session->peer_sha256_valid = 1;
   }
 
+  ssl->method->next_message(ssl);
   return 1;
 }
 
@@ -1271,11 +1268,9 @@
   size_t premaster_secret_len = 0;
   uint8_t *decrypt_buf = NULL;
 
-  if (hs->state == SSL3_ST_SR_KEY_EXCH_A) {
-    int ret = ssl->method->ssl_get_message(ssl);
-    if (ret <= 0) {
-      return ret;
-    }
+  int ret = ssl->method->ssl_get_message(ssl);
+  if (ret <= 0) {
+    return ret;
   }
 
   if (!ssl_check_message_type(ssl, SSL3_MT_CLIENT_KEY_EXCHANGE)) {
@@ -1349,7 +1344,6 @@
         goto err;
       case ssl_private_key_retry:
         ssl->rwstate = SSL_PRIVATE_KEY_OPERATION;
-        hs->state = SSL3_ST_SR_KEY_EXCH_B;
         goto err;
     }
 
@@ -1501,6 +1495,7 @@
 
   OPENSSL_cleanse(premaster_secret, premaster_secret_len);
   OPENSSL_free(premaster_secret);
+  ssl->method->next_message(ssl);
   return 1;
 
 err:
@@ -1606,6 +1601,7 @@
     return -1;
   }
 
+  ssl->method->next_message(ssl);
   return 1;
 }
 
@@ -1630,14 +1626,15 @@
       CBS_len(&next_protocol) != 0) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
     ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
-    return 0;
+    return -1;
   }
 
   if (!CBS_stow(&selected_protocol, &ssl->s3->next_proto_negotiated,
                 &ssl->s3->next_proto_negotiated_len)) {
-    return 0;
+    return -1;
   }
 
+  ssl->method->next_message(ssl);
   return 1;
 }
 
@@ -1654,6 +1651,7 @@
       !ssl_hash_current_message(hs)) {
     return -1;
   }
+  ssl->method->next_message(ssl);
   return 1;
 }
 
diff --git a/ssl/internal.h b/ssl/internal.h
index c8ed13b..90236e5 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1713,8 +1713,6 @@
   struct {
     int message_type;
 
-    int reuse_message;
-
     uint8_t new_mac_secret_len;
     uint8_t new_key_len;
     uint8_t new_fixed_iv_len;
@@ -2141,7 +2139,7 @@
 int ssl3_send_alert(SSL *ssl, int level, int desc);
 int ssl3_get_message(SSL *ssl);
 void ssl3_get_current_message(const SSL *ssl, CBS *out);
-void ssl3_release_current_message(SSL *ssl);
+void ssl3_next_message(SSL *ssl);
 
 int ssl3_send_finished(SSL_HANDSHAKE *hs);
 int ssl3_dispatch_alert(SSL *ssl);
@@ -2220,7 +2218,7 @@
 
 int dtls1_get_message(SSL *ssl);
 void dtls1_get_current_message(const SSL *ssl, CBS *out);
-void dtls1_release_current_message(SSL *ssl);
+void dtls1_next_message(SSL *ssl);
 int dtls1_dispatch_alert(SSL *ssl);
 
 int tls1_change_cipher_state(SSL_HANDSHAKE *hs, int which);
@@ -2359,16 +2357,15 @@
   char is_dtls;
   int (*ssl_new)(SSL *ssl);
   void (*ssl_free)(SSL *ssl);
-  /* ssl_get_message reads the next handshake message. On success, it returns
-   * one and sets |ssl->s3->tmp.message_type|, |ssl->init_msg|, and
+  /* ssl_get_message completes the current next handshake message. On success,
+   * it returns one and sets |ssl->s3->tmp.message_type|, |ssl->init_msg|, and
    * |ssl->init_num|. Otherwise, it returns <= 0. */
   int (*ssl_get_message)(SSL *ssl);
   /* get_current_message sets |*out| to the current handshake message. This
    * includes the protocol-specific message header. */
   void (*get_current_message)(const SSL *ssl, CBS *out);
-  /* release_current_message is called to release the current handshake
-   * message. */
-  void (*release_current_message)(SSL *ssl);
+  /* next_message is called to release the current handshake message. */
+  void (*next_message)(SSL *ssl);
   /* read_app_data reads up to |len| bytes of application data into |buf|. On
    * success, it returns the number of bytes read. Otherwise, it returns <= 0
    * and sets |*out_got_handshake| to whether the failure was due to a
diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc
index a96b910..4ae6f70 100644
--- a/ssl/s3_both.cc
+++ b/ssl/s3_both.cc
@@ -469,6 +469,7 @@
     }
   }
 
+  ssl->method->next_message(ssl);
   return 1;
 }
 
@@ -683,14 +684,6 @@
 }
 
 int ssl3_get_message(SSL *ssl) {
-  if (ssl->s3->tmp.reuse_message) {
-    /* There must be a current message. */
-    assert(ssl->init_msg != NULL);
-    ssl->s3->tmp.reuse_message = 0;
-  } else {
-    ssl3_release_current_message(ssl);
-  }
-
   /* Re-create the handshake buffer if needed. */
   if (ssl->init_buf == NULL) {
     ssl->init_buf = BUF_MEM_new();
@@ -757,10 +750,8 @@
   return hs->transcript.Update(CBS_data(&cbs), CBS_len(&cbs));
 }
 
-void ssl3_release_current_message(SSL *ssl) {
-  if (ssl->init_msg == NULL) {
-    return;
-  }
+void ssl3_next_message(SSL *ssl) {
+  assert(ssl->init_msg != NULL);
 
   /* |init_buf| never contains data beyond the current message. */
   assert(SSL3_HM_HEADER_LENGTH + ssl->init_num == ssl->init_buf->length);
diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc
index a28cc2d..f929fe3 100644
--- a/ssl/s3_pkt.cc
+++ b/ssl/s3_pkt.cc
@@ -377,8 +377,6 @@
   assert(!ssl->s3->aead_read_ctx->is_null_cipher());
   *out_got_handshake = 0;
 
-  ssl->method->release_current_message(ssl);
-
   SSL3_RECORD *rr = &ssl->s3->rrec;
 
   for (;;) {
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index becf3ad..1ca7a95 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -942,7 +942,7 @@
     if (!ssl_do_post_handshake(ssl)) {
       return -1;
     }
-    ssl->method->release_current_message(ssl);
+    ssl->method->next_message(ssl);
   }
 }
 
diff --git a/ssl/ssl_stat.cc b/ssl/ssl_stat.cc
index 22149e2..56e4f2b 100644
--- a/ssl/ssl_stat.cc
+++ b/ssl/ssl_stat.cc
@@ -188,9 +188,6 @@
     case SSL3_ST_SR_KEY_EXCH_A:
       return "SSLv3 read client key exchange A";
 
-    case SSL3_ST_SR_KEY_EXCH_B:
-      return "SSLv3 read client key exchange B";
-
     case SSL3_ST_SR_CERT_VRFY_A:
       return "SSLv3 read certificate verify A";
 
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index 2940265..fa4731f 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -144,6 +144,7 @@
     return ssl_hs_error;
   }
 
+  ssl->method->next_message(ssl);
   hs->received_hello_retry_request = 1;
   hs->tls13_state = state_send_second_client_hello;
   /* 0-RTT is rejected if we receive a HelloRetryRequest. */
@@ -341,6 +342,8 @@
       !tls13_derive_handshake_secrets(hs)) {
     return ssl_hs_error;
   }
+
+  ssl->method->next_message(ssl);
   hs->tls13_state = state_process_change_cipher_spec;
   return ssl->version == TLS1_3_EXPERIMENT_VERSION
              ? ssl_hs_read_change_cipher_spec
@@ -416,6 +419,7 @@
     return ssl_hs_error;
   }
 
+  ssl->method->next_message(ssl);
   hs->tls13_state = state_continue_second_server_flight;
   if (hs->in_early_data && !ssl->early_data_accepted) {
     return ssl_hs_early_data_rejected;
@@ -480,6 +484,7 @@
     return ssl_hs_error;
   }
 
+  ssl->method->next_message(ssl);
   hs->tls13_state = state_process_server_certificate;
   return ssl_hs_read_message;
 }
@@ -492,6 +497,7 @@
     return ssl_hs_error;
   }
 
+  ssl->method->next_message(ssl);
   hs->tls13_state = state_process_server_certificate_verify;
   return ssl_hs_read_message;
 }
@@ -515,6 +521,7 @@
     return ssl_hs_error;
   }
 
+  ssl->method->next_message(ssl);
   hs->tls13_state = state_process_server_finished;
   return ssl_hs_read_message;
 }
@@ -530,6 +537,7 @@
     return ssl_hs_error;
   }
 
+  ssl->method->next_message(ssl);
   hs->tls13_state = state_send_end_of_early_data;
   return ssl_hs_ok;
 }
diff --git a/ssl/tls13_server.cc b/ssl/tls13_server.cc
index 03f8bdd..c2cd682 100644
--- a/ssl/tls13_server.cc
+++ b/ssl/tls13_server.cc
@@ -458,12 +458,14 @@
     if (need_retry) {
       ssl->early_data_accepted = 0;
       ssl->s3->skip_early_data = 1;
+      ssl->method->next_message(ssl);
       hs->tls13_state = state_send_hello_retry_request;
       return ssl_hs_ok;
     }
     return ssl_hs_error;
   }
 
+  ssl->method->next_message(ssl);
   hs->tls13_state = state_send_server_hello;
   return ssl_hs_ok;
 }
@@ -517,6 +519,7 @@
     return ssl_hs_error;
   }
 
+  ssl->method->next_message(ssl);
   hs->tls13_state = state_send_server_hello;
   return ssl_hs_ok;
 }
@@ -669,7 +672,8 @@
                          static_cast<uint8_t>(hs->hash_len)};
     if (!hs->transcript.Update(header, sizeof(header)) ||
         !hs->transcript.Update(hs->expected_client_finished, hs->hash_len) ||
-        !tls13_derive_resumption_secret(hs) || !add_new_session_tickets(hs)) {
+        !tls13_derive_resumption_secret(hs) ||
+        !add_new_session_tickets(hs)) {
       return ssl_hs_error;
     }
   }
@@ -739,6 +743,7 @@
     return ssl_hs_error;
   }
 
+  ssl->method->next_message(ssl);
   hs->tls13_state = state_process_client_certificate_verify;
   return ssl_hs_read_message;
 }
@@ -768,22 +773,25 @@
     return ssl_hs_error;
   }
 
+  ssl->method->next_message(ssl);
   hs->tls13_state = state_process_channel_id;
   return ssl_hs_read_message;
 }
 
 static enum ssl_hs_wait_t do_process_channel_id(SSL_HANDSHAKE *hs) {
-  if (!hs->ssl->s3->tlsext_channel_id_valid) {
+  SSL *const ssl = hs->ssl;
+  if (!ssl->s3->tlsext_channel_id_valid) {
     hs->tls13_state = state_process_client_finished;
     return ssl_hs_ok;
   }
 
-  if (!ssl_check_message_type(hs->ssl, SSL3_MT_CHANNEL_ID) ||
+  if (!ssl_check_message_type(ssl, SSL3_MT_CHANNEL_ID) ||
       !tls1_verify_channel_id(hs) ||
       !ssl_hash_current_message(hs)) {
     return ssl_hs_error;
   }
 
+  ssl->method->next_message(ssl);
   hs->tls13_state = state_process_client_finished;
   return ssl_hs_read_message;
 }
@@ -808,10 +816,12 @@
 
     /* We send post-handshake tickets as part of the handshake in 1-RTT. */
     hs->tls13_state = state_send_new_session_ticket;
-    return ssl_hs_ok;
+  } else {
+    /* We already sent half-RTT tickets. */
+    hs->tls13_state = state_done;
   }
 
-  hs->tls13_state = state_done;
+  ssl->method->next_message(ssl);
   return ssl_hs_ok;
 }
 
diff --git a/ssl/tls_method.cc b/ssl/tls_method.cc
index b2c7b46..1063ca9 100644
--- a/ssl/tls_method.cc
+++ b/ssl/tls_method.cc
@@ -70,9 +70,19 @@
 static int ssl3_supports_cipher(const SSL_CIPHER *cipher) { return 1; }
 
 static void ssl3_on_handshake_complete(SSL *ssl) {
-  ssl3_release_current_message(ssl);
-  BUF_MEM_free(ssl->init_buf);
-  ssl->init_buf = NULL;
+  /* The handshake should have released its final message. */
+  assert(ssl->init_msg == NULL);
+
+  /* During the handshake, |init_buf| is retained. Release if it there is no
+   * excess in it.
+   *
+   * TODO(davidben): The second check is always true but will not be once we
+   * switch to copying the entire handshake record. Replace this comment with an
+   * explanation when that happens and a TODO to reject it. */
+  if (ssl->init_buf != NULL && ssl->init_buf->length == 0) {
+    BUF_MEM_free(ssl->init_buf);
+    ssl->init_buf = NULL;
+  }
 }
 
 static int ssl3_set_read_state(SSL *ssl, UniquePtr<SSLAEADContext> aead_ctx) {
@@ -104,7 +114,7 @@
     ssl3_free,
     ssl3_get_message,
     ssl3_get_current_message,
-    ssl3_release_current_message,
+    ssl3_next_message,
     ssl3_read_app_data,
     ssl3_read_change_cipher_spec,
     ssl3_read_close_notify,