Emulate the client_cert_cb with cert_cb.

This avoids needing a extra state around client certificates to avoid
calling the callbacks twice. This does, however, come with a behavior
change: configuring both callbacks won't work. No consumer does this.

(Except bssl_shim which needed slight tweaks.)

Change-Id: Ia5426ed2620e40eecdcf352216c4a46764e31a9a
Reviewed-on: https://boringssl-review.googlesource.com/12690
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 8b59ea5..b123b34 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -3460,7 +3460,8 @@
  * |SSL_get_client_CA_list| for information on the server's certificate request.
  *
  * Use |SSL_CTX_set_cert_cb| instead. Configuring intermediate certificates with
- * this function is confusing. */
+ * this function is confusing. This callback may not be registered concurrently
+ * with |SSL_CTX_set_cert_cb| or |SSL_set_cert_cb|. */
 OPENSSL_EXPORT void SSL_CTX_set_client_cert_cb(
     SSL_CTX *ctx,
     int (*client_cert_cb)(SSL *ssl, X509 **out_x509, EVP_PKEY **out_pkey));
diff --git a/include/openssl/ssl3.h b/include/openssl/ssl3.h
index 4cf51e1..e75b70d 100644
--- a/include/openssl/ssl3.h
+++ b/include/openssl/ssl3.h
@@ -321,7 +321,6 @@
 /* write to server */
 #define SSL3_ST_CW_CERT_A (0x170 | SSL_ST_CONNECT)
 #define SSL3_ST_CW_CERT_B (0x171 | SSL_ST_CONNECT)
-#define SSL3_ST_CW_CERT_C (0x172 | SSL_ST_CONNECT)
 #define SSL3_ST_CW_KEY_EXCH_A (0x180 | SSL_ST_CONNECT)
 #define SSL3_ST_CW_KEY_EXCH_B (0x181 | SSL_ST_CONNECT)
 #define SSL3_ST_CW_CERT_VRFY_A (0x190 | SSL_ST_CONNECT)
diff --git a/ssl/handshake_client.c b/ssl/handshake_client.c
index 028b905..b73215b 100644
--- a/ssl/handshake_client.c
+++ b/ssl/handshake_client.c
@@ -326,7 +326,6 @@
 
       case SSL3_ST_CW_CERT_A:
       case SSL3_ST_CW_CERT_B:
