Avoid running the X.509 auto-chaining logic twice in TLS clients

The ssl_get_credential_list helper implicitly assumed we only need to
read the credential list immediately after cert_cb finishes. The SPAKE
CL broke this assumption, by needing to look at credentials during
ClientHello. But since that doesn't need the legacy credential, we can
just say that other uses look at the list directly. As a bonus, we avoid
a copy.

Change-Id: I878cba59903889a648c89bf8a781d9f99c8bfd03
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/76427
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/extensions.cc b/ssl/extensions.cc
index 4e4dbe4..4bd5c5d 100644
--- a/ssl/extensions.cc
+++ b/ssl/extensions.cc
@@ -2879,12 +2879,8 @@
     return true;
   }
 
-  Array<SSL_CREDENTIAL *> creds;
-  if (!ssl_get_credential_list(hs, &creds)) {
-    return false;
-  }
-
-  if (std::none_of(creds.begin(), creds.end(), [](SSL_CREDENTIAL *cred) {
+  const auto &creds = hs->config->cert->credentials;
+  if (std::none_of(creds.begin(), creds.end(), [](const auto &cred) {
         return cred->type == SSLCredentialType::kSPAKE2PlusV1Client;
       })) {
     // If there were no configured PAKE credentials, proceed without filling
@@ -2898,7 +2894,7 @@
     OPENSSL_PUT_ERROR(SSL, SSL_R_UNSUPPORTED_CREDENTIAL_LIST);
     return false;
   }
-  SSL_CREDENTIAL *cred = creds[0];
+  SSL_CREDENTIAL *cred = creds[0].get();
   assert(cred->type == SSLCredentialType::kSPAKE2PlusV1Client);
 
   hs->pake_prover = MakeUnique<spake2plus::Prover>();
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index 6e2dfdd..7319df7 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -1301,7 +1301,7 @@
   }
 
   Array<SSL_CREDENTIAL *> creds;
-  if (!ssl_get_credential_list(hs, &creds)) {
+  if (!ssl_get_full_credential_list(hs, &creds)) {
     return ssl_hs_error;
   }
 
diff --git a/ssl/handshake_server.cc b/ssl/handshake_server.cc
index 1d3619c..12f5662 100644
--- a/ssl/handshake_server.cc
+++ b/ssl/handshake_server.cc
@@ -736,7 +736,7 @@
     return ssl_hs_error;
   }
   Array<SSL_CREDENTIAL *> creds;
-  if (!ssl_get_credential_list(hs, &creds)) {
+  if (!ssl_get_full_credential_list(hs, &creds)) {
     return ssl_hs_error;
   }
   TLS12ServerParams params;
diff --git a/ssl/internal.h b/ssl/internal.h
index 15b344a..4b554c2 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1956,13 +1956,20 @@
 
 BSSL_NAMESPACE_BEGIN
 
-// ssl_get_credential_list computes |hs|'s credential list. On success, it
-// writes it to |*out| and returns true. Otherwise, it returns false. The
-// credential list may be empty, in which case this function will successfully
-// return an empty array.
+// ssl_get_full_credential_list computes |hs|'s full credential list, including
+// the legacy credential. On success, it writes it to |*out| and returns true.
+// Otherwise, it returns false. The credential list may be empty, in which case
+// this function will successfully output an empty array.
+//
+// This function should be called at most once during the handshake and is
+// intended to be used for certificate-based credentials. It runs the
+// auto-chaining logic as part of finishing the legacy credential. Other uses of
+// the credential list (e.g. PAKE credentials) should iterate over
+// |hs->config->cert->credentials|.
 //
 // The pointers in the result are only valid until |hs| is next mutated.
-bool ssl_get_credential_list(SSL_HANDSHAKE *hs, Array<SSL_CREDENTIAL *> *out);
+bool ssl_get_full_credential_list(SSL_HANDSHAKE *hs,
+                                  Array<SSL_CREDENTIAL *> *out);
 
 // ssl_credential_matches_requested_issuers returns true if |cred| is a
 // usable match for any requested issuers in |hs|, and false with an error
diff --git a/ssl/ssl_credential.cc b/ssl/ssl_credential.cc
index d50431f..fcd8ca5 100644
--- a/ssl/ssl_credential.cc
+++ b/ssl/ssl_credential.cc
@@ -35,7 +35,8 @@
   return chain;
 }
 
-bool ssl_get_credential_list(SSL_HANDSHAKE *hs, Array<SSL_CREDENTIAL *> *out) {
+bool ssl_get_full_credential_list(SSL_HANDSHAKE *hs,
+                                  Array<SSL_CREDENTIAL *> *out) {
   CERT *cert = hs->config->cert.get();
   // Finish filling in the legacy credential if needed.
   if (!cert->x509_method->ssl_auto_chain_if_needed(hs)) {
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 2f482df..e7d83e4 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -4180,7 +4180,23 @@
 
   EXPECT_TRUE(ChainsEqual(SSL_get_peer_full_cert_chain(client_.get()),
                           {cert_.get(), cert_.get()}));
+  EXPECT_TRUE(ChainsEqual(SSL_get_peer_full_cert_chain(server_.get()),
+                          {cert_.get(), cert_.get()}));
 
+  // Auto-chaining does not override explicitly-configured intermediates that
+  // are configured as late as cert_cb. If this fails, something in the
+  // handshake is likely auto-chaining too early.
+  SSL_CTX_clear_chain_certs(client_ctx_.get());
+  SSL_CTX_clear_chain_certs(server_ctx_.get());
+  auto install_intermediate = [](SSL *ssl, void *arg) -> int {
+    return SSL_add1_chain_cert(ssl, static_cast<X509 *>(arg));
+  };
+  SSL_CTX_set_cert_cb(client_ctx_.get(), install_intermediate, cert_.get());
+  SSL_CTX_set_cert_cb(server_ctx_.get(), install_intermediate, cert_.get());
+  ASSERT_TRUE(Connect());
+
+  EXPECT_TRUE(ChainsEqual(SSL_get_peer_full_cert_chain(client_.get()),
+                          {cert_.get(), cert_.get()}));
   EXPECT_TRUE(ChainsEqual(SSL_get_peer_full_cert_chain(server_.get()),
                           {cert_.get(), cert_.get()}));
 }
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index 607a0ba..cf36d0c 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -935,7 +935,7 @@
   }
 
   Array<SSL_CREDENTIAL *> creds;
-  if (!ssl_get_credential_list(hs, &creds)) {
+  if (!ssl_get_full_credential_list(hs, &creds)) {
     return ssl_hs_error;
   }
 
diff --git a/ssl/tls13_server.cc b/ssl/tls13_server.cc
index 1c92ea7..e34ecf4 100644
--- a/ssl/tls13_server.cc
+++ b/ssl/tls13_server.cc
@@ -318,7 +318,7 @@
   }
 
   Array<SSL_CREDENTIAL *> creds;
-  if (!ssl_get_credential_list(hs, &creds)) {
+  if (!ssl_get_full_credential_list(hs, &creds)) {
     return ssl_hs_error;
   }
   if (creds.empty()) {