Maintain SSL_HANDSHAKE lifetime outside of handshake_func.

We currently look up SSL_HANDSHAKE off of ssl->s3->hs everywhere, but
this is a little dangerous. Unlike ssl->s3->tmp, ssl->s3->hs may not be
present. Right now we just know not to call some functions outside the
handshake.

Instead, code which expects to only be called during a handshake should
take an explicit SSL_HANDSHAKE * parameter and can assume it non-NULL.
This replaces the SSL * parameter. Instead, that is looked up from
hs->ssl.

Code which is called in both cases, reads from ssl->s3->hs. Ultimately,
we should get to the point that all direct access of ssl->s3->hs needs
to be NULL-checked.

As a start, manage the lifetime of the ssl->s3->hs in SSL_do_handshake.
This allows the top-level handshake_func hooks to be passed in the
SSL_HANDSHAKE *. Later work will route it through the stack. False Start
is a little wonky, but I think this is cleaner overall.

Change-Id: I26dfeb95f1bc5a0a630b5c442c90c26a6b9e2efe
Reviewed-on: https://boringssl-review.googlesource.com/12236
Reviewed-by: Steven Valdez <svaldez@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 796949d..62c12d9 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -4051,6 +4051,8 @@
   int freelist_max_len;
 };
 
+typedef struct ssl_handshake_st SSL_HANDSHAKE;
+
 struct ssl_st {
   /* method is the method table corresponding to the current protocol (DTLS or
    * TLS). */
@@ -4089,7 +4091,7 @@
    * with a better mechanism. */
   BIO *bbio;
 
-  int (*handshake_func)(SSL *);
+  int (*handshake_func)(SSL_HANDSHAKE *hs);
 
   BUF_MEM *init_buf; /* buffer used during init */
 
diff --git a/ssl/handshake_client.c b/ssl/handshake_client.c
index 70d8d96..8d503a5 100644
--- a/ssl/handshake_client.c
+++ b/ssl/handshake_client.c
@@ -186,7 +186,8 @@
 static int ssl3_send_channel_id(SSL *ssl);
 static int ssl3_get_new_session_ticket(SSL *ssl);
 
-int ssl3_connect(SSL *ssl) {
+int ssl3_connect(SSL_HANDSHAKE *hs) {
+  SSL *const ssl = hs->ssl;
   int ret = -1;
   int state, skip = 0;
 
@@ -205,12 +206,6 @@
       case SSL_ST_CONNECT:
         ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_START, 1);
 
-        ssl->s3->hs = ssl_handshake_new(tls13_client_handshake);
-        if (ssl->s3->hs == NULL) {
-          ret = -1;
-          goto end;
-        }
-
         if (!ssl_init_wbio_buffer(ssl)) {
           ret = -1;
           goto end;
@@ -277,7 +272,7 @@
         break;
 
       case SSL3_ST_CR_CERT_STATUS_A:
-        if (ssl->s3->hs->certificate_status_expected) {
+        if (hs->certificate_status_expected) {
           ret = ssl3_get_cert_status(ssl);
           if (ret <= 0) {
             goto end;
@@ -332,7 +327,7 @@
       case SSL3_ST_CW_CERT_A:
       case SSL3_ST_CW_CERT_B:
       case SSL3_ST_CW_CERT_C:
-        if (ssl->s3->hs->cert_request) {
+        if (hs->cert_request) {
           ret = ssl3_send_client_certificate(ssl);
           if (ret <= 0) {
             goto end;
@@ -355,7 +350,7 @@
       case SSL3_ST_CW_CERT_VRFY_A:
       case SSL3_ST_CW_CERT_VRFY_B:
       case SSL3_ST_CW_CERT_VRFY_C:
-        if (ssl->s3->hs->cert_request) {
+        if (hs->cert_request) {
           ret = ssl3_send_cert_verify(ssl);
           if (ret <= 0) {
             goto end;
@@ -383,7 +378,7 @@
 
       case SSL3_ST_CW_NEXT_PROTO_A:
       case SSL3_ST_CW_NEXT_PROTO_B:
-        if (ssl->s3->hs->next_proto_neg_seen) {
+        if (hs->next_proto_neg_seen) {
           ret = ssl3_send_next_proto(ssl);
           if (ret <= 0) {
             goto end;
@@ -441,14 +436,14 @@
 
       case SSL3_ST_FALSE_START:
         ssl->state = SSL3_ST_CR_SESSION_TICKET_A;
-        ssl->s3->hs->in_false_start = 1;
+        hs->in_false_start = 1;
 
         ssl_free_wbio_buffer(ssl);
         ret = 1;
         goto end;
 
       case SSL3_ST_CR_SESSION_TICKET_A:
-        if (ssl->s3->hs->ticket_expected) {
+        if (hs->ticket_expected) {
           ret = ssl3_get_new_session_ticket(ssl);
           if (ret <= 0) {
             goto end;
@@ -543,9 +538,6 @@
           ssl_update_cache(ssl, SSL_SESS_CACHE_CLIENT);
         }
 
-        ssl_handshake_free(ssl->s3->hs);
-        ssl->s3->hs = NULL;
-
         ret = 1;
         ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_DONE, 1);
         goto end;
@@ -894,6 +886,7 @@
 
   if (ssl3_protocol_version(ssl) >= TLS1_3_VERSION) {
     ssl->state = SSL_ST_TLS13;
+    ssl->s3->hs->do_tls13_handshake = tls13_client_handshake;
     return 1;
   }
 
diff --git a/ssl/handshake_server.c b/ssl/handshake_server.c
index 3b66ab7..ccf0e8b 100644
--- a/ssl/handshake_server.c
+++ b/ssl/handshake_server.c
@@ -185,7 +185,8 @@
 static int ssl3_get_channel_id(SSL *ssl);
 static int ssl3_send_new_session_ticket(SSL *ssl);
 
-int ssl3_accept(SSL *ssl) {
+int ssl3_accept(SSL_HANDSHAKE *hs) {
+  SSL *const ssl = hs->ssl;
   uint32_t alg_a;
   int ret = -1;
   int state, skip = 0;
@@ -205,12 +206,6 @@
       case SSL_ST_ACCEPT:
         ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_START, 1);
 
-        ssl->s3->hs = ssl_handshake_new(tls13_server_handshake);
-        if (ssl->s3->hs == NULL) {
-          ret = -1;
-          goto end;
-        }
-
         /* Enable a write buffer. This groups handshake messages within a flight
          * into a single write. */
         if (!ssl_init_wbio_buffer(ssl)) {
@@ -271,7 +266,7 @@
 
       case SSL3_ST_SW_CERT_STATUS_A:
       case SSL3_ST_SW_CERT_STATUS_B:
-        if (ssl->s3->hs->certificate_status_expected) {
+        if (hs->certificate_status_expected) {
           ret = ssl3_send_certificate_status(ssl);
           if (ret <= 0) {
             goto end;
@@ -303,7 +298,7 @@
 
       case SSL3_ST_SW_CERT_REQ_A:
       case SSL3_ST_SW_CERT_REQ_B:
-        if (ssl->s3->hs->cert_request) {
+        if (hs->cert_request) {
           ret = ssl3_send_certificate_request(ssl);
           if (ret <= 0) {
             goto end;
@@ -325,7 +320,7 @@
         break;
 
       case SSL3_ST_SR_CERT_A:
-        if (ssl->s3->hs->cert_request) {
+        if (hs->cert_request) {
           ret = ssl3_get_client_certificate(ssl);
           if (ret <= 0) {
             goto end;
@@ -367,7 +362,7 @@
         break;
 
       case SSL3_ST_SR_NEXT_PROTO_A:
-        if (ssl->s3->hs->next_proto_neg_seen) {
+        if (hs->next_proto_neg_seen) {
           ret = ssl3_get_next_proto(ssl);
           if (ret <= 0) {
             goto end;
@@ -416,7 +411,7 @@
 
       case SSL3_ST_SW_SESSION_TICKET_A:
       case SSL3_ST_SW_SESSION_TICKET_B:
-        if (ssl->s3->hs->ticket_expected) {
+        if (hs->ticket_expected) {
           ret = ssl3_send_new_session_ticket(ssl);
           if (ret <= 0) {
             goto end;
@@ -505,9 +500,6 @@
         ssl->s3->initial_handshake_complete = 1;
         ssl_update_cache(ssl, SSL_SESS_CACHE_SERVER);
 
-        ssl_handshake_free(ssl->s3->hs);
-        ssl->s3->hs = NULL;
-
         ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_DONE, 1);
         ret = 1;
         goto end;
@@ -717,6 +709,7 @@
 
     if (ssl3_protocol_version(ssl) >= TLS1_3_VERSION) {
       ssl->state = SSL_ST_TLS13;
+      ssl->s3->hs->do_tls13_handshake = tls13_server_handshake;
       return 1;
     }
   }
diff --git a/ssl/internal.h b/ssl/internal.h
index 5893d4d..af833fb 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -878,15 +878,18 @@
   ssl_hs_private_key_operation,
 };
 
-typedef struct ssl_handshake_st {
-  /* wait contains the operation |do_handshake| is currently blocking on or
-   * |ssl_hs_ok| if none. */
+struct ssl_handshake_st {
+  /* ssl is a non-owning pointer to the parent |SSL| object. */
+  SSL *ssl;
+
+  /* wait contains the operation |do_tls13_handshake| is currently blocking on
+   * or |ssl_hs_ok| if none. */
   enum ssl_hs_wait_t wait;
 
-  /* do_handshake runs the handshake. On completion, it returns |ssl_hs_ok|.
-   * Otherwise, it returns a value corresponding to what operation is needed to
-   * progress. */
-  enum ssl_hs_wait_t (*do_handshake)(SSL *ssl);
+  /* do_tls13_handshake runs the TLS 1.3 handshake. On completion, it returns
+   * |ssl_hs_ok|. Otherwise, it returns a value corresponding to what operation
+   * is needed to progress. */
+  enum ssl_hs_wait_t (*do_tls13_handshake)(SSL *ssl);
 
   int state;
 
@@ -1022,9 +1025,9 @@
 
   /* hostname, on the server, is the value of the SNI extension. */
   char *hostname;
-} SSL_HANDSHAKE;
+} /* SSL_HANDSHAKE */;
 
-SSL_HANDSHAKE *ssl_handshake_new(enum ssl_hs_wait_t (*do_handshake)(SSL *ssl));
+SSL_HANDSHAKE *ssl_handshake_new(SSL *ssl);
 
 /* ssl_handshake_free releases all memory associated with |hs|. */
 void ssl_handshake_free(SSL_HANDSHAKE *hs);
@@ -1033,7 +1036,7 @@
  * 0 on error. */
 int tls13_handshake(SSL *ssl);
 
-/* The following are implementations of |do_handshake| for the client and
+/* The following are implementations of |do_tls13_handshake| for the client and
  * server. */
 enum ssl_hs_wait_t tls13_client_handshake(SSL *ssl);
 enum ssl_hs_wait_t tls13_server_handshake(SSL *ssl);
@@ -1760,8 +1763,8 @@
 
 int ssl3_new(SSL *ssl);
 void ssl3_free(SSL *ssl);
-int ssl3_accept(SSL *ssl);
-int ssl3_connect(SSL *ssl);
+int ssl3_accept(SSL_HANDSHAKE *hs);
+int ssl3_connect(SSL_HANDSHAKE *hs);
 
 int ssl3_init_message(SSL *ssl, CBB *cbb, CBB *body, uint8_t type);
 int ssl3_finish_message(SSL *ssl, CBB *cbb, uint8_t **out_msg, size_t *out_len);
diff --git a/ssl/s3_both.c b/ssl/s3_both.c
index d872020..b27938a 100644
--- a/ssl/s3_both.c
+++ b/ssl/s3_both.c
@@ -130,14 +130,14 @@
 #include "internal.h"
 
 
-SSL_HANDSHAKE *ssl_handshake_new(enum ssl_hs_wait_t (*do_handshake)(SSL *ssl)) {
+SSL_HANDSHAKE *ssl_handshake_new(SSL *ssl) {
   SSL_HANDSHAKE *hs = OPENSSL_malloc(sizeof(SSL_HANDSHAKE));
   if (hs == NULL) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
     return NULL;
   }
   memset(hs, 0, sizeof(SSL_HANDSHAKE));
-  hs->do_handshake = do_handshake;
+  hs->ssl = ssl;
   hs->wait = ssl_hs_ok;
   return hs;
 }
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c
index aafad33..76a9de0 100644
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -624,7 +624,28 @@
     return 1;
   }
 
-  return ssl->handshake_func(ssl);
+  /* Set up a new handshake if necessary. */
+  if (ssl->state == SSL_ST_INIT && ssl->s3->hs == NULL) {
+    ssl->s3->hs = ssl_handshake_new(ssl);
+    if (ssl->s3->hs == NULL) {
+      return -1;
+    }
+  }
+
+  /* Run the handshake. */
+  assert(ssl->s3->hs != NULL);
+  int ret = ssl->handshake_func(ssl->s3->hs);
+  if (ret <= 0) {
+    return ret;
+  }
+
+  /* Destroy the handshake object if the handshake has completely finished. */
+  if (!SSL_in_init(ssl)) {
+    ssl_handshake_free(ssl->s3->hs);
+    ssl->s3->hs = NULL;
+  }
+
+  return 1;
 }
 
 int SSL_connect(SSL *ssl) {
diff --git a/ssl/tls13_both.c b/ssl/tls13_both.c
index 17f7161..c8d32a1 100644
--- a/ssl/tls13_both.c
+++ b/ssl/tls13_both.c
@@ -95,7 +95,7 @@
     }
 
     /* Run the state machine again. */
-    hs->wait = hs->do_handshake(ssl);
+    hs->wait = hs->do_tls13_handshake(ssl);
     if (hs->wait == ssl_hs_error) {
       /* Don't loop around to avoid a stray |SSL_R_SSL_HANDSHAKE_FAILURE| the
        * first time around. */