-      case SSL3_ST_CW_CERT_C:
         if (hs->cert_request) {
           ret = ssl3_send_client_certificate(hs);
           if (ret <= 0) {
@@ -1459,53 +1458,41 @@
 
 static int ssl3_send_client_certificate(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  if (ssl->state == SSL3_ST_CW_CERT_A) {
-    /* Call cert_cb to update the certificate. */
-    if (ssl->cert->cert_cb) {
-      int ret = ssl->cert->cert_cb(ssl, ssl->cert->cert_cb_arg);
-      if (ret < 0) {
-        ssl->rwstate = SSL_X509_LOOKUP;
-        return -1;
-      }
-      if (ret == 0) {
-        OPENSSL_PUT_ERROR(SSL, SSL_R_CERT_CB_ERROR);
-        ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
-        return -1;
-      }
-    }
-
-    ssl->state = SSL3_ST_CW_CERT_B;
-  }
-
   if (ssl->state == SSL3_ST_CW_CERT_B) {
-    /* Call client_cert_cb to update the certificate. */
-    int should_retry;
-    if (!ssl_do_client_cert_cb(ssl, &should_retry)) {
-      if (should_retry) {
-        ssl->rwstate = SSL_X509_LOOKUP;
-      }
+    return ssl->method->write_message(ssl);
+  }
+  assert(ssl->state == SSL3_ST_CW_CERT_A);
+
+  /* Call cert_cb to update the certificate. */
+  if (ssl->cert->cert_cb) {
+    int ret = ssl->cert->cert_cb(ssl, ssl->cert->cert_cb_arg);
+    if (ret < 0) {
+      ssl->rwstate = SSL_X509_LOOKUP;
       return -1;
     }
-
-    if (!ssl_has_certificate(ssl)) {
-      hs->cert_request = 0;
-      /* Without a client certificate, the handshake buffer may be released. */
-      ssl3_free_handshake_buffer(ssl);
-
-      if (ssl->version == SSL3_VERSION) {
-        /* In SSL 3.0, send no certificate by skipping both messages. */
-        ssl3_send_alert(ssl, SSL3_AL_WARNING, SSL_AD_NO_CERTIFICATE);
-        return 1;
-      }
-    }
-
-    if (!ssl3_output_cert_chain(ssl)) {
+    if (ret == 0) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_CERT_CB_ERROR);
+      ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
       return -1;
     }
-    ssl->state = SSL3_ST_CW_CERT_C;
   }
 
-  assert(ssl->state == SSL3_ST_CW_CERT_C);
+  if (!ssl_has_certificate(ssl)) {
+    hs->cert_request = 0;
+    /* Without a client certificate, the handshake buffer may be released. */
+    ssl3_free_handshake_buffer(ssl);
+
+    if (ssl->version == SSL3_VERSION) {
+      /* In SSL 3.0, send no certificate by skipping both messages. */
+      ssl3_send_alert(ssl, SSL3_AL_WARNING, SSL_AD_NO_CERTIFICATE);
+      return 1;
+    }
+  }
+
+  if (!ssl3_output_cert_chain(ssl)) {
+    return -1;
+  }
+  ssl->state = SSL3_ST_CW_CERT_B;
   return ssl->method->write_message(ssl);
 }
 
diff --git a/ssl/internal.h b/ssl/internal.h
index 1801ba2..69ea489 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -783,11 +783,6 @@
  * error queue. */
 int ssl_check_leaf_certificate(SSL *ssl, X509 *leaf);
 
-/* ssl_do_client_cert_cb runs the client_cert_cb, if any, and returns one on
- * success and zero on error. On error, it sets |*out_should_retry| to one if
- * the callback failed and should be retried and zero otherwise. */
-int ssl_do_client_cert_cb(SSL *ssl, int *out_should_retry);
-
 
 /* TLS 1.3 key derivation. */
 
diff --git a/ssl/ssl_cert.c b/ssl/ssl_cert.c
index c0bdb5c..d1ad4ec 100644
--- a/ssl/ssl_cert.c
+++ b/ssl/ssl_cert.c
@@ -674,33 +674,6 @@
   return CBB_flush(cbb);
 }
 
-int ssl_do_client_cert_cb(SSL *ssl, int *out_should_retry) {
-  if (ssl_has_certificate(ssl) || ssl->ctx->client_cert_cb == NULL) {
-    return 1;
-  }
-
-  X509 *x509 = NULL;
-  EVP_PKEY *pkey = NULL;
-  int ret = ssl->ctx->client_cert_cb(ssl, &x509, &pkey);
-  if (ret < 0) {
-    *out_should_retry = 1;
-    return 0;
-  }
-
-  if (ret != 0) {
-    if (!SSL_use_certificate(ssl, x509) ||
-        !SSL_use_PrivateKey(ssl, pkey)) {
-      ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
-      *out_should_retry = 0;
-      return 0;
-    }
-  }
-
-  X509_free(x509);
-  EVP_PKEY_free(pkey);
-  return 1;
-}
-
 static int set_cert_store(X509_STORE **store_ptr, X509_STORE *new_store, int take_ref) {
   X509_STORE_free(*store_ptr);
   *store_ptr = new_store;
@@ -852,3 +825,35 @@
   EVP_PKEY_free(pkey);
   return ret;
 }
+
+static int do_client_cert_cb(SSL *ssl, void *arg) {
+  if (ssl_has_certificate(ssl) || ssl->ctx->client_cert_cb == NULL) {
+    return 1;
+  }
+
+  X509 *x509 = NULL;
+  EVP_PKEY *pkey = NULL;
+  int ret = ssl->ctx->client_cert_cb(ssl, &x509, &pkey);
+  if (ret < 0) {
+    return -1;
+  }
+
+  if (ret != 0) {
+    if (!SSL_use_certificate(ssl, x509) ||
+        !SSL_use_PrivateKey(ssl, pkey)) {
+      return 0;
+    }
+  }
+
+  X509_free(x509);
+  EVP_PKEY_free(pkey);
+  return 1;
+}
+
+void SSL_CTX_set_client_cert_cb(SSL_CTX *ctx, int (*cb)(SSL *ssl,
+                                                        X509 **out_x509,
+                                                        EVP_PKEY **out_pkey)) {
+  /* Emulate the old client certificate callback with the new one. */
+  SSL_CTX_set_cert_cb(ctx, do_client_cert_cb, NULL);
+  ctx->client_cert_cb = cb;
+}
diff --git a/ssl/ssl_session.c b/ssl/ssl_session.c
index e6bac0c..1e6e21b 100644
--- a/ssl/ssl_session.c
+++ b/ssl/ssl_session.c
@@ -1007,12 +1007,6 @@
   return ctx->info_callback;
 }
 
