Port ssl3_get_key_exchange to CBS.

Also tidy up some variable names and update RSA_verify call for it no longer
returning -1. Add CBS helper functions for dealing with C strings.

Change-Id: Ibc398d27714744f5d99d4f94ae38210cbc89471a
Reviewed-on: https://boringssl-review.googlesource.com/1164
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/s3_clnt.c b/ssl/s3_clnt.c
index a998a1e..3165805 100644
--- a/ssl/s3_clnt.c
+++ b/ssl/s3_clnt.c
@@ -1255,10 +1255,8 @@
 
 int ssl3_get_key_exchange(SSL *s)
 	{
-	unsigned char *q,md_buf[EVP_MAX_MD_SIZE*2];
 	EVP_MD_CTX md_ctx;
-	unsigned char *param,*p;
-	int al,i,j,param_len,ok;
+	int al,ok;
 	long n,alg_k,alg_a;
 	EVP_PKEY *pkey=NULL;
 	const EVP_MD *md = NULL;
@@ -1270,9 +1268,8 @@
 	EC_KEY *ecdh = NULL;
 	BN_CTX *bn_ctx = NULL;
 	EC_POINT *srvr_ecpoint = NULL;
-	int curve_nid = 0;
-	int encoded_pt_len = 0;
 #endif
+	CBS server_key_exchange, server_key_exchange_orig, parameter;
 
 	/* use same message size as in ssl3_get_certificate_request()
 	 * as ServerKeyExchange message may be skipped */
@@ -1316,7 +1313,11 @@
 		return(1);
 		}
 
-	param = p = s->init_msg;
+	/* Retain a copy of the original CBS to compute the signature
+	 * over. */
+	CBS_init(&server_key_exchange, s->init_msg, n);
+	server_key_exchange_orig = server_key_exchange;
+
 	if (s->session->sess_cert != NULL)
 		{
 		if (s->session->sess_cert->peer_rsa_tmp != NULL)
@@ -1344,7 +1345,6 @@
 		s->session->sess_cert=ssl_sess_cert_new();
 		}
 
-	param_len=0;
 	alg_k=s->s3->tmp.new_cipher->algorithm_mkey;
 	alg_a=s->s3->tmp.new_cipher->algorithm_auth;
 	EVP_MD_CTX_init(&md_ctx);
@@ -1352,90 +1352,80 @@
 #ifndef OPENSSL_NO_PSK
 	if (alg_a & SSL_aPSK)
 		{
-		char tmp_id_hint[PSK_MAX_IDENTITY_LEN+1];
+		CBS psk_identity_hint;
 
-		al=SSL_AD_HANDSHAKE_FAILURE;
-		n2s(p,i);
-		param_len=i+2;
-		if (s->session->psk_identity_hint)
+		/* Each of the PSK key exchanges begins with a
+		 * psk_identity_hint. */
+		if (!CBS_get_u16_length_prefixed(&server_key_exchange, &psk_identity_hint))
 			{
-			OPENSSL_free(s->session->psk_identity_hint);
-			s->session->psk_identity_hint = NULL;
-			}
-		if (i != 0)
-			{
-			/* Store PSK identity hint for later use, hint is used
-			 * in ssl3_send_client_key_exchange.  Assume that the
-			 * maximum length of a PSK identity hint can be as
-			 * long as the maximum length of a PSK identity. */
-			if (i > PSK_MAX_IDENTITY_LEN)
-				{
-				OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_DATA_LENGTH_TOO_LONG);
-				goto f_err;
-				}
-			if (param_len > n)
-				{
-				al=SSL_AD_DECODE_ERROR;
-				OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_BAD_PSK_IDENTITY_HINT_LENGTH);
-				goto f_err;
-				}
-			/* If received PSK identity hint contains NULL
-			 * characters, the hint is truncated from the first
-			 * NULL. p may not be ending with NULL, so create a
-			 * NULL-terminated string. */
-			memcpy(tmp_id_hint, p, i);
-			memset(tmp_id_hint+i, 0, PSK_MAX_IDENTITY_LEN+1-i);
-			s->session->psk_identity_hint = BUF_strdup(tmp_id_hint);
-			if (s->session->psk_identity_hint == NULL)
-				{
-				OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, ERR_R_MALLOC_FAILURE);
-				goto f_err;
-				}
+			al = SSL_AD_DECODE_ERROR;
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_DECODE_ERROR);
+			goto f_err;
 			}
 
