Give SSL_PRIVATE_KEY_METHOD a message-based API.

This allows us to implement custom RSA-PSS-based keys, so the async TLS
1.3 tests can proceed. For now, both sign and sign_digest exist, so
downstreams only need to manage a small change atomically. We'll remove
sign_digest separately.

In doing so, fold all the *_complete hooks into a single complete hook
as no one who implemented two operations ever used different function
pointers for them.

While I'm here, I've bumped BORINGSSL_API_VERSION. I do not believe we
have any SSL_PRIVATE_KEY_METHOD versions who cannot update atomically,
but save a round-trip in case we do. It's free.

Change-Id: I7f031aabfb3343805deee429b9e244aed5d76aed
Reviewed-on: https://boringssl-review.googlesource.com/8786
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/include/openssl/base.h b/include/openssl/base.h
index 7a3adfb..4fe2230 100644
--- a/include/openssl/base.h
+++ b/include/openssl/base.h
@@ -125,7 +125,7 @@
  * A consumer may use this symbol in the preprocessor to temporarily build
  * against multiple revisions of BoringSSL at the same time. It is not
  * recommended to do so for longer than is necessary. */
-#define BORINGSSL_API_VERSION 1
+#define BORINGSSL_API_VERSION 2
 
 #if defined(BORINGSSL_SHARED_LIBRARY)
 
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 75ea320..1caca73 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -992,35 +992,50 @@
    * key used by |ssl|. This must be a constant value for a given |ssl|. */
   size_t (*max_signature_len)(SSL *ssl);
 
-  /* sign signs |in_len| bytes of digest from |in|. |md| is the hash function
-   * used to calculate |in|. On success, it returns |ssl_private_key_success|
-   * and writes at most |max_out| bytes of signature data to |out|. On failure,
-   * it returns |ssl_private_key_failure|. If the operation has not completed,
-   * it returns |ssl_private_key_retry|. |sign| should arrange for the
-   * high-level operation on |ssl| to be retried when the operation is
-   * completed. This will result in a call to |sign_complete|.
+  /* sign signs the message |in| in using the specified signature algorithm. On
+   * success, it returns |ssl_private_key_success| and writes at most |max_out|
+   * bytes of signature data to |out| and sets |*out_len| to the number of bytes
+   * written. On failure, it returns |ssl_private_key_failure|. If the operation
+   * has not completed, it returns |ssl_private_key_retry|. |sign| should
+   * arrange for the high-level operation on |ssl| to be retried when the
+   * operation is completed. This will result in a call to |complete|.
+   *
+   * |signature_algorithm| is one of the |SSL_SIGN_*| values, as defined in TLS
+   * 1.3. Note that, in TLS 1.2, ECDSA algorithms do not require that curve
+   * sizes match hash sizes, so the curve portion of |SSL_SIGN_ECDSA_*| values
+   * must be ignored. BoringSSL will internally handle the curve matching logic
+   * where appropriate.
+   *
+   * It is an error to call |sign| while another private key operation is in
+   * progress on |ssl|. */
+  enum ssl_private_key_result_t (*sign)(SSL *ssl, uint8_t *out, size_t *out_len,
+                                        size_t max_out,
+                                        uint16_t signature_algorithm,
+                                        const uint8_t *in, size_t in_len);
+
+  /* sign_digest signs |in_len| bytes of digest from |in|. |md| is the hash
+   * function used to calculate |in|. On success, it returns
+   * |ssl_private_key_success| and writes at most |max_out| bytes of signature
+   * data to |out|. On failure, it returns |ssl_private_key_failure|. If the
+   * operation has not completed, it returns |ssl_private_key_retry|. |sign|
+   * should arrange for the high-level operation on |ssl| to be retried when the
+   * operation is completed. This will result in a call to |complete|.
    *
    * If the key is an RSA key, implementations must use PKCS#1 padding. |in| is
    * the digest itself, so the DigestInfo prefix, if any, must be prepended by
    * |sign|. If |md| is |EVP_md5_sha1|, there is no prefix.
    *
-   * It is an error to call |sign| while another private key operation is in
-   * progress on |ssl|. */
-  enum ssl_private_key_result_t (*sign)(SSL *ssl, uint8_t *out, size_t *out_len,
-                                        size_t max_out, const EVP_MD *md,
-                                        const uint8_t *in, size_t in_len);
-
-  /* sign_complete completes a pending |sign| operation. If the operation has
-   * completed, it returns |ssl_private_key_success| and writes the result to
-   * |out| as in |sign|. Otherwise, it returns |ssl_private_key_failure| on
-   * failure and |ssl_private_key_retry| if the operation is still in progress.
+   * It is an error to call |sign_digest| while another private key operation is
+   * in progress on |ssl|.
    *
-   * |sign_complete| may be called arbitrarily many times before completion, but
-   * it is an error to call |sign_complete| if there is no pending |sign|
-   * operation in progress on |ssl|. */
-  enum ssl_private_key_result_t (*sign_complete)(SSL *ssl, uint8_t *out,
-                                                 size_t *out_len,
-                                                 size_t max_out);
+   * This function is deprecated. Implement |sign| instead.
+   *
+   * TODO(davidben): Remove this function. */
+  enum ssl_private_key_result_t (*sign_digest)(SSL *ssl, uint8_t *out,
+                                               size_t *out_len, size_t max_out,
+                                               const EVP_MD *md,
+                                               const uint8_t *in,
+                                               size_t in_len);
 
   /* decrypt decrypts |in_len| bytes of encrypted data from |in|. On success it
    * returns |ssl_private_key_success|, writes at most |max_out| bytes of
@@ -1028,9 +1043,9 @@
    * written. On failure it returns |ssl_private_key_failure|. If the operation
    * has not completed, it returns |ssl_private_key_retry|. The caller should
    * arrange for the high-level operation on |ssl| to be retried when the
-   * operation is completed, which will result in a call to |decrypt_complete|.
-   * This function only works with RSA keys and should perform a raw RSA
-   * decryption operation with no padding.
+   * operation is completed, which will result in a call to |complete|. This
+   * function only works with RSA keys and should perform a raw RSA decryption
+   * operation with no padding.
    *
    * It is an error to call |decrypt| while another private key operation is in
    * progress on |ssl|. */
