Replace init_msg/init_num with a get_message hook.

Rather than init_msg/init_num, there is a get_message function which
either returns success or try again. This function does not advance the
current message (see the previous preparatory change). It only completes
the current one if necessary.

Being idempotent means it may be freely placed at the top of states
which otherwise have other asychronous operations. It also eases
converting the TLS 1.2 state machine. See
https://docs.google.com/a/google.com/document/d/11n7LHsT3GwE34LAJIe3EFs4165TI4UR_3CqiM9LJVpI/edit?usp=sharing
for details.

The read_message hook (later to be replaced by something which doesn't
depend on BIO) intentionally does not finish the handshake, only "makes
progress". A follow-up change will align both TLS and DTLS on consuming
one handshake record and always consuming the entire record (so init_buf
may contain trailing data). In a few places I've gone ahead and
accounted for that case because it was more natural to do so.

This change also removes a couple pointers of redundant state from every
socket.

Bug: 128
Change-Id: I89d8f3622d3b53147d69ee3ac34bb654ed044a71
Reviewed-on: https://boringssl-review.googlesource.com/18806
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/handshake_server.cc b/ssl/handshake_server.cc
index 1889177..6cccff9 100644
--- a/ssl/handshake_server.cc
+++ b/ssl/handshake_server.cc
@@ -172,7 +172,7 @@
 
 namespace bssl {
 
-static int ssl3_process_client_hello(SSL_HANDSHAKE *hs);
+static int ssl3_read_client_hello(SSL_HANDSHAKE *hs);
 static int ssl3_select_certificate(SSL_HANDSHAKE *hs);
 static int ssl3_select_parameters(SSL_HANDSHAKE *hs);
 static int ssl3_send_server_hello(SSL_HANDSHAKE *hs);
@@ -203,7 +203,7 @@
         break;
 
       case SSL3_ST_SR_CLNT_HELLO_A:
-        ret = ssl->method->ssl_get_message(ssl);
+        ret = ssl3_read_client_hello(hs);
         if (ret <= 0) {
           goto end;
         }
@@ -211,24 +211,16 @@
         break;
 
       case SSL3_ST_SR_CLNT_HELLO_B:
-        ret = ssl3_process_client_hello(hs);
-        if (ret <= 0) {
-          goto end;
-        }
-        hs->state = SSL3_ST_SR_CLNT_HELLO_C;
-        break;
-
-      case SSL3_ST_SR_CLNT_HELLO_C:
         ret = ssl3_select_certificate(hs);
         if (ret <= 0) {
           goto end;
         }
         if (hs->state != SSL_ST_TLS13) {
-          hs->state = SSL3_ST_SR_CLNT_HELLO_D;
+          hs->state = SSL3_ST_SR_CLNT_HELLO_C;
         }
         break;
 
-      case SSL3_ST_SR_CLNT_HELLO_D:
+      case SSL3_ST_SR_CLNT_HELLO_C:
         ret = ssl3_select_parameters(hs);
         if (ret <= 0) {
           goto end;
@@ -685,15 +677,19 @@
   return nullptr;
 }
 
-static int ssl3_process_client_hello(SSL_HANDSHAKE *hs) {
+static int ssl3_read_client_hello(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  if (!ssl_check_message_type(ssl, SSL3_MT_CLIENT_HELLO)) {
+  SSLMessage msg;
+  int ret = ssl_read_message(ssl, &msg);
+  if (ret <= 0) {
+    return ret;
+  }
+  if (!ssl_check_message_type(ssl, msg, SSL3_MT_CLIENT_HELLO)) {
     return -1;
   }
 
   SSL_CLIENT_HELLO client_hello;
-  if (!ssl_client_hello_init(ssl, &client_hello, ssl->init_msg,
-                             ssl->init_num)) {
+  if (!ssl_client_hello_init(ssl, &client_hello, msg)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
     ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
     return -1;
@@ -758,6 +754,12 @@
 
 static int ssl3_select_certificate(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
+  SSLMessage msg;
+  int ret = ssl_read_message(ssl, &msg);
+  if (ret <= 0) {
+    return ret;
+  }
+
   /* Call |cert_cb| to update server certificates if required. */
   if (ssl->cert->cert_cb != NULL) {
     int rv = ssl->cert->cert_cb(ssl, ssl->cert->cert_cb_arg);
@@ -784,8 +786,7 @@
   }
 
   SSL_CLIENT_HELLO client_hello;
-  if (!ssl_client_hello_init(ssl, &client_hello, ssl->init_msg,
-                             ssl->init_num)) {
+  if (!ssl_client_hello_init(ssl, &client_hello, msg)) {
     return -1;
   }
 
@@ -804,9 +805,13 @@
 
 static int ssl3_select_parameters(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
+  SSLMessage msg;
+  int ret = ssl_read_message(ssl, &msg);
+  if (ret <= 0) {
+    return ret;
+  }
   SSL_CLIENT_HELLO client_hello;
-  if (!ssl_client_hello_init(ssl, &client_hello, ssl->init_msg,
-                             ssl->init_num)) {
+  if (!ssl_client_hello_init(ssl, &client_hello, msg)) {
     return -1;
   }
 
@@ -914,7 +919,7 @@
    * the ClientHello. */
   if (!hs->transcript.InitHash(ssl3_protocol_version(ssl),
                                hs->new_cipher->algorithm_prf) ||
-      !ssl_hash_current_message(hs)) {
+      !ssl_hash_message(hs, msg)) {
     ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
     return -1;
   }
@@ -1176,14 +1181,15 @@
   SSL *const ssl = hs->ssl;
   assert(hs->cert_request);
 
-  int msg_ret = ssl->method->ssl_get_message(ssl);
-  if (msg_ret <= 0) {
-    return msg_ret;
+  SSLMessage msg;
+  int ret = ssl_read_message(ssl, &msg);
+  if (ret <= 0) {
+    return ret;
   }
 
-  if (ssl->s3->tmp.message_type != SSL3_MT_CERTIFICATE) {
+  if (msg.type != SSL3_MT_CERTIFICATE) {
     if (ssl->version == SSL3_VERSION &&
-        ssl->s3->tmp.message_type == SSL3_MT_CLIENT_KEY_EXCHANGE) {
+        msg.type == SSL3_MT_CLIENT_KEY_EXCHANGE) {
       /* In SSL 3.0, the Certificate message is omitted to signal no
        * certificate. */
       if (ssl->verify_mode & SSL_VERIFY_FAIL_IF_NO_PEER_CERT) {
@@ -1203,13 +1209,11 @@
     return -1;
   }
 
-  if (!ssl_hash_current_message(hs)) {
+  if (!ssl_hash_message(hs, msg)) {
     return -1;
   }
 
-  CBS certificate_msg;
-  CBS_init(&certificate_msg, ssl->init_msg, ssl->init_num);
-
+  CBS certificate_msg = msg.body;
   uint8_t alert = SSL_AD_DECODE_ERROR;
   UniquePtr<STACK_OF(CRYPTO_BUFFER)> chain;
   if (!ssl_parse_cert_chain(&alert, &chain, &hs->peer_pubkey,
@@ -1263,21 +1267,21 @@
 
 static int ssl3_get_client_key_exchange(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  CBS client_key_exchange;
   uint8_t *premaster_secret = NULL;
   size_t premaster_secret_len = 0;
   uint8_t *decrypt_buf = NULL;
 
-  int ret = ssl->method->ssl_get_message(ssl);
+  SSLMessage msg;
+  int ret = ssl_read_message(ssl, &msg);
   if (ret <= 0) {
     return ret;
   }
 
-  if (!ssl_check_message_type(ssl, SSL3_MT_CLIENT_KEY_EXCHANGE)) {
+  if (!ssl_check_message_type(ssl, msg, SSL3_MT_CLIENT_KEY_EXCHANGE)) {
     return -1;
   }
 
-  CBS_init(&client_key_exchange, ssl->init_msg, ssl->init_num);
+  CBS client_key_exchange = msg.body;
   uint32_t alg_k = hs->new_cipher->algorithm_mkey;
   uint32_t alg_a = hs->new_cipher->algorithm_auth;
 
@@ -1481,7 +1485,7 @@
     premaster_secret_len = new_len;
   }
 
-  if (!ssl_hash_current_message(hs)) {
+  if (!ssl_hash_message(hs, msg)) {
     goto err;
   }
 
@@ -1510,7 +1514,6 @@
 
 static int ssl3_get_cert_verify(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  CBS certificate_verify, signature;
 
   /* Only RSA and ECDSA client certificates are supported, so a
    * CertificateVerify is required if and only if there's a client certificate.
@@ -1520,18 +1523,19 @@
     return 1;
   }
 
-  int msg_ret = ssl->method->ssl_get_message(ssl);
-  if (msg_ret <= 0) {
-    return msg_ret;
+  SSLMessage msg;
+  int ret = ssl_read_message(ssl, &msg);
+  if (ret <= 0) {
+    return ret;
   }
 
-  if (!ssl_check_message_type(ssl, SSL3_MT_CERTIFICATE_VERIFY)) {
+  if (!ssl_check_message_type(ssl, msg, SSL3_MT_CERTIFICATE_VERIFY)) {
     return -1;
   }
 
-  CBS_init(&certificate_verify, ssl->init_msg, ssl->init_num);
+  CBS certificate_verify = msg.body, signature;
 
-  /* Determine the digest type if needbe. */
+  /* Determine the signature algorithm. */
   uint16_t signature_algorithm = 0;
   if (ssl3_protocol_version(ssl) >= TLS1_2_VERSION) {
     if (!CBS_get_u16(&certificate_verify, &signature_algorithm)) {
@@ -1597,7 +1601,7 @@
   /* The handshake buffer is no longer necessary, and we may hash the current
    * message.*/
   hs->transcript.FreeBuffer();
-  if (!ssl_hash_current_message(hs)) {
+  if (!ssl_hash_message(hs, msg)) {
     return -1;
   }
 
@@ -1609,18 +1613,18 @@
  * sets the next_proto member in s if found */
 static int ssl3_get_next_proto(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  int ret = ssl->method->ssl_get_message(ssl);
+  SSLMessage msg;
+  int ret = ssl_read_message(ssl, &msg);
   if (ret <= 0) {
     return ret;
   }
 
-  if (!ssl_check_message_type(ssl, SSL3_MT_NEXT_PROTO) ||
-      !ssl_hash_current_message(hs)) {
+  if (!ssl_check_message_type(ssl, msg, SSL3_MT_NEXT_PROTO) ||
+      !ssl_hash_message(hs, msg)) {
     return -1;
   }
 
-  CBS next_protocol, selected_protocol, padding;
-  CBS_init(&next_protocol, ssl->init_msg, ssl->init_num);
+  CBS next_protocol = msg.body, selected_protocol, padding;
   if (!CBS_get_u8_length_prefixed(&next_protocol, &selected_protocol) ||
       !CBS_get_u8_length_prefixed(&next_protocol, &padding) ||
       CBS_len(&next_protocol) != 0) {
@@ -1641,14 +1645,15 @@
 /* ssl3_get_channel_id reads and verifies a ClientID handshake message. */
 static int ssl3_get_channel_id(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  int msg_ret = ssl->method->ssl_get_message(ssl);
-  if (msg_ret <= 0) {
-    return msg_ret;
+  SSLMessage msg;
+  int ret = ssl_read_message(ssl, &msg);
+  if (ret <= 0) {
+    return ret;
   }
 
-  if (!ssl_check_message_type(ssl, SSL3_MT_CHANNEL_ID) ||
-      !tls1_verify_channel_id(hs) ||
-      !ssl_hash_current_message(hs)) {
+  if (!ssl_check_message_type(ssl, msg, SSL3_MT_CHANNEL_ID) ||
+      !tls1_verify_channel_id(hs, msg) ||
+      !ssl_hash_message(hs, msg)) {
     return -1;
   }
   ssl->method->next_message(ssl);