Store the PSS parameters in the RSA object

In preparation for adding support for the other modes. OpenSSL similarly
stores it in there, but as a whole RSA_PSS_PARAMS object.

In principle we could implement RSA_get0_pss_params now, but I haven't
bothered filling that in yet.

Bug: 384818542
Change-Id: Ida824e0cf80d2bf233b76dbbea68a9e20980ddbe
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/81988
Auto-Submit: David Benjamin <davidben@google.com>
Reviewed-by: Lily Chen <chlily@google.com>
Commit-Queue: Lily Chen <chlily@google.com>
diff --git a/crypto/evp/p_rsa.cc b/crypto/evp/p_rsa.cc
index 9692307..0cdfa91 100644
--- a/crypto/evp/p_rsa.cc
+++ b/crypto/evp/p_rsa.cc
@@ -65,14 +65,16 @@
 
   if (is_pss_only(ctx)) {
     rctx->pad_mode = RSA_PKCS1_PSS_PADDING;
-    // Pick up PSS parameters from the key. For now, we only support the SHA-256
-    // parameter set, so every key is necessarily SHA-256. If we ever support
-    // other parameters, we will need more state in |EVP_PKEY| and to translate
-    // that state into defaults here.
-    if (ctx->pkey != nullptr) {
-      rctx->md = rctx->mgf1md = EVP_sha256();
-      rctx->saltlen = EVP_MD_size(rctx->md);
-      rctx->restrict_pss_params = true;
+    // Pick up PSS parameters from the key.
+    if (ctx->pkey != nullptr && ctx->pkey->pkey != nullptr) {
+      RSA *rsa = static_cast<RSA *>(ctx->pkey->pkey);
+      const EVP_MD *md = rsa_pss_params_get_md(rsa->pss_params);
+      if (md != nullptr) {
+        rctx->md = rctx->mgf1md = md;
+        // All our supported modes use the digest length as the salt length.
+        rctx->saltlen = EVP_MD_size(rctx->md);
+        rctx->restrict_pss_params = true;
+      }
     }
   }
 
diff --git a/crypto/evp/p_rsa_asn1.cc b/crypto/evp/p_rsa_asn1.cc
index e0a9133..6598053 100644
--- a/crypto/evp/p_rsa_asn1.cc
+++ b/crypto/evp/p_rsa_asn1.cc
@@ -147,7 +147,7 @@
       !CBB_add_asn1_element(&algorithm, CBS_ASN1_OBJECT,
                             rsa_pss_sha256_asn1_meth.oid,
                             rsa_pss_sha256_asn1_meth.oid_len) ||
-      !rsa_marshal_pss_params(&algorithm, rsa_pss_sha256) ||
+      !rsa_marshal_pss_params(&algorithm, rsa->pss_params) ||
       !CBB_add_asn1(&spki, &key_bitstring, CBS_ASN1_BITSTRING) ||
       !CBB_add_u8(&key_bitstring, 0 /* padding */) ||
       !RSA_marshal_public_key(&key_bitstring, rsa) ||  //
@@ -174,6 +174,7 @@
     return evp_decode_error;
   }
 
+  rsa->pss_params = rsa_pss_sha256;
   evp_pkey_set0(out, &rsa_pss_sha256_asn1_meth, rsa.release());
   return evp_decode_ok;
 }
@@ -187,7 +188,7 @@
       !CBB_add_asn1_element(&algorithm, CBS_ASN1_OBJECT,
                             rsa_pss_sha256_asn1_meth.oid,
                             rsa_pss_sha256_asn1_meth.oid_len) ||
-      !rsa_marshal_pss_params(&algorithm, rsa_pss_sha256) ||
+      !rsa_marshal_pss_params(&algorithm, rsa->pss_params) ||
       !CBB_add_asn1(&pkcs8, &private_key, CBS_ASN1_OCTETSTRING) ||
       !RSA_marshal_private_key(&private_key, rsa) ||  //
       !CBB_flush(out)) {
@@ -213,6 +214,7 @@
     return evp_decode_error;
   }
 
+  rsa->pss_params = rsa_pss_sha256;
   evp_pkey_set0(out, &rsa_pss_sha256_asn1_meth, rsa.release());
   return evp_decode_ok;
 }
diff --git a/crypto/fipsmodule/rsa/internal.h b/crypto/fipsmodule/rsa/internal.h
index 94719ce..12fa09c 100644
--- a/crypto/fipsmodule/rsa/internal.h
+++ b/crypto/fipsmodule/rsa/internal.h
@@ -29,6 +29,20 @@
 
 typedef struct bn_blinding_st BN_BLINDING;
 
