Route DHE through the SSL_ECDH abstraction as well.

This unifies the ClientKeyExchange code rather nicely. ServerKeyExchange
is still pretty specialized. For simplicity, I've extended the yaSSL bug
workaround for clients as well as servers rather than route in a
boolean.

Chrome's already banished DHE to a fallback with intention to remove
altogether later, and the spec doesn't say anything useful about
ClientDiffieHellmanPublic encoding, so this is unlikely to cause
problems.

Change-Id: I0355cd1fd0fab5729e8812e4427dd689124f53a2
Reviewed-on: https://boringssl-review.googlesource.com/6784
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 681066c..af48419 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -3997,7 +3997,6 @@
 
     /* used to hold the new cipher we are going to use */
     const SSL_CIPHER *new_cipher;
-    DH *dh;
 
     /* used when SSL_ST_FLUSH_DATA is entered */
     int next_state;
@@ -4098,15 +4097,12 @@
      * |TLSEXT_hash_none|. */
     uint8_t server_key_exchange_hash;
 
-    /* peer_dh_tmp, on a client, is the server's DHE public key. */
-    DH *peer_dh_tmp;
-
     /* ecdh_ctx is the current ECDH instance. */
     SSL_ECDH_CTX ecdh_ctx;
 
     /* peer_key is the peer's ECDH key. */
     uint8_t *peer_key;
-    uint8_t peer_key_len;
+    uint16_t peer_key_len;
   } tmp;
 
   /* Connection binding to prevent renegotiation attacks */
diff --git a/ssl/internal.h b/ssl/internal.h
index 7741527..2f34907 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -554,6 +554,10 @@
  * on success and zero on error. */
 int SSL_ECDH_CTX_init(SSL_ECDH_CTX *ctx, uint16_t curve_id);
 
+/* SSL_ECDH_CTX_init_for_dhe sets up |ctx| for use with legacy DHE-based ciphers
+ * where the server specifies a group. It takes ownership of |params|. */
+void SSL_ECDH_CTX_init_for_dhe(SSL_ECDH_CTX *ctx, DH *params);
+
 /* SSL_ECDH_CTX_cleanup releases memory associated with |ctx|. It is legal to
  * call it in the zero state. */
 void SSL_ECDH_CTX_cleanup(SSL_ECDH_CTX *ctx);
diff --git a/ssl/s3_clnt.c b/ssl/s3_clnt.c
index 6d5e3b1..792af1f 100644
--- a/ssl/s3_clnt.c
+++ b/ssl/s3_clnt.c
@@ -1137,7 +1137,6 @@
 
   if (alg_k & SSL_kDHE) {
     CBS dh_p, dh_g, dh_Ys;
-
     if (!CBS_get_u16_length_prefixed(&server_key_exchange, &dh_p) ||
         CBS_len(&dh_p) == 0 ||
         !CBS_get_u16_length_prefixed(&server_key_exchange, &dh_g) ||
@@ -1151,15 +1150,12 @@
 
     dh = DH_new();
     if (dh == NULL) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_DH_LIB);
       goto err;
     }
 
     dh->p = BN_bin2bn(CBS_data(&dh_p), CBS_len(&dh_p), NULL);
     dh->g = BN_bin2bn(CBS_data(&dh_g), CBS_len(&dh_g), NULL);
-    dh->pub_key = BN_bin2bn(CBS_data(&dh_Ys), CBS_len(&dh_Ys), NULL);
-    if (dh->p == NULL || dh->g == NULL || dh->pub_key == NULL) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_BN_LIB);
+    if (dh->p == NULL || dh->g == NULL) {
       goto err;
     }
 
@@ -1167,17 +1163,25 @@
     if (s->session->key_exchange_info < 1024) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_DH_P_LENGTH);
       goto err;
-    }
-    if (s->session->key_exchange_info > 4096) {
+    } else if (s->session->key_exchange_info > 4096) {
       /* Overly large DHE groups are prohibitively expensive, so enforce a limit
        * to prevent a server from causing us to perform too expensive of a
        * computation. */
       OPENSSL_PUT_ERROR(SSL, SSL_R_DH_P_TOO_LONG);
       goto err;
     }
-    DH_free(s->s3->tmp.peer_dh_tmp);
-    s->s3->tmp.peer_dh_tmp = dh;
+
+    SSL_ECDH_CTX_init_for_dhe(&s->s3->tmp.ecdh_ctx, dh);
     dh = NULL;
