Port ssl3_get_client_key_exchange to CBS.

Change-Id: I065554d058395322a4ac675155bfe66c874b47ad
Reviewed-on: https://boringssl-review.googlesource.com/1171
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 2882c79..3ea2c4b 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -515,7 +515,6 @@
 #define SSL_OP_SSLREF2_REUSE_CERT_TYPE_BUG		0x00000010L
 #define SSL_OP_MICROSOFT_BIG_SSLV3_BUFFER		0x00000020L
 #define SSL_OP_SAFARI_ECDHE_ECDSA_BUG			0x00000040L
-#define SSL_OP_SSLEAY_080_CLIENT_DH_BUG			0x00000080L
 #define SSL_OP_TLS_D5_BUG				0x00000100L
 #define SSL_OP_TLS_BLOCK_PADDING_BUG			0x00000200L
 
diff --git a/ssl/s3_srvr.c b/ssl/s3_srvr.c
index 5fb07e8..31c2e37 100644
--- a/ssl/s3_srvr.c
+++ b/ssl/s3_srvr.c
@@ -2023,11 +2023,12 @@
 
 int ssl3_get_client_key_exchange(SSL *s)
 	{
-	int i,al,ok;
+	int al,ok;
 	long n;
+	CBS client_key_exchange;
 	unsigned long alg_k;
 	unsigned long alg_a;
-	unsigned char *p;
+	uint8_t *premaster_secret = NULL;
 	RSA *rsa=NULL;
 	EVP_PKEY *pkey=NULL;
 #ifndef OPENSSL_NO_DH
@@ -2054,7 +2055,7 @@
 		&ok);
 
 	if (!ok) return((int)n);
-	p = s->init_msg;
+	CBS_init(&client_key_exchange, s->init_msg, n);
 
 	alg_k=s->s3->tmp.new_cipher->algorithm_mkey;
 	alg_a=s->s3->tmp.new_cipher->algorithm_auth;
@@ -2062,47 +2063,58 @@
 #ifndef OPENSSL_NO_PSK
 	if (alg_a & SSL_aPSK)
 		{
+		CBS psk_identity;
+		int psk_err = -1;
 		unsigned char *t = NULL;
 		unsigned char pre_ms[PSK_MAX_PSK_LEN*2+4];
 		unsigned int pre_ms_len = 0;
-		int psk_err = 1;
-		char tmp_id[PSK_MAX_IDENTITY_LEN+1];
 
-		al=SSL_AD_HANDSHAKE_FAILURE;
-
-		n2s(p, i);
-		if (n != i+2 && !(alg_k & SSL_kEECDH))
+		/* If using PSK, the ClientKeyExchange contains a
+		 * psk_identity. If PSK, then this is the only field
+		 * in the message. */
+		if (!CBS_get_u16_length_prefixed(&client_key_exchange, &psk_identity) ||
+			((alg_k & SSL_kPSK) && CBS_len(&client_key_exchange) != 0))
 			{
-			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_LENGTH_MISMATCH);
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_DECODE_ERROR);
+			al = SSL_AD_DECODE_ERROR;
 			goto psk_err;
 			}
-		if (i > PSK_MAX_IDENTITY_LEN)
-			{
-			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_DATA_LENGTH_TOO_LONG);
-			goto psk_err;
-			}
+
 		if (s->psk_server_callback == NULL)
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_PSK_NO_SERVER_CB);
+			al = SSL_AD_INTERNAL_ERROR;
 			goto psk_err;
 			}
 
-		/* Create guaranteed NUL-terminated identity
-		 * string for the callback */
-		memcpy(tmp_id, p, i);
-		memset(tmp_id+i, 0, PSK_MAX_IDENTITY_LEN+1-i);
-		psk_len = s->psk_server_callback(s, tmp_id, psk, sizeof(psk));
+		if (CBS_len(&psk_identity) > PSK_MAX_IDENTITY_LEN ||
+			CBS_contains_zero_byte(&psk_identity))
+			{
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_DATA_LENGTH_TOO_LONG);
+			al = SSL_AD_ILLEGAL_PARAMETER;
+			goto psk_err;
+			}
 
