Refactor PSK logic in ssl3_get_client_key_exchange.

This avoids duplicating the code to build the final premaster in PSK and
ECDHE_PSK. It also ports it to CBB for an initial trial of the API. Computing
the premaster secret now proceeds in four steps:

1. If a PSK key exchange (alg_a), look up the pre-shared key.
2. Compute the premaster secret based on alg_k. If PSK, it's all zeros.
3. If a PSK key exchange (alg_a), wrap the premaster in a struct with the
   pre-shared key.
4. Use the possibly modified premaster to compute the master secret.

Change-Id: Ib511dd2724cbed42c82b82a676f641114cec5470
Reviewed-on: https://boringssl-review.googlesource.com/1173
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/s3_clnt.c b/ssl/s3_clnt.c
index 905acee..94769ce 100644
--- a/ssl/s3_clnt.c
+++ b/ssl/s3_clnt.c
@@ -2123,7 +2123,9 @@
 
 			if (!(alg_k & SSL_kEECDH))
 				{
-				/* Create the shared secret now if we're not using ECDHE-PSK.*/
+				/* Create the shared secret now if we're not using ECDHE-PSK.
+				 * TODO(davidben): Refactor this logic similarly
+				 * to ssl3_get_client_key_exchange. */
 				pre_ms_len = 2+psk_len+2+psk_len;
 				t = pre_ms;
 				s2n(psk_len, t);
diff --git a/ssl/s3_srvr.c b/ssl/s3_srvr.c
index e8dd6fd..953e22f 100644
--- a/ssl/s3_srvr.c
+++ b/ssl/s3_srvr.c
@@ -2029,6 +2029,8 @@
 	unsigned long alg_k;
 	unsigned long alg_a;
 	uint8_t *premaster_secret = NULL;
+	size_t premaster_secret_len = 0;
+	int skip_certificate_verify = 0;
 	RSA *rsa=NULL;
 	EVP_PKEY *pkey=NULL;
 #ifndef OPENSSL_NO_DH
@@ -2061,13 +2063,10 @@
 	alg_a=s->s3->tmp.new_cipher->algorithm_auth;
 
 #ifndef OPENSSL_NO_PSK
+	/* If using a PSK key exchange, prepare the pre-shared key. */
 	if (alg_a & SSL_aPSK)
 		{
 		CBS psk_identity;
-		int psk_err = -1;
-		unsigned char *t = NULL;
-		unsigned char pre_ms[PSK_MAX_PSK_LEN*2+4];
-		unsigned int pre_ms_len = 0;
 
 		/* If using PSK, the ClientKeyExchange contains a
 		 * psk_identity. If PSK, then this is the only field
@@ -2077,14 +2076,14 @@
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_DECODE_ERROR);
 			al = SSL_AD_DECODE_ERROR;
-			goto psk_err;
+			goto f_err;
 			}
 
 		if (s->psk_server_callback == NULL)
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_PSK_NO_SERVER_CB);
 			al = SSL_AD_INTERNAL_ERROR;
-			goto psk_err;
+			goto f_err;
 			}
 
 		if (CBS_len(&psk_identity) > PSK_MAX_IDENTITY_LEN ||
@@ -2092,14 +2091,14 @@
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_DATA_LENGTH_TOO_LONG);
 			al = SSL_AD_ILLEGAL_PARAMETER;
-			goto psk_err;
+			goto f_err;
 			}
 
 		if (!CBS_strdup(&psk_identity, &s->session->psk_identity))
 			{
 			al = SSL_AD_INTERNAL_ERROR;
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_MALLOC_FAILURE);
-			goto psk_err;
+			goto f_err;
 			}
 
 		/* Look up the key for the identity. */
@@ -2108,39 +2107,21 @@
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_INTERNAL_ERROR);
 			al = SSL_AD_INTERNAL_ERROR;
-			goto psk_err;
+			goto f_err;
 			}
 		else if (psk_len == 0)
 			{
 			/* PSK related to the given identity not found */
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_PSK_IDENTITY_NOT_FOUND);
 			al = SSL_AD_UNKNOWN_PSK_IDENTITY;
