Maintain SSL_HANDSHAKE lifetime outside of handshake_func. We currently look up SSL_HANDSHAKE off of ssl->s3->hs everywhere, but this is a little dangerous. Unlike ssl->s3->tmp, ssl->s3->hs may not be present. Right now we just know not to call some functions outside the handshake. Instead, code which expects to only be called during a handshake should take an explicit SSL_HANDSHAKE * parameter and can assume it non-NULL. This replaces the SSL * parameter. Instead, that is looked up from hs->ssl. Code which is called in both cases, reads from ssl->s3->hs. Ultimately, we should get to the point that all direct access of ssl->s3->hs needs to be NULL-checked. As a start, manage the lifetime of the ssl->s3->hs in SSL_do_handshake. This allows the top-level handshake_func hooks to be passed in the SSL_HANDSHAKE *. Later work will route it through the stack. False Start is a little wonky, but I think this is cleaner overall. Change-Id: I26dfeb95f1bc5a0a630b5c442c90c26a6b9e2efe Reviewed-on: https://boringssl-review.googlesource.com/12236 Reviewed-by: Steven Valdez <svaldez@google.com> Reviewed-by: David Benjamin <davidben@google.com> Commit-Queue: David Benjamin <davidben@google.com> CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/handshake_client.c b/ssl/handshake_client.c index 70d8d96..8d503a5 100644 --- a/ssl/handshake_client.c +++ b/ssl/handshake_client.c
@@ -186,7 +186,8 @@ static int ssl3_send_channel_id(SSL *ssl); static int ssl3_get_new_session_ticket(SSL *ssl); -int ssl3_connect(SSL *ssl) { +int ssl3_connect(SSL_HANDSHAKE *hs) { + SSL *const ssl = hs->ssl; int ret = -1; int state, skip = 0; @@ -205,12 +206,6 @@ case SSL_ST_CONNECT: ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_START, 1); - ssl->s3->hs = ssl_handshake_new(tls13_client_handshake); - if (ssl->s3->hs == NULL) { - ret = -1; - goto end; - } - if (!ssl_init_wbio_buffer(ssl)) { ret = -1; goto end; @@ -277,7 +272,7 @@ break; case SSL3_ST_CR_CERT_STATUS_A: - if (ssl->s3->hs->certificate_status_expected) { + if (hs->certificate_status_expected) { ret = ssl3_get_cert_status(ssl); if (ret <= 0) { goto end; @@ -332,7 +327,7 @@ case SSL3_ST_CW_CERT_A: case SSL3_ST_CW_CERT_B: case SSL3_ST_CW_CERT_C: - if (ssl->s3->hs->cert_request) { + if (hs->cert_request) { ret = ssl3_send_client_certificate(ssl); if (ret <= 0) { goto end; @@ -355,7 +350,7 @@ case SSL3_ST_CW_CERT_VRFY_A: case SSL3_ST_CW_CERT_VRFY_B: case SSL3_ST_CW_CERT_VRFY_C: - if (ssl->s3->hs->cert_request) { + if (hs->cert_request) { ret = ssl3_send_cert_verify(ssl); if (ret <= 0) { goto end; @@ -383,7 +378,7 @@ case SSL3_ST_CW_NEXT_PROTO_A: case SSL3_ST_CW_NEXT_PROTO_B: - if (ssl->s3->hs->next_proto_neg_seen) { + if (hs->next_proto_neg_seen) { ret = ssl3_send_next_proto(ssl); if (ret <= 0) { goto end; @@ -441,14 +436,14 @@ case SSL3_ST_FALSE_START: ssl->state = SSL3_ST_CR_SESSION_TICKET_A; - ssl->s3->hs->in_false_start = 1; + hs->in_false_start = 1; ssl_free_wbio_buffer(ssl); ret = 1; goto end; case SSL3_ST_CR_SESSION_TICKET_A: - if (ssl->s3->hs->ticket_expected) { + if (hs->ticket_expected) { ret = ssl3_get_new_session_ticket(ssl); if (ret <= 0) { goto end; @@ -543,9 +538,6 @@ ssl_update_cache(ssl, SSL_SESS_CACHE_CLIENT); } - ssl_handshake_free(ssl->s3->hs); - ssl->s3->hs = NULL; - ret = 1; ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_DONE, 1); goto end; @@ -894,6 +886,7 @@ if (ssl3_protocol_version(ssl) >= TLS1_3_VERSION) { ssl->state = SSL_ST_TLS13; + ssl->s3->hs->do_tls13_handshake = tls13_client_handshake; return 1; }
diff --git a/ssl/handshake_server.c b/ssl/handshake_server.c index 3b66ab7..ccf0e8b 100644 --- a/ssl/handshake_server.c +++ b/ssl/handshake_server.c
@@ -185,7 +185,8 @@ static int ssl3_get_channel_id(SSL *ssl); static int ssl3_send_new_session_ticket(SSL *ssl); -int ssl3_accept(SSL *ssl) { +int ssl3_accept(SSL_HANDSHAKE *hs) { + SSL *const ssl = hs->ssl; uint32_t alg_a; int ret = -1; int state, skip = 0; @@ -205,12 +206,6 @@ case SSL_ST_ACCEPT: ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_START, 1); - ssl->s3->hs = ssl_handshake_new(tls13_server_handshake); - if (ssl->s3->hs == NULL) { - ret = -1; - goto end; - } - /* Enable a write buffer. This groups handshake messages within a flight * into a single write. */ if (!ssl_init_wbio_buffer(ssl)) { @@ -271,7 +266,7 @@ case SSL3_ST_SW_CERT_STATUS_A: case SSL3_ST_SW_CERT_STATUS_B: - if (ssl->s3->hs->certificate_status_expected) { + if (hs->certificate_status_expected) { ret = ssl3_send_certificate_status(ssl); if (ret <= 0) { goto end; @@ -303,7 +298,7 @@ case SSL3_ST_SW_CERT_REQ_A: case SSL3_ST_SW_CERT_REQ_B: - if (ssl->s3->hs->cert_request) { + if (hs->cert_request) { ret = ssl3_send_certificate_request(ssl); if (ret <= 0) { goto end; @@ -325,7 +320,7 @@ break; case SSL3_ST_SR_CERT_A: - if (ssl->s3->hs->cert_request) { + if (hs->cert_request) { ret = ssl3_get_client_certificate(ssl); if (ret <= 0) { goto end; @@ -367,7 +362,7 @@ break; case SSL3_ST_SR_NEXT_PROTO_A: - if (ssl->s3->hs->next_proto_neg_seen) { + if (hs->next_proto_neg_seen) { ret = ssl3_get_next_proto(ssl); if (ret <= 0) { goto end; @@ -416,7 +411,7 @@ case SSL3_ST_SW_SESSION_TICKET_A: case SSL3_ST_SW_SESSION_TICKET_B: - if (ssl->s3->hs->ticket_expected) { + if (hs->ticket_expected) { ret = ssl3_send_new_session_ticket(ssl); if (ret <= 0) { goto end; @@ -505,9 +500,6 @@ ssl->s3->initial_handshake_complete = 1; ssl_update_cache(ssl, SSL_SESS_CACHE_SERVER); - ssl_handshake_free(ssl->s3->hs); - ssl->s3->hs = NULL; - ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_DONE, 1); ret = 1; goto end; @@ -717,6 +709,7 @@ if (ssl3_protocol_version(ssl) >= TLS1_3_VERSION) { ssl->state = SSL_ST_TLS13; + ssl->s3->hs->do_tls13_handshake = tls13_server_handshake; return 1; } }
diff --git a/ssl/internal.h b/ssl/internal.h index 5893d4d..af833fb 100644 --- a/ssl/internal.h +++ b/ssl/internal.h
@@ -878,15 +878,18 @@ ssl_hs_private_key_operation, }; -typedef struct ssl_handshake_st { - /* wait contains the operation |do_handshake| is currently blocking on or - * |ssl_hs_ok| if none. */ +struct ssl_handshake_st { + /* ssl is a non-owning pointer to the parent |SSL| object. */ + SSL *ssl; + + /* wait contains the operation |do_tls13_handshake| is currently blocking on + * or |ssl_hs_ok| if none. */ enum ssl_hs_wait_t wait; - /* do_handshake runs the handshake. On completion, it returns |ssl_hs_ok|. - * Otherwise, it returns a value corresponding to what operation is needed to - * progress. */ - enum ssl_hs_wait_t (*do_handshake)(SSL *ssl); + /* do_tls13_handshake runs the TLS 1.3 handshake. On completion, it returns + * |ssl_hs_ok|. Otherwise, it returns a value corresponding to what operation + * is needed to progress. */ + enum ssl_hs_wait_t (*do_tls13_handshake)(SSL *ssl); int state; @@ -1022,9 +1025,9 @@ /* hostname, on the server, is the value of the SNI extension. */ char *hostname; -} SSL_HANDSHAKE; +} /* SSL_HANDSHAKE */; -SSL_HANDSHAKE *ssl_handshake_new(enum ssl_hs_wait_t (*do_handshake)(SSL *ssl)); +SSL_HANDSHAKE *ssl_handshake_new(SSL *ssl); /* ssl_handshake_free releases all memory associated with |hs|. */ void ssl_handshake_free(SSL_HANDSHAKE *hs); @@ -1033,7 +1036,7 @@ * 0 on error. */ int tls13_handshake(SSL *ssl); -/* The following are implementations of |do_handshake| for the client and +/* The following are implementations of |do_tls13_handshake| for the client and * server. */ enum ssl_hs_wait_t tls13_client_handshake(SSL *ssl); enum ssl_hs_wait_t tls13_server_handshake(SSL *ssl); @@ -1760,8 +1763,8 @@ int ssl3_new(SSL *ssl); void ssl3_free(SSL *ssl); -int ssl3_accept(SSL *ssl); -int ssl3_connect(SSL *ssl); +int ssl3_accept(SSL_HANDSHAKE *hs); +int ssl3_connect(SSL_HANDSHAKE *hs); int ssl3_init_message(SSL *ssl, CBB *cbb, CBB *body, uint8_t type); int ssl3_finish_message(SSL *ssl, CBB *cbb, uint8_t **out_msg, size_t *out_len);
diff --git a/ssl/s3_both.c b/ssl/s3_both.c index d872020..b27938a 100644 --- a/ssl/s3_both.c +++ b/ssl/s3_both.c
@@ -130,14 +130,14 @@ #include "internal.h" -SSL_HANDSHAKE *ssl_handshake_new(enum ssl_hs_wait_t (*do_handshake)(SSL *ssl)) { +SSL_HANDSHAKE *ssl_handshake_new(SSL *ssl) { SSL_HANDSHAKE *hs = OPENSSL_malloc(sizeof(SSL_HANDSHAKE)); if (hs == NULL) { OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE); return NULL; } memset(hs, 0, sizeof(SSL_HANDSHAKE)); - hs->do_handshake = do_handshake; + hs->ssl = ssl; hs->wait = ssl_hs_ok; return hs; }
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c index aafad33..76a9de0 100644 --- a/ssl/ssl_lib.c +++ b/ssl/ssl_lib.c
@@ -624,7 +624,28 @@ return 1; } - return ssl->handshake_func(ssl); + /* Set up a new handshake if necessary. */ + if (ssl->state == SSL_ST_INIT && ssl->s3->hs == NULL) { + ssl->s3->hs = ssl_handshake_new(ssl); + if (ssl->s3->hs == NULL) { + return -1; + } + } + + /* Run the handshake. */ + assert(ssl->s3->hs != NULL); + int ret = ssl->handshake_func(ssl->s3->hs); + if (ret <= 0) { + return ret; + } + + /* Destroy the handshake object if the handshake has completely finished. */ + if (!SSL_in_init(ssl)) { + ssl_handshake_free(ssl->s3->hs); + ssl->s3->hs = NULL; + } + + return 1; } int SSL_connect(SSL *ssl) {
diff --git a/ssl/tls13_both.c b/ssl/tls13_both.c index 17f7161..c8d32a1 100644 --- a/ssl/tls13_both.c +++ b/ssl/tls13_both.c
@@ -95,7 +95,7 @@ } /* Run the state machine again. */ - hs->wait = hs->do_handshake(ssl); + hs->wait = hs->do_tls13_handshake(ssl); if (hs->wait == ssl_hs_error) { /* Don't loop around to avoid a stray |SSL_R_SSL_HANDSHAKE_FAILURE| the * first time around. */