+		if (!CBS_strdup(&psk_identity, &s->session->psk_identity))
+			{
+			al = SSL_AD_INTERNAL_ERROR;
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_MALLOC_FAILURE);
+			goto psk_err;
+			}
+
+		/* Look up the key for the identity. */
+		psk_len = s->psk_server_callback(s, s->session->psk_identity, psk, sizeof(psk));
 		if (psk_len > PSK_MAX_PSK_LEN)
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_INTERNAL_ERROR);
+			al = SSL_AD_INTERNAL_ERROR;
 			goto psk_err;
 			}
 		else if (psk_len == 0)
 			{
 			/* PSK related to the given identity not found */
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_PSK_IDENTITY_NOT_FOUND);
-			al=SSL_AD_UNKNOWN_PSK_IDENTITY;
+			al = SSL_AD_UNKNOWN_PSK_IDENTITY;
 			goto psk_err;
 			}
 		if (!(alg_k & SSL_kEECDH))
@@ -2120,18 +2132,7 @@
 				s->method->ssl3_enc->generate_master_secret(s,
 					s->session->master_key, pre_ms, pre_ms_len);
 			}
-		if (s->session->psk_identity != NULL)
-			OPENSSL_free(s->session->psk_identity);
-		s->session->psk_identity = BUF_strdup(tmp_id);
-		OPENSSL_cleanse(tmp_id, PSK_MAX_IDENTITY_LEN+1);
-		if (s->session->psk_identity == NULL)
-			{
-			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_MALLOC_FAILURE);
-			goto psk_err;
-			}
 
-		p += i;
-		n -= (i + 2);
 		psk_err = 0;
 	psk_err:
 		OPENSSL_cleanse(pre_ms, sizeof(pre_ms));
@@ -2140,9 +2141,9 @@
 		}
 #endif /* OPENSSL_NO_PSK */
 
-	if (0) {}
-	else if (alg_k & SSL_kRSA)
+	if (alg_k & SSL_kRSA)
 		{
+		CBS encrypted_premaster_secret;
 		unsigned char rand_premaster_secret[SSL_MAX_MASTER_KEY_LENGTH];
 		int decrypt_len, decrypt_good_mask;
 		unsigned char version_good;
@@ -2177,11 +2178,18 @@
 			rsa=pkey->pkey.rsa;
 			}
 
-		/* TLS and [incidentally] DTLS{0xFEFF} */
-		if (s->version > SSL3_VERSION && s->version != DTLS1_BAD_VER)
+		/* TLS and [incidentally] DTLS{0xFEFF}
+		 *
+		 * TODO(davidben): Should this (and
+		 * ssl3_send_client_key_exchange) include DTLS1_BAD_VER?
+		 * Alternatively, get rid of DTLS1_BAD_VER?
+		 */
+		if (s->version > SSL3_VERSION)
 			{
-			n2s(p,i);
-			if (n != i+2)
+			CBS copy = client_key_exchange;
+			if (!CBS_get_u16_length_prefixed(&client_key_exchange,
+					&encrypted_premaster_secret) ||
+				CBS_len(&client_key_exchange) != 0)
 				{
 				if (!(s->options & SSL_OP_TLS_D5_BUG))
 					{
@@ -2190,11 +2198,11 @@
 					goto f_err;
 					}
 				else
-					p-=2;
+					encrypted_premaster_secret = copy;
 				}
-			else
-				n=i;
 			}
+		else
+			encrypted_premaster_secret = client_key_exchange;
 
 		/* Reject overly short RSA ciphertext because we want to be
 		 * sure that the buffer size makes it safe to iterate over the
@@ -2202,7 +2210,7 @@
 		 * (SSL_MAX_MASTER_KEY_LENGTH). The actual expected size is
 		 * larger due to RSA padding, but the bound is sufficient to be
 		 * safe. */