@@ -1038,18 +1053,16 @@
                                            size_t *out_len, size_t max_out,
                                            const uint8_t *in, size_t in_len);
 
-  /* decrypt_complete completes a pending |decrypt| operation. If the operation
-   * has completed, it returns |ssl_private_key_success| and writes the result
-   * to |out| as in |decrypt|. Otherwise, it returns |ssl_private_key_failure|
-   * on failure and |ssl_private_key_retry| if the operation is still in
-   * progress.
+  /* complete completes a pending operation. If the operation has completed, it
+   * returns |ssl_private_key_success| and writes the result to |out| as in
+   * |sign|. Otherwise, it returns |ssl_private_key_failure| on failure and
+   * |ssl_private_key_retry| if the operation is still in progress.
    *
-   * |decrypt_complete| may be called arbitrarily many times before completion,
-   * but it is an error to call |decrypt_complete| if there is no pending
-   * |decrypt| operation in progress on |ssl|. */
-  enum ssl_private_key_result_t (*decrypt_complete)(SSL *ssl, uint8_t *out,
-                                                    size_t *out_len,
-                                                    size_t max_out);
+   * |complete| may be called arbitrarily many times before completion, but it
+   * is an error to call |complete| if there is no pending operation in progress
+   * on |ssl|. */
+  enum ssl_private_key_result_t (*complete)(SSL *ssl, uint8_t *out,
+                                            size_t *out_len, size_t max_out);
 } SSL_PRIVATE_KEY_METHOD;
 
 /* SSL_set_private_key_method configures a custom private key on |ssl|.
diff --git a/ssl/handshake_client.c b/ssl/handshake_client.c
index 09f8070..ee2de2b 100644
--- a/ssl/handshake_client.c
+++ b/ssl/handshake_client.c
@@ -1727,8 +1727,7 @@
     ssl3_free_handshake_buffer(ssl);
   } else {
     assert(ssl->state == SSL3_ST_CW_CERT_VRFY_B);
-    sign_result =
-        ssl_private_key_sign_complete(ssl, ptr, &sig_len, max_sig_len);
+    sign_result = ssl_private_key_complete(ssl, ptr, &sig_len, max_sig_len);
   }
 
   switch (sign_result) {
diff --git a/ssl/handshake_server.c b/ssl/handshake_server.c
index 7cd8265..4b1b16d 100644
--- a/ssl/handshake_server.c
+++ b/ssl/handshake_server.c
@@ -1112,8 +1112,7 @@
       OPENSSL_free(transcript_data);
     } else {
       assert(ssl->state == SSL3_ST_SW_KEY_EXCH_B);
-      sign_result =
-          ssl_private_key_sign_complete(ssl, ptr, &sig_len, max_sig_len);
+      sign_result = ssl_private_key_complete(ssl, ptr, &sig_len, max_sig_len);
     }
 
     switch (sign_result) {
@@ -1452,8 +1451,8 @@
     } else {
       assert(ssl->state == SSL3_ST_SR_KEY_EXCH_B);
       /* Complete async decrypt. */