-			goto psk_err;
-			}
-		if (alg_k & SSL_kPSK)
-			{
-			/* Create the shared secret now if we're using plain PSK. */
-			pre_ms_len=2+psk_len+2+psk_len;
-			t = pre_ms;
-			s2n(psk_len, t);
-			memset(t, 0, psk_len);
-			t+=psk_len;
-			s2n(psk_len, t);
-			memcpy(t, psk, psk_len);
-
-			s->session->master_key_length=
-				s->method->ssl3_enc->generate_master_secret(s,
-					s->session->master_key, pre_ms, pre_ms_len);
-			}
-
-		psk_err = 0;
-	psk_err:
-		OPENSSL_cleanse(pre_ms, sizeof(pre_ms));
-		if (psk_err != 0)
 			goto f_err;
+			}
 		}
 #endif /* OPENSSL_NO_PSK */
 
+	/* Depending on the key exchange method, compute |premaster_secret| and
+	 * |premaster_secret_len|. Also, for DH and ECDH, set
+	 * |skip_certificate_verify| as appropriate. */
 	if (alg_k & SSL_kRSA)
 		{
 		CBS encrypted_premaster_secret;
@@ -2318,14 +2299,7 @@
 			       (rand_premaster_secret[j] & ~decrypt_good_mask);
 			}
 
-		s->session->master_key_length=
-			s->method->ssl3_enc->generate_master_secret(s,
-				s->session->master_key,
-				premaster_secret, sizeof(rand_premaster_secret));
-
-		OPENSSL_cleanse(premaster_secret, sizeof(rand_premaster_secret));
-		OPENSSL_free(premaster_secret);
-		premaster_secret = NULL;
+		premaster_secret_len = sizeof(rand_premaster_secret);
 		}
 #ifndef OPENSSL_NO_DH
 	else if (alg_k & (SSL_kEDH|SSL_kDHr|SSL_kDHd))
@@ -2425,21 +2399,16 @@
 		else
 			BN_clear_free(pub);
 		pub=NULL;
-		s->session->master_key_length=
-			s->method->ssl3_enc->generate_master_secret(s,
-				s->session->master_key, premaster_secret, dh_len);
-		OPENSSL_cleanse(premaster_secret, dh_len);
-		OPENSSL_free(premaster_secret);
-		premaster_secret = NULL;
+
+		premaster_secret_len = dh_len;
 		if (dh_clnt)
-			return 2;
+			skip_certificate_verify = 1;
 		}
 #endif
 
 #ifndef OPENSSL_NO_ECDH
 	else if (alg_k & (SSL_kEECDH|SSL_kECDHr|SSL_kECDHe))
 		{
-		int ret = 1;
 		int field_size = 0, ecdh_len;
 		const EC_KEY   *tkey;
 		const EC_GROUP *group;
@@ -2519,7 +2488,8 @@
 				OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_EC_LIB);
 				goto err;
 				}
-			ret = 2; /* Skip certificate verify processing */
+			/* Skip certificate verify processing */
+			skip_certificate_verify = 1;
 			}
 		else
 			{
@@ -2582,47 +2552,7 @@
 		EC_KEY_free(s->s3->tmp.ecdh);
 		s->s3->tmp.ecdh = NULL;
 
-#ifndef OPENSSL_NO_PSK
-		/* ECDHE PSK ciphersuites from RFC 5489 */
-		if (alg_a & SSL_aPSK)
-			{
-			unsigned char *pre_ms;
-			unsigned int pre_ms_len;
-			unsigned char *t;
-
-			pre_ms_len = 2+psk_len+2+ecdh_len;
-			pre_ms = OPENSSL_malloc(pre_ms_len);
-			if (pre_ms == NULL)
-				{
-				OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_MALLOC_FAILURE);
-				goto err;
-				}
-
-			memset(pre_ms, 0, pre_ms_len);
-			t = pre_ms;
-			s2n(psk_len, t);
-			memcpy(t, psk, psk_len);
-			t += psk_len;
-			s2n(ecdh_len, t);
-			memcpy(t, premaster_secret, ecdh_len);
-			s->session->master_key_length = s->method->ssl3_enc \
-				-> generate_master_secret(s,
-					s->session->master_key, pre_ms, pre_ms_len);
-			OPENSSL_cleanse(pre_ms, pre_ms_len);
-			OPENSSL_free(pre_ms);
-			}
-		else
-#endif /* OPENSSL_NO_PSK */
-			{
-			/* Compute the master secret */
-			s->session->master_key_length = s->method->ssl3_enc \
-				-> generate_master_secret(s,
-					s->session->master_key, premaster_secret, ecdh_len);
-			}
-
-		OPENSSL_cleanse(premaster_secret, ecdh_len);
-		OPENSSL_free(premaster_secret);
-		return ret;
+		premaster_secret_len = ecdh_len;
 		}
 #endif
 	else if (alg_k & SSL_kGOST) 
