Convert ssl3_send_client_key_exchange to CBB.

This relieves some complexity budget for adding Curve25519 to this
code.

This also adds a BN_bn2cbb_padded helper function since this seems to be a
fairly common need.

Change-Id: Ied0066fdaec9d02659abd6eb1a13f33502c9e198
Reviewed-on: https://boringssl-review.googlesource.com/6767
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/bn/convert.c b/crypto/bn/convert.c
index 0122709..1f7af64 100644
--- a/crypto/bn/convert.c
+++ b/crypto/bn/convert.c
@@ -63,6 +63,7 @@
 #include <string.h>
 
 #include <openssl/bio.h>
+#include <openssl/bytestring.h>
 #include <openssl/err.h>
 #include <openssl/mem.h>
 
@@ -195,6 +196,11 @@
   return 1;
 }
 
+int BN_bn2cbb_padded(CBB *out, size_t len, const BIGNUM *in) {
+  uint8_t *ptr;
+  return CBB_add_space(out, &ptr, len) && BN_bn2bin_padded(ptr, len, in);
+}
+
 static const char hextable[] = "0123456789abcdef";
 
 char *BN_bn2hex(const BIGNUM *bn) {
diff --git a/include/openssl/bn.h b/include/openssl/bn.h
index bc30d0a..6e971e4 100644
--- a/include/openssl/bn.h
+++ b/include/openssl/bn.h
@@ -253,6 +253,9 @@
  * returns 0. Otherwise, it returns 1. */
 OPENSSL_EXPORT int BN_bn2bin_padded(uint8_t *out, size_t len, const BIGNUM *in);
 
+/* BN_bn2cbb_padded behaves like |BN_bn2bin_padded| but writes to a |CBB|. */
+OPENSSL_EXPORT int BN_bn2cbb_padded(CBB *out, size_t len, const BIGNUM *in);
+
 /* BN_bn2hex returns an allocated string that contains a NUL-terminated, hex
  * representation of |bn|. If |bn| is negative, the first char in the resulting
  * string will be '-'. Returns NULL on allocation failure. */
diff --git a/ssl/internal.h b/ssl/internal.h
index bbbd939..a72d5d6 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1079,7 +1079,7 @@
 int ssl3_send_cert_verify(SSL *s);
 int ssl3_send_client_certificate(SSL *s);
 int ssl_do_client_cert_cb(SSL *s, X509 **px509, EVP_PKEY **ppkey);
-int ssl3_send_client_key_exchange(SSL *s);
+int ssl3_send_client_key_exchange(SSL *ssl);
 int ssl3_get_server_key_exchange(SSL *s);
 int ssl3_get_server_certificate(SSL *s);
 int ssl3_send_next_proto(SSL *ssl);
diff --git a/ssl/s3_clnt.c b/ssl/s3_clnt.c
index 43a77d4..7f61c89 100644
--- a/ssl/s3_clnt.c
+++ b/ssl/s3_clnt.c
@@ -1588,333 +1588,264 @@
 OPENSSL_COMPILE_ASSERT(sizeof(size_t) >= sizeof(unsigned),
                        SIZE_T_IS_SMALLER_THAN_UNSIGNED);
 
-int ssl3_send_client_key_exchange(SSL *s) {
-  uint8_t *p;
-  int n = 0;
-  uint32_t alg_k;
-  uint32_t alg_a;
-  uint8_t *q;
-  EVP_PKEY *pkey = NULL;
-  EC_KEY *clnt_ecdh = NULL;
-  const EC_POINT *srvr_ecpoint = NULL;
-  EVP_PKEY *srvr_pub_pkey = NULL;
-  uint8_t *encodedPoint = NULL;
-  int encoded_pt_len = 0;
-  BN_CTX *bn_ctx = NULL;
-  unsigned int psk_len = 0;
-  uint8_t psk[PSK_MAX_PSK_LEN];
+int ssl3_send_client_key_exchange(SSL *ssl) {
+  if (ssl->state == SSL3_ST_CW_KEY_EXCH_B) {
+    return ssl_do_write(ssl);
+  }
+  assert(ssl->state == SSL3_ST_CW_KEY_EXCH_A);
+
   uint8_t *pms = NULL;
   size_t pms_len = 0;
+  EC_KEY *eckey = NULL;
+  CBB cbb;
+  if (!CBB_init_fixed(&cbb, ssl_handshake_start(ssl),
+                      ssl->init_buf->max - SSL_HM_HEADER_LENGTH(ssl))) {
+    goto err;
+  }
 
-  if (s->state == SSL3_ST_CW_KEY_EXCH_A) {
-    p = ssl_handshake_start(s);
+  uint32_t alg_k = ssl->s3->tmp.new_cipher->algorithm_mkey;
+  uint32_t alg_a = ssl->s3->tmp.new_cipher->algorithm_auth;
 
-    alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
-    alg_a = s->s3->tmp.new_cipher->algorithm_auth;
-
-    /* If using a PSK key exchange, prepare the pre-shared key. */
-    if (alg_a & SSL_aPSK) {
-      char identity[PSK_MAX_IDENTITY_LEN + 1];
-      size_t identity_len;
-
-      if (s->psk_client_callback == NULL) {
-        OPENSSL_PUT_ERROR(SSL, SSL_R_PSK_NO_CLIENT_CB);
-        goto err;
-      }
-
-      memset(identity, 0, sizeof(identity));
-      psk_len =
-          s->psk_client_callback(s, s->s3->tmp.peer_psk_identity_hint, identity,
-                                 sizeof(identity), psk, sizeof(psk));
-      if (psk_len > PSK_MAX_PSK_LEN) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-        goto err;
-      } else if (psk_len == 0) {
-        OPENSSL_PUT_ERROR(SSL, SSL_R_PSK_IDENTITY_NOT_FOUND);
-        ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
-        goto err;
-      }
-
-      identity_len = OPENSSL_strnlen(identity, sizeof(identity));
-      if (identity_len > PSK_MAX_IDENTITY_LEN) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-        goto err;
-      }
-
-      OPENSSL_free(s->session->psk_identity);
-      s->session->psk_identity = BUF_strdup(identity);
-      if (s->session->psk_identity == NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-        goto err;
-      }
-
-      /* Write out psk_identity. */
-      s2n(identity_len, p);
-      memcpy(p, identity, identity_len);
-      p += identity_len;
-      n = 2 + identity_len;
+  /* If using a PSK key exchange, prepare the pre-shared key. */
+  unsigned psk_len = 0;
+  uint8_t psk[PSK_MAX_PSK_LEN];
+  if (alg_a & SSL_aPSK) {
+    if (ssl->psk_client_callback == NULL) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_PSK_NO_CLIENT_CB);
+      goto err;
     }
 
-    /* Depending on the key exchange method, compute |pms| and |pms_len|. */
-    if (alg_k & SSL_kRSA) {
-      size_t enc_pms_len;
+    char identity[PSK_MAX_IDENTITY_LEN + 1];
+    memset(identity, 0, sizeof(identity));
+    psk_len = ssl->psk_client_callback(
+        ssl, ssl->s3->tmp.peer_psk_identity_hint, identity, sizeof(identity),
+        psk, sizeof(psk));
+    if (psk_len == 0) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_PSK_IDENTITY_NOT_FOUND);
+      ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
+      goto err;
+    }
+    assert(psk_len <= PSK_MAX_PSK_LEN);
 
-      pms_len = SSL_MAX_MASTER_KEY_LENGTH;
-      pms = OPENSSL_malloc(pms_len);
-      if (pms == NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-        goto err;
-      }
+    OPENSSL_free(ssl->session->psk_identity);
+    ssl->session->psk_identity = BUF_strdup(identity);
+    if (ssl->session->psk_identity == NULL) {
+      OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+      goto err;
+    }
 
-      pkey = X509_get_pubkey(s->session->peer);
-      if (pkey == NULL) {
-        goto err;
-      }
+    /* Write out psk_identity. */
+    CBB child;
+    if (!CBB_add_u16_length_prefixed(&cbb, &child) ||
+        !CBB_add_bytes(&child, (const uint8_t *)identity,
+                       OPENSSL_strnlen(identity, sizeof(identity))) ||
+        !CBB_flush(&cbb)) {
+      goto err;
+    }
+  }
 
