Fix a number of cases overwriting certificates, keys, etc. with SSL_CREDENTIAL

Field-by-field setters make the worst APIs. This fixes the following:

- Calling SSL_CTX_set_chain_and_key twice should override the old one
  (Regression from SSL_CREDENTIAL.)

- Various APIs forgot to clear the old chain before appending new ones.
  (Regression from SSL_CREDENTIAL.)

- Switching between a custom private key and a concrete one should not
  leave the old one lying around. (I think this was always broken.)

Add tests for all of these cases.

Change-Id: Ief7b3aecf2ada3b123d79d4eddf464c65d5f7d0d
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/66907
Commit-Queue: David Benjamin <davidben@google.com>
Auto-Submit: David Benjamin <davidben@google.com>
Reviewed-by: Bob Beck <bbe@google.com>
Commit-Queue: Bob Beck <bbe@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index d10bb02..d73f9da 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -919,8 +919,9 @@
 OPENSSL_EXPORT int SSL_CREDENTIAL_set1_signing_algorithm_prefs(
     SSL_CREDENTIAL *cred, const uint16_t *prefs, size_t num_prefs);
 
-// SSL_CREDENTIAL_set1_cert_chain sets |cred|'s certificate chain to |num_cert|s
-// certificates from |certs|. It returns one on success and zero on error.
+// SSL_CREDENTIAL_set1_cert_chain sets |cred|'s certificate chain, starting from
+// the leaf, to |num_cert|s certificates from |certs|. It returns one on success
+// and zero on error.
 OPENSSL_EXPORT int SSL_CREDENTIAL_set1_cert_chain(SSL_CREDENTIAL *cred,
                                                   CRYPTO_BUFFER *const *certs,
                                                   size_t num_certs);
@@ -1002,11 +1003,13 @@
 OPENSSL_EXPORT int SSL_use_certificate(SSL *ssl, X509 *x509);
 
 // SSL_CTX_use_PrivateKey sets |ctx|'s private key to |pkey|. It returns one on
-// success and zero on failure.
+// success and zero on failure. If |ctx| had a private key or
+// |SSL_PRIVATE_KEY_METHOD| previously configured, it is replaced.
 OPENSSL_EXPORT int SSL_CTX_use_PrivateKey(SSL_CTX *ctx, EVP_PKEY *pkey);
 
 // SSL_use_PrivateKey sets |ssl|'s private key to |pkey|. It returns one on
-// success and zero on failure.
+// success and zero on failure. If |ssl| had a private key or
+// |SSL_PRIVATE_KEY_METHOD| previously configured, it is replaced.
 OPENSSL_EXPORT int SSL_use_PrivateKey(SSL *ssl, EVP_PKEY *pkey);
 
 // SSL_CTX_set0_chain sets |ctx|'s certificate chain, excluding the leaf, to
diff --git a/ssl/internal.h b/ssl/internal.h
index 0e55739..0c2c2f8 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1640,6 +1640,10 @@
   bool SetLeafCert(bssl::UniquePtr<CRYPTO_BUFFER> leaf,
                    bool discard_key_on_mismatch);
 
+  // ClearIntermediateCerts clears intermediate certificates in the certificate
+  // chain, while preserving the leaf.
+  void ClearIntermediateCerts();
+
   // AppendIntermediateCert appends |cert| to the certificate chain. If there is
   // no leaf certificate configured, it leaves a placeholder null in |chain|. It
   // returns one on success and zero on error.
diff --git a/ssl/ssl_cert.cc b/ssl/ssl_cert.cc
index 39798ba..e30ec73 100644
--- a/ssl/ssl_cert.cc
+++ b/ssl/ssl_cert.cc
@@ -191,6 +191,7 @@
     return 0;
   }
 