-		if (n < SSL_MAX_MASTER_KEY_LENGTH)
+		if (CBS_len(&encrypted_premaster_secret) < SSL_MAX_MASTER_KEY_LENGTH)
 			{
 			al = SSL_AD_DECRYPT_ERROR;
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_DECRYPTION_FAILED);
@@ -2219,7 +2227,21 @@
 				      sizeof(rand_premaster_secret)) <= 0)
 			goto err;
 
-		decrypt_len = RSA_private_decrypt((int)n,p,p,rsa,RSA_PKCS1_PADDING);
+		/* Allocate a buffer large enough for an RSA decryption. */
+		premaster_secret = OPENSSL_malloc(RSA_size(rsa));
+		if (premaster_secret == NULL)
+			{
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_MALLOC_FAILURE);
+			goto err;
+			}
+
+		decrypt_len = RSA_private_decrypt(
+			CBS_len(&encrypted_premaster_secret),
+			CBS_data(&encrypted_premaster_secret),
+			premaster_secret,
+			rsa,
+			RSA_PKCS1_PADDING);
+
 		ERR_clear_error();
 
 		/* decrypt_len should be SSL_MAX_MASTER_KEY_LENGTH.
@@ -2233,8 +2255,8 @@
 		 * number check as a "bad version oracle". Thus version checks
 		 * are done in constant time and are treated like any other
 		 * decryption error. */
-		version_good = p[0] ^ (s->client_version>>8);
-		version_good |= p[1] ^ (s->client_version&0xff);
+		version_good = premaster_secret[0] ^ (s->client_version>>8);
+		version_good |= premaster_secret[1] ^ (s->client_version&0xff);
 
 		/* The premaster secret must contain the same version number as
 		 * the ClientHello to detect version rollback attacks
@@ -2256,8 +2278,8 @@
 			workaround_mask |= workaround_mask >> 1;
 			workaround_mask = ~((workaround_mask & 1) - 1);
 
-			workaround = p[0] ^ (s->version>>8);
-			workaround |= p[1] ^ (s->version&0xff);
+			workaround = premaster_secret[0] ^ (s->version>>8);
+			workaround |= premaster_secret[1] ^ (s->version&0xff);
 
 			/* If workaround_mask is 0xff (i.e. there was a version
 			 * mismatch) then we copy the value of workaround over
@@ -2288,42 +2310,39 @@
 		decrypt_good_mask &= 1;
 		decrypt_good_mask--;
 
-		/* Now copy rand_premaster_secret over p using
+		/* Now copy rand_premaster_secret over premaster_secret using
 		 * decrypt_good_mask. */
 		for (j = 0; j < sizeof(rand_premaster_secret); j++)
 			{
-			p[j] = (p[j] & decrypt_good_mask) |
+			premaster_secret[j] = (premaster_secret[j] & decrypt_good_mask) |
 			       (rand_premaster_secret[j] & ~decrypt_good_mask);
 			}
 
 		s->session->master_key_length=
 			s->method->ssl3_enc->generate_master_secret(s,
 				s->session->master_key,
-				p,sizeof(rand_premaster_secret));
-		OPENSSL_cleanse(p,sizeof(rand_premaster_secret));
+				premaster_secret, sizeof(rand_premaster_secret));
+
+		OPENSSL_cleanse(premaster_secret, sizeof(rand_premaster_secret));
+		OPENSSL_free(premaster_secret);
+		premaster_secret = NULL;
 		}
 #ifndef OPENSSL_NO_DH
 	else if (alg_k & (SSL_kEDH|SSL_kDHr|SSL_kDHd))
 		{
+		CBS dh_Yc;
 		int idx = -1;
+		int dh_len;
 		EVP_PKEY *skey = NULL;
-		if (n)
-			n2s(p,i);
-		else
-			i = 0;
-		if (n && n != i+2)
+
+		if (!CBS_get_u16_length_prefixed(&client_key_exchange, &dh_Yc) ||
+			CBS_len(&client_key_exchange) != 0)
 			{
-			if (!(s->options & SSL_OP_SSLEAY_080_CLIENT_DH_BUG))
-				{
-				OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_DH_PUBLIC_VALUE_LENGTH_IS_WRONG);
-				goto err;
-				}
-			else
-				{
-				p-=2;
-				i=(int)n;
-				}
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_DH_PUBLIC_VALUE_LENGTH_IS_WRONG);
+			al = SSL_R_DECODE_ERROR;
+			goto f_err;
 			}
