acvp: RSA signature generation tests.

Change-Id: Ibc794a66ea9b04e2d48c2124d52234a0bed10aff
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/43625
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: Adam Langley <agl@google.com>
diff --git a/util/fipstools/acvp/acvptool/subprocess/rsa.go b/util/fipstools/acvp/acvptool/subprocess/rsa.go
index 3133d91..0825c04 100644
--- a/util/fipstools/acvp/acvptool/subprocess/rsa.go
+++ b/util/fipstools/acvp/acvptool/subprocess/rsa.go
@@ -57,6 +57,36 @@
 	D  string `json:"d"`
 }
 
+type rsaSigGenTestVectorSet struct {
+	Groups []rsaSigGenGroup `json:"testGroups"`
+}
+
+type rsaSigGenGroup struct {
+	ID          uint64          `json:"tgId"`
+	Type        string          `json:"testType"`
+	SigType     string          `json:"sigType"`
+	ModulusBits uint32          `json:"modulo"`
+	Hash        string          `json:"hashAlg"`
+	Tests       []rsaSigGenTest `json:"tests"`
+}
+
+type rsaSigGenTest struct {
+	ID         uint64 `json:"tcId"`
+	MessageHex string `json:"message"`
+}
+
+type rsaSigGenTestGroupResponse struct {
+	ID    uint64                  `json:"tgId"`
+	N     string                  `json:"n"`
+	E     string                  `json:"e"`
+	Tests []rsaSigGenTestResponse `json:"tests"`
+}
+
+type rsaSigGenTestResponse struct {
+	ID  uint64 `json:"tcId"`
+	Sig string `json:"signature"`
+}
+
 func processKeyGen(vectorSet []byte, m Transactable) (interface{}, error) {
 	var parsed rsaKeyGenTestVectorSet
 	if err := json.Unmarshal(vectorSet, &parsed); err != nil {
@@ -98,6 +128,57 @@
 	return ret, nil
 }
 
+func processSigGen(vectorSet []byte, m Transactable) (interface{}, error) {
+	var parsed rsaSigGenTestVectorSet
+	if err := json.Unmarshal(vectorSet, &parsed); err != nil {
+		return nil, err
+	}
+
+	var ret []rsaSigGenTestGroupResponse
+
+	for _, group := range parsed.Groups {
+		// GDT means "Generated data test", i.e. "please generate an RSA signature".
+		const expectedType = "GDT"
+		if group.Type != expectedType {
+			return nil, fmt.Errorf("RSA SigGen test group has type %q, but only generation tests (%q) are supported", group.Type, expectedType)
+		}
+
+		response := rsaSigGenTestGroupResponse{
+			ID: group.ID,
+		}
+
+		operation := "RSA/sigGen/" + group.Hash + "/" + group.SigType
+
+		for _, test := range group.Tests {
+			msg, err := hex.DecodeString(test.MessageHex)
+			if err != nil {
+				return nil, fmt.Errorf("test case %d/%d contains invalid hex: %s", group.ID, test.ID, err)
+			}
+
+			results, err := m.Transact(operation, 3, uint32le(group.ModulusBits), msg)
+			if err != nil {
+				return nil, err
+			}
+
+			if len(response.N) == 0 {
+				response.N = hex.EncodeToString(results[0])
+				response.E = hex.EncodeToString(results[1])
+			} else if response.N != hex.EncodeToString(results[0]) {
+				return nil, fmt.Errorf("module wrapper returned different RSA keys for the same SigGen configuration")
+			}
+
+			response.Tests = append(response.Tests, rsaSigGenTestResponse{
+				ID:  test.ID,
+				Sig: hex.EncodeToString(results[2]),
+			})
+		}
+
+		ret = append(ret, response)
+	}
+
+	return ret, nil
+}
+
 type rsa struct{}
 
 func (*rsa) Process(vectorSet []byte, m Transactable) (interface{}, error) {
@@ -109,6 +190,8 @@
 	switch parsed.Mode {
 	case "keyGen":
 		return processKeyGen(vectorSet, m)
+	case "sigGen":
+		return processSigGen(vectorSet, m)
 	default:
 		return nil, fmt.Errorf("Unknown RSA mode %q", parsed.Mode)
 	}
diff --git a/util/fipstools/acvp/modulewrapper/modulewrapper.cc b/util/fipstools/acvp/modulewrapper/modulewrapper.cc
index 48b96da..aad2a3c 100644
--- a/util/fipstools/acvp/modulewrapper/modulewrapper.cc
+++ b/util/fipstools/acvp/modulewrapper/modulewrapper.cc
@@ -12,6 +12,7 @@
  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
 
+#include <map>
 #include <string>
 #include <vector>
 
@@ -425,6 +426,123 @@
         }]
       },
       {
+        "algorithm": "RSA",
+        "mode": "sigGen",
+        "revision": "FIPS186-4",
+        "capabilities": [{
+          "sigType": "pkcs1v1.5",
+          "properties": [{
+            "modulo": 2048,
+            "hashPair": [{
+              "hashAlg": "SHA2-224"
+            }, {
+              "hashAlg": "SHA2-256"
+            }, {
+              "hashAlg": "SHA2-384"
+            }, {
+              "hashAlg": "SHA2-512"
+            }, {
+              "hashAlg": "SHA-1"
+            }]
+          }]
+        },{
+          "sigType": "pkcs1v1.5",
+          "properties": [{
+            "modulo": 3072,
+            "hashPair": [{
+              "hashAlg": "SHA2-224"
+            }, {
+              "hashAlg": "SHA2-256"
+            }, {
+              "hashAlg": "SHA2-384"
+            }, {
+              "hashAlg": "SHA2-512"
+            }, {
+              "hashAlg": "SHA-1"
+            }]
+          }]
+        },{
+          "sigType": "pkcs1v1.5",
+          "properties": [{
+            "modulo": 4096,
+            "hashPair": [{
+              "hashAlg": "SHA2-224"
+            }, {
+              "hashAlg": "SHA2-256"
+            }, {
+              "hashAlg": "SHA2-384"
+            }, {
+              "hashAlg": "SHA2-512"
+            }, {
+              "hashAlg": "SHA-1"
+            }]
+          }]
+        },{
+          "sigType": "pss",
+          "properties": [{
+            "modulo": 2048,
+            "hashPair": [{
+              "hashAlg": "SHA2-224",
+              "saltLen": 28
+            }, {
+              "hashAlg": "SHA2-256",
+              "saltLen": 32
+            }, {
+              "hashAlg": "SHA2-384",
+              "saltLen": 48
+            }, {
+              "hashAlg": "SHA2-512",
+              "saltLen": 64
+            }, {
+              "hashAlg": "SHA-1",
+              "saltLen": 20
+            }]
+          }]
+        },{
+          "sigType": "pss",
+          "properties": [{
+            "modulo": 3072,
+            "hashPair": [{
+              "hashAlg": "SHA2-224",
+              "saltLen": 28
+            }, {
+              "hashAlg": "SHA2-256",
+              "saltLen": 32
+            }, {
+              "hashAlg": "SHA2-384",
+              "saltLen": 48
+            }, {
+              "hashAlg": "SHA2-512",
+              "saltLen": 64
+            }, {
+              "hashAlg": "SHA-1",
+              "saltLen": 20
+            }]
+          }]
+        },{
+          "sigType": "pss",
+          "properties": [{
+            "modulo": 4096,
+            "hashPair": [{
+              "hashAlg": "SHA2-224",
+              "saltLen": 28
+            }, {
+              "hashAlg": "SHA2-256",
+              "saltLen": 32
+            }, {
+              "hashAlg": "SHA2-384",
+              "saltLen": 48
+            }, {
+              "hashAlg": "SHA2-512",
+              "saltLen": 64
+            }, {
+              "hashAlg": "SHA-1",
+              "saltLen": 20
+            }]
+          }]
+        }]
+      },
+      {
         "algorithm": "CMAC-AES",
         "revision": "1.0",
         "capabilities": [{
@@ -1032,6 +1150,28 @@
   return WriteReply(STDOUT_FILENO, Span<const uint8_t>(mac, mac_len));
 }
 
+static std::map<unsigned, bssl::UniquePtr<RSA>>& CachedRSAKeys() {
+  static std::map<unsigned, bssl::UniquePtr<RSA>> keys;
+  return keys;
+}
+
+static RSA* GetRSAKey(unsigned bits) {
+  auto it = CachedRSAKeys().find(bits);
+  if (it != CachedRSAKeys().end()) {
+    return it->second.get();
+  }
+
+  bssl::UniquePtr<RSA> key(RSA_new());
+  if (!RSA_generate_key_fips(key.get(), bits, nullptr)) {
+    abort();
+  }
+
+  RSA *const ret = key.get();
+  CachedRSAKeys().emplace(static_cast<unsigned>(bits), std::move(key));
+
+  return ret;
+}
+
 static bool RSAKeyGen(const Span<const uint8_t> args[]) {
   uint32_t bits;
   if (args[0].size() != sizeof(bits)) {
@@ -1050,8 +1190,52 @@
   RSA_get0_key(key.get(), &n, &e, &d);
   RSA_get0_factors(key.get(), &p, &q);
 
-  return WriteReply(STDOUT_FILENO, BIGNUMBytes(e), BIGNUMBytes(p),
-                    BIGNUMBytes(q), BIGNUMBytes(n), BIGNUMBytes(d));
+  if (!WriteReply(STDOUT_FILENO, BIGNUMBytes(e), BIGNUMBytes(p), BIGNUMBytes(q),
+                  BIGNUMBytes(n), BIGNUMBytes(d))) {
+    return false;
+  }
+
+  CachedRSAKeys().emplace(static_cast<unsigned>(bits), std::move(key));
+  return true;
+}
+
+template<const EVP_MD *(MDFunc)(), bool UsePSS>
+static bool RSASigGen(const Span<const uint8_t> args[]) {
+  uint32_t bits;
+  if (args[0].size() != sizeof(bits)) {
+    return false;
+  }
+  memcpy(&bits, args[0].data(), sizeof(bits));
+  const Span<const uint8_t> msg = args[1];
+
+  RSA *const key = GetRSAKey(bits);
+  const EVP_MD *const md = MDFunc();
+  uint8_t digest_buf[EVP_MAX_MD_SIZE];
+  unsigned digest_len;
+  if (!EVP_Digest(msg.data(), msg.size(), digest_buf, &digest_len, md, NULL)) {
+    return false;
+  }
+
+  std::vector<uint8_t> sig(RSA_size(key));
+  size_t sig_len;
+  if (UsePSS) {
+    if (!RSA_sign_pss_mgf1(key, &sig_len, sig.data(), sig.size(), digest_buf,
+                           digest_len, md, md, -1)) {
+      return false;
+    }
+  } else {
+    unsigned sig_len_u;
+    if (!RSA_sign(EVP_MD_type(md), digest_buf, digest_len, sig.data(),
+                  &sig_len_u, key)) {
+      return false;
+    }
+    sig_len = sig_len_u;
+  }
+
+  sig.resize(sig_len);
+
+  return WriteReply(STDOUT_FILENO, BIGNUMBytes(RSA_get0_n(key)),
+                    BIGNUMBytes(RSA_get0_e(key)), sig);
 }
 
 static constexpr struct {
@@ -1095,6 +1279,16 @@
     {"ECDSA/sigVer", 7, ECDSASigVer},
     {"CMAC-AES", 3, CMAC_AES},
     {"RSA/keyGen", 1, RSAKeyGen},
+    {"RSA/sigGen/SHA2-224/pkcs1v1.5", 2, RSASigGen<EVP_sha224, false>},
+    {"RSA/sigGen/SHA2-256/pkcs1v1.5", 2, RSASigGen<EVP_sha256, false>},
+    {"RSA/sigGen/SHA2-384/pkcs1v1.5", 2, RSASigGen<EVP_sha384, false>},
+    {"RSA/sigGen/SHA2-512/pkcs1v1.5", 2, RSASigGen<EVP_sha512, false>},
+    {"RSA/sigGen/SHA-1/pkcs1v1.5", 2, RSASigGen<EVP_sha1, false>},
+    {"RSA/sigGen/SHA2-224/pss", 2, RSASigGen<EVP_sha224, true>},
+    {"RSA/sigGen/SHA2-256/pss", 2, RSASigGen<EVP_sha256, true>},
+    {"RSA/sigGen/SHA2-384/pss", 2, RSASigGen<EVP_sha384, true>},
+    {"RSA/sigGen/SHA2-512/pss", 2, RSASigGen<EVP_sha512, true>},
+    {"RSA/sigGen/SHA-1/pss", 2, RSASigGen<EVP_sha1, true>},
 };
 
 int main() {