-      decrypt_result = ssl_private_key_decrypt_complete(
-          ssl, decrypt_buf, &decrypt_len, rsa_size);
+      decrypt_result =
+          ssl_private_key_complete(ssl, decrypt_buf, &decrypt_len, rsa_size);
     }
 
     switch (decrypt_result) {
diff --git a/ssl/internal.h b/ssl/internal.h
index 06148f1..d2cb04c 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -484,15 +484,13 @@
     SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out,
     uint16_t signature_algorithm, const uint8_t *in, size_t in_len);
 
-enum ssl_private_key_result_t ssl_private_key_sign_complete(
-    SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out);
-
 enum ssl_private_key_result_t ssl_private_key_decrypt(
     SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out,
     const uint8_t *in, size_t in_len);
 
-enum ssl_private_key_result_t ssl_private_key_decrypt_complete(
-    SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out);
+enum ssl_private_key_result_t ssl_private_key_complete(SSL *ssl, uint8_t *out,
+                                                       size_t *out_len,
+                                                       size_t max_out);
 
 /* ssl_private_key_supports_signature_algorithm returns one if |ssl|'s private
  * key supports |signature_algorithm| and zero otherwise. */
diff --git a/ssl/ssl_rsa.c b/ssl/ssl_rsa.c
index 9678493..295f37a 100644
--- a/ssl/ssl_rsa.c
+++ b/ssl/ssl_rsa.c
@@ -626,10 +626,13 @@
     SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out,
     uint16_t signature_algorithm, const uint8_t *in, size_t in_len) {
   if (ssl->cert->key_method != NULL) {
-    /* For now, custom private keys can only handle pre-TLS-1.3 signature
-     * algorithms.
-     *
-     * TODO(davidben): Switch SSL_PRIVATE_KEY_METHOD to message-based APIs. */
+    if (ssl->cert->key_method->sign != NULL) {
+      return ssl->cert->key_method->sign(ssl, out, out_len, max_out,
+                                         signature_algorithm, in, in_len);
+    }
+
+    /* TODO(davidben): Remove support for |sign_digest|-only
+     * |SSL_PRIVATE_KEY_METHOD|s. */
     const EVP_MD *md;
     int curve;
     if (!is_rsa_pkcs1(&md, signature_algorithm) &&
@@ -644,8 +647,8 @@
       return ssl_private_key_failure;
     }
 
-    return ssl->cert->key_method->sign(ssl, out, out_len, max_out, md, hash,
-                                       hash_len);
+    return ssl->cert->key_method->sign_digest(ssl, out, out_len, max_out, md,
+                                              hash, hash_len);
   }
 
   const EVP_MD *md;
@@ -673,12 +676,6 @@
   return ssl_private_key_failure;
 }
 