-void SSL_CTX_set_client_cert_cb(SSL_CTX *ctx, int (*cb)(SSL *ssl,
-                                                        X509 **out_x509,
-                                                        EVP_PKEY **out_pkey)) {
-  ctx->client_cert_cb = cb;
-}
-
 void SSL_CTX_set_channel_id_cb(SSL_CTX *ctx,
                                void (*cb)(SSL *ssl, EVP_PKEY **pkey)) {
   ctx->channel_id_cb = cb;
diff --git a/ssl/ssl_stat.c b/ssl/ssl_stat.c
index 1ed0bdf..51f6fd3 100644
--- a/ssl/ssl_stat.c
+++ b/ssl/ssl_stat.c
@@ -131,9 +131,6 @@
     case SSL3_ST_CW_CERT_B:
       return "SSLv3 write client certificate B";
 
-    case SSL3_ST_CW_CERT_C:
-      return "SSLv3 write client certificate C";
-
     case SSL3_ST_CW_KEY_EXCH_A:
       return "SSLv3 write client key exchange A";
 
@@ -288,9 +285,6 @@
     case SSL3_ST_CW_CERT_B:
       return "3WCC_B";
 
-    case SSL3_ST_CW_CERT_C:
-      return "3WCC_C";
-
     case SSL3_ST_CW_KEY_EXCH_A:
       return "3WCKEA";
 
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index dbfaf10..140e666 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -488,7 +488,31 @@
   return 1;
 }
 
+static bool CheckCertificateRequest(SSL *ssl) {
+  const TestConfig *config = GetTestConfig(ssl);
+
+  if (!config->expected_certificate_types.empty()) {
+    const uint8_t *certificate_types;
+    size_t certificate_types_len =
+        SSL_get0_certificate_types(ssl, &certificate_types);
+    if (certificate_types_len != config->expected_certificate_types.size() ||
+        memcmp(certificate_types,
+               config->expected_certificate_types.data(),
+               certificate_types_len) != 0) {
+      fprintf(stderr, "certificate types mismatch\n");
+      return false;
+    }
+  }
+
+  // TODO(davidben): Test |SSL_get_client_CA_list|.
+  return true;
+}
+
 static int ClientCertCallback(SSL *ssl, X509 **out_x509, EVP_PKEY **out_pkey) {
+  if (!CheckCertificateRequest(ssl)) {
+    return -1;
+  }
+
   if (GetTestConfig(ssl)->async && !GetTestState(ssl)->cert_ready) {
     return -1;
   }
@@ -511,6 +535,32 @@
   return 1;
 }
 
+static int CertCallback(SSL *ssl, void *arg) {
+  const TestConfig *config = GetTestConfig(ssl);
+
+  // Check the CertificateRequest metadata is as expected.
+  if (!SSL_is_server(ssl) && !CheckCertificateRequest(ssl)) {
+    return -1;
+  }
+
+  if (config->fail_cert_callback) {
+    return 0;
+  }
+
+  // The certificate will be installed via other means.
+  if (!config->async || config->use_early_callback) {
+    return 1;
+  }
+
+  if (!GetTestState(ssl)->cert_ready) {
+    return -1;
+  }
+  if (!InstallCertificate(ssl)) {
+    return 0;
+  }
+  return 1;
+}
+
 static int VerifySucceed(X509_STORE_CTX *store_ctx, void *arg) {
   SSL* ssl = (SSL*)X509_STORE_CTX_get_ex_data(store_ctx,
       SSL_get_ex_data_X509_STORE_CTX_idx());
@@ -643,45 +693,6 @@
   *out_pkey = GetTestState(ssl)->channel_id.release();
 }
 
