Move optional message type checks out of ssl_get_message.
This aligns the TLS 1.2 state machine closer with the TLS 1.3 state
machine. This is more work for the handshake, but ultimately the
plan is to take the ssl_get_message call out of the handshake (so it is
just the state machine rather than calling into BIO), so the parameters
need to be folded out as in TLS 1.3.
The WrongMessageType-* family of tests should make sure we don't miss
one of these.
BUG=128
Change-Id: I17a1e6177c52a7540b2bc6b0b3f926ab386c4950
Reviewed-on: https://boringssl-review.googlesource.com/13264
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/d1_both.c b/ssl/d1_both.c
index 049f7a4..96bae41 100644
--- a/ssl/d1_both.c
+++ b/ssl/d1_both.c
@@ -395,8 +395,7 @@
return 1;
}
-int dtls1_get_message(SSL *ssl, int msg_type,
- enum ssl_hash_message_t hash_message) {
+int dtls1_get_message(SSL *ssl, enum ssl_hash_message_t hash_message) {
if (ssl->s3->tmp.reuse_message) {
/* A ssl_dont_hash_message call cannot be combined with reuse_message; the
* ssl_dont_hash_message would have to have been applied to the previous
@@ -430,11 +429,6 @@
ssl->init_msg = frag->data + DTLS1_HM_HEADER_LENGTH;
ssl->init_num = frag->msg_len;
- if (msg_type >= 0 && ssl->s3->tmp.message_type != msg_type) {
- ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
- OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
- return -1;
- }
if (hash_message == ssl_hash_message && !ssl_hash_current_message(ssl)) {
return -1;
}
diff --git a/ssl/handshake_client.c b/ssl/handshake_client.c
index d53cb01..acd4370 100644
--- a/ssl/handshake_client.c
+++ b/ssl/handshake_client.c
@@ -808,7 +808,7 @@
CBS hello_verify_request, cookie;
uint16_t server_version;
- int ret = ssl->method->ssl_get_message(ssl, -1, ssl_hash_message);
+ int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (ret <= 0) {
return ret;
}
@@ -852,7 +852,7 @@
uint16_t server_wire_version, cipher_suite;
uint8_t compression_method;
- int ret = ssl->method->ssl_get_message(ssl, -1, ssl_hash_message);
+ int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (ret <= 0) {
uint32_t err = ERR_peek_error();
if (ERR_GET_LIB(err) == ERR_LIB_SSL &&
@@ -914,9 +914,7 @@
ssl_clear_tls13_state(hs);
- if (ssl->s3->tmp.message_type != SSL3_MT_SERVER_HELLO) {
- ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
- OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
+ if (!ssl_check_message_type(ssl, SSL3_MT_SERVER_HELLO)) {
return -1;
}
@@ -1053,12 +1051,15 @@
static int ssl3_get_server_certificate(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- int ret =
- ssl->method->ssl_get_message(ssl, SSL3_MT_CERTIFICATE, ssl_hash_message);
+ int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (ret <= 0) {
return ret;
}
+ if (!ssl_check_message_type(ssl, SSL3_MT_CERTIFICATE)) {
+ return -1;
+ }
+
CBS cbs;
CBS_init(&cbs, ssl->init_msg, ssl->init_num);
@@ -1097,7 +1098,7 @@
CBS certificate_status, ocsp_response;
uint8_t status_type;
- int ret = ssl->method->ssl_get_message(ssl, -1, ssl_hash_message);
+ int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (ret <= 0) {
return ret;
}
@@ -1150,7 +1151,7 @@
EC_KEY *ecdh = NULL;
EC_POINT *srvr_ecpoint = NULL;
- int ret = ssl->method->ssl_get_message(ssl, -1, ssl_hash_message);
+ int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (ret <= 0) {
return ret;
}
@@ -1380,7 +1381,7 @@
static int ssl3_get_certificate_request(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- int msg_ret = ssl->method->ssl_get_message(ssl, -1, ssl_hash_message);
+ int msg_ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (msg_ret <= 0) {
return msg_ret;
}
@@ -1393,9 +1394,7 @@
return 1;
}
- if (ssl->s3->tmp.message_type != SSL3_MT_CERTIFICATE_REQUEST) {
- ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
- OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
+ if (!ssl_check_message_type(ssl, SSL3_MT_CERTIFICATE_REQUEST)) {
return -1;
}
@@ -1448,12 +1447,15 @@
static int ssl3_get_server_hello_done(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- int ret = ssl->method->ssl_get_message(ssl, SSL3_MT_SERVER_HELLO_DONE,
- ssl_hash_message);
+ int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (ret <= 0) {
return ret;
}
+ if (!ssl_check_message_type(ssl, SSL3_MT_SERVER_HELLO_DONE)) {
+ return -1;
+ }
+
/* ServerHelloDone is empty. */
if (ssl->init_num > 0) {
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
@@ -1829,12 +1831,15 @@
static int ssl3_get_new_session_ticket(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- int ret = ssl->method->ssl_get_message(ssl, SSL3_MT_NEW_SESSION_TICKET,
- ssl_hash_message);
+ int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (ret <= 0) {
return ret;
}
+ if (!ssl_check_message_type(ssl, SSL3_MT_NEW_SESSION_TICKET)) {
+ return -1;
+ }
+
CBS new_session_ticket, ticket;
uint32_t tlsext_tick_lifetime_hint;
CBS_init(&new_session_ticket, ssl->init_msg, ssl->init_num);
diff --git a/ssl/handshake_server.c b/ssl/handshake_server.c
index dc39c93..8c85799 100644
--- a/ssl/handshake_server.c
+++ b/ssl/handshake_server.c
@@ -815,12 +815,15 @@
if (hs->state == SSL3_ST_SR_CLNT_HELLO_A) {
/* The first time around, read the ClientHello. */
- int msg_ret = ssl->method->ssl_get_message(ssl, SSL3_MT_CLIENT_HELLO,
- ssl_hash_message);
+ int msg_ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (msg_ret <= 0) {
return msg_ret;
}
+ if (!ssl_check_message_type(ssl, SSL3_MT_CLIENT_HELLO)) {
+ return -1;
+ }
+
hs->state = SSL3_ST_SR_CLNT_HELLO_B;
}
@@ -1399,7 +1402,7 @@
SSL *const ssl = hs->ssl;
assert(hs->cert_request);
- int msg_ret = ssl->method->ssl_get_message(ssl, -1, ssl_hash_message);
+ int msg_ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (msg_ret <= 0) {
return msg_ret;
}
@@ -1503,11 +1506,14 @@
uint8_t psk[PSK_MAX_PSK_LEN];
if (hs->state == SSL3_ST_SR_KEY_EXCH_A) {
- int ret = ssl->method->ssl_get_message(ssl, SSL3_MT_CLIENT_KEY_EXCHANGE,
- ssl_hash_message);
+ int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (ret <= 0) {
return ret;
}
+
+ if (!ssl_check_message_type(ssl, SSL3_MT_CLIENT_KEY_EXCHANGE)) {
+ return -1;
+ }
}
CBS_init(&client_key_exchange, ssl->init_msg, ssl->init_num);
@@ -1771,12 +1777,15 @@
return 1;
}
- int msg_ret = ssl->method->ssl_get_message(ssl, SSL3_MT_CERTIFICATE_VERIFY,
- ssl_dont_hash_message);
+ int msg_ret = ssl->method->ssl_get_message(ssl, ssl_dont_hash_message);
if (msg_ret <= 0) {
return msg_ret;
}
+ if (!ssl_check_message_type(ssl, SSL3_MT_CERTIFICATE_VERIFY)) {
+ return -1;
+ }
+
CBS_init(&certificate_verify, ssl->init_msg, ssl->init_num);
/* Determine the digest type if needbe. */
@@ -1865,11 +1874,15 @@
static int ssl3_get_next_proto(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
int ret =
- ssl->method->ssl_get_message(ssl, SSL3_MT_NEXT_PROTO, ssl_hash_message);
+ ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (ret <= 0) {
return ret;
}
+ if (!ssl_check_message_type(ssl, SSL3_MT_NEXT_PROTO)) {
+ return -1;
+ }
+
CBS next_protocol, selected_protocol, padding;
CBS_init(&next_protocol, ssl->init_msg, ssl->init_num);
if (!CBS_get_u8_length_prefixed(&next_protocol, &selected_protocol) ||
@@ -1891,13 +1904,13 @@
/* ssl3_get_channel_id reads and verifies a ClientID handshake message. */
static int ssl3_get_channel_id(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- int msg_ret = ssl->method->ssl_get_message(ssl, SSL3_MT_CHANNEL_ID,
- ssl_dont_hash_message);
+ int msg_ret = ssl->method->ssl_get_message(ssl, ssl_dont_hash_message);
if (msg_ret <= 0) {
return msg_ret;
}
- if (!tls1_verify_channel_id(ssl) ||
+ if (!ssl_check_message_type(ssl, SSL3_MT_CHANNEL_ID) ||
+ !tls1_verify_channel_id(ssl) ||
!ssl_hash_current_message(ssl)) {
return -1;
}
diff --git a/ssl/internal.h b/ssl/internal.h
index 8b94689..bb2c30f 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1064,6 +1064,10 @@
/* ssl_handshake_free releases all memory associated with |hs|. */
void ssl_handshake_free(SSL_HANDSHAKE *hs);
+/* ssl_check_message_type checks if the current message has type |type|. If so
+ * it returns one. Otherwise, it sends an alert and returns zero. */
+int ssl_check_message_type(SSL *ssl, int type);
+
/* tls13_handshake runs the TLS 1.3 handshake. It returns one on success and <=
* 0 on error. */
int tls13_handshake(SSL_HANDSHAKE *hs);
@@ -1077,10 +1081,6 @@
* success and zero on failure. */
int tls13_post_handshake(SSL *ssl);
-/* tls13_check_message_type checks if the current message has type |type|. If so
- * it returns one. Otherwise, it sends an alert and returns zero. */
-int tls13_check_message_type(SSL *ssl, int type);
-
int tls13_process_certificate(SSL_HANDSHAKE *hs, int allow_anonymous);
int tls13_process_certificate_verify(SSL_HANDSHAKE *hs);
int tls13_process_finished(SSL_HANDSHAKE *hs);
@@ -1304,12 +1304,10 @@
uint16_t (*version_to_wire)(uint16_t version);
int (*ssl_new)(SSL *ssl);
void (*ssl_free)(SSL *ssl);
- /* ssl_get_message reads the next handshake message. If |msg_type| is not -1,
- * the message must have the specified type. On success, it returns one and
- * sets |ssl->s3->tmp.message_type|, |ssl->init_msg|, and |ssl->init_num|.
- * Otherwise, it returns <= 0. */
- int (*ssl_get_message)(SSL *ssl, int msg_type,
- enum ssl_hash_message_t hash_message);
+ /* ssl_get_message reads the next handshake message. On success, it returns
+ * one and sets |ssl->s3->tmp.message_type|, |ssl->init_msg|, and
+ * |ssl->init_num|. Otherwise, it returns <= 0. */
+ int (*ssl_get_message)(SSL *ssl, enum ssl_hash_message_t hash_message);
/* get_current_message sets |*out| to the current handshake message. This
* includes the protocol-specific message header. */
void (*get_current_message)(const SSL *ssl, CBS *out);
@@ -1766,8 +1764,7 @@
int ssl3_get_finished(SSL_HANDSHAKE *hs);
int ssl3_send_alert(SSL *ssl, int level, int desc);
-int ssl3_get_message(SSL *ssl, int msg_type,
- enum ssl_hash_message_t hash_message);
+int ssl3_get_message(SSL *ssl, enum ssl_hash_message_t hash_message);
void ssl3_get_current_message(const SSL *ssl, CBS *out);
void ssl3_release_current_message(SSL *ssl, int free_buffer);
@@ -1853,7 +1850,7 @@
int dtls1_connect(SSL *ssl);
void dtls1_free(SSL *ssl);
-int dtls1_get_message(SSL *ssl, int mt, enum ssl_hash_message_t hash_message);
+int dtls1_get_message(SSL *ssl, enum ssl_hash_message_t hash_message);
void dtls1_get_current_message(const SSL *ssl, CBS *out);
void dtls1_release_current_message(SSL *ssl, int free_buffer);
int dtls1_dispatch_alert(SSL *ssl);
diff --git a/ssl/s3_both.c b/ssl/s3_both.c
index a7dac7c..70ed435 100644
--- a/ssl/s3_both.c
+++ b/ssl/s3_both.c
@@ -180,6 +180,18 @@
OPENSSL_free(hs);
}
+int ssl_check_message_type(SSL *ssl, int type) {
+ if (ssl->s3->tmp.message_type != type) {
+ ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
+ OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
+ ERR_add_error_dataf("got type %d, wanted type %d",
+ ssl->s3->tmp.message_type, type);
+ return 0;
+ }
+
+ return 1;
+}
+
static int add_record_to_flight(SSL *ssl, uint8_t type, const uint8_t *in,
size_t in_len) {
/* We'll never add a flight while in the process of writing it out. */
@@ -386,12 +398,15 @@
int ssl3_get_finished(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- int ret = ssl->method->ssl_get_message(ssl, SSL3_MT_FINISHED,
- ssl_dont_hash_message);
+ int ret = ssl->method->ssl_get_message(ssl, ssl_dont_hash_message);
if (ret <= 0) {
return ret;
}
+ if (!ssl_check_message_type(ssl, SSL3_MT_FINISHED)) {
+ return -1;
+ }
+
/* Snapshot the finished hash before incorporating the new message. */
uint8_t finished[EVP_MAX_MD_SIZE];
size_t finished_len =
@@ -645,8 +660,7 @@
return 1;
}
-int ssl3_get_message(SSL *ssl, int msg_type,
- enum ssl_hash_message_t hash_message) {
+int ssl3_get_message(SSL *ssl, enum ssl_hash_message_t hash_message) {
again:
/* Re-create the handshake buffer if needed. */
if (ssl->init_buf == NULL) {
@@ -725,12 +739,6 @@
goto again;
}
- if (msg_type >= 0 && ssl->s3->tmp.message_type != msg_type) {
- ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
- OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
- return -1;
- }
-
/* Feed this message into MAC computation. */
if (hash_message == ssl_hash_message && !ssl_hash_current_message(ssl)) {
return -1;
diff --git a/ssl/s3_pkt.c b/ssl/s3_pkt.c
index 5a31e1f..af7f4be 100644
--- a/ssl/s3_pkt.c
+++ b/ssl/s3_pkt.c
@@ -372,7 +372,7 @@
}
/* Parse post-handshake handshake messages. */
- int ret = ssl3_get_message(ssl, -1, ssl_dont_hash_message);
+ int ret = ssl3_get_message(ssl, ssl_dont_hash_message);
if (ret <= 0) {
return ret;
}
diff --git a/ssl/tls13_both.c b/ssl/tls13_both.c
index d4c1f1a..67308b6 100644
--- a/ssl/tls13_both.c
+++ b/ssl/tls13_both.c
@@ -58,7 +58,7 @@
}
case ssl_hs_read_message: {
- int ret = ssl->method->ssl_get_message(ssl, -1, ssl_dont_hash_message);
+ int ret = ssl->method->ssl_get_message(ssl, ssl_dont_hash_message);
if (ret <= 0) {
return ret;
}
@@ -398,18 +398,6 @@
return ret;
}
-int tls13_check_message_type(SSL *ssl, int type) {
- if (ssl->s3->tmp.message_type != type) {
- ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
- OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
- ERR_add_error_dataf("got type %d, wanted type %d",
- ssl->s3->tmp.message_type, type);
- return 0;
- }
-
- return 1;
-}
-
int tls13_process_finished(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
uint8_t verify_data[EVP_MAX_MD_SIZE];
diff --git a/ssl/tls13_client.c b/ssl/tls13_client.c
index ad279f5..5000d17 100644
--- a/ssl/tls13_client.c
+++ b/ssl/tls13_client.c
@@ -150,7 +150,7 @@
static enum ssl_hs_wait_t do_process_server_hello(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- if (!tls13_check_message_type(ssl, SSL3_MT_SERVER_HELLO)) {
+ if (!ssl_check_message_type(ssl, SSL3_MT_SERVER_HELLO)) {
return ssl_hs_error;
}
@@ -338,7 +338,7 @@
static enum ssl_hs_wait_t do_process_encrypted_extensions(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- if (!tls13_check_message_type(ssl, SSL3_MT_ENCRYPTED_EXTENSIONS)) {
+ if (!ssl_check_message_type(ssl, SSL3_MT_ENCRYPTED_EXTENSIONS)) {
return ssl_hs_error;
}
@@ -420,7 +420,7 @@
static enum ssl_hs_wait_t do_process_server_certificate(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- if (!tls13_check_message_type(ssl, SSL3_MT_CERTIFICATE) ||
+ if (!ssl_check_message_type(ssl, SSL3_MT_CERTIFICATE) ||
!tls13_process_certificate(hs, 0 /* certificate required */) ||
!ssl_hash_current_message(ssl)) {
return ssl_hs_error;
@@ -433,7 +433,7 @@
static enum ssl_hs_wait_t do_process_server_certificate_verify(
SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- if (!tls13_check_message_type(ssl, SSL3_MT_CERTIFICATE_VERIFY) ||
+ if (!ssl_check_message_type(ssl, SSL3_MT_CERTIFICATE_VERIFY) ||
!tls13_process_certificate_verify(hs) ||
!ssl_hash_current_message(ssl)) {
return ssl_hs_error;
@@ -445,7 +445,7 @@
static enum ssl_hs_wait_t do_process_server_finished(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- if (!tls13_check_message_type(ssl, SSL3_MT_FINISHED) ||
+ if (!ssl_check_message_type(ssl, SSL3_MT_FINISHED) ||
!tls13_process_finished(hs) ||
!ssl_hash_current_message(ssl) ||
/* Update the secret to the master secret and derive traffic keys. */
diff --git a/ssl/tls13_server.c b/ssl/tls13_server.c
index 09b9cfe..52b2672 100644
--- a/ssl/tls13_server.c
+++ b/ssl/tls13_server.c
@@ -90,7 +90,7 @@
static enum ssl_hs_wait_t do_process_client_hello(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- if (!tls13_check_message_type(ssl, SSL3_MT_CLIENT_HELLO)) {
+ if (!ssl_check_message_type(ssl, SSL3_MT_CLIENT_HELLO)) {
return ssl_hs_error;
}
@@ -354,7 +354,7 @@
static enum ssl_hs_wait_t do_process_second_client_hello(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- if (!tls13_check_message_type(ssl, SSL3_MT_CLIENT_HELLO)) {
+ if (!ssl_check_message_type(ssl, SSL3_MT_CLIENT_HELLO)) {
return ssl_hs_error;
}
@@ -536,7 +536,7 @@
const int allow_anonymous =
(ssl->verify_mode & SSL_VERIFY_FAIL_IF_NO_PEER_CERT) == 0;
- if (!tls13_check_message_type(ssl, SSL3_MT_CERTIFICATE) ||
+ if (!ssl_check_message_type(ssl, SSL3_MT_CERTIFICATE) ||
!tls13_process_certificate(hs, allow_anonymous) ||
!ssl_hash_current_message(ssl)) {
return ssl_hs_error;
@@ -555,7 +555,7 @@
return ssl_hs_ok;
}
- if (!tls13_check_message_type(ssl, SSL3_MT_CERTIFICATE_VERIFY) ||
+ if (!ssl_check_message_type(ssl, SSL3_MT_CERTIFICATE_VERIFY) ||
!tls13_process_certificate_verify(hs) ||
!ssl_hash_current_message(ssl)) {
return ssl_hs_error;
@@ -572,7 +572,7 @@
return ssl_hs_ok;
}
- if (!tls13_check_message_type(ssl, SSL3_MT_CHANNEL_ID) ||
+ if (!ssl_check_message_type(ssl, SSL3_MT_CHANNEL_ID) ||
!tls1_verify_channel_id(ssl) ||
!ssl_hash_current_message(ssl)) {
return ssl_hs_error;
@@ -584,7 +584,7 @@
static enum ssl_hs_wait_t do_process_client_finished(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- if (!tls13_check_message_type(ssl, SSL3_MT_FINISHED) ||
+ if (!ssl_check_message_type(ssl, SSL3_MT_FINISHED) ||
!tls13_process_finished(hs) ||
!ssl_hash_current_message(ssl) ||
/* evp_aead_seal keys have already been switched. */