Refactor ECDH key exchange to make it asymmetrical

Previously, SSL_ECDH_METHOD consisted of two methods: one to produce a
public key to be sent to the peer, and another to produce the shared key
upon receipt of the peer's message.

This API does not work for NEWHOPE, because the client-to-server message
cannot be produced until the server's message has been received by the
client.

Solve this by introducing a new method which consumes data from the
server key exchange message and produces data for the client key
exchange message.

Change-Id: I1ed5a2bf198ca2d2ddb6d577888c1fa2008ef99a
Reviewed-on: https://boringssl-review.googlesource.com/7961
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/internal.h b/ssl/internal.h
index 99d7f0f..cb16b34 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -544,18 +544,29 @@
   /* cleanup releases state in |ctx|. */
   void (*cleanup)(SSL_ECDH_CTX *ctx);
 
-  /* generate_keypair generates a keypair and writes the public value to
+  /* offer generates a keypair and writes the public value to
    * |out_public_key|. It returns one on success and zero on error. */
-  int (*generate_keypair)(SSL_ECDH_CTX *ctx, CBB *out_public_key);
+  int (*offer)(SSL_ECDH_CTX *ctx, CBB *out_public_key);
 
-  /* compute_secret performs a key exchange against |peer_key| and, on
-   * success, returns one and sets |*out_secret| and |*out_secret_len| to
-   * a newly-allocated buffer containing the shared secret. The caller must
-   * release this buffer with |OPENSSL_free|. Otherwise, it returns zero and
-   * sets |*out_alert| to an alert to send to the peer. */
-  int (*compute_secret)(SSL_ECDH_CTX *ctx, uint8_t **out_secret,
-                        size_t *out_secret_len, uint8_t *out_alert,
-                        const uint8_t *peer_key, size_t peer_key_len);
+  /* accept performs a key exchange against the |peer_key| generated by |offer|.
+   * On success, it returns one, writes the public value to |out_public_key|,
+   * and sets |*out_secret| and |*out_secret_len| to a newly-allocated buffer
+   * containing the shared secret. The caller must release this buffer with
+   * |OPENSSL_free|. On failure, it returns zero and sets |*out_alert| to an
+   * alert to send to the peer. */
+  int (*accept)(SSL_ECDH_CTX *ctx, CBB *out_public_key, uint8_t **out_secret,
+                size_t *out_secret_len, uint8_t *out_alert,
+                const uint8_t *peer_key, size_t peer_key_len);
+
+  /* finish performs a key exchange against the |peer_key| generated by
+   * |accept|. On success, it returns one and sets |*out_secret| and
+   * |*out_secret_len| to a newly-allocated buffer containing the shared
+   * secret. The caller must release this buffer with |OPENSSL_free|. On
+   * failure, it returns zero and sets |*out_alert| to an alert to send to the
+   * peer. */
+  int (*finish)(SSL_ECDH_CTX *ctx, uint8_t **out_secret, size_t *out_secret_len,
+                uint8_t *out_alert, const uint8_t *peer_key,
+                size_t peer_key_len);
 } /* SSL_ECDH_METHOD */;
 
 /* ssl_nid_to_curve_id looks up the curve corresponding to |nid|. On success, it
@@ -575,13 +586,19 @@
  * call it in the zero state. */
 void SSL_ECDH_CTX_cleanup(SSL_ECDH_CTX *ctx);
 
-/* The following functions call the corresponding method of
- * |SSL_ECDH_METHOD|. */
-int SSL_ECDH_CTX_generate_keypair(SSL_ECDH_CTX *ctx, CBB *out_public_key);
-int SSL_ECDH_CTX_compute_secret(SSL_ECDH_CTX *ctx, uint8_t **out_secret,
-                                size_t *out_secret_len, uint8_t *out_alert,
-                                const uint8_t *peer_key, size_t peer_key_len);
+/* SSL_ECDH_CTX_offer calls the |offer| method of |SSL_ECDH_METHOD|. */
+int SSL_ECDH_CTX_offer(SSL_ECDH_CTX *ctx, CBB *out_public_key);
 
+/* SSL_ECDH_CTX_accept calls the |accept| method of |SSL_ECDH_METHOD|. */
+int SSL_ECDH_CTX_accept(SSL_ECDH_CTX *ctx, CBB *out_public_key,
+                        uint8_t **out_secret, size_t *out_secret_len,
+                        uint8_t *out_alert, const uint8_t *peer_key,
+                        size_t peer_key_len);
+
+/* SSL_ECDH_CTX_finish the |finish| method of |SSL_ECDH_METHOD|. */
+int SSL_ECDH_CTX_finish(SSL_ECDH_CTX *ctx, uint8_t **out_secret,
+                        size_t *out_secret_len, uint8_t *out_alert,
+                        const uint8_t *peer_key, size_t peer_key_len);
 
 /* Handshake messages. */
 
