Implement asynchronous private key operations for client auth.

This adds a new API, SSL_set_private_key_method, which allows the consumer to
customize private key operations. For simplicity, it is incompatible with the
multiple slots feature (which will hopefully go away) but does not, for now,
break it.

The new method is only routed up for the client for now. The server will
require a decrypt hook as well for the plain RSA key exchange.

BUG=347404

Change-Id: I35d69095c29134c34c2af88c613ad557d6957614
Reviewed-on: https://boringssl-review.googlesource.com/5049
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 7e278ae..b651acb 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -495,6 +495,72 @@
 OPENSSL_EXPORT uint32_t SSL_get_mode(const SSL *ssl);
 
 
+/* Configuring certificates and private keys.
+ *
+ * TODO(davidben): Move the other, more conventional, certificate and key
+ * configuration functions here, possibly after simplifying the multiple slots
+ * machinery first. https://crbug.com/486295. */
+
+enum ssl_private_key_result_t {
+  ssl_private_key_success,
+  ssl_private_key_retry,
+  ssl_private_key_failure,
+};
+
+/* SSL_PRIVATE_KEY_METHOD describes private key hooks. This is used to off-load
+ * signing operations to a custom, potentially asynchronous, backend. */
+typedef struct ssl_private_key_method_st {
+  /* type returns either |EVP_PKEY_RSA| or |EVP_PKEY_EC| to denote the type of
+   * key used by |ssl|. */
+  int (*type)(SSL *ssl);
+
+  /* supports_digest returns one if the key used by |ssl| supports signing
+   * digests of type |md| and zero otherwise. */
+  int (*supports_digest)(SSL *ssl, const EVP_MD *md);
+
+  /* max_signature_len returns the maximum length of a signature signed by the
+   * key used by |ssl|. This must be a constant value for a given |ssl|. */
+  size_t (*max_signature_len)(SSL *ssl);
+
+  /* sign signs |in_len| bytes of digest from |in|. |md| is the hash function
+   * used to calculate |in|. On success, it returns |ssl_private_key_success|
+   * and writes at most |max_out| bytes of signature data to |out|. On failure,
+   * it returns |ssl_private_key_failure|. If the operation has not completed,
+   * it returns |ssl_private_key_retry|. |sign| should arrange for the
+   * high-level operation on |ssl| to be retried when the operation is
+   * completed. This will result in a call to |sign_complete|.
+   *
+   * If the key is an RSA key, implementations must use PKCS#1 padding. |in| is
+   * the digest itself, so the DigestInfo prefix, if any, must be prepended by
+   * |sign|. If |md| is |EVP_md5_sha1|, there is no prefix.
+   *
+   * It is an error to call |sign| while another private key operation is in
+   * progress on |ssl|. */
+  enum ssl_private_key_result_t (*sign)(SSL *ssl, uint8_t *out, size_t *out_len,
+                                        size_t max_out, const EVP_MD *md,
+                                        const uint8_t *in, size_t in_len);
+
+  /* sign_complete completes a pending |sign| operation. If the operation has
+   * completed, it returns |ssl_private_key_success| and writes the result to
+   * |out| as in |sign|. Otherwise, it returns |ssl_private_key_failure| on
+   * failure and |ssl_private_key_retry| if the operation is still in progress.
+   *
+   * |sign_complete| may be called arbitrarily many times before completion, but
+   * it is an error to call |sign_complete| if there is no pending |sign|
+   * operation in progress on |ssl|. */
+  enum ssl_private_key_result_t (*sign_complete)(SSL *ssl, uint8_t *out,
+                                                 size_t *out_len, size_t max_out);
+} SSL_PRIVATE_KEY_METHOD;
+
+/* SSL_use_private_key_method configures a custom private key on
+ * |ssl|. |key_method| must remain valid for the lifetime of |ssl|. Using custom
+ * keys with the multiple certificate slots feature is not supported.
+ *
+ * TODO(davidben): Remove the multiple certificate slots feature. */
+OPENSSL_EXPORT void SSL_set_private_key_method(
+    SSL *ssl, const SSL_PRIVATE_KEY_METHOD *key_method);
+
+
 /* Connection information. */
 
 /* SSL_get_tls_unique writes at most |max_out| bytes of the tls-unique value
@@ -1279,6 +1345,7 @@
 #define SSL_CHANNEL_ID_LOOKUP 5
 #define SSL_PENDING_SESSION 7
 #define SSL_CERTIFICATE_SELECTION_PENDING 8
+#define SSL_PRIVATE_KEY_OPERATION 9
 
 /* These will only be used when doing non-blocking IO */
 #define SSL_want_nothing(s) (SSL_want(s) == SSL_NOTHING)
