Port early callback support to CBS.

Resolve one of the TODOs since it's quick. Adjust the
-expect-server-name test to assert it both in the normal codepath and
in the early callback, to provide test coverage for
SSL_early_callback_ctx_extension_get.

Change-Id: I4d71158b9fd2f4fbb54d3e51184bd25d117bdc91
Reviewed-on: https://boringssl-review.googlesource.com/1120
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/t1_lib.c b/ssl/t1_lib.c
index 86914dc..2fa4101 100644
--- a/ssl/t1_lib.c
+++ b/ssl/t1_lib.c
@@ -296,102 +296,64 @@
 
 char ssl_early_callback_init(struct ssl_early_callback_ctx *ctx)
 	{
-	size_t len = ctx->client_hello_len;
-	const unsigned char *p = ctx->client_hello;
-	CBS extensions;
+	CBS client_hello, session_id, cipher_suites, compression_methods, extensions;
+
+	CBS_init(&client_hello, ctx->client_hello, ctx->client_hello_len);
 
 	/* Skip client version. */
-	if (len < 2)
+	if (!CBS_skip(&client_hello, 2))
 		return 0;
-	len -= 2; p += 2;
 
 	/* Skip client nonce. */
-	if (len < 32)
+	if (!CBS_skip(&client_hello, 32))
 		return 0;
-	len -= 32; p += 32;
 
-	/* Get length of session id. */
-	if (len < 1)
+	/* Extract session_id. */
+	if (!CBS_get_u8_length_prefixed(&client_hello, &session_id))
 		return 0;
-	ctx->session_id_len = *p;
-	p++; len--;
-
-	ctx->session_id = p;
-	if (len < ctx->session_id_len)
-		return 0;
-	p += ctx->session_id_len; len -= ctx->session_id_len;
+	ctx->session_id = CBS_data(&session_id);
+	ctx->session_id_len = CBS_len(&session_id);
 
 	/* Skip past DTLS cookie */
 	if (ctx->ssl->version == DTLS1_VERSION || ctx->ssl->version == DTLS1_BAD_VER)
 		{
-		unsigned cookie_len;
+		CBS cookie;
 
-		if (len < 1)
+		if (!CBS_get_u8_length_prefixed(&client_hello, &cookie))
 			return 0;
-		cookie_len = *p;
-		p++; len--;
-		if (len < cookie_len)
-			return 0;
-		p += cookie_len; len -= cookie_len;
 		}
 
-	/* Skip cipher suites. */
-	if (len < 2)
+	/* Extract cipher_suites. */
+	if (!CBS_get_u16_length_prefixed(&client_hello, &cipher_suites) ||
+		CBS_len(&cipher_suites) < 2 ||
+		(CBS_len(&cipher_suites) & 1) != 0)
 		return 0;
-	n2s(p, ctx->cipher_suites_len);
-	len -= 2;
+	ctx->cipher_suites = CBS_data(&cipher_suites);
+	ctx->cipher_suites_len = CBS_len(&cipher_suites);
 
-	if ((ctx->cipher_suites_len & 1) != 0)
+	/* Extract compression_methods. */
+	if (!CBS_get_u8_length_prefixed(&client_hello, &compression_methods) ||
+		CBS_len(&compression_methods) < 1)
 		return 0;
-
-	ctx->cipher_suites = p;
-	if (len < ctx->cipher_suites_len)
-		return 0;
-	p += ctx->cipher_suites_len; len -= ctx->cipher_suites_len;
-
-	/* Skip compression methods. */
-	if (len < 1)
-		return 0;
-	ctx->compression_methods_len = *p;
-	p++; len--;
-
-	ctx->compression_methods = p;
-	if (len < ctx->compression_methods_len)
-		return 0;
-	p += ctx->compression_methods_len; len -= ctx->compression_methods_len;
+	ctx->compression_methods = CBS_data(&compression_methods);
+	ctx->compression_methods_len = CBS_len(&compression_methods);
 
 	/* If the ClientHello ends here then it's valid, but doesn't have any
 	 * extensions. (E.g. SSLv3.) */
-	if (len == 0)
+	if (CBS_len(&client_hello) == 0)
 		{
 		ctx->extensions = NULL;
 		ctx->extensions_len = 0;
 		return 1;
 		}
 
-	if (len < 2)
+	/* Extract extensions and check it is valid. */
+	if (!CBS_get_u16_length_prefixed(&client_hello, &extensions) ||
+		!tls1_check_duplicate_extensions(&extensions) ||
+		CBS_len(&client_hello) != 0)
 		return 0;
-	n2s(p, ctx->extensions_len);
-	len -= 2;
-
-	if (ctx->extensions_len == 0 && len == 0)
-		{
-		ctx->extensions = NULL;
-		return 1;
-		}
-
-	ctx->extensions = p;
-	if (len != ctx->extensions_len)
-		return 0;
-
-	/* Verify that the extensions have valid lengths and that there are
-	 * no duplicates.
-	 *
-	 * TODO(fork): Port the rest of this processing to CBS.
-	 */
-	CBS_init(&extensions, ctx->extensions, ctx->extensions_len);
-	if (!tls1_check_duplicate_extensions(&extensions))
-		return 0;
+	ctx->extensions = CBS_data(&extensions);
+	ctx->extensions_len = CBS_len(&extensions);
 
 	return 1;
 	}
@@ -402,29 +364,26 @@
 				     const unsigned char **out_data,
 				     size_t *out_len)
 	{
-	size_t len = ctx->extensions_len;
-	const unsigned char *p = ctx->extensions;
+	CBS extensions;
 
-	while (len != 0)
+	CBS_init(&extensions, ctx->extensions, ctx->extensions_len);
+
+	while (CBS_len(&extensions) != 0)
 		{
-		uint16_t ext_type, ext_len;
+		uint16_t type;
+		CBS extension;
 
-		if (len < 4)
+		/* Decode the next extension. */
+		if (!CBS_get_u16(&extensions, &type) ||
+			!CBS_get_u16_length_prefixed(&extensions, &extension))
 			return 0;
-		n2s(p, ext_type);
-		n2s(p, ext_len);
-		len -= 4;
 
-		if (len < ext_len)
-			return 0;
-		if (ext_type == extension_type)
+		if (type == extension_type)
 			{
-			*out_data = p;
-			*out_len = ext_len;
+			*out_data = CBS_data(&extension);
+			*out_len = CBS_len(&extension);
 			return 1;
 			}
-
-		p += ext_len; len -= ext_len;
 		}
 
 	return 0;