Convert ssl_privkey.c to message-based signing APIs.

This allows us to share some of the is_ecdsa mess between signing and
verifying in a way that will generalize to Ed25519. This makes it a lot
shorter and gets us closer to Ed25519.

Later work will tidy this up further.

BUG=187

Change-Id: Ibf3c07c48824061389b8c86294225d9ef25dd82d
Reviewed-on: https://boringssl-review.googlesource.com/14448
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/ssl_privkey.c b/ssl/ssl_privkey.c
index e988827..4006759 100644
--- a/ssl/ssl_privkey.c
+++ b/ssl/ssl_privkey.c
@@ -332,9 +332,6 @@
   return EVP_PKEY_size(ssl->cert->privatekey);
 }
 
-/* TODO(davidben): Forbid RSA-PKCS1 in TLS 1.3. For now we allow it because NSS
- * has yet to start doing RSA-PSS, so enforcing it would complicate interop
- * testing. */
 static int is_rsa_pkcs1(const EVP_MD **out_md, uint16_t sigalg) {
   switch (sigalg) {
     case SSL_SIGN_RSA_PKCS1_MD5_SHA1:
@@ -357,35 +354,20 @@
   }
 }
 
-static int ssl_sign_rsa_pkcs1(SSL *ssl, uint8_t *out, size_t *out_len,
-                              size_t max_out, const EVP_MD *md,
-                              const uint8_t *in, size_t in_len) {
-  EVP_MD_CTX ctx;
-  EVP_MD_CTX_init(&ctx);
-  *out_len = max_out;
-  int ret = EVP_DigestSignInit(&ctx, NULL, md, NULL, ssl->cert->privatekey) &&
-            EVP_DigestSignUpdate(&ctx, in, in_len) &&
-            EVP_DigestSignFinal(&ctx, out, out_len);
-  EVP_MD_CTX_cleanup(&ctx);
-  return ret;
-}
-
-static int ssl_verify_rsa_pkcs1(SSL *ssl, const uint8_t *signature,
-                                size_t signature_len, const EVP_MD *md,
-                                EVP_PKEY *pkey, const uint8_t *in,
-                                size_t in_len) {
-  if (pkey->type != EVP_PKEY_RSA) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
-    return 0;
+static int is_rsa_pss(const EVP_MD **out_md, uint16_t sigalg) {
+  switch (sigalg) {
+    case SSL_SIGN_RSA_PSS_SHA256:
+      *out_md = EVP_sha256();
+      return 1;
+    case SSL_SIGN_RSA_PSS_SHA384:
+      *out_md = EVP_sha384();
+      return 1;
+    case SSL_SIGN_RSA_PSS_SHA512:
+      *out_md = EVP_sha512();
+      return 1;
+    default:
+      return 0;
   }
-
-  EVP_MD_CTX md_ctx;
-  EVP_MD_CTX_init(&md_ctx);
-  int ret = EVP_DigestVerifyInit(&md_ctx, NULL, md, NULL, pkey) &&
-            EVP_DigestVerifyUpdate(&md_ctx, in, in_len) &&
-            EVP_DigestVerifyFinal(&md_ctx, signature, signature_len);
-  EVP_MD_CTX_cleanup(&md_ctx);
-  return ret;
 }
 
 static int is_ecdsa(int *out_curve, const EVP_MD **out_md, uint16_t sigalg) {
@@ -411,112 +393,53 @@
   }
 }
 