-static int CertCallback(SSL *ssl, void *arg) {
-  const TestConfig *config = GetTestConfig(ssl);
-
-  // Check the CertificateRequest metadata is as expected.
-  //
-  // TODO(davidben): Test |SSL_get_client_CA_list|.
-  if (!SSL_is_server(ssl) &&
-      !config->expected_certificate_types.empty()) {
-    const uint8_t *certificate_types;
-    size_t certificate_types_len =
-        SSL_get0_certificate_types(ssl, &certificate_types);
-    if (certificate_types_len != config->expected_certificate_types.size() ||
-        memcmp(certificate_types,
-               config->expected_certificate_types.data(),
-               certificate_types_len) != 0) {
-      fprintf(stderr, "certificate types mismatch\n");
-      return 0;
-    }
-  }
-
-  if (config->fail_cert_callback) {
-    return 0;
-  }
-
-  // The certificate will be installed via other means.
-  if (!config->async || config->use_early_callback ||
-      config->use_old_client_cert_callback) {
-    return 1;
-  }
-
-  if (!GetTestState(ssl)->cert_ready) {
-    return -1;
-  }
-  if (!InstallCertificate(ssl)) {
-    return 0;
-  }
-  return 1;
-}
-
 static SSL_SESSION *GetSessionCallback(SSL *ssl, uint8_t *data, int len,
                                        int *copy) {
   TestState *async_state = GetTestState(ssl);
@@ -1484,7 +1495,9 @@
       !InstallCertificate(ssl.get())) {
     return false;
   }
-  SSL_set_cert_cb(ssl.get(), CertCallback, nullptr);
+  if (!config->use_old_client_cert_callback) {
+    SSL_set_cert_cb(ssl.get(), CertCallback, nullptr);
+  }
   if (config->require_any_client_certificate) {
     SSL_set_verify(ssl.get(), SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
                    NULL);
diff --git a/ssl/tls13_client.c b/ssl/tls13_client.c
index ba1589f..b9c3c68 100644
--- a/ssl/tls13_client.c
+++ b/ssl/tls13_client.c
@@ -38,7 +38,6 @@
   state_process_server_certificate,
   state_process_server_certificate_verify,
   state_process_server_finished,
-  state_certificate_callback,
   state_send_client_certificate,
   state_send_client_certificate_verify,
   state_complete_client_certificate_verify,
@@ -439,11 +438,11 @@
   }
 
   ssl->method->received_flight(ssl);
-  hs->tls13_state = state_certificate_callback;
+  hs->tls13_state = state_send_client_certificate;
   return ssl_hs_ok;
 }
 
-static enum ssl_hs_wait_t do_certificate_callback(SSL_HANDSHAKE *hs) {
+static enum ssl_hs_wait_t do_send_client_certificate(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
   /* The peer didn't request a certificate. */
   if (!hs->cert_request) {
@@ -460,25 +459,9 @@
       return ssl_hs_error;
     }
     if (rv < 0) {
-      hs->tls13_state = state_certificate_callback;
-      return ssl_hs_x509_lookup;
-    }
-  }
-
-  hs->tls13_state = state_send_client_certificate;
-  return ssl_hs_ok;
-}
-
-static enum ssl_hs_wait_t do_send_client_certificate(SSL_HANDSHAKE *hs) {
-  SSL *const ssl = hs->ssl;
-  /* Call client_cert_cb to update the certificate. */
-  int should_retry;
-  if (!ssl_do_client_cert_cb(ssl, &should_retry)) {
-    if (should_retry) {
       hs->tls13_state = state_send_client_certificate;
       return ssl_hs_x509_lookup;
     }
-    return ssl_hs_error;
   }
 
   if (!tls13_prepare_certificate(hs)) {
@@ -597,9 +580,6 @@
       case state_process_server_finished:
         ret = do_process_server_finished(hs);
         break;
-      case state_certificate_callback:
-        ret = do_certificate_callback(hs);
-        break;
       case state_send_client_certificate:
         ret = do_send_client_certificate(hs);
         break;