-		p+=i;
-		n-=param_len;
+		/* Store PSK identity hint for later use, hint is used in
+		 * ssl3_send_client_key_exchange.  Assume that the maximum
+		 * length of a PSK identity hint can be as long as the maximum
+		 * length of a PSK identity. Also do not allow NULL
+		 * characters; identities are saved as C strings.
+		 *
+		 * TODO(davidben): Should invalid hints be ignored? It's a hint
+		 * rather than a specific identity. */
+		if (CBS_len(&psk_identity_hint) > PSK_MAX_IDENTITY_LEN ||
+			CBS_contains_zero_byte(&psk_identity_hint))
+			{
+			al = SSL_AD_HANDSHAKE_FAILURE;
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_DATA_LENGTH_TOO_LONG);
+			goto f_err;
+			}
+
+		/* Save the identity hint as a C string. */
+		if (!CBS_strdup(&psk_identity_hint, &s->session->psk_identity_hint))
+			{
+			al = SSL_AD_HANDSHAKE_FAILURE;
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, ERR_R_MALLOC_FAILURE);
+			goto f_err;
+			}
 		}
 #endif /* !OPENSSL_NO_PSK */
 
 	if (0) {}
 	else if (alg_k & SSL_kRSA)
 		{
+		CBS rsa_modulus, rsa_exponent;
+
+		/* TODO(davidben): This was originally for export
+		 * reasons. Do we still need to support it? */
+
+		if (!CBS_get_u16_length_prefixed(&server_key_exchange, &rsa_modulus) ||
+			CBS_len(&rsa_modulus) == 0 ||
+			!CBS_get_u16_length_prefixed(&server_key_exchange, &rsa_exponent) ||
+			CBS_len(&rsa_exponent) == 0)
+			{
+			al = SSL_AD_DECODE_ERROR;
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_DECODE_ERROR);
+			goto f_err;
+			}
+
 		if ((rsa=RSA_new()) == NULL)
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, ERR_R_MALLOC_FAILURE);
 			goto err;
 			}
-		n2s(p,i);
-		param_len=i+2;
-		if (param_len > n)
-			{
-			al=SSL_AD_DECODE_ERROR;
-			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_BAD_RSA_MODULUS_LENGTH);
-			goto f_err;
-			}
-		if (!(rsa->n=BN_bin2bn(p,i,rsa->n)))
-			{
-			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, ERR_R_BN_LIB);
-			goto err;
-			}
-		p+=i;
 
-		n2s(p,i);
-		param_len+=i+2;
-		if (param_len > n)
-			{
-			al=SSL_AD_DECODE_ERROR;
-			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_BAD_RSA_E_LENGTH);
-			goto f_err;
-			}
-		if (!(rsa->e=BN_bin2bn(p,i,rsa->e)))
+		if (!(rsa->n = BN_bin2bn(CBS_data(&rsa_modulus),
+					CBS_len(&rsa_modulus), rsa->n)))
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, ERR_R_BN_LIB);
 			goto err;
 			}
-		p+=i;
-		n-=param_len;
+
+		if (!(rsa->e = BN_bin2bn(CBS_data(&rsa_exponent),
+					CBS_len(&rsa_exponent), rsa->e)))
+			{
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, ERR_R_BN_LIB);
+			goto err;
+			}
 
 		/* this should be because we are using an export cipher */
 		if (alg_a & SSL_aRSA)