-      RSA *rsa = EVP_PKEY_get0_RSA(pkey);
-      if (rsa == NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-        EVP_PKEY_free(pkey);
-        goto err;
-      }
+  /* Depending on the key exchange method, compute |pms| and |pms_len|. */
+  if (alg_k & SSL_kRSA) {
+    pms_len = SSL_MAX_MASTER_KEY_LENGTH;
+    pms = OPENSSL_malloc(pms_len);
+    if (pms == NULL) {
+      OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+      goto err;
+    }
 
-      s->session->key_exchange_info = EVP_PKEY_bits(pkey);
+    EVP_PKEY *pkey = X509_get_pubkey(ssl->session->peer);
+    if (pkey == NULL) {
+      goto err;
+    }
+
+    RSA *rsa = EVP_PKEY_get0_RSA(pkey);
+    if (rsa == NULL) {
+      OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
       EVP_PKEY_free(pkey);
+      goto err;
+    }
 
-      pms[0] = s->client_version >> 8;
-      pms[1] = s->client_version & 0xff;
-      if (!RAND_bytes(&pms[2], SSL_MAX_MASTER_KEY_LENGTH - 2)) {
+    ssl->session->key_exchange_info = EVP_PKEY_bits(pkey);
+    EVP_PKEY_free(pkey);
+
+    pms[0] = ssl->client_version >> 8;
+    pms[1] = ssl->client_version & 0xff;
+    if (!RAND_bytes(&pms[2], SSL_MAX_MASTER_KEY_LENGTH - 2)) {
+      goto err;
+    }
+
+    CBB child, *enc_pms = &cbb;
+    size_t enc_pms_len;
+    /* In TLS, there is a length prefix. */
+    if (ssl->version > SSL3_VERSION) {
+      if (!CBB_add_u16_length_prefixed(&cbb, &child)) {
         goto err;
       }
+      enc_pms = &child;
+    }
 
-      s->session->master_key_length = SSL_MAX_MASTER_KEY_LENGTH;
-
-      q = p;
-      /* In TLS and beyond, reserve space for the length prefix. */
-      if (s->version > SSL3_VERSION) {
-        p += 2;
-        n += 2;
-      }
-      if (!RSA_encrypt(rsa, &enc_pms_len, p, RSA_size(rsa), pms, pms_len,
-                       RSA_PKCS1_PADDING)) {
-        OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_RSA_ENCRYPT);
-        goto err;
-      }
-      n += enc_pms_len;
-
-      /* Log the premaster secret, if logging is enabled. */
-      if (!ssl_log_rsa_client_key_exchange(s, p, enc_pms_len, pms, pms_len)) {
-        goto err;
-      }
-
-      /* Fill in the length prefix. */
-      if (s->version > SSL3_VERSION) {
-        s2n(enc_pms_len, q);
-      }
-    } else if (alg_k & SSL_kDHE) {
-      DH *dh_srvr, *dh_clnt;
-      int dh_len;
-      size_t pub_len;
-
-      if (s->s3->tmp.peer_dh_tmp == NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-        goto err;
-      }
-      dh_srvr = s->s3->tmp.peer_dh_tmp;
-
-      /* generate a new random key */
-      dh_clnt = DHparams_dup(dh_srvr);
-      if (dh_clnt == NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_DH_LIB);
-        goto err;
-      }
-      if (!DH_generate_key(dh_clnt)) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_DH_LIB);
-        DH_free(dh_clnt);
-        goto err;
-      }
-
-      pms_len = DH_size(dh_clnt);
-      pms = OPENSSL_malloc(pms_len);
-      if (pms == NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-        DH_free(dh_clnt);
-        goto err;
-      }
-
-      dh_len = DH_compute_key(pms, dh_srvr->pub_key, dh_clnt);
-      if (dh_len <= 0) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_DH_LIB);
-        DH_free(dh_clnt);
-        goto err;
-      }
-      pms_len = dh_len;
-
-      /* send off the data */
-      pub_len = BN_num_bytes(dh_clnt->pub_key);
-      s2n(pub_len, p);
-      BN_bn2bin(dh_clnt->pub_key, p);
-      n += 2 + pub_len;
-
-      DH_free(dh_clnt);
-    } else if (alg_k & SSL_kECDHE) {
-      const EC_GROUP *srvr_group = NULL;
-      EC_KEY *tkey;
-      int ecdh_len;
-
-      if (s->s3->tmp.peer_ecdh_tmp == NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-        goto err;
-      }
-
-      tkey = s->s3->tmp.peer_ecdh_tmp;
-
-      srvr_group = EC_KEY_get0_group(tkey);
-      srvr_ecpoint = EC_KEY_get0_public_key(tkey);
-      if (srvr_group == NULL || srvr_ecpoint == NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-        goto err;
-      }
-
-      clnt_ecdh = EC_KEY_new();
-      if (clnt_ecdh == NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-        goto err;
-      }
-
-      if (!EC_KEY_set_group(clnt_ecdh, srvr_group)) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_EC_LIB);
-        goto err;
-      }
-
-      /* Generate a new ECDH key pair */
-      if (!EC_KEY_generate_key(clnt_ecdh)) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_ECDH_LIB);
-        goto err;
-      }
-
-      unsigned field_size = EC_GROUP_get_degree(srvr_group);
-      if (field_size == 0) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_ECDH_LIB);
-        goto err;
-      }
-
-      pms_len = (field_size + 7) / 8;
-      pms = OPENSSL_malloc(pms_len);
-      if (pms == NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-        goto err;
-      }
-
-      ecdh_len = ECDH_compute_key(pms, pms_len, srvr_ecpoint, clnt_ecdh, NULL);
-      if (ecdh_len <= 0) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_ECDH_LIB);
-        goto err;
-      }
-      pms_len = ecdh_len;
-
-      /* First check the size of encoding and allocate memory accordingly. */
-      encoded_pt_len =
-          EC_POINT_point2oct(srvr_group, EC_KEY_get0_public_key(clnt_ecdh),
-                             POINT_CONVERSION_UNCOMPRESSED, NULL, 0, NULL);
-
-      encodedPoint =
-          (uint8_t *)OPENSSL_malloc(encoded_pt_len * sizeof(uint8_t));
-      bn_ctx = BN_CTX_new();
-      if (encodedPoint == NULL || bn_ctx == NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-        goto err;
-      }
-
-      /* Encode the public key */
-      encoded_pt_len = EC_POINT_point2oct(
-          srvr_group, EC_KEY_get0_public_key(clnt_ecdh),
-          POINT_CONVERSION_UNCOMPRESSED, encodedPoint, encoded_pt_len, bn_ctx);
-
-      *p = encoded_pt_len; /* length of encoded point */
-      /* Encoded point will be copied here */
-      p += 1;
-      n += 1;
-      /* copy the point */
-      memcpy(p, encodedPoint, encoded_pt_len);
-      /* increment n to account for length field */
-      n += encoded_pt_len;
-
-      /* Free allocated memory */
-      BN_CTX_free(bn_ctx);
-      bn_ctx = NULL;
-      OPENSSL_free(encodedPoint);
-      encodedPoint = NULL;
-      EC_KEY_free(clnt_ecdh);
-      clnt_ecdh = NULL;
-      EVP_PKEY_free(srvr_pub_pkey);
-      srvr_pub_pkey = NULL;
-    } else if (alg_k & SSL_kPSK) {
-      /* For plain PSK, other_secret is a block of 0s with the same length as
-       * the pre-shared key. */
-      pms_len = psk_len;
-      pms = OPENSSL_malloc(pms_len);
-      if (pms == NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-        goto err;
-      }
-      memset(pms, 0, pms_len);
-    } else {
-      ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
+    uint8_t *ptr;
+    if (!CBB_reserve(enc_pms, &ptr, RSA_size(rsa)) ||
+        !RSA_encrypt(rsa, &enc_pms_len, ptr, RSA_size(rsa), pms, pms_len,
+                     RSA_PKCS1_PADDING) ||
+        /* Log the premaster secret, if logging is enabled. */
+        !ssl_log_rsa_client_key_exchange(ssl, ptr, enc_pms_len, pms, pms_len) ||
+        !CBB_did_write(enc_pms, enc_pms_len) ||
+        !CBB_flush(&cbb)) {
+      goto err;
+    }
+  } else if (alg_k & SSL_kDHE) {
+    if (ssl->s3->tmp.peer_dh_tmp == NULL) {
       OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
       goto err;
     }
+    DH *peer_dh = ssl->s3->tmp.peer_dh_tmp;
 
-    /* For a PSK cipher suite, other_secret is combined with the pre-shared
-     * key. */
-    if (alg_a & SSL_aPSK) {
-      CBB cbb, child;
-      uint8_t *new_pms;
-      size_t new_pms_len;
-
-      CBB_zero(&cbb);
-      if (!CBB_init(&cbb, 2 + psk_len + 2 + pms_len) ||
-          !CBB_add_u16_length_prefixed(&cbb, &child) ||
-          !CBB_add_bytes(&child, pms, pms_len) ||
-          !CBB_add_u16_length_prefixed(&cbb, &child) ||
-          !CBB_add_bytes(&child, psk, psk_len) ||
-          !CBB_finish(&cbb, &new_pms, &new_pms_len)) {
-        CBB_cleanup(&cbb);
-        OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-        goto err;
-      }
-      OPENSSL_cleanse(pms, pms_len);
-      OPENSSL_free(pms);
-      pms = new_pms;
-      pms_len = new_pms_len;
-    }
-
-    /* The message must be added to the finished hash before calculating the
-     * master secret. */
-    if (!ssl_set_handshake_header(s, SSL3_MT_CLIENT_KEY_EXCHANGE, n)) {
+    /* Generate a keypair. */
+    DH *dh = DHparams_dup(peer_dh);
+    if (dh == NULL || !DH_generate_key(dh)) {
+      OPENSSL_PUT_ERROR(SSL, ERR_R_DH_LIB);
+      DH_free(dh);
       goto err;
     }
-    s->state = SSL3_ST_CW_KEY_EXCH_B;
 
-    s->session->master_key_length = s->enc_method->generate_master_secret(
-        s, s->session->master_key, pms, pms_len);
-    if (s->session->master_key_length == 0) {
+    pms_len = DH_size(dh);
+    pms = OPENSSL_malloc(pms_len);
+    if (pms == NULL) {
+      OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+      DH_free(dh);
       goto err;
     }
-    s->session->extended_master_secret = s->s3->tmp.extended_master_secret;
-    OPENSSL_cleanse(pms, pms_len);
-    OPENSSL_free(pms);
+
+    int dh_len = DH_compute_key(pms, peer_dh->pub_key, dh);
+    if (dh_len <= 0) {
+      OPENSSL_PUT_ERROR(SSL, ERR_R_DH_LIB);
+      DH_free(dh);
+      goto err;
+    }
+    pms_len = dh_len;
+
+    /* Write the public key. */
+    CBB child;
+    if (!CBB_add_u16_length_prefixed(&cbb, &child) ||
+        !BN_bn2cbb_padded(&child, BN_num_bytes(dh->pub_key), dh->pub_key) ||
+        !CBB_flush(&cbb)) {
+      DH_free(dh);
+      goto err;
+    }
+
+    DH_free(dh);
+  } else if (alg_k & SSL_kECDHE) {
+    if (ssl->s3->tmp.peer_ecdh_tmp == NULL) {
+      OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+      goto err;
+    }
+    EC_KEY *peer_eckey = ssl->s3->tmp.peer_ecdh_tmp;
+
+    const EC_GROUP *group = EC_KEY_get0_group(peer_eckey);
+    eckey = EC_KEY_new();
+    if (eckey == NULL ||
+        !EC_KEY_set_group(eckey, group) ||
+        !EC_KEY_generate_key(eckey)) {
+      goto err;
+    }
+
+    pms_len = (EC_GROUP_get_degree(group) + 7) / 8;
+    pms = OPENSSL_malloc(pms_len);
+    if (pms == NULL) {
+      OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+      goto err;
+    }
+
+    int ecdh_len = ECDH_compute_key(
+        pms, pms_len, EC_KEY_get0_public_key(peer_eckey), eckey, NULL);
+    if (ecdh_len <= 0) {
+      OPENSSL_PUT_ERROR(SSL, ERR_R_ECDH_LIB);
+      goto err;
+    }
+    pms_len = ecdh_len;
+
+    size_t encoded_len =
+        EC_POINT_point2oct(group, EC_KEY_get0_public_key(eckey),
+                           POINT_CONVERSION_UNCOMPRESSED, NULL, 0, NULL);
+    uint8_t *ptr;
+    CBB child;
+    if (encoded_len == 0 ||
+        !CBB_add_u8_length_prefixed(&cbb, &child) ||
+        !CBB_add_space(&child, &ptr, encoded_len) ||
+        EC_POINT_point2oct(group, EC_KEY_get0_public_key(eckey),
+                           POINT_CONVERSION_UNCOMPRESSED, ptr, encoded_len,
+                           NULL) != encoded_len ||
+        !CBB_flush(&cbb)) {
+      goto err;
+    }
+
+    EC_KEY_free(eckey);
+    eckey = NULL;
+  } else if (alg_k & SSL_kPSK) {
+    /* For plain PSK, other_secret is a block of 0s with the same length as
+     * the pre-shared key. */
+    pms_len = psk_len;
+    pms = OPENSSL_malloc(pms_len);
+    if (pms == NULL) {
+      OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+      goto err;
+    }
+    memset(pms, 0, pms_len);
+  } else {
+    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
+    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+    goto err;
   }
 
+  /* For a PSK cipher suite, other_secret is combined with the pre-shared
+   * key. */
+  if (alg_a & SSL_aPSK) {
+    CBB pms_cbb, child;
+    uint8_t *new_pms;
+    size_t new_pms_len;
+
+    CBB_zero(&pms_cbb);
+    if (!CBB_init(&pms_cbb, 2 + psk_len + 2 + pms_len) ||
+        !CBB_add_u16_length_prefixed(&pms_cbb, &child) ||
+        !CBB_add_bytes(&child, pms, pms_len) ||
+        !CBB_add_u16_length_prefixed(&pms_cbb, &child) ||
+        !CBB_add_bytes(&child, psk, psk_len) ||
+        !CBB_finish(&pms_cbb, &new_pms, &new_pms_len)) {
+      CBB_cleanup(&pms_cbb);
+      OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+      goto err;
+    }
+    OPENSSL_cleanse(pms, pms_len);
+    OPENSSL_free(pms);
+    pms = new_pms;
+    pms_len = new_pms_len;
+  }
+
+  /* The message must be added to the finished hash before calculating the
+   * master secret. */
+  size_t length;
+  if (!CBB_finish(&cbb, NULL, &length) ||
+      !ssl_set_handshake_header(ssl, SSL3_MT_CLIENT_KEY_EXCHANGE, length)) {
+    goto err;
+  }
+  ssl->state = SSL3_ST_CW_KEY_EXCH_B;
+
+  ssl->session->master_key_length = ssl->enc_method->generate_master_secret(
+      ssl, ssl->session->master_key, pms, pms_len);
+  if (ssl->session->master_key_length == 0) {
+    goto err;
+  }
+  ssl->session->extended_master_secret = ssl->s3->tmp.extended_master_secret;
+  OPENSSL_cleanse(pms, pms_len);
+  OPENSSL_free(pms);
+
   /* SSL3_ST_CW_KEY_EXCH_B */
-  return s->method->do_write(s);
+  return ssl_do_write(ssl);
 
 err:
-  BN_CTX_free(bn_ctx);
-  OPENSSL_free(encodedPoint);
-  EC_KEY_free(clnt_ecdh);
-  EVP_PKEY_free(srvr_pub_pkey);
-  if (pms) {
+  EC_KEY_free(eckey);
+  if (pms != NULL) {
     OPENSSL_cleanse(pms, pms_len);
     OPENSSL_free(pms);
   }
@@ -2124,12 +2055,6 @@
   return ssl_do_write(ssl);
 }
 
-static int write_32_byte_big_endian(CBB *out, const BIGNUM *in) {
-  uint8_t *ptr;
-  return CBB_add_space(out, &ptr, 32) &&
-         BN_bn2bin_padded(ptr, 32, in);
-}
-
 int ssl3_send_channel_id(SSL *ssl) {
   if (ssl->state == SSL3_ST_CW_CHANNEL_ID_B) {
     return ssl_do_write(ssl);
@@ -2190,10 +2115,10 @@
                       ssl->init_buf->max - SSL_HM_HEADER_LENGTH(ssl)) ||
       !CBB_add_u16(&cbb, TLSEXT_TYPE_channel_id) ||
       !CBB_add_u16_length_prefixed(&cbb, &child) ||
-      !write_32_byte_big_endian(&child, x) ||
-      !write_32_byte_big_endian(&child, y) ||
-      !write_32_byte_big_endian(&child, sig->r) ||
-      !write_32_byte_big_endian(&child, sig->s) ||
+      !BN_bn2cbb_padded(&child, 32, x) ||
+      !BN_bn2cbb_padded(&child, 32, y) ||
+      !BN_bn2cbb_padded(&child, 32, sig->r) ||
+      !BN_bn2cbb_padded(&child, 32, sig->s) ||
       !CBB_finish(&cbb, NULL, &length) ||
       !ssl_set_handshake_header(ssl, SSL3_MT_ENCRYPTED_EXTENSIONS, length)) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);