Only negotiate ECDHE curves and sigalgs once

In the process of picking credentials, we pick the signature algorithm
and, in TLS 1.2, also the cipher suite. Save those decisions so we don't
repeat them. Right now we have to recompute it, even though it cannot
fail, which means there were a handful of error paths that were actually
impossible.

Bug: 249
Change-Id: If8d5cbf4dc07e722bf7c33b4b4ccf967c451a5f9
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/66707
Reviewed-by: Bob Beck <bbe@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index b10a27b..b958dce 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -1332,7 +1332,8 @@
   return ssl_hs_ok;
 }
 
-static bool check_credential(SSL_HANDSHAKE *hs, const SSL_CREDENTIAL *cred) {
+static bool check_credential(SSL_HANDSHAKE *hs, const SSL_CREDENTIAL *cred,
+                             uint16_t *out_sigalg) {
   if (cred->type != SSLCredentialType::kX509) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_UNKNOWN_CERTIFICATE_TYPE);
     return false;
@@ -1360,12 +1361,11 @@
     }
   }
 
-  // Check that we will be able to generate a signature. Note this does not
+  // All currently supported credentials require a signature. Note this does not
   // check the ECDSA curve. Prior to TLS 1.3, there is no way to determine which
   // ECDSA curves are supported by the peer, so we must assume all curves are
   // supported.