@@ -2630,6 +2560,21 @@
 		OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_GOST_NOT_SUPPORTED);
 		goto err;
 		}
+#ifndef OPENSSL_NO_PSK
+	else if (alg_k & SSL_kPSK)
+		{
+		/* For plain PSK, other_secret is a block of 0s with the same
+		 * length as the pre-shared key. */
+		premaster_secret_len = psk_len;
+		premaster_secret = OPENSSL_malloc(premaster_secret_len);
+		if (premaster_secret == NULL)
+			{
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_MALLOC_FAILURE);
+			goto err;
+			}
+		memset(premaster_secret, 0, premaster_secret_len);
+		}
+#endif  /* !OPENSSL_NO_PSK */
 	else
 		{
 		al=SSL_AD_HANDSHAKE_FAILURE;
@@ -2637,12 +2582,55 @@
 		goto f_err;
 		}
 
-	return(1);
+#ifndef OPENSSL_NO_PSK
+	/* For a PSK cipher suite, the actual pre-master secret is combined with
+	 * the pre-shared key. */
+	if (alg_a & SSL_aPSK)
+		{
+		CBB new_premaster, child;
+		uint8_t *new_data;
+		size_t new_len;
+
+		if (!CBB_init(&new_premaster, 2 + psk_len + 2 + premaster_secret_len))
+			{
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_MALLOC_FAILURE);
+			goto err;
+			}
+		if (!CBB_add_u16_length_prefixed(&new_premaster, &child) ||
+			!CBB_add_bytes(&child, premaster_secret, premaster_secret_len) ||
+			!CBB_add_u16_length_prefixed(&new_premaster, &child) ||
+			!CBB_add_bytes(&child, psk, psk_len) ||
+			!CBB_finish(&new_premaster, &new_data, &new_len))
+			{
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_INTERNAL_ERROR);
+			CBB_cleanup(&new_premaster);
+			goto err;
+			}
+
+		OPENSSL_cleanse(premaster_secret, premaster_secret_len);
+		OPENSSL_free(premaster_secret);
+		premaster_secret = new_data;
+		premaster_secret_len = new_len;
+		}
+#endif  /* !OPENSSL_NO_PSK */
+
+	/* Compute the master secret */
+	s->session->master_key_length = s->method->ssl3_enc
+		->generate_master_secret(s,
+			s->session->master_key, premaster_secret, premaster_secret_len);
+
+	OPENSSL_cleanse(premaster_secret, premaster_secret_len);
+	OPENSSL_free(premaster_secret);
+	return skip_certificate_verify ? 2 : 1;
 f_err:
 	ssl3_send_alert(s,SSL3_AL_FATAL,al);
 err:
 	if (premaster_secret)
+		{
+		if (premaster_secret_len)
+			OPENSSL_cleanse(premaster_secret, premaster_secret_len);
 		OPENSSL_free(premaster_secret);
+		}
 #ifndef OPENSSL_NO_ECDH
 	EVP_PKEY_free(clnt_pub_pkey);
 	EC_POINT_free(clnt_ecpoint);