Rewrite ssl3_send_server_key_exchange to use CBB.

There is some messiness around saving and restoring the CBB, but this is
still significantly clearer.

Note that the BUF_MEM_grow line is gone in favor of a fixed CBB like the
other functions ported thus far. This line was never necessary as
init_buf is initialized to 16k and none of our key exchanges get that
large. (The largest one can get is DHE_RSA. Even so, it'd take a roughly
30k-bit DH group with a 30k-bit RSA key.)

Having such limits and tight assumptions on init_buf's initial size is
poor (but on par for the old code which usually just blindly assumed the
message would not get too large) and the size of the certificate chain
is much less obviously bounded, so those BUF_MEM_grows can't easily go.

My current plan is convert everything but those which legitimately need
BUF_MEM_grow to CBB, then atomically convert the rest, remove init_buf,
and switch everything to non-fixed CBBs. This will hopefully also
simplify async resumption. In the meantime, having a story for
resumption means the future atomic change is smaller and, more
importantly, relieves some complexity budget in the ServerKeyExchange
code for adding Curve25519.

Change-Id: I1de6af9856caaed353453d92a502ba461a938fbd
Reviewed-on: https://boringssl-review.googlesource.com/6770
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/bytestring/cbb.c b/crypto/bytestring/cbb.c
index a9e9b3c..8fc5187 100644
--- a/crypto/bytestring/cbb.c
+++ b/crypto/bytestring/cbb.c
@@ -261,6 +261,11 @@
   return 1;
 }
 