@@ -1451,56 +1441,41 @@
 #ifndef OPENSSL_NO_DH
 	else if (alg_k & SSL_kEDH)
 		{
+		CBS dh_p, dh_g, dh_Ys;
+
+		if (!CBS_get_u16_length_prefixed(&server_key_exchange, &dh_p) ||
+			CBS_len(&dh_p) == 0 ||
+			!CBS_get_u16_length_prefixed(&server_key_exchange, &dh_g) ||
+			CBS_len(&dh_g) == 0 ||
+			!CBS_get_u16_length_prefixed(&server_key_exchange, &dh_Ys) ||
+			CBS_len(&dh_Ys) == 0)
+			{
+			al = SSL_AD_DECODE_ERROR;
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_DECODE_ERROR);
+			goto f_err;
+			}
+
 		if ((dh=DH_new()) == NULL)
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, ERR_R_DH_LIB);
 			goto err;
 			}
-		n2s(p,i);
-		param_len=i+2;
-		if (param_len > n)
-			{
-			al=SSL_AD_DECODE_ERROR;
-			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_BAD_DH_P_LENGTH);
-			goto f_err;
-			}
-		if (!(dh->p=BN_bin2bn(p,i,NULL)))
-			{
-			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, ERR_R_BN_LIB);
-			goto err;
-			}
-		p+=i;
 
-		n2s(p,i);
-		param_len+=i+2;
-		if (param_len > n)
-			{
-			al=SSL_AD_DECODE_ERROR;
-			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_BAD_DH_G_LENGTH);
-			goto f_err;
-			}
-		if (!(dh->g=BN_bin2bn(p,i,NULL)))
+		if (!(dh->p = BN_bin2bn(CBS_data(&dh_p), CBS_len(&dh_p), NULL)))
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, ERR_R_BN_LIB);
 			goto err;
 			}
-		p+=i;
-
-		n2s(p,i);
-		param_len+=i+2;
-		if (param_len > n)
-			{
-			al=SSL_AD_DECODE_ERROR;
-			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_BAD_DH_PUB_KEY_LENGTH);
-			goto f_err;
-			}
-		if (!(dh->pub_key=BN_bin2bn(p,i,NULL)))
+		if (!(dh->g=BN_bin2bn(CBS_data(&dh_g), CBS_len(&dh_g), NULL)))
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, ERR_R_BN_LIB);
 			goto err;
 			}
-		p+=i;
-		n-=param_len;
+		if (!(dh->pub_key = BN_bin2bn(CBS_data(&dh_Ys), CBS_len(&dh_Ys), NULL)))
+			{
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, ERR_R_BN_LIB);
+			goto err;
+			}
 
 		if (alg_a & SSL_aRSA)
 			pkey=X509_get_pubkey(s->session->sess_cert->peer_pkeys[SSL_PKEY_RSA_ENC].x509);
@@ -1524,42 +1499,36 @@
 #ifndef OPENSSL_NO_ECDH
 	else if (alg_k & SSL_kEECDH)
 		{
+		uint16_t curve_id;
+		int curve_nid = 0;
 		EC_GROUP *ngroup;
 		const EC_GROUP *group;
+		CBS point;
 
-		if ((ecdh=EC_KEY_new()) == NULL)
+		/* Extract elliptic curve parameters and the server's
+		 * ephemeral ECDH public key.  Check curve is one of
+		 * our preferences, if not server has sent an invalid
+		 * curve.
+		 */
+		if (!tls1_check_curve(s, &server_key_exchange, &curve_id))
 			{
-			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, ERR_R_MALLOC_FAILURE);
-			goto err;
-			}
-
-		/* Extract elliptic curve parameters and the
-		 * server's ephemeral ECDH public key.
-		 * Keep accumulating lengths of various components in
-		 * param_len and make sure it never exceeds n.
-		 */
-
-		/* XXX: For now we only support named (not generic) curves
-		 * and the ECParameters in this case is just three bytes.
-		 */
-		param_len=3;
-		/* Check curve is one of our prefrences, if not server has
-		 * sent an invalid curve.
-		 */
-		if (!tls1_check_curve(s, p, param_len))
-			{
-			al=SSL_AD_DECODE_ERROR;
+			al = SSL_AD_DECODE_ERROR;
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_WRONG_CURVE);
 			goto f_err;
 			}
 