-static int ssl_sign_ecdsa(SSL *ssl, uint8_t *out, size_t *out_len,
-                          size_t max_out, int curve, const EVP_MD *md,
-                          const uint8_t *in, size_t in_len) {
-  EC_KEY *ec_key = EVP_PKEY_get0_EC_KEY(ssl->cert->privatekey);
-  if (ec_key == NULL) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
-    return 0;
-  }
+static int setup_ctx(SSL *ssl, EVP_PKEY_CTX *ctx,
+                     uint16_t signature_algorithm) {
+  EVP_PKEY *pkey = EVP_PKEY_CTX_get0_pkey(ctx);
 
-  /* In TLS 1.3, the curve is also specified by the signature algorithm. */
-  if (ssl3_protocol_version(ssl) >= TLS1_3_VERSION &&
-      (curve == NID_undef ||
-       EC_GROUP_get_curve_name(EC_KEY_get0_group(ec_key)) != curve)) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
-    return 0;
-  }
-
-  EVP_MD_CTX ctx;
-  EVP_MD_CTX_init(&ctx);
-  *out_len = max_out;
-  int ret = EVP_DigestSignInit(&ctx, NULL, md, NULL, ssl->cert->privatekey) &&
-            EVP_DigestSignUpdate(&ctx, in, in_len) &&
-            EVP_DigestSignFinal(&ctx, out, out_len);
-  EVP_MD_CTX_cleanup(&ctx);
-  return ret;
-}
-
-static int ssl_verify_ecdsa(SSL *ssl, const uint8_t *signature,
-                            size_t signature_len, int curve, const EVP_MD *md,
-                            EVP_PKEY *pkey, const uint8_t *in, size_t in_len) {
-  EC_KEY *ec_key = EVP_PKEY_get0_EC_KEY(pkey);
-  if (ec_key == NULL) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
-    return 0;
-  }
-
-  /* In TLS 1.3, the curve is also specified by the signature algorithm. */
-  if (ssl3_protocol_version(ssl) >= TLS1_3_VERSION &&
-      (curve == NID_undef ||
-       EC_GROUP_get_curve_name(EC_KEY_get0_group(ec_key)) != curve)) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
-    return 0;
-  }
-
-  EVP_MD_CTX md_ctx;
-  EVP_MD_CTX_init(&md_ctx);
-  int ret = EVP_DigestVerifyInit(&md_ctx, NULL, md, NULL, pkey) &&
-            EVP_DigestVerifyUpdate(&md_ctx, in, in_len) &&
-            EVP_DigestVerifyFinal(&md_ctx, signature, signature_len);
-  EVP_MD_CTX_cleanup(&md_ctx);
-  return ret;
-}
-
-static int is_rsa_pss(const EVP_MD **out_md, uint16_t sigalg) {
-  switch (sigalg) {
-    case SSL_SIGN_RSA_PSS_SHA256:
-      *out_md = EVP_sha256();
-      return 1;
-    case SSL_SIGN_RSA_PSS_SHA384:
-      *out_md = EVP_sha384();
-      return 1;
-    case SSL_SIGN_RSA_PSS_SHA512:
-      *out_md = EVP_sha512();
-      return 1;
-    default:
+  const EVP_MD *md;
+  if (is_rsa_pkcs1(&md, signature_algorithm) &&
+      ssl3_protocol_version(ssl) < TLS1_3_VERSION) {
+    if (pkey->type != EVP_PKEY_RSA) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
       return 0;
-  }
-}
+    }
 
-static int ssl_sign_rsa_pss(SSL *ssl, uint8_t *out, size_t *out_len,
-                              size_t max_out, const EVP_MD *md,
-                              const uint8_t *in, size_t in_len) {
-  EVP_MD_CTX ctx;
-  EVP_MD_CTX_init(&ctx);
-  *out_len = max_out;
-  EVP_PKEY_CTX *pctx;
-  int ret =
-      EVP_DigestSignInit(&ctx, &pctx, md, NULL, ssl->cert->privatekey) &&
-      EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PSS_PADDING) &&
-      EVP_PKEY_CTX_set_rsa_pss_saltlen(pctx, -1 /* salt len = hash len */) &&
-      EVP_DigestSignUpdate(&ctx, in, in_len) &&
-      EVP_DigestSignFinal(&ctx, out, out_len);
-  EVP_MD_CTX_cleanup(&ctx);
-  return ret;
-}
-
-static int ssl_verify_rsa_pss(SSL *ssl, const uint8_t *signature,
-                                size_t signature_len, const EVP_MD *md,
-                                EVP_PKEY *pkey, const uint8_t *in,
-                                size_t in_len) {
-  if (pkey->type != EVP_PKEY_RSA) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
-    return 0;
+    return EVP_PKEY_CTX_set_signature_md(ctx, md);
   }
 
