Move optional message type checks out of ssl_get_message.
This aligns the TLS 1.2 state machine closer with the TLS 1.3 state
machine. This is more work for the handshake, but ultimately the
plan is to take the ssl_get_message call out of the handshake (so it is
just the state machine rather than calling into BIO), so the parameters
need to be folded out as in TLS 1.3.
The WrongMessageType-* family of tests should make sure we don't miss
one of these.
BUG=128
Change-Id: I17a1e6177c52a7540b2bc6b0b3f926ab386c4950
Reviewed-on: https://boringssl-review.googlesource.com/13264
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/handshake_server.c b/ssl/handshake_server.c
index dc39c93..8c85799 100644
--- a/ssl/handshake_server.c
+++ b/ssl/handshake_server.c
@@ -815,12 +815,15 @@
if (hs->state == SSL3_ST_SR_CLNT_HELLO_A) {
/* The first time around, read the ClientHello. */
- int msg_ret = ssl->method->ssl_get_message(ssl, SSL3_MT_CLIENT_HELLO,
- ssl_hash_message);
+ int msg_ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (msg_ret <= 0) {
return msg_ret;
}
+ if (!ssl_check_message_type(ssl, SSL3_MT_CLIENT_HELLO)) {
+ return -1;
+ }
+
hs->state = SSL3_ST_SR_CLNT_HELLO_B;
}
@@ -1399,7 +1402,7 @@
SSL *const ssl = hs->ssl;
assert(hs->cert_request);
- int msg_ret = ssl->method->ssl_get_message(ssl, -1, ssl_hash_message);
+ int msg_ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (msg_ret <= 0) {
return msg_ret;
}
@@ -1503,11 +1506,14 @@
uint8_t psk[PSK_MAX_PSK_LEN];
if (hs->state == SSL3_ST_SR_KEY_EXCH_A) {
- int ret = ssl->method->ssl_get_message(ssl, SSL3_MT_CLIENT_KEY_EXCHANGE,
- ssl_hash_message);
+ int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (ret <= 0) {
return ret;
}
+
+ if (!ssl_check_message_type(ssl, SSL3_MT_CLIENT_KEY_EXCHANGE)) {
+ return -1;
+ }
}
CBS_init(&client_key_exchange, ssl->init_msg, ssl->init_num);
@@ -1771,12 +1777,15 @@
return 1;
}
- int msg_ret = ssl->method->ssl_get_message(ssl, SSL3_MT_CERTIFICATE_VERIFY,
- ssl_dont_hash_message);
+ int msg_ret = ssl->method->ssl_get_message(ssl, ssl_dont_hash_message);
if (msg_ret <= 0) {
return msg_ret;
}
+ if (!ssl_check_message_type(ssl, SSL3_MT_CERTIFICATE_VERIFY)) {
+ return -1;
+ }
+
CBS_init(&certificate_verify, ssl->init_msg, ssl->init_num);
/* Determine the digest type if needbe. */
@@ -1865,11 +1874,15 @@
static int ssl3_get_next_proto(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
int ret =
- ssl->method->ssl_get_message(ssl, SSL3_MT_NEXT_PROTO, ssl_hash_message);
+ ssl->method->ssl_get_message(ssl, ssl_hash_message);
if (ret <= 0) {
return ret;
}
+ if (!ssl_check_message_type(ssl, SSL3_MT_NEXT_PROTO)) {
+ return -1;
+ }
+
CBS next_protocol, selected_protocol, padding;
CBS_init(&next_protocol, ssl->init_msg, ssl->init_num);
if (!CBS_get_u8_length_prefixed(&next_protocol, &selected_protocol) ||
@@ -1891,13 +1904,13 @@
/* ssl3_get_channel_id reads and verifies a ClientID handshake message. */
static int ssl3_get_channel_id(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- int msg_ret = ssl->method->ssl_get_message(ssl, SSL3_MT_CHANNEL_ID,
- ssl_dont_hash_message);
+ int msg_ret = ssl->method->ssl_get_message(ssl, ssl_dont_hash_message);
if (msg_ret <= 0) {
return msg_ret;
}
- if (!tls1_verify_channel_id(ssl) ||
+ if (!ssl_check_message_type(ssl, SSL3_MT_CHANNEL_ID) ||
+ !tls1_verify_channel_id(ssl) ||
!ssl_hash_current_message(ssl)) {
return -1;
}