+  cert->default_credential->ClearCertAndKey();
   if (!SSL_CREDENTIAL_set1_cert_chain(cert->default_credential.get(), certs,
                                       num_certs)) {
     return 0;
diff --git a/ssl/ssl_credential.cc b/ssl/ssl_credential.cc
index f787098..f4bb55e 100644
--- a/ssl/ssl_credential.cc
+++ b/ssl/ssl_credential.cc
@@ -203,6 +203,16 @@
   return true;
 }
 
+void ssl_credential_st::ClearIntermediateCerts() {
+  if (chain == nullptr) {
+    return;
+  }
+
+  while (sk_CRYPTO_BUFFER_num(chain.get()) > 1) {
+    CRYPTO_BUFFER_free(sk_CRYPTO_BUFFER_pop(chain.get()));
+  }
+}
+
 bool ssl_credential_st::AppendIntermediateCert(UniquePtr<CRYPTO_BUFFER> cert) {
   if (!UsesX509()) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
@@ -249,6 +259,7 @@
   }
 
   cred->privkey = UpRef(key);
+  cred->key_method = nullptr;
   return 1;
 }
 
@@ -259,6 +270,7 @@
     return 0;
   }
 
+  cred->privkey = nullptr;
   cred->key_method = key_method;
   return 1;
 }
@@ -275,6 +287,7 @@
     return 0;
   }
 
+  cred->ClearIntermediateCerts();
   for (size_t i = 1; i < num_certs; i++) {
     if (!cred->AppendIntermediateCert(UpRef(certs[i]))) {
       return 0;
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 9247dc3..503ad5f 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -1317,6 +1317,34 @@
       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
 }
 
+static bssl::UniquePtr<CRYPTO_BUFFER> BufferFromPEM(const char *pem) {
+  bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(pem, strlen(pem)));
+  char *name, *header;
+  uint8_t *data;
+  long data_len;
+  if (!PEM_read_bio(bio.get(), &name, &header, &data,
+                    &data_len)) {
+    return nullptr;
+  }
+  OPENSSL_free(name);
+  OPENSSL_free(header);
+
+  auto ret = bssl::UniquePtr<CRYPTO_BUFFER>(
+      CRYPTO_BUFFER_new(data, data_len, nullptr));
+  OPENSSL_free(data);
+  return ret;
+}
+
+static bssl::UniquePtr<X509> X509FromBuffer(
+    bssl::UniquePtr<CRYPTO_BUFFER> buffer) {
+  if (!buffer) {
+    return nullptr;
+  }
+  const uint8_t *derp = CRYPTO_BUFFER_data(buffer.get());
+  return bssl::UniquePtr<X509>(
+      d2i_X509(NULL, &derp, CRYPTO_BUFFER_len(buffer.get())));
+}
+
 static bssl::UniquePtr<X509> GetTestCertificate() {
   static const char kCertPEM[] =
       "-----BEGIN CERTIFICATE-----\n"
@@ -1370,7 +1398,7 @@
   return ctx;
 }
 
