Touch up ssl3_get_message.
The |skip_message| variable was overly complex and, since we have at
least 32-bit ints, we know that a 24-bit value doesn't overflow an int.
Change-Id: I5c16fa979e1716f39cc47882c033bcf5bce3284c
Reviewed-on: https://boringssl-review.googlesource.com/2610
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/s3_both.c b/ssl/s3_both.c
index 215082d..a34d221 100644
--- a/ssl/s3_both.c
+++ b/ssl/s3_both.c
@@ -312,73 +312,70 @@
return l + SSL_HM_HEADER_LENGTH(s);
}
-/* Obtain handshake message of message type 'mt' (any if mt == -1),
- * maximum acceptable body length 'max'.
- * The first four bytes (msg_type and length) are read in state 'st1',
- * the body is read in state 'stn'. */
-long ssl3_get_message(SSL *s, int st1, int stn, int mt, long max,
- int hash_message, int *ok) {
+/* Obtain handshake message of message type |msg_type| (any if |msg_type| == -1),
+ * maximum acceptable body length |max|. The first four bytes (msg_type and
+ * length) are read in state |header_state|, the body is read in state |body_state|. */
+long ssl3_get_message(SSL *s, int header_state, int body_state, int msg_type,
+ long max, int hash_message, int *ok) {
uint8_t *p;
unsigned long l;
long n;
- int i, al;
+ int al;
if (s->s3->tmp.reuse_message) {
- /* A SSL_GET_MESSAGE_DONT_HASH_MESSAGE call cannot be combined
- * with reuse_message; the SSL_GET_MESSAGE_DONT_HASH_MESSAGE
- * would have to have been applied to the previous call. */
+ /* A SSL_GET_MESSAGE_DONT_HASH_MESSAGE call cannot be combined with
+ * reuse_message; the SSL_GET_MESSAGE_DONT_HASH_MESSAGE would have to have
+ * been applied to the previous call. */
assert(hash_message != SSL_GET_MESSAGE_DONT_HASH_MESSAGE);
s->s3->tmp.reuse_message = 0;
- if ((mt >= 0) && (s->s3->tmp.message_type != mt)) {
+ if (msg_type >= 0 && s->s3->tmp.message_type != msg_type) {
al = SSL_AD_UNEXPECTED_MESSAGE;
OPENSSL_PUT_ERROR(SSL, ssl3_get_message, SSL_R_UNEXPECTED_MESSAGE);
goto f_err;
}
*ok = 1;
- s->state = stn;
+ s->state = body_state;
s->init_msg = (uint8_t *)s->init_buf->data + 4;
s->init_num = (int)s->s3->tmp.message_size;
return s->init_num;
}
- p = (unsigned char *)s->init_buf->data;
+ p = (uint8_t *)s->init_buf->data;
- if (s->state == st1) /* s->init_num < 4 */
- {
- int skip_message;
+ if (s->state == header_state) {
+ assert(s->init_num < 4);
- do {
+ for (;;) {
while (s->init_num < 4) {
- i = s->method->ssl_read_bytes(s, SSL3_RT_HANDSHAKE, &p[s->init_num],
- 4 - s->init_num, 0);
- if (i <= 0) {
+ int bytes_read = s->method->ssl_read_bytes(
+ s, SSL3_RT_HANDSHAKE, &p[s->init_num], 4 - s->init_num, 0);
+ if (bytes_read <= 0) {
s->rwstate = SSL_READING;
*ok = 0;
- return i;
+ return bytes_read;
}
- s->init_num += i;
+ s->init_num += bytes_read;
}
- skip_message = 0;
- if (!s->server)
- if (p[0] == SSL3_MT_HELLO_REQUEST)
- /* The server may always send 'Hello Request' messages --
- * we are doing a handshake anyway now, so ignore them
- * if their format is correct. Does not count for
- * 'Finished' MAC. */
- if (p[1] == 0 && p[2] == 0 && p[3] == 0) {
- s->init_num = 0;
- skip_message = 1;
+ static const uint8_t kHelloRequest[4] = {SSL3_MT_HELLO_REQUEST, 0, 0, 0};
+ if (s->server || memcmp(p, kHelloRequest, sizeof(kHelloRequest)) != 0) {
+ break;
+ }
- if (s->msg_callback)
- s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE, p, 4, s,
- s->msg_callback_arg);
- }
- } while (skip_message);
+ /* The server may always send 'Hello Request' messages -- we are doing
+ * a handshake anyway now, so ignore them if their format is correct.
+ * Does not count for 'Finished' MAC. */
+ s->init_num = 0;
+
+ if (s->msg_callback) {
+ s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE, p, 4, s,
+ s->msg_callback_arg);
+ }
+ }
/* s->init_num == 4 */
- if ((mt >= 0) && (*p != mt)) {
+ if (msg_type >= 0 && *p != msg_type) {
al = SSL_AD_UNEXPECTED_MESSAGE;
OPENSSL_PUT_ERROR(SSL, ssl3_get_message, SSL_R_UNEXPECTED_MESSAGE);
goto f_err;
@@ -391,43 +388,41 @@
OPENSSL_PUT_ERROR(SSL, ssl3_get_message, SSL_R_EXCESSIVE_MESSAGE_SIZE);
goto f_err;
}
- if (l > (INT_MAX - 4)) /* BUF_MEM_grow takes an 'int' parameter */
- {
- al = SSL_AD_ILLEGAL_PARAMETER;
- OPENSSL_PUT_ERROR(SSL, ssl3_get_message, SSL_R_EXCESSIVE_MESSAGE_SIZE);
- goto f_err;
- }
- if (l && !BUF_MEM_grow_clean(s->init_buf, (int)l + 4)) {
+
+ if (l && !BUF_MEM_grow_clean(s->init_buf, l + 4)) {
OPENSSL_PUT_ERROR(SSL, ssl3_get_message, ERR_R_BUF_LIB);
goto err;
}
s->s3->tmp.message_size = l;
- s->state = stn;
+ s->state = body_state;
s->init_msg = (uint8_t *)s->init_buf->data + 4;
s->init_num = 0;
}
- /* next state (stn) */
+ /* next state (body_state) */
p = s->init_msg;
n = s->s3->tmp.message_size - s->init_num;
while (n > 0) {
- i = s->method->ssl_read_bytes(s, SSL3_RT_HANDSHAKE, &p[s->init_num], n, 0);
- if (i <= 0) {
+ int bytes_read =
+ s->method->ssl_read_bytes(s, SSL3_RT_HANDSHAKE, &p[s->init_num], n, 0);
+ if (bytes_read <= 0) {
s->rwstate = SSL_READING;
*ok = 0;
- return i;
+ return bytes_read;
}
- s->init_num += i;
- n -= i;
+ s->init_num += bytes_read;
+ n -= bytes_read;
}
/* Feed this message into MAC computation. */
- if (hash_message != SSL_GET_MESSAGE_DONT_HASH_MESSAGE)
+ if (hash_message != SSL_GET_MESSAGE_DONT_HASH_MESSAGE) {
ssl3_hash_current_message(s);
- if (s->msg_callback)
+ }
+ if (s->msg_callback) {
s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE, s->init_buf->data,
(size_t)s->init_num + 4, s, s->msg_callback_arg);
+ }
*ok = 1;
return s->init_num;
diff --git a/ssl/ssl_locl.h b/ssl/ssl_locl.h
index 26e3edc..cf89a49 100644
--- a/ssl/ssl_locl.h
+++ b/ssl/ssl_locl.h
@@ -584,8 +584,9 @@
int (*ssl_shutdown)(SSL *s);
int (*ssl_renegotiate)(SSL *s);
int (*ssl_renegotiate_check)(SSL *s);
- long (*ssl_get_message)(SSL *s, int st1, int stn, int mt, long
- max, int hash_message, int *ok);
+ long (*ssl_get_message)(SSL *s, int header_state, int body_state,
+ int msg_type, long max, int hash_message,
+ int *ok);
int (*ssl_read_bytes)(SSL *s, int type, unsigned char *buf, int len,
int peek);
int (*ssl_write_bytes)(SSL *s, int type, const void *buf_, int len);
@@ -758,7 +759,8 @@
int ssl3_generate_master_secret(SSL *s, unsigned char *out,
unsigned char *p, int len);
int ssl3_get_req_cert_type(SSL *s,unsigned char *p);
-long ssl3_get_message(SSL *s, int st1, int stn, int mt, long max, int hash_message, int *ok);
+long ssl3_get_message(SSL *s, int header_state, int body_state, int msg_type,
+ long max, int hash_message, int *ok);
/* ssl3_hash_current_message incorporates the current handshake message into
* the handshake hash. */