-		if ((curve_nid = tls1_ec_curve_id2nid(*(p + 2))) == 0) 
+		if ((curve_nid = tls1_ec_curve_id2nid(curve_id)) == 0)
 			{
 			al=SSL_AD_INTERNAL_ERROR;
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_UNABLE_TO_FIND_ECDH_PARAMETERS);
 			goto f_err;
 			}
 
+		if ((ecdh=EC_KEY_new()) == NULL)
+			{
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, ERR_R_MALLOC_FAILURE);
+			goto err;
+			}
 		ngroup = EC_GROUP_new_by_curve_name(curve_nid);
 		if (ngroup == NULL)
 			{
@@ -1583,9 +1552,14 @@
 			goto f_err;
 			}
 
-		p+=3;
-
 		/* Next, get the encoded ECPoint */
+		if (!CBS_get_u8_length_prefixed(&server_key_exchange, &point))
+			{
+			al = SSL_AD_DECODE_ERROR;
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_DECODE_ERROR);
+			goto f_err;
+			}
+
 		if (((srvr_ecpoint = EC_POINT_new(group)) == NULL) ||
 		    ((bn_ctx = BN_CTX_new()) == NULL))
 			{
@@ -1593,21 +1567,14 @@
 			goto err;
 			}
 
-		encoded_pt_len = *p;  /* length of encoded point */
-		p+=1;
-		param_len += (1 + encoded_pt_len);
-		if ((param_len > n) ||
-		    (EC_POINT_oct2point(group, srvr_ecpoint, 
-			p, encoded_pt_len, bn_ctx) == 0))
+		if (!EC_POINT_oct2point(group, srvr_ecpoint,
+				CBS_data(&point), CBS_len(&point), bn_ctx))
 			{
-			al=SSL_AD_DECODE_ERROR;
+			al = SSL_AD_DECODE_ERROR;
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_BAD_ECPOINT);
 			goto f_err;
 			}
 
-		n-=param_len;
-		p+=encoded_pt_len;
-
 		/* The ECC/TLS specification does not mention
 		 * the use of DSA to sign ECParameters in the server
 		 * key exchange message. We do support RSA and ECDSA.
@@ -1637,14 +1604,34 @@
 		goto f_err;
 		}
 
-	/* p points to the next byte, there are 'n' bytes left */
+	/* At this point, |server_key_exchange| contains the
+	 * signature, if any, while |server_key_exchange_orig|
+	 * contains the entire message. From that, derive a CBS
+	 * containing just the parameter. */
+	CBS_init(&parameter, CBS_data(&server_key_exchange_orig),
+		CBS_len(&server_key_exchange_orig) -
+		CBS_len(&server_key_exchange));
 
 	/* if it was signed, check the signature */
 	if (pkey != NULL)
 		{
+		CBS signature;
+
 		if (SSL_USE_SIGALGS(s))
 			{
-			int rv = tls12_check_peer_sigalg(&md, s, p, pkey);
+			int rv;
+			const uint8_t *sigalg;
+
+			/* The first two bytes are the signature and
+			 * algorithm. */
+			sigalg = CBS_data(&server_key_exchange);
+			if (!CBS_skip(&server_key_exchange, 2))
+				{
+				al = SSL_AD_DECODE_ERROR;
+				OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_DECODE_ERROR);
+				goto f_err;
+				}
+			rv = tls12_check_peer_sigalg(&md, s, sigalg, pkey);
 			if (rv == -1)
 				goto err;
 			else if (rv == 0)
@@ -1652,56 +1639,44 @@
 				al = SSL_AD_DECODE_ERROR;
 				goto f_err;
 				}
-#ifdef SSL_DEBUG
-fprintf(stderr, "USING TLSv1.2 HASH %s\n", EVP_MD_name(md));
-#endif
-			p += 2;
-			n -= 2;
 			}
 		else
 			md = EVP_sha1();
