Simplify ssl3_get_message.
Rather than this confusing coordination with the handshake state machine and
init_num changing meaning partway through, use the length field already in
BUF_MEM. Like the new record layer parsing, is no need to keep track of whether
we are reading the header or the body. Simply keep extending the handshake
message until it's far enough along.
ssl3_get_message still needs tons of work, but this allows us to disentangle it
from the handshake state.
Change-Id: Ic2b3e7cfe6152a7e28a04980317d3c7c396d9b08
Reviewed-on: https://boringssl-review.googlesource.com/7948
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/s3_both.c b/ssl/s3_both.c
index 5235b97..48519cb 100644
--- a/ssl/s3_both.c
+++ b/ssl/s3_both.c
@@ -310,124 +310,113 @@
return kMaxMessageLen;
}
+static int extend_handshake_buffer(SSL *ssl, size_t length) {
+ if (!BUF_MEM_reserve(ssl->init_buf, length)) {
+ return -1;
+ }
+ while (ssl->init_buf->length < length) {
+ int ret =
+ ssl3_read_bytes(ssl, SSL3_RT_HANDSHAKE,
+ (uint8_t *)ssl->init_buf->data + ssl->init_buf->length,
+ length - ssl->init_buf->length, 0);
+ if (ret <= 0) {
+ return ret;
+ }
+ ssl->init_buf->length += (size_t)ret;
+ }
+ return 1;
+}
+
/* Obtain handshake message of message type |msg_type| (any if |msg_type| ==
* -1). The first four bytes (msg_type and length) are read in state
* |header_state|, the body is read in state |body_state|. */
long ssl3_get_message(SSL *ssl, int header_state, int body_state, int msg_type,
enum ssl_hash_message_t hash_message, int *ok) {
- uint8_t *p;
- unsigned long l;
- long n;
- int al;
+ *ok = 0;
if (ssl->s3->tmp.reuse_message) {
/* A ssl_dont_hash_message call cannot be combined with reuse_message; the
* ssl_dont_hash_message would have to have been applied to the previous
* call. */
assert(hash_message == ssl_hash_message);
+ assert(ssl->s3->tmp.message_complete);
ssl->s3->tmp.reuse_message = 0;
if (msg_type >= 0 && ssl->s3->tmp.message_type != msg_type) {
- al = SSL_AD_UNEXPECTED_MESSAGE;
+ ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
- goto f_err;
+ return -1;
}
*ok = 1;
ssl->state = body_state;
+ assert(ssl->init_buf->length >= 4);
ssl->init_msg = (uint8_t *)ssl->init_buf->data + 4;
- ssl->init_num = (int)ssl->s3->tmp.message_size;
+ ssl->init_num = (int)ssl->init_buf->length - 4;
return ssl->init_num;
}
- p = (uint8_t *)ssl->init_buf->data;
-
- if (ssl->state == header_state) {
- assert(ssl->init_num < 4);
-
- for (;;) {
- while (ssl->init_num < 4) {
- int bytes_read = ssl3_read_bytes(
- ssl, SSL3_RT_HANDSHAKE, &p[ssl->init_num], 4 - ssl->init_num, 0);
- if (bytes_read <= 0) {
- *ok = 0;
- return bytes_read;
- }
- ssl->init_num += bytes_read;
- }
-
- static const uint8_t kHelloRequest[4] = {SSL3_MT_HELLO_REQUEST, 0, 0, 0};
- if (ssl->server || memcmp(p, kHelloRequest, sizeof(kHelloRequest)) != 0) {
- break;
- }
-
- /* The server may always send 'Hello Request' messages -- we are doing
- * a handshake anyway now, so ignore them if their format is correct.
- * Does not count for 'Finished' MAC. */
- ssl->init_num = 0;
-
- if (ssl->msg_callback) {
- ssl->msg_callback(0, ssl->version, SSL3_RT_HANDSHAKE, p, 4, ssl,
- ssl->msg_callback_arg);
- }
- }
-
- /* ssl->init_num == 4 */
-
- if (msg_type >= 0 && *p != msg_type) {
- al = SSL_AD_UNEXPECTED_MESSAGE;
- OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
- goto f_err;
- }
- ssl->s3->tmp.message_type = *(p++);
-
- n2l3(p, l);
- if (l > ssl_max_handshake_message_len(ssl)) {
- al = SSL_AD_ILLEGAL_PARAMETER;
- OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
- goto f_err;
- }
-
- if (l && !BUF_MEM_grow_clean(ssl->init_buf, l + 4)) {
- OPENSSL_PUT_ERROR(SSL, ERR_R_BUF_LIB);
- goto err;
- }
- ssl->s3->tmp.message_size = l;
- ssl->state = body_state;
-
- ssl->init_msg = (uint8_t *)ssl->init_buf->data + 4;
- ssl->init_num = 0;
+again:
+ if (ssl->s3->tmp.message_complete) {
+ ssl->s3->tmp.message_complete = 0;
+ ssl->init_buf->length = 0;
}
- /* next state (body_state) */
- p = ssl->init_msg;
- n = ssl->s3->tmp.message_size - ssl->init_num;
- while (n > 0) {
- int bytes_read =
- ssl3_read_bytes(ssl, SSL3_RT_HANDSHAKE, &p[ssl->init_num], n, 0);
- if (bytes_read <= 0) {
- *ok = 0;
- return bytes_read;
- }
- ssl->init_num += bytes_read;
- n -= bytes_read;
+ /* Read the message header, if we haven't yet. */
+ int ret = extend_handshake_buffer(ssl, 4);
+ if (ret <= 0) {
+ return ret;
}
+ /* Parse out the length. Cap it so the peer cannot force us to buffer up to
+ * 2^24 bytes. */
+ const uint8_t *p = (uint8_t *)ssl->init_buf->data;
+ size_t msg_len = (((uint32_t)p[1]) << 16) | (((uint32_t)p[2]) << 8) | p[3];
+ if (msg_len > ssl_max_handshake_message_len(ssl)) {
+ ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
+ OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
+ return -1;
+ }
+
+ /* Read the message body, if we haven't yet. */
+ ret = extend_handshake_buffer(ssl, 4 + msg_len);
+ if (ret <= 0) {
+ return ret;
+ }
+
+ /* We have now received a complete message. */
+ ssl->s3->tmp.message_complete = 1;
+ if (ssl->msg_callback) {
+ ssl->msg_callback(0, ssl->version, SSL3_RT_HANDSHAKE, ssl->init_buf->data,
+ ssl->init_buf->length, ssl, ssl->msg_callback_arg);
+ }
+
+ static const uint8_t kHelloRequest[4] = {SSL3_MT_HELLO_REQUEST, 0, 0, 0};
+ if (!ssl->server && ssl->init_buf->length == sizeof(kHelloRequest) &&
+ memcmp(kHelloRequest, ssl->init_buf->data, sizeof(kHelloRequest)) == 0) {
+ /* The server may always send 'Hello Request' messages -- we are doing a
+ * handshake anyway now, so ignore them if their format is correct. Does
+ * not count for 'Finished' MAC. */
+ goto again;
+ }
+
+ uint8_t actual_type = ((const uint8_t *)ssl->init_buf->data)[0];
+ if (msg_type >= 0 && actual_type != msg_type) {
+ ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
+ OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
+ return -1;
+ }
+ ssl->s3->tmp.message_type = actual_type;
+ ssl->state = body_state;
+
+ ssl->init_msg = (uint8_t*)ssl->init_buf->data + 4;
+ ssl->init_num = ssl->init_buf->length - 4;
+
/* Feed this message into MAC computation. */
if (hash_message == ssl_hash_message && !ssl3_hash_current_message(ssl)) {
- goto err;
+ return -1;
}
- if (ssl->msg_callback) {
- ssl->msg_callback(0, ssl->version, SSL3_RT_HANDSHAKE, ssl->init_buf->data,
- (size_t)ssl->init_num + 4, ssl, ssl->msg_callback_arg);
- }
+
*ok = 1;
return ssl->init_num;
-
-f_err:
- ssl3_send_alert(ssl, SSL3_AL_FATAL, al);
-
-err:
- *ok = 0;
- return -1;
}
int ssl3_hash_current_message(SSL *ssl) {