@@ -1289,6 +1356,8 @@
 #define SSL_want_session(s) (SSL_want(s) == SSL_PENDING_SESSION)
 #define SSL_want_certificate(s) \
   (SSL_want(s) == SSL_CERTIFICATE_SELECTION_PENDING)
+#define SSL_want_private_key_operation(s) \
+  (SSL_want(s) == SSL_PRIVATE_KEY_OPERATION)
 
 struct ssl_st {
   /* version is the protocol version. */
@@ -1637,6 +1706,7 @@
 #define SSL_ERROR_WANT_CHANNEL_ID_LOOKUP 9
 #define SSL_ERROR_PENDING_SESSION 11
 #define SSL_ERROR_PENDING_CERTIFICATE 12
+#define SSL_ERROR_WANT_PRIVATE_KEY_OPERATION 13
 
 #define SSL_CTRL_EXTRA_CHAIN_CERT 14
 
diff --git a/include/openssl/ssl3.h b/include/openssl/ssl3.h
index f6c8972..9021309 100644
--- a/include/openssl/ssl3.h
+++ b/include/openssl/ssl3.h
@@ -582,6 +582,7 @@
 #define SSL3_ST_CW_KEY_EXCH_B (0x181 | SSL_ST_CONNECT)
 #define SSL3_ST_CW_CERT_VRFY_A (0x190 | SSL_ST_CONNECT)
 #define SSL3_ST_CW_CERT_VRFY_B (0x191 | SSL_ST_CONNECT)
+#define SSL3_ST_CW_CERT_VRFY_C (0x192 | SSL_ST_CONNECT)
 #define SSL3_ST_CW_CHANGE_A (0x1A0 | SSL_ST_CONNECT)
 #define SSL3_ST_CW_CHANGE_B (0x1A1 | SSL_ST_CONNECT)
 #define SSL3_ST_CW_NEXT_PROTO_A (0x200 | SSL_ST_CONNECT)
diff --git a/ssl/internal.h b/ssl/internal.h
index 00eccfa..1736fb0 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -350,6 +350,30 @@
                       size_t in_len);
 
 
+/* Private key operations. */
+
+/* ssl_private_key_* call the corresponding function on the
+ * |SSL_PRIVATE_KEY_METHOD| for |ssl|, if configured. Otherwise, they implement
+ * the operation on |pkey|.
+ *
+ * TODO(davidben): The |EVP_PKEY| must be passed in to due to the multiple
+ * certificate slots feature. Remove it. */
+
+int ssl_private_key_type(SSL *ssl, const EVP_PKEY *pkey);
+
+int ssl_private_key_supports_digest(SSL *ssl, const EVP_PKEY *pkey,
+                                    const EVP_MD *md);
+
+size_t ssl_private_key_max_signature_len(SSL *ssl, const EVP_PKEY *pkey);
+
+enum ssl_private_key_result_t ssl_private_key_sign(
+    SSL *ssl, EVP_PKEY *pkey, uint8_t *out, size_t *out_len, size_t max_out,
+    const EVP_MD *md, const uint8_t *in, size_t in_len);
+
+enum ssl_private_key_result_t ssl_private_key_sign_complete(
+    SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out);
+
+
 /* Underdocumented functions.
  *
  * Functions below here haven't been touched up and may be underdocumented. */
