Add crypto/rsa-level RSA-PSS functions.

This allows us to implement RSA-PSS in the FIPS module without pulling
in EVP_PKEY. It also allows people to use RSA-PSS on an RSA*.
Empirically folks seem to use the low-level padding functions a lot,
which is unfortunate.

This allows us to remove a now redundant length check in p_rsa.c.

Change-Id: I5270e01c6999d462d378865db2b858103c335485
Reviewed-on: https://boringssl-review.googlesource.com/15825
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/evp/evp_tests.txt b/crypto/evp/evp_tests.txt
index 1c87265..0d6dfdf 100644
--- a/crypto/evp/evp_tests.txt
+++ b/crypto/evp/evp_tests.txt
@@ -136,7 +136,7 @@
 Sign = RSA-2048
 Digest = SHA1
 Input = "0123456789ABCDEF12345"
-Error = INVALID_DIGEST_LENGTH
+Error = INVALID_MESSAGE_LENGTH
 
 Verify = RSA-2048
 Digest = SHA1
@@ -148,7 +148,7 @@
 Sign = RSA-2048
 Digest = SHA1
 Input = "0123456789ABCDEF123"
-Error = INVALID_DIGEST_LENGTH
+Error = INVALID_MESSAGE_LENGTH
 
 Verify = RSA-2048
 Digest = SHA1
@@ -254,7 +254,7 @@
 Digest = SHA256
 Input = "0123456789ABCDEF0123456789ABCDE"
 Output = 4de433d5844043ef08d354da03cb29068780d52706d7d1e4d50efb7d58c9d547d83a747ddd0635a96b28f854e50145518482cb49e963054621b53c60c498d07c16e9c2789c893cf38d4d86900de71bde463bd2761d1271e358c7480a1ac0bab930ddf39602ad1bc165b5d7436b516b7a7858e8eb7ab1c420eeb482f4d207f0e462b1724959320a084e13848d11d10fb593e66bf680bf6d3f345fc3e9c3de60abbac37e1c6ec80a268c8d9fc49626c679097aa690bc1aa662b95eb8db70390861aa0898229f9349b4b5fdd030d4928c47084708a933144be23bd3c6e661b85b2c0ef9ed36d498d5b7320e8194d363d4ad478c059bae804181965e0b81b663158a
-Error = INVALID_DIGEST_LENGTH
+Error = INVALID_MESSAGE_LENGTH
 
 # Digest too long
 Verify = RSA-2048-SPKI
@@ -263,7 +263,7 @@
 Digest = SHA256
 Input = "0123456789ABCDEF0123456789ABCDEF0"
 Output = 4de433d5844043ef08d354da03cb29068780d52706d7d1e4d50efb7d58c9d547d83a747ddd0635a96b28f854e50145518482cb49e963054621b53c60c498d07c16e9c2789c893cf38d4d86900de71bde463bd2761d1271e358c7480a1ac0bab930ddf39602ad1bc165b5d7436b516b7a7858e8eb7ab1c420eeb482f4d207f0e462b1724959320a084e13848d11d10fb593e66bf680bf6d3f345fc3e9c3de60abbac37e1c6ec80a268c8d9fc49626c679097aa690bc1aa662b95eb8db70390861aa0898229f9349b4b5fdd030d4928c47084708a933144be23bd3c6e661b85b2c0ef9ed36d498d5b7320e8194d363d4ad478c059bae804181965e0b81b663158a
-Error = INVALID_DIGEST_LENGTH
+Error = INVALID_MESSAGE_LENGTH
 
 # Wrong salt length
 Verify = RSA-2048