+
+    /* Save the peer public key for later. */
+    size_t peer_key_len;
+    if (!CBS_stow(&dh_Ys, &s->s3->tmp.peer_key, &peer_key_len)) {
+      goto err;
+    }
+    /* |dh_Ys| has a u16 length prefix, so this fits in a |uint16_t|. */
+    assert(sizeof(s->s3->tmp.peer_key_len) == 2 && peer_key_len <= 0xffff);
+    s->s3->tmp.peer_key_len = (uint16_t)peer_key_len;
   } else if (alg_k & SSL_kECDHE) {
     /* Parse the server parameters. */
     uint8_t curve_type;
@@ -1206,9 +1210,9 @@
         !CBS_stow(&point, &s->s3->tmp.peer_key, &peer_key_len)) {
       goto err;
     }
-    /* |point| has a u8 length prefix, so this fits in a |uint8_t|. */
-    assert(peer_key_len <= 0xff);
-    s->s3->tmp.peer_key_len = (uint8_t)peer_key_len;
+    /* |point| has a u8 length prefix, so this fits in a |uint16_t|. */
+    assert(sizeof(s->s3->tmp.peer_key_len) == 2 && peer_key_len <= 0xffff);
+    s->s3->tmp.peer_key_len = (uint16_t)peer_key_len;
   } else if (!(alg_k & SSL_kPSK)) {
     al = SSL_AD_UNEXPECTED_MESSAGE;
     OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
@@ -1670,51 +1674,18 @@
         !CBB_flush(&cbb)) {
       goto err;
     }
