Don't decompose signature algorithms in ssl_rsa.c.

This is a lot more verbose and looks the same between RSA and ECDSA for
now, but it gives us room to implement the various algorithm-specific
checks. ECDSA algorithms must match the curve, PKCS#1 is forbidden in
TLS 1.3, etc.

Change-Id: I348cfae664d7b08195a2ab1190820b410e74c5e9
Reviewed-on: https://boringssl-review.googlesource.com/8694
Reviewed-by: Steven Valdez <svaldez@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/ssl_rsa.c b/ssl/ssl_rsa.c
index 6dcbcc9..3dd8ae0 100644
--- a/ssl/ssl_rsa.c
+++ b/ssl/ssl_rsa.c
@@ -399,77 +399,137 @@
   return EVP_PKEY_size(ssl->cert->privatekey);
 }
 
-/* tls12_get_hash returns the EVP_MD corresponding to the TLS signature
- * algorithm |sigalg|. It returns NULL if the type is unknown. */
-static const EVP_MD *tls12_get_hash(uint16_t sigalg) {
-  if (sigalg == SSL_SIGN_RSA_PKCS1_MD5_SHA1) {
-    return EVP_md5_sha1();
-  }
-
-  switch (sigalg >> 8) {
-    case TLSEXT_hash_sha1:
-      return EVP_sha1();
-
-    case TLSEXT_hash_sha256:
-      return EVP_sha256();
-
-    case TLSEXT_hash_sha384:
-      return EVP_sha384();
-
-    case TLSEXT_hash_sha512:
-      return EVP_sha512();
-
+static int is_rsa_pkcs1(const EVP_MD **out_md, uint16_t sigalg) {
+  switch (sigalg) {
+    case SSL_SIGN_RSA_PKCS1_MD5_SHA1:
+      *out_md = EVP_md5_sha1();
+      return 1;
+    case SSL_SIGN_RSA_PKCS1_SHA1:
+      *out_md = EVP_sha1();
+      return 1;
+    case SSL_SIGN_RSA_PKCS1_SHA256:
+      *out_md = EVP_sha256();
+      return 1;
+    case SSL_SIGN_RSA_PKCS1_SHA384:
+      *out_md = EVP_sha384();
+      return 1;
+    case SSL_SIGN_RSA_PKCS1_SHA512:
+      *out_md = EVP_sha512();
+      return 1;
     default:
-      return NULL;
+      return 0;
   }
 }
 
+static int ssl_sign_rsa_pkcs1(SSL *ssl, uint8_t *out, size_t *out_len,
+                              size_t max_out, const EVP_MD *md,
+                              const uint8_t *in, size_t in_len) {
+  EVP_MD_CTX ctx;
+  EVP_MD_CTX_init(&ctx);
+  *out_len = max_out;
+  int ret = EVP_DigestSignInit(&ctx, NULL, md, NULL, ssl->cert->privatekey) &&
+            EVP_DigestSignUpdate(&ctx, in, in_len) &&
+            EVP_DigestSignFinal(&ctx, out, out_len);
+  EVP_MD_CTX_cleanup(&ctx);
+  return ret;
+}
+
+static int ssl_verify_rsa_pkcs1(SSL *ssl, const uint8_t *signature,
+                                size_t signature_len, const EVP_MD *md,
+                                EVP_PKEY *pkey, const uint8_t *in,
+                                size_t in_len) {
+  EVP_MD_CTX md_ctx;
+  EVP_MD_CTX_init(&md_ctx);
+  int ret = EVP_DigestVerifyInit(&md_ctx, NULL, md, NULL, pkey) &&
+            EVP_DigestVerifyUpdate(&md_ctx, in, in_len) &&
+            EVP_DigestVerifyFinal(&md_ctx, signature, signature_len);
+  EVP_MD_CTX_cleanup(&md_ctx);
+  return ret;
+}
+
+static int is_ecdsa(const EVP_MD **out_md, uint16_t sigalg) {
+  switch (sigalg) {
+    case SSL_SIGN_ECDSA_SHA1:
+      *out_md = EVP_sha1();
+      return 1;
+    case SSL_SIGN_ECDSA_SECP256R1_SHA256:
+      *out_md = EVP_sha256();
+      return 1;
+    case SSL_SIGN_ECDSA_SECP384R1_SHA384:
+      *out_md = EVP_sha384();
+      return 1;
+    case SSL_SIGN_ECDSA_SECP521R1_SHA512:
+      *out_md = EVP_sha512();
+      return 1;
+    default:
+      return 0;
+  }
+}
+
+static int ssl_sign_ecdsa(SSL *ssl, uint8_t *out, size_t *out_len,
+                          size_t max_out, const EVP_MD *md, const uint8_t *in,
+                          size_t in_len) {
+  EVP_MD_CTX ctx;
+  EVP_MD_CTX_init(&ctx);
+  *out_len = max_out;
+  int ret = EVP_DigestSignInit(&ctx, NULL, md, NULL, ssl->cert->privatekey) &&
+            EVP_DigestSignUpdate(&ctx, in, in_len) &&
+            EVP_DigestSignFinal(&ctx, out, out_len);
+  EVP_MD_CTX_cleanup(&ctx);
+  return ret;
+}
+
+static int ssl_verify_ecdsa(SSL *ssl, const uint8_t *signature,
+                            size_t signature_len, const EVP_MD *md,
+                            EVP_PKEY *pkey, const uint8_t *in, size_t in_len) {
+  EVP_MD_CTX md_ctx;
+  EVP_MD_CTX_init(&md_ctx);
+  int ret = EVP_DigestVerifyInit(&md_ctx, NULL, md, NULL, pkey) &&
+            EVP_DigestVerifyUpdate(&md_ctx, in, in_len) &&
+            EVP_DigestVerifyFinal(&md_ctx, signature, signature_len);
+  EVP_MD_CTX_cleanup(&md_ctx);
+  return ret;
+}
+
 enum ssl_private_key_result_t ssl_private_key_sign(
     SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out,
     uint16_t signature_algorithm, const uint8_t *in, size_t in_len) {
-  const EVP_MD *md = tls12_get_hash(signature_algorithm);
-  if (md == NULL) {
-    return ssl_private_key_failure;
-  }
-
-  EVP_MD_CTX mctx;
-  uint8_t hash[EVP_MAX_MD_SIZE];
-  unsigned hash_len;
-
-  EVP_MD_CTX_init(&mctx);
-  if (!EVP_DigestInit_ex(&mctx, md, NULL) ||
-      !EVP_DigestUpdate(&mctx, in, in_len) ||
-      !EVP_DigestFinal(&mctx, hash, &hash_len)) {
-    OPENSSL_PUT_ERROR(SSL, ERR_R_EVP_LIB);
-    EVP_MD_CTX_cleanup(&mctx);
-    return 0;
-  }
-  EVP_MD_CTX_cleanup(&mctx);
-
-
   if (ssl->cert->key_method != NULL) {
+    /* For now, custom private keys can only handle pre-TLS-1.3 signature
+     * algorithms.
+     *
+     * TODO(davidben): Switch SSL_PRIVATE_KEY_METHOD to message-based APIs. */
+    const EVP_MD *md;
+    if (!is_rsa_pkcs1(&md, signature_algorithm) &&
+        !is_ecdsa(&md, signature_algorithm)) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_UNSUPPORTED_PROTOCOL_FOR_CUSTOM_KEY);
+      return ssl_private_key_failure;
+    }
+
+    uint8_t hash[EVP_MAX_MD_SIZE];
+    unsigned hash_len;
+    if (!EVP_Digest(in, in_len, hash, &hash_len, md, NULL)) {
+      return ssl_private_key_failure;
+    }
+
     return ssl->cert->key_method->sign(ssl, out, out_len, max_out, md, hash,
                                        hash_len);
   }
 
-  enum ssl_private_key_result_t ret = ssl_private_key_failure;
-  EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new(ssl->cert->privatekey, NULL);
-  if (ctx == NULL) {
-    goto end;
+  const EVP_MD *md;
+  if (is_rsa_pkcs1(&md, signature_algorithm)) {
+    return ssl_sign_rsa_pkcs1(ssl, out, out_len, max_out, md, in, in_len)
+               ? ssl_private_key_success
+               : ssl_private_key_failure;
+  }
+  if (is_ecdsa(&md, signature_algorithm)) {
+    return ssl_sign_ecdsa(ssl, out, out_len, max_out, md, in, in_len)
+               ? ssl_private_key_success
+               : ssl_private_key_failure;
   }
 
-  size_t len = max_out;
-  if (!EVP_PKEY_sign_init(ctx) ||
-      !EVP_PKEY_CTX_set_signature_md(ctx, md) ||
-      !EVP_PKEY_sign(ctx, out, &len, hash, hash_len)) {
-    goto end;
-  }
-  *out_len = len;
-  ret = ssl_private_key_success;
-
-end:
-  EVP_PKEY_CTX_free(ctx);
-  return ret;
+  OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
+  return ssl_private_key_failure;
 }
 
 enum ssl_private_key_result_t ssl_private_key_sign_complete(
@@ -481,18 +541,18 @@
 int ssl_public_key_verify(SSL *ssl, const uint8_t *signature,
                           size_t signature_len, uint16_t signature_algorithm,
                           EVP_PKEY *pkey, const uint8_t *in, size_t in_len) {
-  const EVP_MD *md = tls12_get_hash(signature_algorithm);
-  if (md == NULL) {
-    return 0;
+  const EVP_MD *md;
+  if (is_rsa_pkcs1(&md, signature_algorithm)) {
+    return ssl_verify_rsa_pkcs1(ssl, signature, signature_len, md, pkey, in,
+                                in_len);
+  }
+  if (is_ecdsa(&md, signature_algorithm)) {
+    return ssl_verify_ecdsa(ssl, signature, signature_len, md, pkey, in,
+                            in_len);
   }
 
-  EVP_MD_CTX md_ctx;
-  EVP_MD_CTX_init(&md_ctx);
-  int ret = EVP_DigestVerifyInit(&md_ctx, NULL, md, NULL, pkey) &&
-            EVP_DigestVerifyUpdate(&md_ctx, in, in_len) &&
-            EVP_DigestVerifyFinal(&md_ctx, signature, signature_len);
-  EVP_MD_CTX_cleanup(&md_ctx);
-  return ret;
+  OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SIGNATURE_TYPE);
+  return 0;
 }
 
 enum ssl_private_key_result_t ssl_private_key_decrypt(