-  uint16_t unused;
-  return tls1_choose_signature_algorithm(hs, cred, &unused);
+  return tls1_choose_signature_algorithm(hs, cred, out_sigalg);
 }
 
 static enum ssl_hs_wait_t do_send_client_certificate(SSL_HANDSHAKE *hs) {
@@ -1406,13 +1406,12 @@
     hs->transcript.FreeBuffer();
   } else {
     // Select the credential to use.
-    //
-    // TODO(davidben): In doing so, we pick the signature algorithm. Save that
-    // decision to avoid redoing it later.
     for (SSL_CREDENTIAL *cred : creds) {
       ERR_clear_error();
-      if (check_credential(hs, cred)) {
+      uint16_t sigalg;
+      if (check_credential(hs, cred, &sigalg)) {
         hs->credential = UpRef(cred);
+        hs->signature_algorithm = sigalg;
         break;
       }
     }
@@ -1615,15 +1614,10 @@
     return ssl_hs_error;
   }
 
-  uint16_t signature_algorithm;
-  if (!tls1_choose_signature_algorithm(hs, hs->credential.get(),
-                                       &signature_algorithm)) {
-    ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
-    return ssl_hs_error;
-  }
+  assert(hs->signature_algorithm != 0);
   if (ssl_protocol_version(ssl) >= TLS1_2_VERSION) {
     // Write out the digest type in TLS 1.2.
-    if (!CBB_add_u16(&body, signature_algorithm)) {
+    if (!CBB_add_u16(&body, hs->signature_algorithm)) {
       OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
       return ssl_hs_error;
     }
@@ -1639,7 +1633,7 @@
 
   size_t sig_len = max_sig_len;
   switch (ssl_private_key_sign(hs, ptr, &sig_len, max_sig_len,
-                               signature_algorithm,
+                               hs->signature_algorithm,
                                hs->transcript.buffer())) {
     case ssl_private_key_success:
       break;
diff --git a/ssl/handshake_server.cc b/ssl/handshake_server.cc
index fc5202a..06d9025 100644
--- a/ssl/handshake_server.cc
+++ b/ssl/handshake_server.cc
@@ -292,71 +292,9 @@
   return sk;
 }
 
-// ssl_get_compatible_server_ciphers determines the key exchange and
-// authentication cipher suite masks compatible with the server configuration
-// and current ClientHello parameters of |hs|. It sets |*out_mask_k| to the key
-// exchange mask and |*out_mask_a| to the authentication mask. It returns true
-// on success and false on error.
-static bool ssl_get_compatible_server_ciphers(SSL_HANDSHAKE *hs,
-                                              const SSL_CREDENTIAL *cred,
-                                              uint32_t *out_mask_k,
-                                              uint32_t *out_mask_a) {
-  uint32_t mask_k = 0;
-  uint32_t mask_a = 0;
-
-  // Check for a shared group to consider ECDHE ciphers.
-  uint16_t unused;
-  if (tls1_get_shared_group(hs, &unused)) {
-    mask_k |= SSL_kECDHE;
-  }
-
-  // PSK requires a server callback.
-  if (hs->config->psk_server_callback != nullptr) {
-    mask_k |= SSL_kPSK;
-    mask_a |= SSL_aPSK;
-  }
-
-  if (cred != nullptr && cred->type == SSLCredentialType::kX509) {
-    bool sign_ok = tls1_choose_signature_algorithm(hs, cred, &unused);
-    ERR_clear_error();
-
-    // ECDSA keys must additionally be checked against the peer's supported
-    // curve list.
-    int key_type = EVP_PKEY_id(cred->pubkey.get());
-    if (hs->config->check_ecdsa_curve && key_type == EVP_PKEY_EC) {
-      EC_KEY *ec_key = EVP_PKEY_get0_EC_KEY(cred->pubkey.get());
-      uint16_t group_id;
-      if (!ssl_nid_to_group_id(
-              &group_id, EC_GROUP_get_curve_name(EC_KEY_get0_group(ec_key))) ||
-          std::find(hs->peer_supported_group_list.begin(),
-                    hs->peer_supported_group_list.end(),
-                    group_id) == hs->peer_supported_group_list.end()) {
-        sign_ok = false;
-
-        // If this would make us unable to pick any cipher, return an error.
-        // This is not strictly necessary, but it gives us a more specific
-        // error to help the caller diagnose issues.
-        if (mask_a == 0) {
-          OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_CURVE);
-          return false;
-        }
-      }
-    }
-
-    mask_a |= ssl_cipher_auth_mask_for_key(cred->pubkey.get(), sign_ok);
-    if (key_type == EVP_PKEY_RSA) {
-      mask_k |= SSL_kRSA;
-    }
-  }
-
-  *out_mask_k = mask_k;
-  *out_mask_a = mask_a;
-  return true;
-}
-
-static const SSL_CIPHER *choose_cipher(
-    SSL_HANDSHAKE *hs, const SSL_CREDENTIAL *cred,
-    const STACK_OF(SSL_CIPHER) *client_pref) {
+static const SSL_CIPHER *choose_cipher(SSL_HANDSHAKE *hs,
+                                       const STACK_OF(SSL_CIPHER) *client_pref,
+                                       uint32_t mask_k, uint32_t mask_a) {
   SSL *const ssl = hs->ssl;
   const STACK_OF(SSL_CIPHER) *prio, *allow;
   // in_group_flags will either be NULL, or will point to an array of bytes
@@ -381,11 +319,6 @@
     allow = server_pref->ciphers.get();
   }
 
-  uint32_t mask_k, mask_a;
-  if (!ssl_get_compatible_server_ciphers(hs, cred, &mask_k, &mask_a)) {
-    return nullptr;
-  }
-
   for (size_t i = 0; i < sk_SSL_CIPHER_num(prio); i++) {
     const SSL_CIPHER *c = sk_SSL_CIPHER_value(prio, i);
 
@@ -423,6 +356,72 @@
   return nullptr;
 }
 
+struct TLS12ServerParams {
+  bool ok() const { return cipher != nullptr; }
+
+  const SSL_CIPHER *cipher = nullptr;
+  uint16_t signature_algorithm = 0;
+};
+
+static TLS12ServerParams choose_params(SSL_HANDSHAKE *hs,
+                                       const SSL_CREDENTIAL *cred,
+                                       const STACK_OF(SSL_CIPHER) *client_pref,
+                                       bool has_ecdhe_group) {
+  // Determine the usable cipher suites.
+  uint32_t mask_k = 0, mask_a = 0;
+  if (has_ecdhe_group) {
+    mask_k |= SSL_kECDHE;
+  }
+  if (hs->config->psk_server_callback != nullptr) {
+    mask_k |= SSL_kPSK;
+    mask_a |= SSL_aPSK;
+  }
+  uint16_t sigalg = 0;
+  if (cred != nullptr && cred->type == SSLCredentialType::kX509) {
+    bool sign_ok = tls1_choose_signature_algorithm(hs, cred, &sigalg);
+    ERR_clear_error();
+
+    // ECDSA keys must additionally be checked against the peer's supported
+    // curve list.
+    int key_type = EVP_PKEY_id(cred->pubkey.get());
+    if (hs->config->check_ecdsa_curve && key_type == EVP_PKEY_EC) {
+      EC_KEY *ec_key = EVP_PKEY_get0_EC_KEY(cred->pubkey.get());
+      uint16_t group_id;
+      if (!ssl_nid_to_group_id(
+              &group_id, EC_GROUP_get_curve_name(EC_KEY_get0_group(ec_key))) ||
+          std::find(hs->peer_supported_group_list.begin(),
+                    hs->peer_supported_group_list.end(),
+                    group_id) == hs->peer_supported_group_list.end()) {
+        sign_ok = false;
+
+        // If this would make us unable to pick any cipher, return an error.
+        // This is not strictly necessary, but it gives us a more specific
+        // error to help the caller diagnose issues.
+        if (mask_a == 0) {
+          OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_CURVE);
+          return TLS12ServerParams();
+        }
+      }
+    }
+
+    mask_a |= ssl_cipher_auth_mask_for_key(cred->pubkey.get(), sign_ok);
+    if (key_type == EVP_PKEY_RSA) {
+      mask_k |= SSL_kRSA;
+    }
+  }
+
+  TLS12ServerParams params;
+  params.cipher = choose_cipher(hs, client_pref, mask_k, mask_a);
+  if (params.cipher == nullptr) {
+    return TLS12ServerParams();
+  }
+  if (ssl_cipher_requires_server_key_exchange(params.cipher) &&
+      ssl_cipher_uses_certificate_auth(params.cipher)) {
+    params.signature_algorithm = sigalg;
+  }
+  return params;
+}
+
 static enum ssl_hs_wait_t do_start_accept(SSL_HANDSHAKE *hs) {
   ssl_do_info_callback(hs->ssl, SSL_CB_HANDSHAKE_START, 1);
   hs->state = state12_read_client_hello;
@@ -848,6 +847,10 @@
     return ssl_hs_error;
   }
 
+  // Determine the ECDHE group to use, if we are to use ECDHE.
+  uint16_t group_id = 0;
+  bool has_ecdhe_group = tls1_get_shared_group(hs, &group_id);
+
   // Select the credential and cipher suite. This must be done after |cert_cb|
   // runs, so the final credential list is known.
   //
@@ -863,26 +866,29 @@
   if (!ssl_get_credential_list(hs, &creds)) {
     return ssl_hs_error;
   }
+  TLS12ServerParams params;
   if (creds.empty()) {
     // The caller may have configured no credentials, but set a PSK callback.
-    hs->new_cipher = choose_cipher(hs, /*cred=*/nullptr, client_pref.get());
+    params =
+        choose_params(hs, /*cred=*/nullptr, client_pref.get(), has_ecdhe_group);
   } else {
     // Select the first credential which works.
     for (SSL_CREDENTIAL *cred : creds) {
       ERR_clear_error();
-      hs->new_cipher = choose_cipher(hs, cred, client_pref.get());
-      if (hs->new_cipher != nullptr) {
+      params = choose_params(hs, cred, client_pref.get(), has_ecdhe_group);
+      if (params.ok()) {
         hs->credential = UpRef(cred);
         break;
       }
     }
   }
-
-  if (hs->new_cipher == nullptr) {
+  if (!params.ok()) {
     // The error from the last attempt is in the error queue.
     ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
     return ssl_hs_error;
   }
+  hs->new_cipher = params.cipher;
+  hs->signature_algorithm = params.signature_algorithm;
 
   hs->session_id_len = client_hello.session_id_len;
   // This is checked in |ssl_client_hello_init|.
@@ -947,6 +953,10 @@
 
   if (ssl->session == NULL) {
     hs->new_session->cipher = hs->new_cipher;
+    if (hs->new_session->cipher->algorithm_mkey & SSL_kECDHE) {
+      assert(has_ecdhe_group);
+      hs->new_session->group_id = group_id;
+    }
 
     // Determine whether to request a client certificate.
     hs->cert_request = !!(hs->config->verify_mode & SSL_VERIFY_PEER);
@@ -1146,19 +1156,11 @@
     }
 
     if (alg_k & SSL_kECDHE) {
-      // Determine the group to use.
-      uint16_t group_id;
-      if (!tls1_get_shared_group(hs, &group_id)) {
-        OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-        ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
-        return ssl_hs_error;
-      }
-      hs->new_session->group_id = group_id;
-
-      hs->key_shares[0] = SSLKeyShare::Create(group_id);
+      assert(hs->new_session->group_id != 0);
+      hs->key_shares[0] = SSLKeyShare::Create(hs->new_session->group_id);
       if (!hs->key_shares[0] ||
           !CBB_add_u8(cbb.get(), NAMED_CURVE_TYPE) ||
-          !CBB_add_u16(cbb.get(), group_id) ||
+          !CBB_add_u16(cbb.get(), hs->new_session->group_id) ||
           !CBB_add_u8_length_prefixed(cbb.get(), &child)) {
         return ssl_hs_error;
       }
@@ -1166,7 +1168,7 @@
       SSL_HANDSHAKE_HINTS *const hints = hs->hints.get();
       bool hint_ok = false;
       if (hints && !hs->hints_requested &&
-          hints->ecdhe_group_id == group_id &&
+          hints->ecdhe_group_id == hs->new_session->group_id &&
           !hints->ecdhe_public_key.empty() &&
           !hints->ecdhe_private_key.empty()) {
         CBS cbs = MakeConstSpan(hints->ecdhe_private_key);
@@ -1194,7 +1196,7 @@
                               &hints->ecdhe_private_key)) {
             return ssl_hs_error;
           }
-          hints->ecdhe_group_id = group_id;
+          hints->ecdhe_group_id = hs->new_session->group_id;
         }
       }
     } else {
diff --git a/ssl/internal.h b/ssl/internal.h
index cc98538..0e55739 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2167,6 +2167,10 @@
   // record layer.
   uint16_t early_data_written = 0;
 
+  // signature_algorithm is the signature algorithm to be used in signing with
+  // the selected credential, or zero if not applicable or not yet selected.
+  uint16_t signature_algorithm = 0;
+
   // ech_config_id is the ECH config sent by the client.
   uint8_t ech_config_id = 0;
 
diff --git a/ssl/tls13_both.cc b/ssl/tls13_both.cc
index 8058057..4a9b78e 100644
--- a/ssl/tls13_both.cc
+++ b/ssl/tls13_both.cc
@@ -553,18 +553,12 @@
 
 enum ssl_private_key_result_t tls13_add_certificate_verify(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  uint16_t signature_algorithm;
-  if (!tls1_choose_signature_algorithm(hs, hs->credential.get(),
-                                       &signature_algorithm)) {
-    ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
-    return ssl_private_key_failure;
-  }
-
+  assert(hs->signature_algorithm != 0);
   ScopedCBB cbb;
   CBB body;
   if (!ssl->method->init_message(ssl, cbb.get(), &body,
                                  SSL3_MT_CERTIFICATE_VERIFY) ||
-      !CBB_add_u16(&body, signature_algorithm)) {
+      !CBB_add_u16(&body, hs->signature_algorithm)) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return ssl_private_key_failure;
   }
@@ -588,7 +582,7 @@
   }
 
   enum ssl_private_key_result_t sign_result = ssl_private_key_sign(
-      hs, sig, &sig_len, max_sig_len, signature_algorithm, msg);
+      hs, sig, &sig_len, max_sig_len, hs->signature_algorithm, msg);
   if (sign_result != ssl_private_key_success) {
     return sign_result;
   }
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index cdb46c6..6c9ea75 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -832,15 +832,15 @@
   return ssl_hs_ok;
 }
 