-  } else if (alg_k & SSL_kDHE) {
-    if (ssl->s3->tmp.peer_dh_tmp == NULL) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-      goto err;
-    }
-    DH *peer_dh = ssl->s3->tmp.peer_dh_tmp;
-
-    /* Generate a keypair. */
-    DH *dh = DHparams_dup(peer_dh);
-    if (dh == NULL || !DH_generate_key(dh)) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_DH_LIB);
-      DH_free(dh);
-      goto err;
-    }
-
-    pms_len = DH_size(dh);
-    pms = OPENSSL_malloc(pms_len);
-    if (pms == NULL) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-      DH_free(dh);
-      goto err;
-    }
-
-    int dh_len = DH_compute_key(pms, peer_dh->pub_key, dh);
-    if (dh_len <= 0) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_DH_LIB);
-      DH_free(dh);
-      goto err;
-    }
-    pms_len = dh_len;
-
-    /* Write the public key. */
+  } else if (alg_k & (SSL_kECDHE|SSL_kDHE)) {
+    /* Generate a keypair and serialize the public half. ECDHE uses a u8 length
+     * prefix while DHE uses u16. */
     CBB child;
-    if (!CBB_add_u16_length_prefixed(&cbb, &child) ||
-        !BN_bn2cbb_padded(&child, BN_num_bytes(dh->pub_key), dh->pub_key) ||
-        !CBB_flush(&cbb)) {
-      DH_free(dh);
-      goto err;
+    int child_ok;
+    if (alg_k & SSL_kECDHE) {
+      child_ok = CBB_add_u8_length_prefixed(&cbb, &child);
+    } else {
+      child_ok = CBB_add_u16_length_prefixed(&cbb, &child);
     }
 
-    DH_free(dh);
-  } else if (alg_k & SSL_kECDHE) {
-    /* Generate a keypair and serialize the public half. */
-    CBB child;
-    if (!CBB_add_u8_length_prefixed(&cbb, &child) ||
+    if (!child_ok ||
         !SSL_ECDH_CTX_generate_keypair(&ssl->s3->tmp.ecdh_ctx, &child) ||
         !CBB_flush(&cbb)) {
       goto err;
diff --git a/ssl/s3_lib.c b/ssl/s3_lib.c
index e8f7cc2..680d95e 100644
--- a/ssl/s3_lib.c
+++ b/ssl/s3_lib.c
@@ -228,7 +228,6 @@
   ssl3_cleanup_key_block(s);
   ssl_read_buffer_clear(s);
   ssl_write_buffer_clear(s);
-  DH_free(s->s3->tmp.dh);
   SSL_ECDH_CTX_cleanup(&s->s3->tmp.ecdh_ctx);
   OPENSSL_free(s->s3->tmp.peer_key);
 
@@ -236,7 +235,6 @@
   OPENSSL_free(s->s3->tmp.certificate_types);
   OPENSSL_free(s->s3->tmp.peer_ellipticcurvelist);
   OPENSSL_free(s->s3->tmp.peer_psk_identity_hint);
-  DH_free(s->s3->tmp.peer_dh_tmp);
   ssl3_free_handshake_buffer(s);
   ssl3_free_handshake_hash(s);
   OPENSSL_free(s->s3->alpn_selected);
diff --git a/ssl/s3_srvr.c b/ssl/s3_srvr.c
index ef2c396..08856c7 100644
--- a/ssl/s3_srvr.c
+++ b/ssl/s3_srvr.c
@@ -1239,24 +1239,19 @@
       }
       ssl->session->key_exchange_info = DH_num_bits(params);
 
-      /* Generate and save a keypair. */
+      /* Set up DH, generate a key, and emit the public half. */
       DH *dh = DHparams_dup(params);
-      if (dh == NULL || !DH_generate_key(dh)) {
-        DH_free(dh);
-        OPENSSL_PUT_ERROR(SSL, ERR_R_DH_LIB);
+      if (dh == NULL) {
         goto err;
       }
-      DH_free(ssl->s3->tmp.dh);
-      ssl->s3->tmp.dh = dh;
 
+      SSL_ECDH_CTX_init_for_dhe(&ssl->s3->tmp.ecdh_ctx, dh);
       if (!CBB_add_u16_length_prefixed(&cbb, &child) ||
-          !BN_bn2cbb_padded(&child, BN_num_bytes(dh->p), dh->p) ||
+          !BN_bn2cbb_padded(&child, BN_num_bytes(params->p), params->p) ||
           !CBB_add_u16_length_prefixed(&cbb, &child) ||
-          !BN_bn2cbb_padded(&child, BN_num_bytes(dh->g), dh->g) ||
+          !BN_bn2cbb_padded(&child, BN_num_bytes(params->g), params->g) ||
           !CBB_add_u16_length_prefixed(&cbb, &child) ||
-          /* Due to a bug in yaSSL, the public key must be zero padded to the
-           * size of the prime. */
-          !BN_bn2cbb_padded(&child, BN_num_bytes(dh->p), dh->pub_key)) {
+          !SSL_ECDH_CTX_generate_keypair(&ssl->s3->tmp.ecdh_ctx, &child)) {
         goto err;
       }
     } else if (alg_k & SSL_kECDHE) {
@@ -1464,8 +1459,6 @@
   uint8_t *premaster_secret = NULL;
   size_t premaster_secret_len = 0;
   uint8_t *decrypt_buf = NULL;
-  BIGNUM *pub = NULL;
-  DH *dh_srvr;
 
   unsigned psk_len = 0;
   uint8_t psk[PSK_MAX_PSK_LEN];
@@ -1639,56 +1632,19 @@
 
     OPENSSL_free(decrypt_buf);
     decrypt_buf = NULL;
-  } else if (alg_k & SSL_kDHE) {
-    CBS dh_Yc;
-    int dh_len;
-
-    if (!CBS_get_u16_length_prefixed(&client_key_exchange, &dh_Yc) ||
-        CBS_len(&dh_Yc) == 0 || CBS_len(&client_key_exchange) != 0) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_DH_PUBLIC_VALUE_LENGTH_IS_WRONG);
-      al = SSL_R_DECODE_ERROR;
-      goto f_err;
+  } else if (alg_k & (SSL_kECDHE|SSL_kDHE)) {
+    /* Parse the ClientKeyExchange. ECDHE uses a u8 length prefix while DHE uses
+     * u16. */
+    CBS peer_key;
+    int peer_key_ok;
+    if (alg_k & SSL_kECDHE) {
+      peer_key_ok = CBS_get_u8_length_prefixed(&client_key_exchange, &peer_key);
+    } else {
+      peer_key_ok =
+          CBS_get_u16_length_prefixed(&client_key_exchange, &peer_key);
     }
 
-    if (s->s3->tmp.dh == NULL) {
-      al = SSL_AD_HANDSHAKE_FAILURE;
-      OPENSSL_PUT_ERROR(SSL, SSL_R_MISSING_TMP_DH_KEY);
-      goto f_err;
-    }
-    dh_srvr = s->s3->tmp.dh;
-
-    pub = BN_bin2bn(CBS_data(&dh_Yc), CBS_len(&dh_Yc), NULL);
-    if (pub == NULL) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_BN_LIB);
-      goto err;
-    }
-
-    /* Allocate a buffer for the premaster secret. */
-    premaster_secret = OPENSSL_malloc(DH_size(dh_srvr));
-    if (premaster_secret == NULL) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-      BN_clear_free(pub);
-      goto err;
-    }
-
-    dh_len = DH_compute_key(premaster_secret, pub, dh_srvr);
-    if (dh_len <= 0) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_DH_LIB);
-      BN_clear_free(pub);
-      goto err;
-    }
-
-    DH_free(s->s3->tmp.dh);
-    s->s3->tmp.dh = NULL;
-    BN_clear_free(pub);
-    pub = NULL;
-
-    premaster_secret_len = dh_len;
-  } else if (alg_k & SSL_kECDHE) {
-    /* Parse the ClientKeyExchange. */
-    CBS ecdh_Yc;
-    if (!CBS_get_u8_length_prefixed(&client_key_exchange, &ecdh_Yc) ||
-        CBS_len(&client_key_exchange) != 0) {
+    if (!peer_key_ok || CBS_len(&client_key_exchange) != 0) {
       al = SSL_AD_DECODE_ERROR;
       OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
       goto f_err;
@@ -1698,7 +1654,7 @@
     uint8_t alert;
     if (!SSL_ECDH_CTX_compute_secret(&s->s3->tmp.ecdh_ctx, &premaster_secret,
                                      &premaster_secret_len, &alert,
-                                     CBS_data(&ecdh_Yc), CBS_len(&ecdh_Yc))) {
+                                     CBS_data(&peer_key), CBS_len(&peer_key))) {
       al = alert;
       goto f_err;
     }
diff --git a/ssl/ssl_ecdh.c b/ssl/ssl_ecdh.c
index 3328f5a..45c5b26 100644
--- a/ssl/ssl_ecdh.c
+++ b/ssl/ssl_ecdh.c
@@ -206,6 +206,71 @@
 }
 
 
+/* Legacy DHE-based implementation. */
+
+static void ssl_dhe_cleanup(SSL_ECDH_CTX *ctx) {
+  DH_free((DH *)ctx->data);
+}
+
+static int ssl_dhe_generate_keypair(SSL_ECDH_CTX *ctx, CBB *out) {
+  DH *dh = (DH *)ctx->data;
+  /* The group must have been initialized already, but not the key. */
+  assert(dh != NULL);
+  assert(dh->priv_key == NULL);
+
+  /* Due to a bug in yaSSL, the public key must be zero padded to the size of
+   * the prime. */
+  return DH_generate_key(dh) &&
+         BN_bn2cbb_padded(out, BN_num_bytes(dh->p), dh->pub_key);
+}
+
+static int ssl_dhe_compute_secret(SSL_ECDH_CTX *ctx, uint8_t **out_secret,
+                                  size_t *out_secret_len, uint8_t *out_alert,
+                                  const uint8_t *peer_key,
+                                  size_t peer_key_len) {
+  DH *dh = (DH *)ctx->data;
+  assert(dh != NULL);
+  assert(dh->priv_key != NULL);
+  *out_alert = SSL_AD_INTERNAL_ERROR;
+
+  int secret_len = 0;
+  uint8_t *secret = NULL;
+  BIGNUM *peer_point = BN_bin2bn(peer_key, peer_key_len, NULL);
+  if (peer_point == NULL) {
+    goto err;
+  }
+
+  secret = OPENSSL_malloc(DH_size(dh));
+  if (secret == NULL) {
+    goto err;
+  }
+  secret_len = DH_compute_key(secret, peer_point, dh);
+  if (secret_len <= 0) {
+    goto err;
+  }
+
+  *out_secret = secret;
+  *out_secret_len = (size_t)secret_len;
+  BN_free(peer_point);
+  return 1;
+
+err:
+  if (secret_len > 0) {
+    OPENSSL_cleanse(secret, (size_t)secret_len);
+  }
+  OPENSSL_free(secret);
+  BN_free(peer_point);
+  return 0;
+}
+
+static const SSL_ECDH_METHOD kDHEMethod = {
+    NID_undef, 0, "",
+    ssl_dhe_cleanup,
+    ssl_dhe_generate_keypair,
+    ssl_dhe_compute_secret,
+};
+
+
 static const SSL_ECDH_METHOD kMethods[] = {
     {
         NID_X9_62_prime256v1,
@@ -290,6 +355,13 @@
   return 1;
 }
 
+void SSL_ECDH_CTX_init_for_dhe(SSL_ECDH_CTX *ctx, DH *params) {
+  SSL_ECDH_CTX_cleanup(ctx);
+
+  ctx->method = &kDHEMethod;
+  ctx->data = params;
+}
+
 void SSL_ECDH_CTX_cleanup(SSL_ECDH_CTX *ctx) {
   if (ctx->method == NULL) {
     return;