@@ -508,6 +532,10 @@
                    * Probably it would make more sense to store
                    * an index, not a pointer. */
 
+  /* key_method, if non-NULL, is a set of callbacks to call for private key
+   * operations. */
+  const SSL_PRIVATE_KEY_METHOD *key_method;
+
   /* For clients the following masks are of *disabled* key and auth algorithms
    * based on the current session.
    *
@@ -1069,8 +1097,9 @@
 int tls1_process_ticket(SSL *s, const struct ssl_early_callback_ctx *ctx,
                         SSL_SESSION **ret);
 
-int tls12_get_sigandhash(uint8_t *p, const EVP_PKEY *pk, const EVP_MD *md);
-int tls12_get_sigid(const EVP_PKEY *pk);
+int tls12_get_sigandhash(SSL *ssl, uint8_t *p, const EVP_PKEY *pk,
+                         const EVP_MD *md);
+int tls12_get_sigid(int pkey_type);
 const EVP_MD *tls12_get_hash(uint8_t hash_alg);
 
 int tls1_channel_id_hash(EVP_MD_CTX *ctx, SSL *s);
diff --git a/ssl/s3_both.c b/ssl/s3_both.c
index 06338b7..6fb0e3f 100644
--- a/ssl/s3_both.c
+++ b/ssl/s3_both.c
@@ -458,6 +458,8 @@
 
 int ssl3_cert_verify_hash(SSL *s, uint8_t *out, size_t *out_len,
                           const EVP_MD **out_md, EVP_PKEY *pkey) {
+  const int type = ssl_private_key_type(s, pkey);
+
   /* For TLS v1.2 send signature algorithm and signature using
    * agreed digest and cached handshake records. Otherwise, use
    * SHA1 or MD5 + SHA1 depending on key type.  */
@@ -480,7 +482,7 @@
       return 0;
     }
     *out_len = len;