+
 		if (alg_k & SSL_kDHr)
 			idx = SSL_PKEY_DH_RSA;
 		else if (alg_k & SSL_kDHd)
@@ -2350,9 +2369,16 @@
 		else
 			dh_srvr=s->s3->tmp.dh;
 
-		if (n == 0L)
+		if (CBS_len(&dh_Yc) == 0)
 			{
-			/* Get pubkey from cert */
+			/* Get pubkey from the client certificate. This is the
+			 * 'implicit' case of ClientDiffieHellman.
+			 *
+			 * TODO(davidben): Either lose this code or fix a bug
+			 * (or get the spec changed): if there is a fixed_dh
+			 * client certificate, per spec, the 'implicit' mode
+			 * MUST be used. This logic will still accept 'explicit'
+			 * mode. */
 			EVP_PKEY *clkey=X509_get_pubkey(s->session->peer);
 			if (clkey)
 				{
@@ -2369,16 +2395,23 @@
 			pub = dh_clnt->pub_key;
 			}
 		else
-			pub=BN_bin2bn(p,i,NULL);
+			pub = BN_bin2bn(CBS_data(&dh_Yc), CBS_len(&dh_Yc), NULL);
 		if (pub == NULL)
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_BN_LIB);
 			goto err;
 			}
 
-		i=DH_compute_key(p,pub,dh_srvr);
+		/* Allocate a buffer for the premaster secret. */
+		premaster_secret = OPENSSL_malloc(DH_size(dh_srvr));
+		if (premaster_secret == NULL)
+			{
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_MALLOC_FAILURE);
+			goto err;
+			}
 
-		if (i <= 0)
+		dh_len = DH_compute_key(premaster_secret, pub, dh_srvr);
+		if (dh_len <= 0)
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_DH_LIB);
 			BN_clear_free(pub);
@@ -2394,8 +2427,10 @@
 		pub=NULL;
 		s->session->master_key_length=
 			s->method->ssl3_enc->generate_master_secret(s,
-				s->session->master_key,p,i);
-		OPENSSL_cleanse(p,i);
+				s->session->master_key, premaster_secret, dh_len);
+		OPENSSL_cleanse(premaster_secret, dh_len);
+		OPENSSL_free(premaster_secret);
+		premaster_secret = NULL;
 		if (dh_clnt)
 			return 2;
 		}
@@ -2405,15 +2440,10 @@
 	else if (alg_k & (SSL_kEECDH|SSL_kECDHr|SSL_kECDHe))
 		{
 		int ret = 1;
-		int field_size = 0;
+		int field_size = 0, ecdh_len;
 		const EC_KEY   *tkey;
 		const EC_GROUP *group;
 		const BIGNUM *priv_key;
-#ifndef OPENSSL_NO_PSK
-		unsigned char *pre_ms;
-		unsigned int pre_ms_len;
-		unsigned char *t;
-#endif /* OPENSSL_NO_PSK */
 
 		/* initialize structures for server's ECDH key pair */
 		if ((srvr_ecdh = EC_KEY_new()) == NULL) 
@@ -2453,7 +2483,7 @@
 			goto err;
 			}
 
-		if (n == 0L) 
+		if (CBS_len(&client_key_exchange) == 0)
 			{
 			/* Client Publickey was in Client Certificate */
 
@@ -2493,44 +2523,53 @@
 			}
 		else
 			{
+			CBS ecdh_Yc;
+
 			/* Get client's public key from encoded point
 			 * in the ClientKeyExchange message.
 			 */
+			if (!CBS_get_u8_length_prefixed(&client_key_exchange, &ecdh_Yc) ||
+				CBS_len(&client_key_exchange) != 0)
+				{
+				al = SSL_AD_DECODE_ERROR;
+				OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_DECODE_ERROR);
+				goto f_err;
+				}
+
 			if ((bn_ctx = BN_CTX_new()) == NULL)
 				{
 				OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_MALLOC_FAILURE);
 				goto err;
 				}
 
