Move optional message type checks out of ssl_get_message.

This aligns the TLS 1.2 state machine closer with the TLS 1.3 state
machine. This is more work for the handshake, but ultimately the
plan is to take the ssl_get_message call out of the handshake (so it is
just the state machine rather than calling into BIO), so the parameters
need to be folded out as in TLS 1.3.

The WrongMessageType-* family of tests should make sure we don't miss
one of these.

BUG=128

Change-Id: I17a1e6177c52a7540b2bc6b0b3f926ab386c4950
Reviewed-on: https://boringssl-review.googlesource.com/13264
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/handshake_client.c b/ssl/handshake_client.c
index d53cb01..acd4370 100644
--- a/ssl/handshake_client.c
+++ b/ssl/handshake_client.c
@@ -808,7 +808,7 @@
   CBS hello_verify_request, cookie;
   uint16_t server_version;
 
-  int ret = ssl->method->ssl_get_message(ssl, -1, ssl_hash_message);
+  int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
   if (ret <= 0) {
     return ret;
   }
@@ -852,7 +852,7 @@
   uint16_t server_wire_version, cipher_suite;
   uint8_t compression_method;
 
-  int ret = ssl->method->ssl_get_message(ssl, -1, ssl_hash_message);
+  int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
   if (ret <= 0) {
     uint32_t err = ERR_peek_error();
     if (ERR_GET_LIB(err) == ERR_LIB_SSL &&
@@ -914,9 +914,7 @@
 
   ssl_clear_tls13_state(hs);
 
-  if (ssl->s3->tmp.message_type != SSL3_MT_SERVER_HELLO) {
-    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
-    OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
+  if (!ssl_check_message_type(ssl, SSL3_MT_SERVER_HELLO)) {
     return -1;
   }
 
@@ -1053,12 +1051,15 @@
 
 static int ssl3_get_server_certificate(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  int ret =
-      ssl->method->ssl_get_message(ssl, SSL3_MT_CERTIFICATE, ssl_hash_message);
+  int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
   if (ret <= 0) {
     return ret;
   }
 
+  if (!ssl_check_message_type(ssl, SSL3_MT_CERTIFICATE)) {
+    return -1;
+  }
+
   CBS cbs;
   CBS_init(&cbs, ssl->init_msg, ssl->init_num);
 
@@ -1097,7 +1098,7 @@
   CBS certificate_status, ocsp_response;
   uint8_t status_type;
 
-  int ret = ssl->method->ssl_get_message(ssl, -1, ssl_hash_message);
+  int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
   if (ret <= 0) {
     return ret;
   }
@@ -1150,7 +1151,7 @@
   EC_KEY *ecdh = NULL;
   EC_POINT *srvr_ecpoint = NULL;
 
-  int ret = ssl->method->ssl_get_message(ssl, -1, ssl_hash_message);
+  int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
   if (ret <= 0) {
     return ret;
   }
@@ -1380,7 +1381,7 @@
 
 static int ssl3_get_certificate_request(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  int msg_ret = ssl->method->ssl_get_message(ssl, -1, ssl_hash_message);
+  int msg_ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
   if (msg_ret <= 0) {
     return msg_ret;
   }
@@ -1393,9 +1394,7 @@
     return 1;
   }
 
-  if (ssl->s3->tmp.message_type != SSL3_MT_CERTIFICATE_REQUEST) {
-    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
-    OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
+  if (!ssl_check_message_type(ssl, SSL3_MT_CERTIFICATE_REQUEST)) {
     return -1;
   }
 
@@ -1448,12 +1447,15 @@
 
 static int ssl3_get_server_hello_done(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  int ret = ssl->method->ssl_get_message(ssl, SSL3_MT_SERVER_HELLO_DONE,
-                                         ssl_hash_message);
+  int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
   if (ret <= 0) {
     return ret;
   }
 
+  if (!ssl_check_message_type(ssl, SSL3_MT_SERVER_HELLO_DONE)) {
+    return -1;
+  }
+
   /* ServerHelloDone is empty. */
   if (ssl->init_num > 0) {
     ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
@@ -1829,12 +1831,15 @@
 
 static int ssl3_get_new_session_ticket(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  int ret = ssl->method->ssl_get_message(ssl, SSL3_MT_NEW_SESSION_TICKET,
-                                         ssl_hash_message);
+  int ret = ssl->method->ssl_get_message(ssl, ssl_hash_message);
   if (ret <= 0) {
     return ret;
   }
 
+  if (!ssl_check_message_type(ssl, SSL3_MT_NEW_SESSION_TICKET)) {
+    return -1;
+  }
+
   CBS new_session_ticket, ticket;
   uint32_t tlsext_tick_lifetime_hint;
   CBS_init(&new_session_ticket, ssl->init_msg, ssl->init_num);