-  } else if (pkey->type == EVP_PKEY_RSA) {
+  } else if (type == EVP_PKEY_RSA) {
     if (s->enc_method->cert_verify_mac(s, NID_md5, out) == 0 ||
         s->enc_method->cert_verify_mac(s, NID_sha1, out + MD5_DIGEST_LENGTH) ==
             0) {
@@ -488,7 +490,7 @@
     }
     *out_len = MD5_DIGEST_LENGTH + SHA_DIGEST_LENGTH;
     *out_md = EVP_md5_sha1();
-  } else if (pkey->type == EVP_PKEY_EC) {
+  } else if (type == EVP_PKEY_EC) {
     if (s->enc_method->cert_verify_mac(s, NID_sha1, out) == 0) {
       return 0;
     }
diff --git a/ssl/s3_clnt.c b/ssl/s3_clnt.c
index 8d192b6..a7fee64 100644
--- a/ssl/s3_clnt.c
+++ b/ssl/s3_clnt.c
@@ -350,6 +350,7 @@
 
       case SSL3_ST_CW_CERT_VRFY_A:
       case SSL3_ST_CW_CERT_VRFY_B:
+      case SSL3_ST_CW_CERT_VRFY_C:
         ret = ssl3_send_cert_verify(s);
         if (ret <= 0) {
           goto end;
@@ -2009,87 +2010,92 @@
 }
 
 int ssl3_send_cert_verify(SSL *s) {
-  uint8_t *buf, *p;
-  const EVP_MD *md = NULL;
-  uint8_t digest[EVP_MAX_MD_SIZE];
-  size_t digest_length;
-  EVP_PKEY *pkey;
-  EVP_PKEY_CTX *pctx = NULL;
-  size_t signature_length = 0;
-  unsigned long n = 0;
+  if (s->state == SSL3_ST_CW_CERT_VRFY_A ||
+      s->state == SSL3_ST_CW_CERT_VRFY_B) {
+    enum ssl_private_key_result_t sign_result;
+    uint8_t *p = ssl_handshake_start(s);
+    size_t signature_length = 0;
+    unsigned long n = 0;
+    EVP_PKEY *pkey = s->cert->key->privatekey;
+    assert(pkey != NULL || s->cert->key_method != NULL);
 
-  buf = (uint8_t *)s->init_buf->data;
+    if (s->state == SSL3_ST_CW_CERT_VRFY_A) {
+      uint8_t *buf = (uint8_t *)s->init_buf->data;
+      const EVP_MD *md = NULL;
+      uint8_t digest[EVP_MAX_MD_SIZE];
+      size_t digest_length;
 
-  if (s->state == SSL3_ST_CW_CERT_VRFY_A) {
-    p = ssl_handshake_start(s);
-    pkey = s->cert->key->privatekey;
-
-    /* Write out the digest type if needbe. */
-    if (SSL_USE_SIGALGS(s)) {
-      md = tls1_choose_signing_digest(s, pkey);
-      if (!tls12_get_sigandhash(p, pkey, md)) {
-        OPENSSL_PUT_ERROR(SSL, ssl3_send_cert_verify, ERR_R_INTERNAL_ERROR);
-        goto err;
+      /* Write out the digest type if need be. */
+      if (SSL_USE_SIGALGS(s)) {
+        md = tls1_choose_signing_digest(s, pkey);
+        if (!tls12_get_sigandhash(s, p, pkey, md)) {
+          OPENSSL_PUT_ERROR(SSL, ssl3_send_cert_verify, ERR_R_INTERNAL_ERROR);
+          return -1;
+        }
+        p += 2;
+        n += 2;
       }
-      p += 2;
-      n += 2;
+
+      /* Compute the digest. */
+      if (!ssl3_cert_verify_hash(s, digest, &digest_length, &md, pkey)) {
+        return -1;
+      }
+
+      /* The handshake buffer is no longer necessary. */
+      if (s->s3->handshake_buffer &&
+          !ssl3_digest_cached_records(s, free_handshake_buffer)) {
+        return -1;
+      }
+
+      /* Sign the digest. */
+      signature_length = ssl_private_key_max_signature_len(s, pkey);
+      if (p + 2 + signature_length > buf + SSL3_RT_MAX_PLAIN_LENGTH) {
+        OPENSSL_PUT_ERROR(SSL, ssl3_send_cert_verify,
+                          SSL_R_DATA_LENGTH_TOO_LONG);
+        return -1;
+      }
+
+      s->rwstate = SSL_PRIVATE_KEY_OPERATION;
+      sign_result = ssl_private_key_sign(s, pkey, &p[2], &signature_length,
+                                         signature_length, md, digest,
+                                         digest_length);
+    } else {
+      if (SSL_USE_SIGALGS(s)) {
+        /* The digest has already been selected and written. */
+        p += 2;
+        n += 2;
+      }
+      signature_length = ssl_private_key_max_signature_len(s, pkey);
+      s->rwstate = SSL_PRIVATE_KEY_OPERATION;
+      sign_result = ssl_private_key_sign_complete(s, &p[2], &signature_length,
+                                                  signature_length);
     }
 
-    /* Compute the digest. */
-    if (!ssl3_cert_verify_hash(s, digest, &digest_length, &md, pkey)) {
-      goto err;
+    if (sign_result == ssl_private_key_retry) {
+      s->state = SSL3_ST_CW_CERT_VRFY_B;
+      return -1;
     }
-
-    /* The handshake buffer is no longer necessary. */
-    if (s->s3->handshake_buffer &&
-        !ssl3_digest_cached_records(s, free_handshake_buffer)) {
-      goto err;
-    }
-
-    /* Sign the digest. */
-    pctx = EVP_PKEY_CTX_new(pkey, NULL);
-    if (pctx == NULL) {
-      goto err;
-    }
-
-    /* Initialize the EVP_PKEY_CTX and determine the size of the signature. */
-    if (!EVP_PKEY_sign_init(pctx) || !EVP_PKEY_CTX_set_signature_md(pctx, md) ||
-        !EVP_PKEY_sign(pctx, NULL, &signature_length, digest, digest_length)) {
-      OPENSSL_PUT_ERROR(SSL, ssl3_send_cert_verify, ERR_R_EVP_LIB);
-      goto err;
-    }
-
-    if (p + 2 + signature_length > buf + SSL3_RT_MAX_PLAIN_LENGTH) {
-      OPENSSL_PUT_ERROR(SSL, ssl3_send_cert_verify, SSL_R_DATA_LENGTH_TOO_LONG);
-      goto err;
-    }
-
-    if (!EVP_PKEY_sign(pctx, &p[2], &signature_length, digest, digest_length)) {
-      OPENSSL_PUT_ERROR(SSL, ssl3_send_cert_verify, ERR_R_EVP_LIB);
-      goto err;
+    s->rwstate = SSL_NOTHING;
+    if (sign_result != ssl_private_key_success) {
+      return -1;
     }
 
     s2n(signature_length, p);
     n += signature_length + 2;
-
     if (!ssl_set_handshake_header(s, SSL3_MT_CERTIFICATE_VERIFY, n)) {
-      goto err;
+      return -1;
     }
-    s->state = SSL3_ST_CW_CERT_VRFY_B;
+    s->state = SSL3_ST_CW_CERT_VRFY_C;
   }
 
-  EVP_PKEY_CTX_free(pctx);
   return ssl_do_write(s);
-
-err:
-  EVP_PKEY_CTX_free(pctx);
-  return -1;
 }
 
 /* ssl3_has_client_certificate returns true if a client certificate is
  * configured. */
 static int ssl3_has_client_certificate(SSL *s) {
-  return s->cert && s->cert->key->x509 && s->cert->key->privatekey;
+  return s->cert && s->cert->key->x509 && (s->cert->key->privatekey ||
+                                           s->cert->key_method);
 }
 
 int ssl3_send_client_certificate(SSL *s) {
diff --git a/ssl/s3_srvr.c b/ssl/s3_srvr.c
index a72e17e..3e6903d 100644
--- a/ssl/s3_srvr.c
+++ b/ssl/s3_srvr.c
@@ -1468,7 +1468,7 @@
       /* Determine signature algorithm. */
       if (SSL_USE_SIGALGS(s)) {
         md = tls1_choose_signing_digest(s, pkey);
-        if (!tls12_get_sigandhash(p, pkey, md)) {
+        if (!tls12_get_sigandhash(s, p, pkey, md)) {
           /* Should never happen */
           al = SSL_AD_INTERNAL_ERROR;
           OPENSSL_PUT_ERROR(SSL, ssl3_send_server_key_exchange,
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c
index 5979008..856a599 100644
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -2089,6 +2089,10 @@
     return SSL_ERROR_WANT_CHANNEL_ID_LOOKUP;
   }
 
+  if (SSL_want_private_key_operation(s)) {
+    return SSL_ERROR_WANT_PRIVATE_KEY_OPERATION;
+  }
+
   return SSL_ERROR_SYSCALL;
 }
 
diff --git a/ssl/ssl_rsa.c b/ssl/ssl_rsa.c
index 87f4c1c..b95b231 100644
--- a/ssl/ssl_rsa.c
+++ b/ssl/ssl_rsa.c
@@ -644,3 +644,64 @@
   BIO_free(in);
   return ret;
 }
+
+void SSL_set_private_key_method(SSL *ssl,
+                                const SSL_PRIVATE_KEY_METHOD *key_method) {
+  ssl->cert->key_method = key_method;
+}
+
+int ssl_private_key_type(SSL *ssl, const EVP_PKEY *pkey) {
+  if (ssl->cert->key_method != NULL) {
+    return ssl->cert->key_method->type(ssl);
+  }
+  return EVP_PKEY_id(pkey);
+}
+
+int ssl_private_key_supports_digest(SSL *ssl, const EVP_PKEY *pkey,
+                                    const EVP_MD *md) {
+  if (ssl->cert->key_method != NULL) {
+    return ssl->cert->key_method->supports_digest(ssl, md);
+  }
+  return EVP_PKEY_supports_digest(pkey, md);
+}
+
+size_t ssl_private_key_max_signature_len(SSL *ssl, const EVP_PKEY *pkey) {
+  if (ssl->cert->key_method != NULL) {
+    return ssl->cert->key_method->max_signature_len(ssl);
+  }
+  return EVP_PKEY_size(pkey);
+}
+
+enum ssl_private_key_result_t ssl_private_key_sign(
+    SSL *ssl, EVP_PKEY *pkey, uint8_t *out, size_t *out_len, size_t max_out,
+    const EVP_MD *md, const uint8_t *in, size_t in_len) {
+  if (ssl->cert->key_method != NULL) {
+    return ssl->cert->key_method->sign(ssl, out, out_len, max_out, md, in,
+                                       in_len);
+  }
+
+  enum ssl_private_key_result_t ret = ssl_private_key_failure;
+  EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new(pkey, NULL);
+  if (ctx == NULL) {
+    goto end;
+  }
+
+  size_t len = max_out;
+  if (!EVP_PKEY_sign_init(ctx) ||
+      !EVP_PKEY_CTX_set_signature_md(ctx, md) ||
+      !EVP_PKEY_sign(ctx, out, &len, in, in_len)) {
+    goto end;
+  }
+  *out_len = len;
+  ret = ssl_private_key_success;
+
+end:
+  EVP_PKEY_CTX_free(ctx);
+  return ret;
+}
+
+enum ssl_private_key_result_t ssl_private_key_sign_complete(
+    SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out) {
+  /* Only custom keys may be asynchronous. */
+  return ssl->cert->key_method->sign_complete(ssl, out, out_len, max_out);
+}
diff --git a/ssl/t1_lib.c b/ssl/t1_lib.c
index 6a57660..98cb383 100644
--- a/ssl/t1_lib.c
+++ b/ssl/t1_lib.c
@@ -681,7 +681,7 @@
                             CBS *cbs, EVP_PKEY *pkey) {
   const uint8_t *sent_sigs;
   size_t sent_sigslen, i;
-  int sigalg = tls12_get_sigid(pkey);
+  int sigalg = tls12_get_sigid(pkey->type);
   uint8_t hash, signature;
 
   /* Should never happen */
@@ -2241,7 +2241,13 @@
   return NID_undef;
 }
 
-int tls12_get_sigandhash(uint8_t *p, const EVP_PKEY *pk, const EVP_MD *md) {
+int tls12_get_sigid(int pkey_type) {
+  return tls12_find_id(pkey_type, tls12_sig,
+                       sizeof(tls12_sig) / sizeof(tls12_lookup));
+}
+
+int tls12_get_sigandhash(SSL *ssl, uint8_t *p, const EVP_PKEY *pk,
+                         const EVP_MD *md) {
   int sig_id, md_id;
 
   if (!md) {
@@ -2254,7 +2260,7 @@
     return 0;
   }
 
-  sig_id = tls12_get_sigid(pk);
+  sig_id = tls12_get_sigid(ssl_private_key_type(ssl, pk));
   if (sig_id == -1) {
     return 0;
   }
@@ -2264,11 +2270,6 @@
   return 1;
 }
 
-int tls12_get_sigid(const EVP_PKEY *pk) {
-  return tls12_find_id(pk->type, tls12_sig,
-                       sizeof(tls12_sig) / sizeof(tls12_lookup));
-}
-
 const EVP_MD *tls12_get_hash(uint8_t hash_alg) {
   switch (hash_alg) {
     case TLSEXT_hash_md5:
@@ -2446,7 +2447,7 @@
 
 const EVP_MD *tls1_choose_signing_digest(SSL *s, EVP_PKEY *pkey) {
   CERT *c = s->cert;
-  int type = EVP_PKEY_id(pkey);
+  int type = ssl_private_key_type(s, pkey);
   size_t i;
 
   /* Select the first shared digest supported by our key. */
@@ -2454,7 +2455,7 @@
     const EVP_MD *md = tls12_get_hash(c->shared_sigalgs[i].rhash);
     if (md == NULL ||
         tls12_get_pkey_type(c->shared_sigalgs[i].rsign) != type ||
-        !EVP_PKEY_supports_digest(pkey, md)) {
+        !ssl_private_key_supports_digest(s, pkey, md)) {
       continue;
     }
     return md;
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index 3b95d7e..f4ae982 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -90,6 +90,12 @@
   ScopedSSL_SESSION pending_session;
   bool early_callback_called = false;
   bool handshake_done = false;
+  // private_key is the underlying private key used when testing custom keys.
+  ScopedEVP_PKEY private_key;
+  std::vector<uint8_t> signature;
+  // signature_retries is the number of times an asynchronous sign operation has
+  // been retried.
+  unsigned signature_retries = 0;
 };
 
 static void TestStateExFree(void *parent, void *ptr, CRYPTO_EX_DATA *ad,
@@ -129,12 +135,100 @@
   return pkey;
 }
 
+static int AsyncPrivateKeyType(SSL *ssl) {
+  return EVP_PKEY_id(GetTestState(ssl)->private_key.get());
+}
+
+static int AsyncPrivateKeySupportsDigest(SSL *ssl, const EVP_MD *md) {
+  return EVP_PKEY_supports_digest(GetTestState(ssl)->private_key.get(), md);
+}
+
+static size_t AsyncPrivateKeyMaxSignatureLen(SSL *ssl) {
+  return EVP_PKEY_size(GetTestState(ssl)->private_key.get());
+}
+
+static ssl_private_key_result_t AsyncPrivateKeySign(
+    SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out,
+    const EVP_MD *md, const uint8_t *in, size_t in_len) {
+  TestState *test_state = GetTestState(ssl);
+  if (!test_state->signature.empty()) {
+    fprintf(stderr, "AsyncPrivateKeySign called with operation pending.\n");
+    abort();
+  }
+
+  ScopedEVP_PKEY_CTX ctx(EVP_PKEY_CTX_new(test_state->private_key.get(),
+                                          nullptr));
+  if (!ctx) {
+    return ssl_private_key_failure;
+  }
+
+  // Write the signature into |test_state|.
+  size_t len = 0;
+  if (!EVP_PKEY_sign_init(ctx.get()) ||
+      !EVP_PKEY_CTX_set_signature_md(ctx.get(), md) ||
+      !EVP_PKEY_sign(ctx.get(), nullptr, &len, in, in_len)) {
+    return ssl_private_key_failure;
+  }
+  test_state->signature.resize(len);
+  if (!EVP_PKEY_sign(ctx.get(), bssl::vector_data(&test_state->signature), &len,
+                     in, in_len)) {
+    return ssl_private_key_failure;
+  }
+  test_state->signature.resize(len);
+
+  // The signature will be released asynchronously in |AsyncPrivateKeySignComplete|.
+  return ssl_private_key_retry;
+}
+
+static ssl_private_key_result_t AsyncPrivateKeySignComplete(
+    SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out) {
+  TestState *test_state = GetTestState(ssl);
+  if (test_state->signature.empty()) {
+    fprintf(stderr,
+            "AsyncPrivateKeySignComplete called without operation pending.\n");
+    abort();
+  }
+
+  if (test_state->signature_retries < 2) {
+    // Only return the signature on the second attempt, to test both incomplete
+    // |sign| and |sign_complete|.
+    return ssl_private_key_retry;
+  }
+
+  if (max_out < test_state->signature.size()) {
+    fprintf(stderr, "Output buffer too small.\n");
+    return ssl_private_key_failure;
+  }
+  memcpy(out, bssl::vector_data(&test_state->signature),
+         test_state->signature.size());
+
+  test_state->signature.clear();
+  test_state->signature_retries = 0;
+  return ssl_private_key_success;
+}
+
+static const SSL_PRIVATE_KEY_METHOD g_async_private_key_method = {
+    AsyncPrivateKeyType,
+    AsyncPrivateKeySupportsDigest,
+    AsyncPrivateKeyMaxSignatureLen,
+    AsyncPrivateKeySign,
+    AsyncPrivateKeySignComplete,
+};
+
 static bool InstallCertificate(SSL *ssl) {
   const TestConfig *config = GetConfigPtr(ssl);
-  if (!config->key_file.empty() &&
-      !SSL_use_PrivateKey_file(ssl, config->key_file.c_str(),
-                               SSL_FILETYPE_PEM)) {
-    return false;
+  TestState *test_state = GetTestState(ssl);
+  if (!config->key_file.empty()) {
+    if (config->use_async_private_key) {
+      test_state->private_key = LoadPrivateKey(config->key_file.c_str());
+      if (!test_state->private_key) {
+        return false;
+      }
+      SSL_set_private_key_method(ssl, &g_async_private_key_method);
+    } else if (!SSL_use_PrivateKey_file(ssl, config->key_file.c_str(),
+                                        SSL_FILETYPE_PEM)) {
+      return false;
+    }
   }
   if (!config->cert_file.empty() &&
       !SSL_use_certificate_file(ssl, config->cert_file.c_str(),
@@ -500,6 +594,9 @@
     case SSL_ERROR_PENDING_CERTIFICATE:
       // The handshake will resume without a second call to the early callback.
       return InstallCertificate(ssl);
+    case SSL_ERROR_WANT_PRIVATE_KEY_OPERATION:
+      test_state->signature_retries++;
+      return true;
     default:
       return false;
   }
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index b0eef42..1186313 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -2202,6 +2202,20 @@
 			"-key-file", path.Join(*resourceDir, rsaKeyFile),
 		},
 	})
+	if async {
+		tests = append(tests, testCase{
+			testType: clientTest,
+			name:     "ClientAuth-Client-AsyncKey",
+			config: Config{
+				ClientAuth: RequireAnyClientCert,
+			},
+			flags: []string{
+				"-cert-file", rsaCertificateFile,
+				"-key-file", rsaKeyFile,
+				"-use-async-private-key",
+			},
+		})
+	}
 	tests = append(tests, testCase{
 		testType: serverTest,
 		name:     "ClientAuth-Server",
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index 363b6f3..031ad93 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -82,6 +82,7 @@
   { "-reject-peer-renegotiations", &TestConfig::reject_peer_renegotiations },
   { "-no-legacy-server-connect", &TestConfig::no_legacy_server_connect },
   { "-tls-unique", &TestConfig::tls_unique },
+  { "-use-async-private-key", &TestConfig::use_async_private_key },
 };
 
 const Flag<std::string> kStringFlags[] = {
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index 5d753c8..e9af0de 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -79,6 +79,7 @@
   bool reject_peer_renegotiations = false;
   bool no_legacy_server_connect = false;
   bool tls_unique = false;
+  bool use_async_private_key = false;
 };
 
 bool ParseConfig(int argc, char **argv, TestConfig *out_config);