-static bssl::UniquePtr<X509> GetECDSATestCertificate() {
+static bssl::UniquePtr<CRYPTO_BUFFER> GetECDSATestCertificateBuffer() {
   static const char kCertPEM[] =
       "-----BEGIN CERTIFICATE-----\n"
       "MIIBzzCCAXagAwIBAgIJANlMBNpJfb/rMAkGByqGSM49BAEwRTELMAkGA1UEBhMC\n"
@@ -1384,9 +1412,14 @@
       "BgcqhkjOPQQBA0gAMEUCIQDyoDVeUTo2w4J5m+4nUIWOcAZ0lVfSKXQA9L4Vh13E\n"
       "BwIgfB55FGohg/B6dGh5XxSZmmi08cueFV7mHzJSYV51yRQ=\n"
       "-----END CERTIFICATE-----\n";
-  return CertFromPEM(kCertPEM);
+  return BufferFromPEM(kCertPEM);
 }
 
+static bssl::UniquePtr<X509> GetECDSATestCertificate() {
+  return X509FromBuffer(GetECDSATestCertificateBuffer());
+}
+
+
 static bssl::UniquePtr<EVP_PKEY> GetECDSATestKey() {
   static const char kKeyPEM[] =
       "-----BEGIN PRIVATE KEY-----\n"
@@ -1397,24 +1430,6 @@
   return KeyFromPEM(kKeyPEM);
 }
 
-static bssl::UniquePtr<CRYPTO_BUFFER> BufferFromPEM(const char *pem) {
-  bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(pem, strlen(pem)));
-  char *name, *header;
-  uint8_t *data;
-  long data_len;
-  if (!PEM_read_bio(bio.get(), &name, &header, &data,
-                    &data_len)) {
-    return nullptr;
-  }
-  OPENSSL_free(name);
-  OPENSSL_free(header);
-
-  auto ret = bssl::UniquePtr<CRYPTO_BUFFER>(
-      CRYPTO_BUFFER_new(data, data_len, nullptr));
-  OPENSSL_free(data);
-  return ret;
-}
-
 static bssl::UniquePtr<CRYPTO_BUFFER> GetChainTestCertificateBuffer() {
   static const char kCertPEM[] =
       "-----BEGIN CERTIFICATE-----\n"
@@ -1438,16 +1453,6 @@
   return BufferFromPEM(kCertPEM);
 }
 
-static bssl::UniquePtr<X509> X509FromBuffer(
-    bssl::UniquePtr<CRYPTO_BUFFER> buffer) {
-  if (!buffer) {
-    return nullptr;
-  }
-  const uint8_t *derp = CRYPTO_BUFFER_data(buffer.get());
-  return bssl::UniquePtr<X509>(
-      d2i_X509(NULL, &derp, CRYPTO_BUFFER_len(buffer.get())));
-}
-
 static bssl::UniquePtr<X509> GetChainTestCertificate() {
   return X509FromBuffer(GetChainTestCertificateBuffer());
 }
@@ -4047,7 +4052,7 @@
   ASSERT_FALSE(SSL_clear(server_.get()));
 }
 