+// TODO(davidben): This is inside BCM because |RSA| is inside BCM, but BCM never
+// uses this. Split the RSA type in two.
+enum rsa_pss_params_t {
+  // No parameters.
+  // TODO(davidben): Remove this and use std::optional where appropriate.
+  rsa_pss_none = 0,
+  // RSA-PSS using SHA-256, MGF1 with SHA-256, salt length 32.
+  rsa_pss_sha256,
+  // RSA-PSS using SHA-384, MGF1 with SHA-384, salt length 48.
+  rsa_pss_sha384,
+  // RSA-PSS using SHA-512, MGF1 with SHA-512, salt length 64.
+  rsa_pss_sha512,
+};
+
 struct rsa_st {
   RSA_METHOD *meth;
 
@@ -75,6 +89,10 @@
   unsigned char *blindings_inuse;
   uint64_t blinding_fork_generation;
 
+  // pss_params is the RSA-PSS parameters associated with the key. This is not
+  // used by the low-level RSA implementation, just the EVP layer.
+  rsa_pss_params_t pss_params;
+
   // private_key_frozen is one if the key has been used for a private key
   // operation and may no longer be mutated.
   unsigned private_key_frozen:1;
diff --git a/crypto/rsa/internal.h b/crypto/rsa/internal.h
index 5fb4f17..887f491 100644
--- a/crypto/rsa/internal.h
+++ b/crypto/rsa/internal.h
@@ -17,6 +17,8 @@
 
 #include <openssl/base.h>
 
+#include "../fipsmodule/rsa/internal.h"
+
 #if defined(__cplusplus)
 extern "C" {
 #endif
@@ -28,28 +30,21 @@
                                       size_t param_len, const EVP_MD *md,
                                       const EVP_MD *mgf1md);
 
-enum rsa_pss_params_t {
-  // RSA-PSS using SHA-256, MGF1 with SHA-256, salt length 32.
-  rsa_pss_sha256,
-  // RSA-PSS using SHA-384, MGF1 with SHA-384, salt length 48.
-  rsa_pss_sha384,
-  // RSA-PSS using SHA-512, MGF1 with SHA-512, salt length 64.
-  rsa_pss_sha512,
-};
-
 // rsa_pss_params_get_md returns the hash function used with |params|. This also
 // specifies the MGF-1 hash and the salt length because we do not support other
 // configurations.
 const EVP_MD *rsa_pss_params_get_md(rsa_pss_params_t params);
 
 // rsa_marshal_pss_params marshals |params| as a DER-encoded RSASSA-PSS-params
-// (RFC 4055). It returns one on success and zero on error.
+// (RFC 4055). It returns one on success and zero on error. If |params| is
+// |rsa_pss_params_none|, this function gives an error.
 int rsa_marshal_pss_params(CBB *cbb, rsa_pss_params_t params);
 
 // rsa_marshal_pss_params decodes a DER-encoded RSASSA-PSS-params
 // (RFC 4055). It returns one on success and zero on error. On success, it sets
 // |*out| to the result. If |allow_explicit_trailer| is non-zero, an explicit
-// encoding of the trailerField is allowed, although it is not valid DER.
+// encoding of the trailerField is allowed, although it is not valid DER. This
+// function never outputs |rsa_pss_params_none|.
 int rsa_parse_pss_params(CBS *cbs, rsa_pss_params_t *out,
                          int allow_explicit_trailer);
 
diff --git a/crypto/rsa/rsa_asn1.cc b/crypto/rsa/rsa_asn1.cc
index a57ff0e..f2e76fe 100644
--- a/crypto/rsa/rsa_asn1.cc
+++ b/crypto/rsa/rsa_asn1.cc
@@ -274,6 +274,8 @@
 
 const EVP_MD *rsa_pss_params_get_md(rsa_pss_params_t params) {
   switch (params) {
+    case rsa_pss_none:
+      return nullptr;
     case rsa_pss_sha256:
       return EVP_sha256();
     case rsa_pss_sha384:
@@ -287,6 +289,9 @@
 int rsa_marshal_pss_params(CBB *cbb, rsa_pss_params_t params) {
   bssl::Span<const uint8_t> bytes;
   switch (params) {
+    case rsa_pss_none:
+      OPENSSL_PUT_ERROR(RSA, ERR_R_INTERNAL_ERROR);
+      return 0;
     case rsa_pss_sha256:
       bytes = kPSSParamsSHA256;
       break;
diff --git a/crypto/x509/rsa_pss.cc b/crypto/x509/rsa_pss.cc
index 77ec8d4..78398ce 100644
--- a/crypto/x509/rsa_pss.cc
+++ b/crypto/x509/rsa_pss.cc
@@ -152,6 +152,10 @@
   const char *hash_str = nullptr;
   uint32_t salt_len = 0;
   switch (params) {
+    case rsa_pss_none:
+      // |rsa_pss_decode| will never return this.
+      OPENSSL_PUT_ERROR(X509, ERR_R_INTERNAL_ERROR);
+      return 0;
     case rsa_pss_sha256:
       hash_str = "sha256";
       salt_len = 32;