diff --git a/crypto/evp/p_rsa.c b/crypto/evp/p_rsa.c
index fcfca4b..b4598b6 100644
--- a/crypto/evp/p_rsa.c
+++ b/crypto/evp/p_rsa.c
@@ -180,18 +180,7 @@
   }
 
   if (rctx->md) {
-    unsigned int out_len;
-
-    if (tbslen != EVP_MD_size(rctx->md)) {
-      OPENSSL_PUT_ERROR(EVP, EVP_R_INVALID_DIGEST_LENGTH);
-      return 0;
-    }
-
-    if (EVP_MD_type(rctx->md) == NID_mdc2) {
-      OPENSSL_PUT_ERROR(EVP, EVP_R_NO_MDC2_SUPPORT);
-      return 0;
-    }
-
+    unsigned out_len;
     switch (rctx->pad_mode) {
       case RSA_PKCS1_PADDING:
         if (!RSA_sign(EVP_MD_type(rctx->md), tbs, tbslen, sig, &out_len, rsa)) {
@@ -201,14 +190,8 @@
         return 1;
 
       case RSA_PKCS1_PSS_PADDING:
-        if (!setup_tbuf(rctx, ctx) ||
-            !RSA_padding_add_PKCS1_PSS_mgf1(rsa, rctx->tbuf, tbs, rctx->md,
-                                            rctx->mgf1md, rctx->saltlen) ||
-            !RSA_sign_raw(rsa, siglen, sig, *siglen, rctx->tbuf, key_len,
-                          RSA_NO_PADDING)) {
-          return 0;
-        }
-        return 1;
+        return RSA_sign_pss_mgf1(rsa, siglen, sig, *siglen, tbs, tbslen,
+                                 rctx->md, rctx->mgf1md, rctx->saltlen);
 
       default:
         return 0;
@@ -223,8 +206,6 @@
                            size_t tbslen) {
   RSA_PKEY_CTX *rctx = ctx->data;
   RSA *rsa = ctx->pkey->pkey.rsa;
-  size_t rslen;
-  const size_t key_len = EVP_PKEY_size(ctx->pkey);
 
   if (rctx->md) {
     switch (rctx->pad_mode) {
@@ -232,25 +213,16 @@
         return RSA_verify(EVP_MD_type(rctx->md), tbs, tbslen, sig, siglen, rsa);
 
       case RSA_PKCS1_PSS_PADDING:
-        if (tbslen != EVP_MD_size(rctx->md)) {
-          OPENSSL_PUT_ERROR(EVP, EVP_R_INVALID_DIGEST_LENGTH);
-          return 0;
-        }
-
-        if (!setup_tbuf(rctx, ctx) ||
-            !RSA_verify_raw(rsa, &rslen, rctx->tbuf, key_len, sig, siglen,
-                            RSA_NO_PADDING) ||
-            !RSA_verify_PKCS1_PSS_mgf1(rsa, tbs, rctx->md, rctx->mgf1md,
-                                       rctx->tbuf, rctx->saltlen)) {
-          return 0;
-        }
-        return 1;
+        return RSA_verify_pss_mgf1(rsa, tbs, tbslen, rctx->md, rctx->mgf1md,
+                                   rctx->saltlen, sig, siglen);
 
       default:
         return 0;
     }
   }
 
+  size_t rslen;
+  const size_t key_len = EVP_PKEY_size(ctx->pkey);
   if (!setup_tbuf(rctx, ctx) ||
       !RSA_verify_raw(rsa, &rslen, rctx->tbuf, key_len, sig, siglen,
                       rctx->pad_mode) ||
diff --git a/crypto/rsa/rsa.c b/crypto/rsa/rsa.c
index e6cdce9..843e757 100644
--- a/crypto/rsa/rsa.c
+++ b/crypto/rsa/rsa.c
@@ -60,6 +60,7 @@
 #include <string.h>
 
 #include <openssl/bn.h>
+#include <openssl/digest.h>
 #include <openssl/engine.h>
 #include <openssl/err.h>
 #include <openssl/ex_data.h>
@@ -473,6 +474,29 @@
   return ret;
 }
 
+int RSA_sign_pss_mgf1(RSA *rsa, size_t *out_len, uint8_t *out, size_t max_out,
+                      const uint8_t *in, size_t in_len, const EVP_MD *md,
+                      const EVP_MD *mgf1_md, int salt_len) {
+  if (in_len != EVP_MD_size(md)) {
+    OPENSSL_PUT_ERROR(RSA, RSA_R_INVALID_MESSAGE_LENGTH);
+    return 0;
+  }
+
+  size_t padded_len = RSA_size(rsa);
+  uint8_t *padded = OPENSSL_malloc(padded_len);
+  if (padded == NULL) {
+    OPENSSL_PUT_ERROR(RSA, ERR_R_MALLOC_FAILURE);
+    return 0;
+  }
+
+  int ret =
+      RSA_padding_add_PKCS1_PSS_mgf1(rsa, padded, in, md, mgf1_md, salt_len) &&
+      RSA_sign_raw(rsa, out_len, out, max_out, padded, padded_len,
+                   RSA_NO_PADDING);
+  OPENSSL_free(padded);
+  return ret;
+}
+
 int RSA_verify(int hash_nid, const uint8_t *msg, size_t msg_len,
                const uint8_t *sig, size_t sig_len, RSA *rsa) {
   if (rsa->n == NULL || rsa->e == NULL) {
@@ -525,6 +549,38 @@
   return ret;
 }
 
+int RSA_verify_pss_mgf1(RSA *rsa, const uint8_t *msg, size_t msg_len,
+                        const EVP_MD *md, const EVP_MD *mgf1_md, int salt_len,
+                        const uint8_t *sig, size_t sig_len) {
+  if (msg_len != EVP_MD_size(md)) {
+    OPENSSL_PUT_ERROR(RSA, RSA_R_INVALID_MESSAGE_LENGTH);
+    return 0;
+  }
+
+  size_t em_len = RSA_size(rsa);
+  uint8_t *em = OPENSSL_malloc(em_len);
+  if (em == NULL) {
+    OPENSSL_PUT_ERROR(RSA, ERR_R_MALLOC_FAILURE);
+    return 0;
+  }
+
+  int ret = 0;
+  if (!RSA_verify_raw(rsa, &em_len, em, em_len, sig, sig_len, RSA_NO_PADDING)) {
+    goto err;
+  }
+
+  if (em_len != RSA_size(rsa)) {
+    OPENSSL_PUT_ERROR(RSA, ERR_R_INTERNAL_ERROR);
+    goto err;
+  }
+
+  ret = RSA_verify_PKCS1_PSS_mgf1(rsa, msg, md, mgf1_md, em, salt_len);
+
+err:
+  OPENSSL_free(em);
+  return ret;
+}
+
 static void bn_free_and_null(BIGNUM **bn) {
   BN_free(*bn);
   *bn = NULL;
diff --git a/include/openssl/rsa.h b/include/openssl/rsa.h
index ee1bdde..c5510c8 100644
--- a/include/openssl/rsa.h
+++ b/include/openssl/rsa.h
@@ -206,6 +206,24 @@
                             unsigned int in_len, uint8_t *out,
                             unsigned int *out_len, RSA *rsa);
 
+/* RSA_sign_pss_mgf1 signs |in_len| bytes from |in| with the public key from
+ * |rsa| using RSASSA-PSS with MGF1 as the mask generation function. It writes,
+ * at most, |max_out| bytes of signature data to |out|. The |max_out| argument
+ * must be, at least, |RSA_size| in order to ensure success. It returns 1 on
+ * success or zero on error.
+ *
+ * The |md| and |mgf1_md| arguments identify the hash used to calculate |msg|
+ * and the MGF1 hash, respectively. If |mgf1_md| is NULL, |md| is
+ * used.
+ *
+ * |salt_len| specifies the expected salt length in bytes. If |salt_len| is -1,
+ * then the salt length is the same as the hash length. If -2, then the salt
+ * length is maximal given the size of |rsa|. If unsure, use -1. */
+OPENSSL_EXPORT int RSA_sign_pss_mgf1(RSA *rsa, size_t *out_len, uint8_t *out,
+                                     size_t max_out, const uint8_t *in,
+                                     size_t in_len, const EVP_MD *md,
+                                     const EVP_MD *mgf1_md, int salt_len);
+
 /* RSA_sign_raw signs |in_len| bytes from |in| with the public key from |rsa|
  * and writes, at most, |max_out| bytes of signature data to |out|. The
  * |max_out| argument must be, at least, |RSA_size| in order to ensure success.
@@ -222,7 +240,7 @@
 /* RSA_verify verifies that |sig_len| bytes from |sig| are a valid,
  * RSASSA-PKCS1-v1_5 signature of |msg_len| bytes at |msg| by |rsa|.
  *
- * The |hash_nid| argument identifies the hash function used to calculate |in|
+ * The |hash_nid| argument identifies the hash function used to calculate |msg|
  * and is embedded in the resulting signature in order to prevent hash
  * confusion attacks. For example, it might be |NID_sha256|.
  *
@@ -233,6 +251,23 @@
 OPENSSL_EXPORT int RSA_verify(int hash_nid, const uint8_t *msg, size_t msg_len,
                               const uint8_t *sig, size_t sig_len, RSA *rsa);
 
+/* RSA_verify_pss_mgf1 verifies that |sig_len| bytes from |sig| are a valid,
+ * RSASSA-PSS signature of |msg_len| bytes at |msg| by |rsa|. It returns one if
+ * the signature is valid and zero otherwise. MGF1 is used as the mask
+ * generation function.
+ *
+ * The |md| and |mgf1_md| arguments identify the hash used to calculate |msg|
+ * and the MGF1 hash, respectively. If |mgf1_md| is NULL, |md| is
+ * used. |salt_len| specifies the expected salt length in bytes.
+ *
+ * If |salt_len| is -1, then the salt length is the same as the hash length. If
+ * -2, then the salt length is recovered and all values accepted. If unsure, use
+ * -1. */
+OPENSSL_EXPORT int RSA_verify_pss_mgf1(RSA *rsa, const uint8_t *msg,
+                                       size_t msg_len, const EVP_MD *md,
+                                       const EVP_MD *mgf1_md, int salt_len,
+                                       const uint8_t *sig, size_t sig_len);
+
 /* RSA_verify_raw verifies |in_len| bytes of signature from |in| using the
  * public key from |rsa| and writes, at most, |max_out| bytes of plaintext to
  * |out|. The |max_out| argument must be, at least, |RSA_size| in order to
@@ -318,7 +353,10 @@
  *
  * If unsure, use -1.
  *
- * It returns one on success or zero on error. */
+ * It returns one on success or zero on error.
+ *
+ * This function implements only the low-level padding logic. Use
+ * |RSA_verify_pss_mgf1| instead. */
 OPENSSL_EXPORT int RSA_verify_PKCS1_PSS_mgf1(RSA *rsa, const uint8_t *mHash,
                                              const EVP_MD *Hash,
                                              const EVP_MD *mgf1Hash,
@@ -332,7 +370,10 @@
  * the salt length is the same as the hash length. If -2, then the salt length
  * is maximal given the space in |EM|.
  *
- * It returns one on success or zero on error. */
+ * It returns one on success or zero on error.
+ *
+ * This function implements only the low-level padding logic. Use
+ * |RSA_sign_pss_mgf1| instead. */
 OPENSSL_EXPORT int RSA_padding_add_PKCS1_PSS_mgf1(RSA *rsa, uint8_t *EM,
                                                   const uint8_t *mHash,
                                                   const EVP_MD *Hash,
@@ -497,13 +538,19 @@
 OPENSSL_EXPORT int i2d_RSAPrivateKey(const RSA *in, uint8_t **outp);
 
 /* RSA_padding_add_PKCS1_PSS acts like |RSA_padding_add_PKCS1_PSS_mgf1| but the
- * |mgf1Hash| parameter of the latter is implicitly set to |Hash|. */
+ * |mgf1Hash| parameter of the latter is implicitly set to |Hash|.
+ *
+ * This function implements only the low-level padding logic. Use
+ * |RSA_sign_pss_mgf1| instead. */
 OPENSSL_EXPORT int RSA_padding_add_PKCS1_PSS(RSA *rsa, uint8_t *EM,
                                              const uint8_t *mHash,
                                              const EVP_MD *Hash, int sLen);
 
 /* RSA_verify_PKCS1_PSS acts like |RSA_verify_PKCS1_PSS_mgf1| but the
- * |mgf1Hash| parameter of the latter is implicitly set to |Hash|. */
+ * |mgf1Hash| parameter of the latter is implicitly set to |Hash|.
+ *
+ * This function implements only the low-level padding logic. Use
+ * |RSA_verify_pss_mgf1| instead. */
 OPENSSL_EXPORT int RSA_verify_PKCS1_PSS(RSA *rsa, const uint8_t *mHash,
                                         const EVP_MD *Hash, const uint8_t *EM,
                                         int sLen);