-  EVP_MD_CTX md_ctx;
-  EVP_MD_CTX_init(&md_ctx);
-  EVP_PKEY_CTX *pctx;
-  int ret =
-      EVP_DigestVerifyInit(&md_ctx, &pctx, md, NULL, pkey) &&
-      EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PSS_PADDING) &&
-      EVP_PKEY_CTX_set_rsa_pss_saltlen(pctx, -1 /* salt len = hash len */) &&
-      EVP_DigestVerifyUpdate(&md_ctx, in, in_len) &&
-      EVP_DigestVerifyFinal(&md_ctx, signature, signature_len);
-  EVP_MD_CTX_cleanup(&md_ctx);
-  return ret;
+  if (is_rsa_pss(&md, signature_algorithm)) {
+    if (pkey->type != EVP_PKEY_RSA) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
+      return 0;
+    }
+
+    return EVP_PKEY_CTX_set_signature_md(ctx, md) &&
+           EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_PSS_PADDING) &&
+           EVP_PKEY_CTX_set_rsa_pss_saltlen(ctx, -1 /* salt len = hash len */);
+  }
+
+  int curve;
+  if (is_ecdsa(&curve, &md, signature_algorithm)) {
+    EC_KEY *ec_key = EVP_PKEY_get0_EC_KEY(pkey);
+    if (ec_key == NULL) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
+      return 0;
+    }
+
+    /* In TLS 1.3, the curve is also specified by the signature algorithm. */
+    if (ssl3_protocol_version(ssl) >= TLS1_3_VERSION &&
+        (curve == NID_undef ||
+         EC_GROUP_get_curve_name(EC_KEY_get0_group(ec_key)) != curve)) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
+      return 0;
+    }
+
+    return EVP_PKEY_CTX_set_signature_md(ctx, md);
+  }
+
+  OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
+  return 0;
 }
 
 enum ssl_private_key_result_t ssl_private_key_sign(
@@ -548,54 +471,26 @@
                                               hash, hash_len);
   }
 
-  const EVP_MD *md;
-  if (is_rsa_pkcs1(&md, signature_algorithm) &&
-      ssl3_protocol_version(ssl) < TLS1_3_VERSION) {
-    return ssl_sign_rsa_pkcs1(ssl, out, out_len, max_out, md, in, in_len)
-               ? ssl_private_key_success
-               : ssl_private_key_failure;
-  }
-
-  int curve;
-  if (is_ecdsa(&curve, &md, signature_algorithm)) {
-    return ssl_sign_ecdsa(ssl, out, out_len, max_out, curve, md, in, in_len)
-               ? ssl_private_key_success
-               : ssl_private_key_failure;
-  }
-
-  if (is_rsa_pss(&md, signature_algorithm)) {
-    return ssl_sign_rsa_pss(ssl, out, out_len, max_out, md, in, in_len)
-               ? ssl_private_key_success
-               : ssl_private_key_failure;
-  }
-
-  OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
-  return ssl_private_key_failure;
+  *out_len = max_out;
+  EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new(ssl->cert->privatekey, NULL);
+  int ret = ctx != NULL &&
+            EVP_PKEY_sign_init(ctx) &&
+            setup_ctx(ssl, ctx, signature_algorithm) &&
+            EVP_PKEY_sign_message(ctx, out, out_len, in, in_len);
+  EVP_PKEY_CTX_free(ctx);
+  return ret ? ssl_private_key_success : ssl_private_key_failure;
 }
 
 int ssl_public_key_verify(SSL *ssl, const uint8_t *signature,
                           size_t signature_len, uint16_t signature_algorithm,
                           EVP_PKEY *pkey, const uint8_t *in, size_t in_len) {
-  const EVP_MD *md;
-  if (is_rsa_pkcs1(&md, signature_algorithm) &&
-      ssl3_protocol_version(ssl) < TLS1_3_VERSION) {
-    return ssl_verify_rsa_pkcs1(ssl, signature, signature_len, md, pkey, in,
-                                in_len);
-  }
-
-  int curve;
-  if (is_ecdsa(&curve, &md, signature_algorithm)) {
-    return ssl_verify_ecdsa(ssl, signature, signature_len, curve, md, pkey, in,
-                            in_len);
-  }
-
-  if (is_rsa_pss(&md, signature_algorithm)) {
-    return ssl_verify_rsa_pss(ssl, signature, signature_len, md, pkey, in,
-                              in_len);
-  }
-
-  OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
-  return 0;
+  EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new(pkey, NULL);
+  int ret = ctx != NULL &&
+            EVP_PKEY_verify_init(ctx) &&
+            setup_ctx(ssl, ctx, signature_algorithm) &&
+            EVP_PKEY_verify_message(ctx, signature, signature_len, in, in_len);
+  EVP_PKEY_CTX_free(ctx);
+  return ret;
 }
 
 enum ssl_private_key_result_t ssl_private_key_decrypt(