diff --git a/ssl/s3_clnt.c b/ssl/s3_clnt.c
index 8e79c81..a34659f 100644
--- a/ssl/s3_clnt.c
+++ b/ssl/s3_clnt.c
@@ -1660,20 +1660,21 @@
       child_ok = CBB_add_u16_length_prefixed(&cbb, &child);
     }
 
-    if (!child_ok ||
-        !SSL_ECDH_CTX_generate_keypair(&ssl->s3->tmp.ecdh_ctx, &child) ||
-        !CBB_flush(&cbb)) {
+    if (!child_ok) {
       goto err;
     }
 
     /* Compute the premaster. */
     uint8_t alert;
-    if (!SSL_ECDH_CTX_compute_secret(&ssl->s3->tmp.ecdh_ctx, &pms, &pms_len,
-                                     &alert, ssl->s3->tmp.peer_key,
-                                     ssl->s3->tmp.peer_key_len)) {
+    if (!SSL_ECDH_CTX_accept(&ssl->s3->tmp.ecdh_ctx, &child, &pms, &pms_len,
+                             &alert, ssl->s3->tmp.peer_key,
+                             ssl->s3->tmp.peer_key_len)) {
       ssl3_send_alert(ssl, SSL3_AL_FATAL, alert);
       goto err;
     }
+    if (!CBB_flush(&cbb)) {
+      goto err;
+    }
 
     /* The key exchange state may now be discarded. */
     SSL_ECDH_CTX_cleanup(&ssl->s3->tmp.ecdh_ctx);
diff --git a/ssl/s3_srvr.c b/ssl/s3_srvr.c
index 4b14d65..a13d0e6 100644
--- a/ssl/s3_srvr.c
+++ b/ssl/s3_srvr.c
@@ -1237,7 +1237,7 @@
           !CBB_add_u16_length_prefixed(&cbb, &child) ||
           !BN_bn2cbb_padded(&child, BN_num_bytes(params->g), params->g) ||
           !CBB_add_u16_length_prefixed(&cbb, &child) ||
-          !SSL_ECDH_CTX_generate_keypair(&ssl->s3->tmp.ecdh_ctx, &child)) {
+          !SSL_ECDH_CTX_offer(&ssl->s3->tmp.ecdh_ctx, &child)) {
         goto err;
       }
     } else if (alg_k & SSL_kECDHE) {
@@ -1255,7 +1255,7 @@
           !CBB_add_u8(&cbb, NAMED_CURVE_TYPE) ||
           !CBB_add_u16(&cbb, curve_id) ||
           !CBB_add_u8_length_prefixed(&cbb, &child) ||
-          !SSL_ECDH_CTX_generate_keypair(&ssl->s3->tmp.ecdh_ctx, &child)) {
+          !SSL_ECDH_CTX_offer(&ssl->s3->tmp.ecdh_ctx, &child)) {
         goto err;
       }
     } else {
@@ -1639,9 +1639,9 @@
 
     /* Compute the premaster. */
     uint8_t alert;
-    if (!SSL_ECDH_CTX_compute_secret(&ssl->s3->tmp.ecdh_ctx, &premaster_secret,
-                                     &premaster_secret_len, &alert,
-                                     CBS_data(&peer_key), CBS_len(&peer_key))) {
+    if (!SSL_ECDH_CTX_finish(&ssl->s3->tmp.ecdh_ctx, &premaster_secret,
+                             &premaster_secret_len, &alert, CBS_data(&peer_key),
+                             CBS_len(&peer_key))) {
       al = alert;
       goto f_err;
     }
diff --git a/ssl/ssl_ecdh.c b/ssl/ssl_ecdh.c
index d48c93f..305a5af 100644
--- a/ssl/ssl_ecdh.c
+++ b/ssl/ssl_ecdh.c
@@ -35,7 +35,7 @@
   BN_clear_free(private_key);
 }
 
-static int ssl_ec_point_generate_keypair(SSL_ECDH_CTX *ctx, CBB *out) {
+static int ssl_ec_point_offer(SSL_ECDH_CTX *ctx, CBB *out) {
   assert(ctx->data == NULL);
   BIGNUM *private_key = BN_new();
   if (private_key == NULL) {
@@ -84,12 +84,9 @@
   return ret;
 }
 
-static int ssl_ec_point_compute_secret(SSL_ECDH_CTX *ctx,
-                                       uint8_t **out_secret,
-                                       size_t *out_secret_len,
-                                       uint8_t *out_alert,
-                                       const uint8_t *peer_key,
-                                       size_t peer_key_len) {
+static int ssl_ec_point_finish(SSL_ECDH_CTX *ctx, uint8_t **out_secret,
+                               size_t *out_secret_len, uint8_t *out_alert,
+                               const uint8_t *peer_key, size_t peer_key_len) {
   BIGNUM *private_key = (BIGNUM *)ctx->data;
   assert(private_key != NULL);
   *out_alert = SSL_AD_INTERNAL_ERROR;
@@ -150,6 +147,18 @@
   return ret;
 }
 
+static int ssl_ec_point_accept(SSL_ECDH_CTX *ctx, CBB *out_public_key,
+                               uint8_t **out_secret, size_t *out_secret_len,
+                               uint8_t *out_alert, const uint8_t *peer_key,
+                               size_t peer_key_len) {
+  *out_alert = SSL_AD_INTERNAL_ERROR;
+  if (!ssl_ec_point_offer(ctx, out_public_key) ||
+      !ssl_ec_point_finish(ctx, out_secret, out_secret_len, out_alert, peer_key,
+                           peer_key_len)) {
+    return 0;
+  }
+  return 1;
+}
 
 /* X25119 implementation. */
 
@@ -161,7 +170,7 @@
   OPENSSL_free(ctx->data);
 }
 
-static int ssl_x25519_generate_keypair(SSL_ECDH_CTX *ctx, CBB *out) {
+static int ssl_x25519_offer(SSL_ECDH_CTX *ctx, CBB *out) {
   assert(ctx->data == NULL);
 
   ctx->data = OPENSSL_malloc(32);
@@ -174,10 +183,9 @@
   return CBB_add_bytes(out, public_key, sizeof(public_key));
 }
 
-static int ssl_x25519_compute_secret(SSL_ECDH_CTX *ctx, uint8_t **out_secret,
-                                     size_t *out_secret_len, uint8_t *out_alert,
-                                     const uint8_t *peer_key,
-                                     size_t peer_key_len) {
+static int ssl_x25519_finish(SSL_ECDH_CTX *ctx, uint8_t **out_secret,
+                             size_t *out_secret_len, uint8_t *out_alert,
+                             const uint8_t *peer_key, size_t peer_key_len) {
   assert(ctx->data != NULL);
   *out_alert = SSL_AD_INTERNAL_ERROR;
 
@@ -199,6 +207,18 @@
   return 1;
 }
 
+static int ssl_x25519_accept(SSL_ECDH_CTX *ctx, CBB *out_public_key,
+                             uint8_t **out_secret, size_t *out_secret_len,
+                             uint8_t *out_alert, const uint8_t *peer_key,
+                             size_t peer_key_len) {
+  *out_alert = SSL_AD_INTERNAL_ERROR;
+  if (!ssl_x25519_offer(ctx, out_public_key) ||
+      !ssl_x25519_finish(ctx, out_secret, out_secret_len, out_alert, peer_key,
+                         peer_key_len)) {
+    return 0;
+  }
+  return 1;
+}
 
 /* Legacy DHE-based implementation. */
 
@@ -206,7 +226,7 @@
   DH_free((DH *)ctx->data);
 }
 
-static int ssl_dhe_generate_keypair(SSL_ECDH_CTX *ctx, CBB *out) {
+static int ssl_dhe_offer(SSL_ECDH_CTX *ctx, CBB *out) {
   DH *dh = (DH *)ctx->data;
   /* The group must have been initialized already, but not the key. */
   assert(dh != NULL);
@@ -218,10 +238,9 @@
          BN_bn2cbb_padded(out, BN_num_bytes(dh->p), dh->pub_key);
 }
 
-static int ssl_dhe_compute_secret(SSL_ECDH_CTX *ctx, uint8_t **out_secret,
-                                  size_t *out_secret_len, uint8_t *out_alert,
-                                  const uint8_t *peer_key,
-                                  size_t peer_key_len) {
+static int ssl_dhe_finish(SSL_ECDH_CTX *ctx, uint8_t **out_secret,
+                          size_t *out_secret_len, uint8_t *out_alert,
+                          const uint8_t *peer_key, size_t peer_key_len) {
   DH *dh = (DH *)ctx->data;
   assert(dh != NULL);
   assert(dh->priv_key != NULL);
@@ -257,46 +276,63 @@
   return 0;
 }
 
+static int ssl_dhe_accept(SSL_ECDH_CTX *ctx, CBB *out_public_key,
+                          uint8_t **out_secret, size_t *out_secret_len,
+                          uint8_t *out_alert, const uint8_t *peer_key,
+                          size_t peer_key_len) {
+  *out_alert = SSL_AD_INTERNAL_ERROR;
+  if (!ssl_dhe_offer(ctx, out_public_key) ||
+      !ssl_dhe_finish(ctx, out_secret, out_secret_len, out_alert, peer_key,
+                      peer_key_len)) {
+    return 0;
+  }
+  return 1;
+}
+
 static const SSL_ECDH_METHOD kDHEMethod = {
     NID_undef, 0, "",
     ssl_dhe_cleanup,
-    ssl_dhe_generate_keypair,
-    ssl_dhe_compute_secret,
+    ssl_dhe_offer,
+    ssl_dhe_accept,
+    ssl_dhe_finish,
 };
 
-
 static const SSL_ECDH_METHOD kMethods[] = {
     {
         NID_X9_62_prime256v1,
         SSL_CURVE_SECP256R1,
         "P-256",
         ssl_ec_point_cleanup,
-        ssl_ec_point_generate_keypair,
-        ssl_ec_point_compute_secret,
+        ssl_ec_point_offer,
+        ssl_ec_point_accept,
+        ssl_ec_point_finish,
     },
     {
         NID_secp384r1,
         SSL_CURVE_SECP384R1,
         "P-384",
         ssl_ec_point_cleanup,
-        ssl_ec_point_generate_keypair,
-        ssl_ec_point_compute_secret,
+        ssl_ec_point_offer,
+        ssl_ec_point_accept,
+        ssl_ec_point_finish,
     },
     {
         NID_secp521r1,
         SSL_CURVE_SECP521R1,
         "P-521",
         ssl_ec_point_cleanup,
-        ssl_ec_point_generate_keypair,
-        ssl_ec_point_compute_secret,
+        ssl_ec_point_offer,
+        ssl_ec_point_accept,
+        ssl_ec_point_finish,
     },
     {
         NID_X25519,
         SSL_CURVE_X25519,
         "X25519",
         ssl_x25519_cleanup,
-        ssl_x25519_generate_keypair,
-        ssl_x25519_compute_secret,
+        ssl_x25519_offer,
+        ssl_x25519_accept,
+        ssl_x25519_finish,
     },
 };
 
@@ -365,13 +401,21 @@
   ctx->data = NULL;
 }
 
-int SSL_ECDH_CTX_generate_keypair(SSL_ECDH_CTX *ctx, CBB *out_public_key) {
-  return ctx->method->generate_keypair(ctx, out_public_key);
+int SSL_ECDH_CTX_offer(SSL_ECDH_CTX *ctx, CBB *out_public_key) {
+  return ctx->method->offer(ctx, out_public_key);
 }
 
-int SSL_ECDH_CTX_compute_secret(SSL_ECDH_CTX *ctx, uint8_t **out_secret,
-                                size_t *out_secret_len, uint8_t *out_alert,
-                                const uint8_t *peer_key, size_t peer_key_len) {
-  return ctx->method->compute_secret(ctx, out_secret, out_secret_len, out_alert,
-                                     peer_key, peer_key_len);
+int SSL_ECDH_CTX_accept(SSL_ECDH_CTX *ctx, CBB *out_public_key,
+                        uint8_t **out_secret, size_t *out_secret_len,
+                        uint8_t *out_alert, const uint8_t *peer_key,
+                        size_t peer_key_len) {
+  return ctx->method->accept(ctx, out_public_key, out_secret, out_secret_len,
+                             out_alert, peer_key, peer_key_len);
+}
+
+int SSL_ECDH_CTX_finish(SSL_ECDH_CTX *ctx, uint8_t **out_secret,
+                        size_t *out_secret_len, uint8_t *out_alert,
+                        const uint8_t *peer_key, size_t peer_key_len) {
+  return ctx->method->finish(ctx, out_secret, out_secret_len, out_alert,
+                             peer_key, peer_key_len);
 }
diff --git a/ssl/test/runner/key_agreement.go b/ssl/test/runner/key_agreement.go
index 54aa3d3..f1d44f2 100644
--- a/ssl/test/runner/key_agreement.go
+++ b/ssl/test/runner/key_agreement.go
@@ -252,13 +252,16 @@
 
 // A ecdhCurve is an instance of ECDH-style key agreement for TLS.
 type ecdhCurve interface {
-	// generateKeypair generates a keypair using rand. It returns the
-	// encoded public key.
-	generateKeypair(rand io.Reader) (publicKey []byte, err error)
+	// offer generates a keypair using rand. It returns the encoded |publicKey|.
+	offer(rand io.Reader) (publicKey []byte, err error)
 
-	// computeSecret performs a key exchange against peerKey and returns
-	// the resulting shared secret.
-	computeSecret(peerKey []byte) (preMasterSecret []byte, err error)
+	// accept responds to the |peerKey| generated by |offer| with the acceptor's
+	// |publicKey|, and returns agreed-upon |preMasterSecret| to the acceptor.
+	accept(rand io.Reader, peerKey []byte) (publicKey []byte, preMasterSecret []byte, err error)
+
+	// finish returns the computed |preMasterSecret|, given the |peerKey|
+	// generated by |accept|.
+	finish(peerKey []byte) (preMasterSecret []byte, err error)
 }
 
 // ellipticECDHCurve implements ecdhCurve with an elliptic.Curve.
@@ -267,7 +270,7 @@
 	privateKey []byte
 }
 
-func (e *ellipticECDHCurve) generateKeypair(rand io.Reader) (publicKey []byte, err error) {
+func (e *ellipticECDHCurve) offer(rand io.Reader) (publicKey []byte, err error) {
 	var x, y *big.Int
 	e.privateKey, x, y, err = elliptic.GenerateKey(e.curve, rand)
 	if err != nil {
@@ -276,7 +279,19 @@
 	return elliptic.Marshal(e.curve, x, y), nil
 }
 
-func (e *ellipticECDHCurve) computeSecret(peerKey []byte) (preMasterSecret []byte, err error) {
+func (e *ellipticECDHCurve) accept(rand io.Reader, peerKey []byte) (publicKey []byte, preMasterSecret []byte, err error) {
+	publicKey, err = e.offer(rand)
+	if err != nil {
+		return nil, nil, err
+	}
+	preMasterSecret, err = e.finish(peerKey)
+	if err != nil {
+		return nil, nil, err
+	}
+	return
+}
+
+func (e *ellipticECDHCurve) finish(peerKey []byte) (preMasterSecret []byte, err error) {
 	x, y := elliptic.Unmarshal(e.curve, peerKey)
 	if x == nil {
 		return nil, errors.New("tls: invalid peer key")
@@ -294,7 +309,7 @@
 	privateKey [32]byte
 }
 
-func (e *x25519ECDHCurve) generateKeypair(rand io.Reader) (publicKey []byte, err error) {
+func (e *x25519ECDHCurve) offer(rand io.Reader) (publicKey []byte, err error) {
 	_, err = io.ReadFull(rand, e.privateKey[:])
 	if err != nil {
 		return
@@ -304,7 +319,19 @@
 	return out[:], nil
 }
 
-func (e *x25519ECDHCurve) computeSecret(peerKey []byte) (preMasterSecret []byte, err error) {
+func (e *x25519ECDHCurve) accept(rand io.Reader, peerKey []byte) (publicKey []byte, preMasterSecret []byte, err error) {
+	publicKey, err = e.offer(rand)
+	if err != nil {
+		return nil, nil, err
+	}
+	preMasterSecret, err = e.finish(peerKey)
+	if err != nil {
+		return nil, nil, err
+	}
+	return
+}
+
+func (e *x25519ECDHCurve) finish(peerKey []byte) (preMasterSecret []byte, err error) {
 	if len(peerKey) != 32 {
 		return nil, errors.New("tls: invalid peer key")
 	}
@@ -551,7 +578,7 @@
 		return nil, errors.New("tls: preferredCurves includes unsupported curve")
 	}
 
-	publicKey, err := ka.curve.generateKeypair(config.rand())
+	publicKey, err := ka.curve.offer(config.rand())
 	if err != nil {
 		return nil, err
 	}
@@ -577,7 +604,7 @@
 	if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
 		return nil, errClientKeyExchange
 	}
-	return ka.curve.computeSecret(ckx.ciphertext[1:])
+	return ka.curve.finish(ckx.ciphertext[1:])
 }
 
 func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
@@ -612,11 +639,7 @@
 		return nil, nil, errors.New("missing ServerKeyExchange message")
 	}
 
-	publicKey, err := ka.curve.generateKeypair(config.rand())
-	if err != nil {
-		return nil, nil, err
-	}
-	preMasterSecret, err := ka.curve.computeSecret(ka.peerKey)
+	publicKey, preMasterSecret, err := ka.curve.accept(config.rand(), ka.peerKey)
 	if err != nil {
 		return nil, nil, err
 	}