-enum ssl_private_key_result_t ssl_private_key_sign_complete(
-    SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out) {
-  /* Only custom keys may be asynchronous. */
-  return ssl->cert->key_method->sign_complete(ssl, out, out_len, max_out);
-}
-
 int ssl_public_key_verify(SSL *ssl, const uint8_t *signature,
                           size_t signature_len, uint16_t signature_algorithm,
                           EVP_PKEY *pkey, const uint8_t *in, size_t in_len) {
@@ -727,10 +724,11 @@
   return ssl_private_key_success;
 }
 
-enum ssl_private_key_result_t ssl_private_key_decrypt_complete(
-    SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out) {
+enum ssl_private_key_result_t ssl_private_key_complete(SSL *ssl, uint8_t *out,
+                                                       size_t *out_len,
+                                                       size_t max_out) {
   /* Only custom keys may be asynchronous. */
-  return ssl->cert->key_method->decrypt_complete(ssl, out, out_len, max_out);
+  return ssl->cert->key_method->complete(ssl, out, out_len, max_out);
 }
 
 int ssl_private_key_supports_signature_algorithm(SSL *ssl,
@@ -772,6 +770,11 @@
       return 0;
     }
 
+    /* RSA-PSS is only supported by message-based private keys. */
+    if (ssl->cert->key_method != NULL && ssl->cert->key_method->sign == NULL) {
+      return 0;
+    }
+
     return 1;
   }
 
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index c505f21..0bdfc94 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -165,66 +165,80 @@
 
 static ssl_private_key_result_t AsyncPrivateKeySign(
     SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out,
-    const EVP_MD *md, const uint8_t *in, size_t in_len) {
+    uint16_t signature_algorithm, const uint8_t *in, size_t in_len) {
   TestState *test_state = GetTestState(ssl);
   if (!test_state->private_key_result.empty()) {
     fprintf(stderr, "AsyncPrivateKeySign called with operation pending.\n");
     abort();
   }
 
-  ScopedEVP_PKEY_CTX ctx(EVP_PKEY_CTX_new(test_state->private_key.get(),
-                                          nullptr));
-  if (!ctx) {
+  // Determine the hash.
+  const EVP_MD *md;
+  switch (signature_algorithm) {
+    case SSL_SIGN_RSA_PKCS1_SHA1:
+    case SSL_SIGN_ECDSA_SHA1:
+      md = EVP_sha1();
+      break;
+    case SSL_SIGN_RSA_PKCS1_SHA256:
+    case SSL_SIGN_ECDSA_SECP256R1_SHA256:
+    case SSL_SIGN_RSA_PSS_SHA256:
+      md = EVP_sha256();
+      break;
+    case SSL_SIGN_RSA_PKCS1_SHA384:
+    case SSL_SIGN_ECDSA_SECP384R1_SHA384:
+    case SSL_SIGN_RSA_PSS_SHA384:
+      md = EVP_sha384();
+      break;
+    case SSL_SIGN_RSA_PKCS1_SHA512:
+    case SSL_SIGN_ECDSA_SECP521R1_SHA512:
+    case SSL_SIGN_RSA_PSS_SHA512:
+      md = EVP_sha512();
+      break;
+    case SSL_SIGN_RSA_PKCS1_MD5_SHA1:
+      md = EVP_md5_sha1();
+      break;
+    default:
+      fprintf(stderr, "Unknown signature algorithm %04x.\n",
+              signature_algorithm);
+      return ssl_private_key_failure;
+  }
+
+  ScopedEVP_MD_CTX ctx;
+  EVP_PKEY_CTX *pctx;
+  if (!EVP_DigestSignInit(ctx.get(), &pctx, md, nullptr,
+                          test_state->private_key.get())) {
     return ssl_private_key_failure;
   }
 
+  // Configure additional signature parameters.
+  switch (signature_algorithm) {
+    case SSL_SIGN_RSA_PSS_SHA256:
+    case SSL_SIGN_RSA_PSS_SHA384:
+    case SSL_SIGN_RSA_PSS_SHA512:
+      if (!EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PSS_PADDING) ||
+          !EVP_PKEY_CTX_set_rsa_pss_saltlen(pctx,
+                                            -1 /* salt len = hash len */)) {
+        return ssl_private_key_failure;
+      }
+  }
+
   // Write the signature into |test_state|.
   size_t len = 0;
-  if (!EVP_PKEY_sign_init(ctx.get()) ||
-      !EVP_PKEY_CTX_set_signature_md(ctx.get(), md) ||
-      !EVP_PKEY_sign(ctx.get(), nullptr, &len, in, in_len)) {
+  if (!EVP_DigestSignUpdate(ctx.get(), in, in_len) ||
+      !EVP_DigestSignFinal(ctx.get(), nullptr, &len)) {
     return ssl_private_key_failure;
   }
   test_state->private_key_result.resize(len);
-  if (!EVP_PKEY_sign(ctx.get(), test_state->private_key_result.data(), &len, in,
-                     in_len)) {
+  if (!EVP_DigestSignFinal(ctx.get(), test_state->private_key_result.data(),
+                           &len)) {
     return ssl_private_key_failure;
   }
   test_state->private_key_result.resize(len);
 
-  // The signature will be released asynchronously in
-  // |AsyncPrivateKeySignComplete|.
+  // The signature will be released asynchronously in |AsyncPrivateKeyComplete|.
   return ssl_private_key_retry;
 }
 
-static ssl_private_key_result_t AsyncPrivateKeySignComplete(
-    SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out) {
-  TestState *test_state = GetTestState(ssl);
-  if (test_state->private_key_result.empty()) {
-    fprintf(stderr,
-            "AsyncPrivateKeySignComplete called without operation pending.\n");
-    abort();
-  }
-
-  if (test_state->private_key_retries < 2) {
-    // Only return the signature on the second attempt, to test both incomplete
-    // |sign| and |sign_complete|.
-    return ssl_private_key_retry;
-  }
-
-  if (max_out < test_state->private_key_result.size()) {
-    fprintf(stderr, "Output buffer too small.\n");
-    return ssl_private_key_failure;
-  }
-  memcpy(out, test_state->private_key_result.data(),
-         test_state->private_key_result.size());
-  *out_len = test_state->private_key_result.size();
-
-  test_state->private_key_result.clear();
-  test_state->private_key_retries = 0;
-  return ssl_private_key_success;
-}
-
 static ssl_private_key_result_t AsyncPrivateKeyDecrypt(
     SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out,
     const uint8_t *in, size_t in_len) {
@@ -249,18 +263,16 @@
 
   test_state->private_key_result.resize(*out_len);
 
-  // The decryption will be released asynchronously in
-  // |AsyncPrivateKeyDecryptComplete|.
+  // The decryption will be released asynchronously in |AsyncPrivateComplete|.
   return ssl_private_key_retry;
 }
 
-static ssl_private_key_result_t AsyncPrivateKeyDecryptComplete(
+static ssl_private_key_result_t AsyncPrivateKeyComplete(
     SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out) {
   TestState *test_state = GetTestState(ssl);
   if (test_state->private_key_result.empty()) {
     fprintf(stderr,
-            "AsyncPrivateKeyDecryptComplete called without operation "
-            "pending.\n");
+            "AsyncPrivateKeyComplete called without operation pending.\n");
     abort();
   }
 
@@ -287,9 +299,9 @@
     AsyncPrivateKeyType,
     AsyncPrivateKeyMaxSignatureLen,
     AsyncPrivateKeySign,
-    AsyncPrivateKeySignComplete,
+    nullptr /* sign_digest */,
     AsyncPrivateKeyDecrypt,
-    AsyncPrivateKeyDecryptComplete
+    AsyncPrivateKeyComplete,
 };
 
 template<typename T>