-			/* Get encoded point length */
-			i = *p;
-			p += 1;
-			if (n != 1 + i)
+			if (!EC_POINT_oct2point(group, clnt_ecpoint,
+					CBS_data(&ecdh_Yc), CBS_len(&ecdh_Yc), bn_ctx))
 				{
 				OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_EC_LIB);
 				goto err;
 				}
-			if (EC_POINT_oct2point(group, 
-			    clnt_ecpoint, p, i, bn_ctx) == 0)
-				{
-				OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_EC_LIB);
-				goto err;
-				}
-			/* p is pointing to somewhere in the buffer
-			 * currently, so set it to the start 
-			 */ 
-			p=(unsigned char *)s->init_buf->data;
 			}
 
-		/* Compute the shared pre-master secret */
+		/* Allocate a buffer for both the secret and the PSK. */
 		field_size = EC_GROUP_get_degree(group);
 		if (field_size <= 0)
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_ECDH_LIB);
 			goto err;
 			}
-		i = ECDH_compute_key(p, (field_size+7)/8, clnt_ecpoint, srvr_ecdh, NULL);
-		if (i <= 0)
+
+		ecdh_len = (field_size + 7) / 8;
+		premaster_secret = OPENSSL_malloc(ecdh_len);
+		if (premaster_secret == NULL)
+			{
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_MALLOC_FAILURE);
+			goto err;
+			}
+
+		/* Compute the shared pre-master secret */
+		ecdh_len = ECDH_compute_key(premaster_secret,
+			ecdh_len, clnt_ecpoint, srvr_ecdh, NULL);
+		if (ecdh_len <= 0)
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_ECDH_LIB);
 			goto err;
@@ -2545,9 +2584,13 @@
 
 #ifndef OPENSSL_NO_PSK
 		/* ECDHE PSK ciphersuites from RFC 5489 */
-	    if ((alg_a & SSL_aPSK) && psk_len != 0)
+		if ((alg_a & SSL_aPSK) && psk_len != 0)
 			{
-			pre_ms_len = 2+psk_len+2+i;
+			unsigned char *pre_ms;
+			unsigned int pre_ms_len;
+			unsigned char *t;
+
+			pre_ms_len = 2+psk_len+2+ecdh_len;
 			pre_ms = OPENSSL_malloc(pre_ms_len);
 			if (pre_ms == NULL)
 				{
@@ -2560,8 +2603,8 @@
 			s2n(psk_len, t);
 			memcpy(t, psk, psk_len);
 			t += psk_len;
-			s2n(i, t);
-			memcpy(t, p, i);
+			s2n(ecdh_len, t);
+			memcpy(t, premaster_secret, ecdh_len);
 			s->session->master_key_length = s->method->ssl3_enc \
 				-> generate_master_secret(s,
 					s->session->master_key, pre_ms, pre_ms_len);
@@ -2574,10 +2617,11 @@
 			/* Compute the master secret */
 			s->session->master_key_length = s->method->ssl3_enc \
 				-> generate_master_secret(s,
-					s->session->master_key, p, i);
+					s->session->master_key, premaster_secret, ecdh_len);
 			}
 
-		OPENSSL_cleanse(p, i);
+		OPENSSL_cleanse(premaster_secret, ecdh_len);
+		OPENSSL_free(premaster_secret);
 		return ret;
 		}
 #endif
@@ -2597,6 +2641,8 @@
 f_err:
 	ssl3_send_alert(s,SSL3_AL_FATAL,al);
 err:
+	if (premaster_secret)
+		OPENSSL_free(premaster_secret);
 #ifndef OPENSSL_NO_ECDH
 	EVP_PKEY_free(clnt_pub_pkey);
 	EC_POINT_free(clnt_ecpoint);