-			
-		n2s(p,i);
-		n-=2;
-		j=EVP_PKEY_size(pkey);
 
-		if ((i != n) || (n > j) || (n <= 0))
+		/* The last field in |server_key_exchange| is the
+		 * signature. */
+		if (!CBS_get_u16_length_prefixed(&server_key_exchange, &signature) ||
+			CBS_len(&server_key_exchange) != 0)
 			{
-			/* wrong packet length */
-			al=SSL_AD_DECODE_ERROR;
-			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_WRONG_SIGNATURE_LENGTH);
+			al = SSL_AD_DECODE_ERROR;
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_DECODE_ERROR);
 			goto f_err;
 			}
 
 		if (pkey->type == EVP_PKEY_RSA && !SSL_USE_SIGALGS(s))
 			{
 			int num;
+			unsigned char *q, md_buf[EVP_MAX_MD_SIZE*2];
+			size_t md_len = 0;
 
-			j=0;
 			q=md_buf;
 			for (num=2; num > 0; num--)
 				{
+				unsigned int digest_len;
 				EVP_DigestInit_ex(&md_ctx,(num == 2)
 					?s->ctx->md5:s->ctx->sha1, NULL);
 				EVP_DigestUpdate(&md_ctx,&(s->s3->client_random[0]),SSL3_RANDOM_SIZE);
 				EVP_DigestUpdate(&md_ctx,&(s->s3->server_random[0]),SSL3_RANDOM_SIZE);
-				EVP_DigestUpdate(&md_ctx,param,param_len);
-				EVP_DigestFinal_ex(&md_ctx,q,(unsigned int *)&i);
-				q+=i;
-				j+=i;
+				EVP_DigestUpdate(&md_ctx, CBS_data(&parameter), CBS_len(&parameter));
+				EVP_DigestFinal_ex(&md_ctx, q, &digest_len);
+				q += digest_len;
+				md_len += digest_len;
 				}
-			i=RSA_verify(NID_md5_sha1, md_buf, j, p, n,
-								pkey->pkey.rsa);
-			if (i < 0)
+			if (!RSA_verify(NID_md5_sha1, md_buf, md_len,
+					CBS_data(&signature), CBS_len(&signature),
+					pkey->pkey.rsa))
 				{
-				al=SSL_AD_DECRYPT_ERROR;
-				OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_BAD_RSA_DECRYPT);
-				goto f_err;
-				}
-			if (i == 0)
-				{
-				/* bad signature */
-				al=SSL_AD_DECRYPT_ERROR;
+				al = SSL_AD_DECRYPT_ERROR;
 				OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_BAD_SIGNATURE);
 				goto f_err;
 				}
@@ -1711,8 +1686,8 @@
 			EVP_VerifyInit_ex(&md_ctx, md, NULL);
 			EVP_VerifyUpdate(&md_ctx,&(s->s3->client_random[0]),SSL3_RANDOM_SIZE);
 			EVP_VerifyUpdate(&md_ctx,&(s->s3->server_random[0]),SSL3_RANDOM_SIZE);
-			EVP_VerifyUpdate(&md_ctx,param,param_len);
-			if (EVP_VerifyFinal(&md_ctx,p,(int)n,pkey) <= 0)
+			EVP_VerifyUpdate(&md_ctx, CBS_data(&parameter), CBS_len(&parameter));
+			if (EVP_VerifyFinal(&md_ctx, CBS_data(&signature), CBS_len(&signature), pkey) <= 0)
 				{
 				/* bad signature */
 				al=SSL_AD_DECRYPT_ERROR;
@@ -1732,7 +1707,7 @@
 			goto err;
 			}
 		/* still data left over */
-		if (n != 0)
+		if (CBS_len(&server_key_exchange) > 0)
 			{
 			al=SSL_AD_DECODE_ERROR;
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_key_exchange, SSL_R_EXTRA_DATA_IN_MESSAGE);