-static bool ChainsEqual(STACK_OF(X509) * chain,
+static bool ChainsEqual(const STACK_OF(X509) *chain,
                         const std::vector<X509 *> &expected) {
   if (sk_X509_num(chain) != expected.size()) {
     return false;
@@ -4062,6 +4067,24 @@
   return true;
 }
 
+static bool BuffersEqual(const STACK_OF(CRYPTO_BUFFER) *chain,
+                         const std::vector<CRYPTO_BUFFER *> &expected) {
+  if (sk_CRYPTO_BUFFER_num(chain) != expected.size()) {
+    return false;
+  }
+
+  for (size_t i = 0; i < expected.size(); i++) {
+    const CRYPTO_BUFFER *buf = sk_CRYPTO_BUFFER_value(chain, i);
+    if (Bytes(CRYPTO_BUFFER_data(buf), CRYPTO_BUFFER_len(buf)) !=
+        Bytes(CRYPTO_BUFFER_data(expected[i]),
+              CRYPTO_BUFFER_len(expected[i]))) {
+      return false;
+    }
+  }
+
+  return true;
+}
+
 TEST_P(SSLVersionTest, AutoChain) {
   cert_ = GetChainTestCertificate();
   ASSERT_TRUE(cert_);
@@ -4630,6 +4653,133 @@
   ASSERT_TRUE(SSL_CTX_use_PrivateKey(ctx.get(), key2.get()));
 }
 
+TEST(SSLTest, OverrideKeyMethodWithKey) {
+  // Make an SSL_PRIVATE_KEY_METHOD that should never be called.
+  static const SSL_PRIVATE_KEY_METHOD kErrorMethod = {
+      [](SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out,
+         uint16_t signature_algorithm, const uint8_t *in,
+         size_t in_len) { return ssl_private_key_failure; },
+      [](SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out,
+         const uint8_t *in, size_t in_len) { return ssl_private_key_failure; },
+      [](SSL *ssl, uint8_t *out, size_t *out_len, size_t max_oun) {
+        return ssl_private_key_failure;
+      },
+  };
+
+  bssl::UniquePtr<EVP_PKEY> key = GetTestKey();
+  ASSERT_TRUE(key);
+  bssl::UniquePtr<X509> leaf = GetTestCertificate();
+  ASSERT_TRUE(leaf);
+
+  bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+  ASSERT_TRUE(ctx);
+  ASSERT_TRUE(SSL_CTX_use_certificate(ctx.get(), leaf.get()));
+
+  // Configuring an |SSL_PRIVATE_KEY_METHOD| and then overwriting it with an
+  // |EVP_PKEY| should clear the |SSL_PRIVATE_KEY_METHOD|.
+  SSL_CTX_set_private_key_method(ctx.get(), &kErrorMethod);
+  ASSERT_TRUE(SSL_CTX_use_PrivateKey(ctx.get(), key.get()));
+
+  bssl::UniquePtr<SSL> client, server;
+  ASSERT_TRUE(ConnectClientAndServer(&client, &server, ctx.get(), ctx.get()));
+}
+
+// Configuring a chain and then overwriting it with a different chain should
+// clear the old one.
+TEST(SSLTest, OverrideChain) {
+  bssl::UniquePtr<EVP_PKEY> key = GetChainTestKey();
+  ASSERT_TRUE(key);
+  bssl::UniquePtr<X509> leaf = GetChainTestCertificate();
+  ASSERT_TRUE(leaf);
+  bssl::UniquePtr<X509> ca = GetChainTestIntermediate();
+  ASSERT_TRUE(ca);
+
+  bssl::UniquePtr<STACK_OF(X509)> chain(sk_X509_new_null());
+  ASSERT_TRUE(chain);
+  ASSERT_TRUE(bssl::PushToStack(chain.get(), bssl::UpRef(ca)));
+
+  bssl::UniquePtr<STACK_OF(X509)> wrong_chain(sk_X509_new_null());
+  ASSERT_TRUE(wrong_chain);
+  ASSERT_TRUE(bssl::PushToStack(wrong_chain.get(), bssl::UpRef(leaf)));
+  ASSERT_TRUE(bssl::PushToStack(wrong_chain.get(), bssl::UpRef(leaf)));
+
+  bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+  ASSERT_TRUE(ctx);
+  ASSERT_TRUE(SSL_CTX_use_certificate(ctx.get(), leaf.get()));
+  ASSERT_TRUE(SSL_CTX_use_PrivateKey(ctx.get(), key.get()));
+
+  // Configure one chain, then replace it with another. Note this API considers
+  // the chain to exclude the leaf.
+  ASSERT_TRUE(SSL_CTX_set1_chain(ctx.get(), wrong_chain.get()));
+  ASSERT_TRUE(SSL_CTX_set1_chain(ctx.get(), chain.get()));
+
+  bssl::UniquePtr<SSL> client, server;
+  ASSERT_TRUE(ConnectClientAndServer(&client, &server, ctx.get(), ctx.get()));
+  EXPECT_TRUE(ChainsEqual(SSL_get_peer_full_cert_chain(client.get()),
+                          {leaf.get(), ca.get()}));
+}
+
+TEST(SSLTest, OverrideChainAndKey) {
+  bssl::UniquePtr<EVP_PKEY> key1 = GetChainTestKey();
+  ASSERT_TRUE(key1);
+  bssl::UniquePtr<CRYPTO_BUFFER> leaf1 = GetChainTestCertificateBuffer();
+  ASSERT_TRUE(leaf1);
+  bssl::UniquePtr<CRYPTO_BUFFER> ca1 = GetChainTestIntermediateBuffer();
+  ASSERT_TRUE(ca1);
+  bssl::UniquePtr<EVP_PKEY> key2 = GetECDSATestKey();
+  ASSERT_TRUE(key2);
+  bssl::UniquePtr<CRYPTO_BUFFER> leaf2 = GetECDSATestCertificateBuffer();
+  ASSERT_TRUE(leaf2);
+
+  bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+  ASSERT_TRUE(ctx);
+
+  // Configure one cert and key pair, then replace it with noather.
+  std::vector<CRYPTO_BUFFER *> certs = {leaf1.get(), ca1.get()};
+  ASSERT_TRUE(SSL_CTX_set_chain_and_key(ctx.get(), certs.data(), certs.size(),
+                                        key1.get(), nullptr));
+  certs = {leaf2.get()};
+  ASSERT_TRUE(SSL_CTX_set_chain_and_key(ctx.get(), certs.data(), certs.size(),
+                                        key2.get(), nullptr));
+
+  bssl::UniquePtr<SSL> client, server;
+  ASSERT_TRUE(ConnectClientAndServer(&client, &server, ctx.get(), ctx.get()));
+  EXPECT_TRUE(
+      BuffersEqual(SSL_get0_peer_certificates(client.get()), {leaf2.get()}));
+}
+
+TEST(SSLTest, OverrideCredentialChain) {
+  bssl::UniquePtr<EVP_PKEY> key = GetChainTestKey();
+  ASSERT_TRUE(key);
+  bssl::UniquePtr<CRYPTO_BUFFER> leaf = GetChainTestCertificateBuffer();
+  ASSERT_TRUE(leaf);
+  bssl::UniquePtr<CRYPTO_BUFFER> ca = GetChainTestIntermediateBuffer();
+  ASSERT_TRUE(ca);
+
+  std::vector<CRYPTO_BUFFER *> chain = {leaf.get(), ca.get()};
+  std::vector<CRYPTO_BUFFER *> wrong_chain = {leaf.get(), leaf.get(),
+                                              leaf.get()};
+
+  bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+  ASSERT_TRUE(ctx);
+  bssl::UniquePtr<SSL_CREDENTIAL> cred(SSL_CREDENTIAL_new_x509());
+  ASSERT_TRUE(cred);
+
+  // Configure one chain (including the leaf), then replace it with another.
+  ASSERT_TRUE(SSL_CREDENTIAL_set1_cert_chain(cred.get(), wrong_chain.data(),
+                                             wrong_chain.size()));
+  ASSERT_TRUE(
+      SSL_CREDENTIAL_set1_cert_chain(cred.get(), chain.data(), chain.size()));
+
+  ASSERT_TRUE(SSL_CREDENTIAL_set1_private_key(cred.get(), key.get()));
+  ASSERT_TRUE(SSL_CTX_add1_credential(ctx.get(), cred.get()));
+
+  bssl::UniquePtr<SSL> client, server;
+  ASSERT_TRUE(ConnectClientAndServer(&client, &server, ctx.get(), ctx.get()));
+  EXPECT_TRUE(BuffersEqual(SSL_get0_peer_certificates(client.get()),
+                           {leaf.get(), ca.get()}));
+}
+
 TEST(SSLTest, SetChainAndKeyCtx) {
   bssl::UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(TLS_with_buffers_method()));
   ASSERT_TRUE(client_ctx);
diff --git a/ssl/ssl_x509.cc b/ssl/ssl_x509.cc
index d7f1083..66c3210 100644
--- a/ssl/ssl_x509.cc
+++ b/ssl/ssl_x509.cc
@@ -198,6 +198,7 @@
 // which case no change to |cert->chain| is made. It preverses the existing
 // leaf from |cert->chain|, if any.
 static bool ssl_cert_set1_chain(CERT *cert, STACK_OF(X509) *chain) {
+  cert->default_credential->ClearIntermediateCerts();
   for (X509 *x509 : chain) {
     UniquePtr<CRYPTO_BUFFER> buffer = x509_to_buffer(x509);
     if (!buffer ||