-static bool check_credential(SSL_HANDSHAKE *hs, const SSL_CREDENTIAL *cred) {
+static bool check_credential(SSL_HANDSHAKE *hs, const SSL_CREDENTIAL *cred,
+                             uint16_t *out_sigalg) {
   if (cred->type != SSLCredentialType::kX509) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_UNKNOWN_CERTIFICATE_TYPE);
     return false;
   }
 
-  // Check that we will be able to generate a signature.
-  uint16_t unused;
-  return tls1_choose_signature_algorithm(hs, cred, &unused);
+  // All currently supported credentials require a signature.
+  return tls1_choose_signature_algorithm(hs, cred, out_sigalg);
 }
 
 static enum ssl_hs_wait_t do_send_client_certificate(SSL_HANDSHAKE *hs) {
@@ -877,13 +877,12 @@
 
   if (!creds.empty()) {
     // Select the credential to use.
-    //
-    // TODO(davidben): In doing so, we pick the signature algorithm. Save that
-    // decision to avoid redoing it later.
     for (SSL_CREDENTIAL *cred : creds) {
       ERR_clear_error();
-      if (check_credential(hs, cred)) {
+      uint16_t sigalg;
+      if (check_credential(hs, cred, &sigalg)) {
         hs->credential = UpRef(cred);
+        hs->signature_algorithm = sigalg;
         break;
       }
     }
diff --git a/ssl/tls13_server.cc b/ssl/tls13_server.cc
index ebe0cb4..67e1f78 100644
--- a/ssl/tls13_server.cc
+++ b/ssl/tls13_server.cc
@@ -207,7 +207,8 @@
   return true;
 }
 
-static bool check_credential(SSL_HANDSHAKE *hs, const SSL_CREDENTIAL *cred) {
+static bool check_credential(SSL_HANDSHAKE *hs, const SSL_CREDENTIAL *cred,
+                             uint16_t *out_sigalg) {
   switch (cred->type) {
     case SSLCredentialType::kX509:
       break;
@@ -222,11 +223,10 @@
       break;
   }
 
-  // Check that we will be able to generate a signature. If |cred| is a
+  // All currently supported credentials require a signature. If |cred| is a
   // delegated credential, this also checks that the peer supports delegated
   // credentials and matched |dc_cert_verify_algorithm|.
-  uint16_t unused;
-  return tls1_choose_signature_algorithm(hs, cred, &unused);
+  return tls1_choose_signature_algorithm(hs, cred, out_sigalg);
 }
 
 static enum ssl_hs_wait_t do_select_parameters(SSL_HANDSHAKE *hs) {
@@ -259,13 +259,12 @@
   }
 
   // Select the credential to use.
-  //
-  // TODO(davidben): In doing so, we pick the signature algorithm. Save that
-  // decision to avoid redoing it later.
   for (SSL_CREDENTIAL *cred : creds) {
     ERR_clear_error();
-    if (check_credential(hs, cred)) {
+    uint16_t sigalg;
+    if (check_credential(hs, cred, &sigalg)) {
       hs->credential = UpRef(cred);
+      hs->signature_algorithm = sigalg;
       break;
     }
   }