+const uint8_t *CBB_data(const CBB *cbb) {
+  assert(cbb->child == NULL);
+  return cbb->base->buf + cbb->offset + cbb->pending_len_len;
+}
+
 size_t CBB_len(const CBB *cbb) {
   assert(cbb->child == NULL);
   assert(cbb->offset + cbb->pending_len_len <= cbb->base->len);
diff --git a/include/openssl/bytestring.h b/include/openssl/bytestring.h
index fe26111..9193e11 100644
--- a/include/openssl/bytestring.h
+++ b/include/openssl/bytestring.h
@@ -292,6 +292,13 @@
  * on error. */
 OPENSSL_EXPORT int CBB_flush(CBB *cbb);
 
+/* CBB_data returns a pointer to the bytes written to |cbb|. It does not flush
+ * |cbb|. The pointer is valid until the next operation to |cbb|.
+ *
+ * To avoid unfinalized length prefixes, it is a fatal error to call this on a
+ * CBB with any active children. */
+OPENSSL_EXPORT const uint8_t *CBB_data(const CBB *cbb);
+
 /* CBB_len returns the number of bytes written to |cbb|. It does not flush
  * |cbb|.
  *
diff --git a/ssl/internal.h b/ssl/internal.h
index a72d5d6..e3748a2 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1091,7 +1091,7 @@
 int ssl3_get_v2_client_hello(SSL *s);
 int ssl3_get_client_hello(SSL *s);
 int ssl3_send_server_hello(SSL *ssl);
-int ssl3_send_server_key_exchange(SSL *s);
+int ssl3_send_server_key_exchange(SSL *ssl);
 int ssl3_send_certificate_request(SSL *s);
 int ssl3_send_server_done(SSL *s);
 int ssl3_get_client_certificate(SSL *s);
diff --git a/ssl/s3_srvr.c b/ssl/s3_srvr.c
index 8a573fe..8ca2cf9 100644
--- a/ssl/s3_srvr.c
+++ b/ssl/s3_srvr.c
@@ -1199,301 +1199,212 @@
   return ssl_do_write(s);
 }
 
-int ssl3_send_server_key_exchange(SSL *s) {
-  DH *dh = NULL, *dhp;
-  EC_KEY *ecdh = NULL;
-  uint8_t *encodedPoint = NULL;
-  int encodedlen = 0;
-  uint16_t curve_id = 0;
-  const char *psk_identity_hint = NULL;
-  size_t psk_identity_hint_len = 0;
-  size_t sig_len;
-  size_t max_sig_len;
-  uint8_t *p, *d;
-  int al, i;
-  uint32_t alg_k;
-  uint32_t alg_a;
-  int n;
-  CERT *cert;
-  BIGNUM *r[4];
-  /* r_pad_bytes[i] contains the number of zero padding bytes that need to
-   * precede |r[i]| when serialising it. */
-  unsigned r_pad_bytes[4] = {0};
-  int nr[4];
-  BUF_MEM *buf;
-  EVP_MD_CTX md_ctx;
-
-  if (s->state == SSL3_ST_SW_KEY_EXCH_C) {
-    return ssl_do_write(s);
+int ssl3_send_server_key_exchange(SSL *ssl) {
+  if (ssl->state == SSL3_ST_SW_KEY_EXCH_C) {
+    return ssl_do_write(ssl);
   }
 
-  EVP_MD_CTX_init(&md_ctx);
-
-  if (ssl_cipher_has_server_public_key(s->s3->tmp.new_cipher)) {
-    if (!ssl_has_private_key(s)) {
-      al = SSL_AD_INTERNAL_ERROR;
-      goto f_err;
-    }
-    max_sig_len = ssl_private_key_max_signature_len(s);
-  } else {
-    max_sig_len = 0;
+  CBB cbb, child;
+  if (!CBB_init_fixed(&cbb, ssl_handshake_start(ssl),
+                      ssl->init_buf->max - SSL_HM_HEADER_LENGTH(ssl))) {
+    goto err;
   }
 
-  enum ssl_private_key_result_t sign_result;
-  if (s->state == SSL3_ST_SW_KEY_EXCH_A) {
-    alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
-    alg_a = s->s3->tmp.new_cipher->algorithm_auth;
-    cert = s->cert;
+  if (ssl->state == SSL3_ST_SW_KEY_EXCH_A) {
+    /* This is the first iteration, so write parameters. */
+    uint32_t alg_k = ssl->s3->tmp.new_cipher->algorithm_mkey;
+    uint32_t alg_a = ssl->s3->tmp.new_cipher->algorithm_auth;
 
-    buf = s->init_buf;
-
-    r[0] = r[1] = r[2] = r[3] = NULL;
-    n = 0;
+    /* PSK ciphers begin with an identity hint. */
     if (alg_a & SSL_aPSK) {
-      /* size for PSK identity hint */
-      psk_identity_hint = s->psk_identity_hint;
-      if (psk_identity_hint) {
-        psk_identity_hint_len = strlen(psk_identity_hint);
-      } else {
-        psk_identity_hint_len = 0;
+      size_t len =
+          (ssl->psk_identity_hint == NULL) ? 0 : strlen(ssl->psk_identity_hint);
+      if (!CBB_add_u16_length_prefixed(&cbb, &child) ||
+          !CBB_add_bytes(&child, (const uint8_t *)ssl->psk_identity_hint,
+                         len)) {
+        goto err;
       }
-      n += 2 + psk_identity_hint_len;
     }
 
     if (alg_k & SSL_kDHE) {
-      dhp = cert->dh_tmp;
-      if (dhp == NULL && s->cert->dh_tmp_cb != NULL) {
-        dhp = s->cert->dh_tmp_cb(s, 0, 1024);
+      /* Determine the group to use. */
+      DH *params = ssl->cert->dh_tmp;
+      if (params == NULL && ssl->cert->dh_tmp_cb != NULL) {
+        params = ssl->cert->dh_tmp_cb(ssl, 0, 1024);
       }
-      if (dhp == NULL) {
-        al = SSL_AD_HANDSHAKE_FAILURE;
+      if (params == NULL) {
         OPENSSL_PUT_ERROR(SSL, SSL_R_MISSING_TMP_DH_KEY);
-        goto f_err;
-      }
-
-      if (s->s3->tmp.dh != NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+        ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
         goto err;
       }
-      dh = DHparams_dup(dhp);
-      if (dh == NULL) {
+
+      /* Generate and save a keypair. */
+      DH *dh = DHparams_dup(params);
+      if (dh == NULL || !DH_generate_key(dh)) {
+        DH_free(dh);
         OPENSSL_PUT_ERROR(SSL, ERR_R_DH_LIB);
         goto err;
       }
-      s->s3->tmp.dh = dh;
+      DH_free(ssl->s3->tmp.dh);
+      ssl->s3->tmp.dh = dh;
 
-      if (!DH_generate_key(dh)) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_DH_LIB);
+      if (!CBB_add_u16_length_prefixed(&cbb, &child) ||
+          !BN_bn2cbb_padded(&child, BN_num_bytes(dh->p), dh->p) ||
+          !CBB_add_u16_length_prefixed(&cbb, &child) ||
+          !BN_bn2cbb_padded(&child, BN_num_bytes(dh->g), dh->g) ||
+          !CBB_add_u16_length_prefixed(&cbb, &child) ||
+          /* Due to a bug in yaSSL, the public key must be zero padded to the
+           * size of the prime. */
+          !BN_bn2cbb_padded(&child, BN_num_bytes(dh->p), dh->pub_key)) {
         goto err;
       }
-
-      r[0] = dh->p;
-      r[1] = dh->g;
-      r[2] = dh->pub_key;
-      /* Due to a bug in yaSSL, the public key must be zero padded to the size
-       * of the prime. */
-      assert(BN_num_bytes(dh->pub_key) <= BN_num_bytes(dh->p));
-      r_pad_bytes[2] = BN_num_bytes(dh->p) - BN_num_bytes(dh->pub_key);
     } else if (alg_k & SSL_kECDHE) {
       /* Determine the curve to use. */
-      int nid = tls1_get_shared_curve(s);
+      int nid = tls1_get_shared_curve(ssl);
       if (nid == NID_undef) {
-        al = SSL_AD_HANDSHAKE_FAILURE;
         OPENSSL_PUT_ERROR(SSL, SSL_R_MISSING_TMP_ECDH_KEY);
-        goto f_err;
-      }
-
-      if (s->s3->tmp.ecdh != NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+        ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
         goto err;
       }
-      ecdh = EC_KEY_new_by_curve_name(nid);
-      if (ecdh == NULL) {
-        goto err;
-      }
-      s->s3->tmp.ecdh = ecdh;
-
-      if (!EC_KEY_generate_key(ecdh)) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_ECDH_LIB);
-        goto err;
-      }
-
       /* We only support ephemeral ECDH keys over named (not generic) curves. */
-      const EC_GROUP *group = EC_KEY_get0_group(ecdh);
-      if (!tls1_ec_nid2curve_id(&curve_id, EC_GROUP_get_curve_name(group))) {
+      uint16_t curve_id;
+      if (!tls1_ec_nid2curve_id(&curve_id, nid)) {
         OPENSSL_PUT_ERROR(SSL, SSL_R_UNSUPPORTED_ELLIPTIC_CURVE);
         goto err;
       }
 
-      /* Encode the public key. First check the size of encoding and allocate
-       * memory accordingly. */
-      encodedlen =
-          EC_POINT_point2oct(group, EC_KEY_get0_public_key(ecdh),
-                             POINT_CONVERSION_UNCOMPRESSED, NULL, 0, NULL);
-
-      encodedPoint = (uint8_t *)OPENSSL_malloc(encodedlen * sizeof(uint8_t));
-      if (encodedPoint == NULL) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+      /* Generate and save a keypair. */
+      EC_KEY *ecdh = EC_KEY_new_by_curve_name(nid);
+      if (ecdh == NULL || !EC_KEY_generate_key(ecdh)) {
+        EC_KEY_free(ecdh);
         goto err;
       }
+      EC_KEY_free(ssl->s3->tmp.ecdh);
+      ssl->s3->tmp.ecdh = ecdh;
 
-      encodedlen = EC_POINT_point2oct(group, EC_KEY_get0_public_key(ecdh),
-                                      POINT_CONVERSION_UNCOMPRESSED,
-                                      encodedPoint, encodedlen, NULL);
-
-      if (encodedlen == 0) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_ECDH_LIB);
+      const EC_GROUP *group = EC_KEY_get0_group(ecdh);
+      const EC_POINT *public_key = EC_KEY_get0_public_key(ecdh);
+      size_t point_len = EC_POINT_point2oct(
+          group, public_key, POINT_CONVERSION_UNCOMPRESSED, NULL, 0, NULL);
+      uint8_t *ptr;
+      if (point_len == 0 ||
+          !CBB_add_u8(&cbb, NAMED_CURVE_TYPE) ||
+          !CBB_add_u16(&cbb, curve_id) ||
+          !CBB_add_u8_length_prefixed(&cbb, &child) ||
+          !CBB_add_space(&child, &ptr, point_len) ||
+          EC_POINT_point2oct(group, public_key, POINT_CONVERSION_UNCOMPRESSED,
+                             ptr, point_len, NULL) != point_len) {
         goto err;
       }
-
-      /* We only support named (not generic) curves in ECDH ephemeral key
-       * exchanges. In this situation, we need four additional bytes to encode
-       * the entire ServerECDHParams structure. */
-      n += 4 + encodedlen;
-
-      /* We'll generate the serverKeyExchange message explicitly so we can set
-       * these to NULLs */
-      r[0] = NULL;
-      r[1] = NULL;
-      r[2] = NULL;
-      r[3] = NULL;
-    } else if (!(alg_k & SSL_kPSK)) {
-      al = SSL_AD_HANDSHAKE_FAILURE;
-      OPENSSL_PUT_ERROR(SSL, SSL_R_UNKNOWN_KEY_EXCHANGE_TYPE);
-      goto f_err;
+    } else {
+      assert(alg_k & SSL_kPSK);
     }
 
-    for (i = 0; i < 4 && r[i] != NULL; i++) {
-      nr[i] = BN_num_bytes(r[i]) + r_pad_bytes[i];
-      n += 2 + nr[i];
-    }
+    /* Otherwise, restore |cbb| from the previous iteration.
+     * TODO(davidben): When |ssl->init_buf| is gone, come up with a simpler
+     * pattern. Probably keep the |CBB| around in the handshake state. */
+  } else if (!CBB_did_write(&cbb, ssl->init_num - SSL_HM_HEADER_LENGTH(ssl))) {
+    goto err;
+  }
 
-    if (!BUF_MEM_grow_clean(buf, n + SSL_HM_HEADER_LENGTH(s) + max_sig_len)) {
-      OPENSSL_PUT_ERROR(SSL, ERR_LIB_BUF);
+  /* Add a signature. */
+  if (ssl_cipher_has_server_public_key(ssl->s3->tmp.new_cipher)) {
+    if (!ssl_has_private_key(ssl)) {
+      ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
       goto err;
     }
-    d = p = ssl_handshake_start(s);
 
-    for (i = 0; i < 4 && r[i] != NULL; i++) {
-      s2n(nr[i], p);
-      if (!BN_bn2bin_padded(p, nr[i], r[i])) {
-        OPENSSL_PUT_ERROR(SSL, ERR_LIB_BN);
+    const size_t max_sig_len = ssl_private_key_max_signature_len(ssl);
+    size_t sig_len;
+    enum ssl_private_key_result_t sign_result;
+    if (ssl->state == SSL3_ST_SW_KEY_EXCH_A) {
+      /* This is the first iteration, so set up the signature. Sample the
+       * parameter length before adding a signature algorithm. */
+      if (!CBB_flush(&cbb)) {
         goto err;
       }
-      p += nr[i];
-    }
-
-    /* Note: ECDHE PSK ciphersuites use SSL_kECDHE and SSL_aPSK. When one of
-     * them is used, the server key exchange record needs to have both the
-     * psk_identity_hint and the ServerECDHParams. */
-    if (alg_a & SSL_aPSK) {
-      /* copy PSK identity hint (if provided) */
-      s2n(psk_identity_hint_len, p);
-      if (psk_identity_hint_len > 0) {
-        memcpy(p, psk_identity_hint, psk_identity_hint_len);
-        p += psk_identity_hint_len;
-      }
-    }
-
-    if (alg_k & SSL_kECDHE) {
-      /* We only support named (not generic) curves. In this situation, the
-       * serverKeyExchange message has:
-       * [1 byte CurveType], [2 byte CurveName]
-       * [1 byte length of encoded point], followed by
-       * the actual encoded point itself. */
-      *(p++) = NAMED_CURVE_TYPE;
-      *(p++) = (uint8_t)(curve_id >> 8);
-      *(p++) = (uint8_t)(curve_id & 0xff);
-      *(p++) = encodedlen;
-      memcpy(p, encodedPoint, encodedlen);
-      p += encodedlen;
-      OPENSSL_free(encodedPoint);
-      encodedPoint = NULL;
-    }
-
-    if (ssl_cipher_has_server_public_key(s->s3->tmp.new_cipher)) {
-      /* n is the length of the params, they start at d and p points to
-       * the space at the end. */
-      const EVP_MD *md;
-      uint8_t digest[EVP_MAX_MD_SIZE];
-      unsigned int digest_length;
-
-      const int pkey_type = ssl_private_key_type(s);
+      size_t params_len = CBB_len(&cbb);
 
       /* Determine signature algorithm. */
-      if (SSL_USE_SIGALGS(s)) {
-        md = tls1_choose_signing_digest(s);
-        if (!tls12_get_sigandhash(s, p, md)) {
-          /* Should never happen */
-          al = SSL_AD_INTERNAL_ERROR;
+      const EVP_MD *md;
+      uint8_t *ptr;
+      if (SSL_USE_SIGALGS(ssl)) {
+        md = tls1_choose_signing_digest(ssl);
+        if (!CBB_add_space(&cbb, &ptr, 2) ||
+            !tls12_get_sigandhash(ssl, ptr, md)) {
           OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-          goto f_err;
+          ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
+          goto err;
         }
-        p += 2;
-      } else if (pkey_type == EVP_PKEY_RSA) {
+      } else if (ssl_private_key_type(ssl) == EVP_PKEY_RSA) {
         md = EVP_md5_sha1();
       } else {
         md = EVP_sha1();
       }
 
-      if (!EVP_DigestInit_ex(&md_ctx, md, NULL) ||
-          !EVP_DigestUpdate(&md_ctx, s->s3->client_random, SSL3_RANDOM_SIZE) ||
-          !EVP_DigestUpdate(&md_ctx, s->s3->server_random, SSL3_RANDOM_SIZE) ||
-          !EVP_DigestUpdate(&md_ctx, d, n) ||
-          !EVP_DigestFinal_ex(&md_ctx, digest, &digest_length)) {
-        OPENSSL_PUT_ERROR(SSL, ERR_LIB_EVP);
+      /* Compute the digest and sign it. */
+      uint8_t digest[EVP_MAX_MD_SIZE];
+      unsigned digest_len;
+      EVP_MD_CTX md_ctx;
+      EVP_MD_CTX_init(&md_ctx);
+      int digest_ret =
+          EVP_DigestInit_ex(&md_ctx, md, NULL) &&
+          EVP_DigestUpdate(&md_ctx, ssl->s3->client_random, SSL3_RANDOM_SIZE) &&
+          EVP_DigestUpdate(&md_ctx, ssl->s3->server_random, SSL3_RANDOM_SIZE) &&
+          EVP_DigestUpdate(&md_ctx, CBB_data(&cbb), params_len) &&
+          EVP_DigestFinal_ex(&md_ctx, digest, &digest_len);
+      EVP_MD_CTX_cleanup(&md_ctx);
+      if (!digest_ret ||
+          !CBB_add_u16_length_prefixed(&cbb, &child) ||
+          !CBB_reserve(&child, &ptr, max_sig_len)) {
         goto err;
       }
-
-      sign_result = ssl_private_key_sign(s, &p[2], &sig_len, max_sig_len,
-                                         EVP_MD_CTX_md(&md_ctx), digest,
-                                         digest_length);
+      sign_result = ssl_private_key_sign(ssl, ptr, &sig_len, max_sig_len, md,
+                                         digest, digest_len);
     } else {
-      /* This key exchange doesn't involve a signature. */
-      sign_result = ssl_private_key_success;
-      sig_len = 0;
+      assert(ssl->state == SSL3_ST_SW_KEY_EXCH_B);
+
+      /* Retry the signature. */
+      uint8_t *ptr;
+      if (!CBB_add_u16_length_prefixed(&cbb, &child) ||
+          !CBB_reserve(&child, &ptr, max_sig_len)) {
+        goto err;
+      }
+      sign_result =
+          ssl_private_key_sign_complete(ssl, ptr, &sig_len, max_sig_len);
     }
-  } else {
-    assert(s->state == SSL3_ST_SW_KEY_EXCH_B);
-    /* Restore |p|. */
-    p = ssl_handshake_start(s) + s->init_num - SSL_HM_HEADER_LENGTH(s);
-    sign_result = ssl_private_key_sign_complete(s, &p[2], &sig_len,
-                                                max_sig_len);
+
+    switch (sign_result) {
+      case ssl_private_key_success:
+        ssl->rwstate = SSL_NOTHING;
+        if (!CBB_did_write(&child, sig_len)) {
+          goto err;
+        }
+        break;
+      case ssl_private_key_failure:
+        ssl->rwstate = SSL_NOTHING;
+        goto err;
+      case ssl_private_key_retry:
+        /* Discard the unfinished signature and save the state of |cbb| for the
+         * next iteration. */
+        CBB_discard_child(&cbb);
+        ssl->init_num = SSL_HM_HEADER_LENGTH(ssl) + CBB_len(&cbb);
+        ssl->rwstate = SSL_PRIVATE_KEY_OPERATION;
+        ssl->state = SSL3_ST_SW_KEY_EXCH_B;
+        goto err;
+    }
   }
 
-  switch (sign_result) {
-    case ssl_private_key_success:
-      s->rwstate = SSL_NOTHING;
-      break;
-    case ssl_private_key_failure:
-      s->rwstate = SSL_NOTHING;
-      goto err;
-    case ssl_private_key_retry:
-      s->rwstate = SSL_PRIVATE_KEY_OPERATION;
-      /* Stash away |p|. */
-      s->init_num = p - ssl_handshake_start(s) + SSL_HM_HEADER_LENGTH(s);
-      s->state = SSL3_ST_SW_KEY_EXCH_B;
-      goto err;
-  }
-
-  if (ssl_cipher_has_server_public_key(s->s3->tmp.new_cipher)) {
-    s2n(sig_len, p);
-    p += sig_len;
-  }
-  if (!ssl_set_handshake_header(s, SSL3_MT_SERVER_KEY_EXCHANGE,
-                                p - ssl_handshake_start(s))) {
+  size_t length;
+  if (!CBB_finish(&cbb, NULL, &length) ||
+      !ssl_set_handshake_header(ssl, SSL3_MT_SERVER_KEY_EXCHANGE, length)) {
     goto err;
   }
-  s->state = SSL3_ST_SW_KEY_EXCH_C;
+  ssl->state = SSL3_ST_SW_KEY_EXCH_C;
+  return ssl_do_write(ssl);
 
-  EVP_MD_CTX_cleanup(&md_ctx);
-  return ssl_do_write(s);
-
-f_err:
-  ssl3_send_alert(s, SSL3_AL_FATAL, al);
 err:
-  OPENSSL_free(encodedPoint);
-  EVP_MD_CTX_cleanup(&md_ctx);
+  CBB_cleanup(&cbb);
   return -1;
 }