ML-DSA ACVP
Wire up ACVP tests for ML-DSA which is newly in the FIPS module.
Change-Id: Iea2edd3bc0a29c53419bd6df675c139337fbce2a
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/73367
Reviewed-by: Bob Beck <bbe@google.com>
Commit-Queue: Adam Langley <agl@google.com>
diff --git a/util/fipstools/acvp/ACVP.md b/util/fipstools/acvp/ACVP.md
index db3aca2..9a1575c 100644
--- a/util/fipstools/acvp/ACVP.md
+++ b/util/fipstools/acvp/ACVP.md
@@ -130,6 +130,9 @@
| SHA3-512/MCT | Initial seed¹ | Digest |
| TLSKDF/1.2/<HASH> | Number output bytes, secret, label, seed1, seed2 | Output |
| PBKDF | HMAC name, key length (bits), salt, password, iteration count | Derived key |
+| ML-DSA-XX/keyGen | Seed | Public key, private key |
+| ML-DSA-XX/sigGen | Private key, message, randomizer | Signature |
+| ML-DSA-XX/sigVer | Public key, message, signature | Single-byte validity flag |
¹ The iterated tests would result in excessive numbers of round trips if the module wrapper handled only basic operations. Thus some ACVP logic is pushed down for these tests so that the inner loop can be handled locally. Either read the NIST documentation ([block-ciphers](https://pages.nist.gov/ACVP/draft-celi-acvp-symmetric.html#name-monte-carlo-tests-for-block) [hashes](https://pages.nist.gov/ACVP/draft-celi-acvp-sha.html#name-monte-carlo-tests-for-sha-1)) to understand the iteration count and return values or, probably more fruitfully, see how these functions are handled in the `modulewrapper` directory.
diff --git a/util/fipstools/acvp/acvptool/subprocess/mldsa.go b/util/fipstools/acvp/acvptool/subprocess/mldsa.go
new file mode 100644
index 0000000..d248ff6
--- /dev/null
+++ b/util/fipstools/acvp/acvptool/subprocess/mldsa.go
@@ -0,0 +1,294 @@
+package subprocess
+
+import (
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "strings"
+)
+
+const MLDSARandomizerLength = 32
+
+// Common top-level structure to parse mode
+type mldsaTestVectorSet struct {
+ Algorithm string `json:"algorithm"`
+ Mode string `json:"mode"`
+ Revision string `json:"revision"`
+}
+
+// Key generation specific structures
+type mldsaKeyGenTestVectorSet struct {
+ Algorithm string `json:"algorithm"`
+ Mode string `json:"mode"`
+ Revision string `json:"revision"`
+ Groups []mldsaKeyGenTestGroup `json:"testGroups"`
+}
+
+type mldsaKeyGenTestGroup struct {
+ ID uint64 `json:"tgId"`
+ TestType string `json:"testType"`
+ ParameterSet string `json:"parameterSet"`
+ Tests []mldsaKeyGenTest `json:"tests"`
+}
+
+type mldsaKeyGenTest struct {
+ ID uint64 `json:"tcId"`
+ Seed string `json:"seed"`
+}
+
+type mldsaKeyGenTestGroupResponse struct {
+ ID uint64 `json:"tgId"`
+ Tests []mldsaKeyGenTestResponse `json:"tests"`
+}
+
+type mldsaKeyGenTestResponse struct {
+ ID uint64 `json:"tcId"`
+ PublicKey string `json:"pk"`
+ PrivateKey string `json:"sk"`
+}
+
+// Signature generation specific structures
+type mldsaSigGenTestVectorSet struct {
+ Algorithm string `json:"algorithm"`
+ Mode string `json:"mode"`
+ Revision string `json:"revision"`
+ Groups []mldsaSigGenTestGroup `json:"testGroups"`
+}
+
+type mldsaSigGenTestGroup struct {
+ ID uint64 `json:"tgId"`
+ TestType string `json:"testType"`
+ ParameterSet string `json:"parameterSet"`
+ Deterministic bool `json:"deterministic"`
+ Tests []mldsaSigGenTest `json:"tests"`
+}
+
+type mldsaSigGenTest struct {
+ ID uint64 `json:"tcId"`
+ Message string `json:"message"`
+ PrivateKey string `json:"sk"`
+ Randomizer string `json:"rnd"`
+}
+
+type mldsaSigGenTestGroupResponse struct {
+ ID uint64 `json:"tgId"`
+ Tests []mldsaSigGenTestResponse `json:"tests"`
+}
+
+type mldsaSigGenTestResponse struct {
+ ID uint64 `json:"tcId"`
+ Signature string `json:"signature"`
+}
+
+// Signature verification specific structures
+type mldsaSigVerTestVectorSet struct {
+ Algorithm string `json:"algorithm"`
+ Mode string `json:"mode"`
+ Revision string `json:"revision"`
+ Groups []mldsaSigVerTestGroup `json:"testGroups"`
+}
+
+type mldsaSigVerTestGroup struct {
+ ID uint64 `json:"tgId"`
+ TestType string `json:"testType"`
+ ParameterSet string `json:"parameterSet"`
+ PublicKey string `json:"pk"`
+ Tests []mldsaSigVerTest `json:"tests"`
+}
+
+type mldsaSigVerTest struct {
+ ID uint64 `json:"tcId"`
+ Message string `json:"message"`
+ Signature string `json:"signature"`
+}
+
+type mldsaSigVerTestGroupResponse struct {
+ ID uint64 `json:"tgId"`
+ Tests []mldsaSigVerTestResponse `json:"tests"`
+}
+
+type mldsaSigVerTestResponse struct {
+ ID uint64 `json:"tcId"`
+ TestPassed bool `json:"testPassed"`
+}
+
+type mldsa struct{}
+
+func (m *mldsa) Process(vectorSet []byte, t Transactable) (any, error) {
+ // First parse just the common fields to get the mode
+ var common mldsaTestVectorSet
+ if err := json.Unmarshal(vectorSet, &common); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal vector set: %v", err)
+ }
+
+ switch common.Mode {
+ case "keyGen":
+ return m.processKeyGen(vectorSet, t)
+ case "sigGen":
+ return m.processSigGen(vectorSet, t)
+ case "sigVer":
+ return m.processSigVer(vectorSet, t)
+ default:
+ return nil, fmt.Errorf("unsupported ML-DSA mode: %s", common.Mode)
+ }
+}
+
+func (m *mldsa) processKeyGen(vectorSet []byte, t Transactable) (any, error) {
+ var parsed mldsaKeyGenTestVectorSet
+ if err := json.Unmarshal(vectorSet, &parsed); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal keyGen vector set: %v", err)
+ }
+
+ var ret []mldsaKeyGenTestGroupResponse
+
+ for _, group := range parsed.Groups {
+ response := mldsaKeyGenTestGroupResponse{
+ ID: group.ID,
+ }
+
+ if !strings.HasPrefix(group.ParameterSet, "ML-DSA-") {
+ return nil, fmt.Errorf("invalid parameter set: %s", group.ParameterSet)
+ }
+ cmdName := group.ParameterSet + "/keyGen"
+
+ for _, test := range group.Tests {
+ seed, err := hex.DecodeString(test.Seed)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode seed in test case %d/%d: %s",
+ group.ID, test.ID, err)
+ }
+
+ result, err := t.Transact(cmdName, 2, seed)
+ if err != nil {
+ return nil, fmt.Errorf("key generation failed for test case %d/%d: %s",
+ group.ID, test.ID, err)
+ }
+
+ response.Tests = append(response.Tests, mldsaKeyGenTestResponse{
+ ID: test.ID,
+ PublicKey: hex.EncodeToString(result[0]),
+ PrivateKey: hex.EncodeToString(result[1]),
+ })
+ }
+
+ ret = append(ret, response)
+ }
+
+ return ret, nil
+}
+
+func (m *mldsa) processSigGen(vectorSet []byte, t Transactable) (any, error) {
+ var parsed mldsaSigGenTestVectorSet
+ if err := json.Unmarshal(vectorSet, &parsed); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal sigGen vector set: %v", err)
+ }
+
+ var ret []mldsaSigGenTestGroupResponse
+
+ for _, group := range parsed.Groups {
+ response := mldsaSigGenTestGroupResponse{
+ ID: group.ID,
+ }
+
+ if !strings.HasPrefix(group.ParameterSet, "ML-DSA-") {
+ return nil, fmt.Errorf("invalid parameter set: %s", group.ParameterSet)
+ }
+ cmdName := group.ParameterSet + "/sigGen"
+
+ for _, test := range group.Tests {
+ sk, err := hex.DecodeString(test.PrivateKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode private key in test case %d/%d: %s",
+ group.ID, test.ID, err)
+ }
+
+ msg, err := hex.DecodeString(test.Message)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode message in test case %d/%d: %s",
+ group.ID, test.ID, err)
+ }
+
+ var randomizer []byte
+ if group.Deterministic {
+ randomizer = make([]byte, MLDSARandomizerLength)
+ } else {
+ randomizer, err = hex.DecodeString(test.Randomizer)
+ if err != nil || len(randomizer) != MLDSARandomizerLength {
+ return nil, fmt.Errorf("failed to parse randomizer in test case %d/%d: %s", group.ID, test.ID, err)
+ }
+ }
+
+ result, err := t.Transact(cmdName, 1, sk, msg, randomizer)
+ if err != nil {
+ return nil, fmt.Errorf("signature generation failed for test case %d/%d: %s",
+ group.ID, test.ID, err)
+ }
+
+ response.Tests = append(response.Tests, mldsaSigGenTestResponse{
+ ID: test.ID,
+ Signature: hex.EncodeToString(result[0]),
+ })
+ }
+
+ ret = append(ret, response)
+ }
+
+ return ret, nil
+}
+
+func (m *mldsa) processSigVer(vectorSet []byte, t Transactable) (any, error) {
+ var parsed mldsaSigVerTestVectorSet
+ if err := json.Unmarshal(vectorSet, &parsed); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal sigVer vector set: %v", err)
+ }
+
+ var ret []mldsaSigVerTestGroupResponse
+
+ for _, group := range parsed.Groups {
+ response := mldsaSigVerTestGroupResponse{
+ ID: group.ID,
+ }
+
+ if !strings.HasPrefix(group.ParameterSet, "ML-DSA-") {
+ return nil, fmt.Errorf("invalid parameter set: %s", group.ParameterSet)
+ }
+ cmdName := group.ParameterSet + "/sigVer"
+
+ pk, err := hex.DecodeString(group.PublicKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode public key in group %d: %s",
+ group.ID, err)
+ }
+
+ for _, test := range group.Tests {
+ msg, err := hex.DecodeString(test.Message)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode message in test case %d/%d: %s",
+ group.ID, test.ID, err)
+ }
+
+ sig, err := hex.DecodeString(test.Signature)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode signature in test case %d/%d: %s",
+ group.ID, test.ID, err)
+ }
+
+ result, err := t.Transact(cmdName, 1, pk, msg, sig)
+ if err != nil {
+ return nil, fmt.Errorf("signature verification failed for test case %d/%d: %s",
+ group.ID, test.ID, err)
+ }
+
+ // Result is a single byte: 0 for false, non-zero for true
+ testPassed := result[0][0] != 0
+ response.Tests = append(response.Tests, mldsaSigVerTestResponse{
+ ID: test.ID,
+ TestPassed: testPassed,
+ })
+ }
+
+ ret = append(ret, response)
+ }
+
+ return ret, nil
+}
diff --git a/util/fipstools/acvp/acvptool/subprocess/subprocess.go b/util/fipstools/acvp/acvptool/subprocess/subprocess.go
index ac479ff..1f5fc60 100644
--- a/util/fipstools/acvp/acvptool/subprocess/subprocess.go
+++ b/util/fipstools/acvp/acvptool/subprocess/subprocess.go
@@ -144,6 +144,7 @@
"KAS-ECC-SSC": &kas{},
"KAS-FFC-SSC": &kasDH{},
"PBKDF": &pbkdf{},
+ "ML-DSA": &mldsa{},
}
m.primitives["ECDSA"] = &ecdsa{"ECDSA", map[string]bool{"P-224": true, "P-256": true, "P-384": true, "P-521": true}, m.primitives}
m.primitives["DetECDSA"] = &ecdsa{"DetECDSA", map[string]bool{"P-224": true, "P-256": true, "P-384": true, "P-521": true}, m.primitives}
diff --git a/util/fipstools/acvp/acvptool/test/expected/ML-DSA.bz2 b/util/fipstools/acvp/acvptool/test/expected/ML-DSA.bz2
new file mode 100644
index 0000000..77290c1
--- /dev/null
+++ b/util/fipstools/acvp/acvptool/test/expected/ML-DSA.bz2
Binary files differ
diff --git a/util/fipstools/acvp/acvptool/test/tests.json b/util/fipstools/acvp/acvptool/test/tests.json
index d54f722..1559b98 100644
--- a/util/fipstools/acvp/acvptool/test/tests.json
+++ b/util/fipstools/acvp/acvptool/test/tests.json
@@ -25,6 +25,7 @@
{"Wrapper": "modulewrapper", "In": "vectors/KAS-ECC-SSC.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/KAS-FFC-SSC.bz2"},
{"Wrapper": "testmodulewrapper", "In": "vectors/KDF.bz2"},
+{"Wrapper": "modulewrapper", "In": "vectors/ML-DSA.bz2", "Out": "expected/ML-DSA.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/RSA.bz2", "Out": "expected/RSA.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/SHA-1.bz2", "Out": "expected/SHA-1.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/SHA2-224.bz2", "Out": "expected/SHA2-224.bz2"},
diff --git a/util/fipstools/acvp/acvptool/test/vectors/ML-DSA.bz2 b/util/fipstools/acvp/acvptool/test/vectors/ML-DSA.bz2
new file mode 100644
index 0000000..e4bcaa1
--- /dev/null
+++ b/util/fipstools/acvp/acvptool/test/vectors/ML-DSA.bz2
Binary files differ
diff --git a/util/fipstools/acvp/modulewrapper/modulewrapper.cc b/util/fipstools/acvp/modulewrapper/modulewrapper.cc
index c6c361b..d42d127 100644
--- a/util/fipstools/acvp/modulewrapper/modulewrapper.cc
+++ b/util/fipstools/acvp/modulewrapper/modulewrapper.cc
@@ -13,6 +13,7 @@
* CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
#include <map>
+#include <memory>
#include <string>
#include <vector>
@@ -27,6 +28,7 @@
#include <openssl/aead.h>
#include <openssl/aes.h>
#include <openssl/bn.h>
+#include <openssl/bytestring.h>
#include <openssl/cipher.h>
#include <openssl/cmac.h>
#include <openssl/ctrdrbg.h>
@@ -44,6 +46,7 @@
#include <openssl/sha.h>
#include <openssl/span.h>
+#include "../../../../crypto/fipsmodule/bcm_interface.h"
#include "../../../../crypto/fipsmodule/ec/internal.h"
#include "../../../../crypto/fipsmodule/rand/internal.h"
#include "../../../../crypto/fipsmodule/tls/internal.h"
@@ -267,7 +270,8 @@
return true;
}
-static bool GetConfig(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool GetConfig(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
static constexpr char kConfig[] =
R"([
{
@@ -940,6 +944,44 @@
"PSK",
"PSK-DHE"
]
+ },
+ {
+ "algorithm": "ML-DSA",
+ "mode": "keyGen",
+ "revision": "FIPS204",
+ "parameterSets": [
+ "ML-DSA-65",
+ "ML-DSA-87"
+ ]
+ },
+ {
+ "algorithm": "ML-DSA",
+ "mode": "sigGen",
+ "revision": "FIPS204",
+ "parameterSets": [
+ "ML-DSA-65",
+ "ML-DSA-87"
+ ],
+ "deterministic": [
+ true,
+ false
+ ],
+ "messageLength": [
+ {
+ "min": 8,
+ "max": 65536,
+ "increment": 8
+ }
+ ]
+ },
+ {
+ "algorithm": "ML-DSA",
+ "mode": "sigVer",
+ "revision": "FIPS204",
+ "parameterSets": [
+ "ML-DSA-65",
+ "ML-DSA-87"
+ ]
}
])";
return write_reply({Span<const uint8_t>(
@@ -1000,7 +1042,7 @@
memcpy(&iterations, iterations_bytes.data(), sizeof(iterations));
if (iterations == 0 || iterations == UINT32_MAX) {
LOG_ERROR("Invalid number of iterations: %x.\n",
- static_cast<unsigned>(iterations));
+ static_cast<unsigned>(iterations));
abort();
}
@@ -1037,7 +1079,8 @@
template <int (*SetKey)(const uint8_t *key, unsigned bits, AES_KEY *out),
int Direction>
-static bool AES_CBC(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool AES_CBC(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
AES_KEY key;
if (SetKey(args[0].data(), args[0].size() * 8, &key) != 0) {
return false;
@@ -1084,7 +1127,8 @@
{Span<const uint8_t>(result), Span<const uint8_t>(prev_result)});
}
-static bool AES_CTR(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool AES_CTR(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
static const uint32_t kOneIteration = 1;
if (args[3].size() != sizeof(kOneIteration) ||
memcmp(args[3].data(), &kOneIteration, sizeof(kOneIteration))) {
@@ -1230,7 +1274,8 @@
template <bool (*SetupFunc)(EVP_AEAD_CTX *ctx, Span<const uint8_t> tag_len_span,
Span<const uint8_t> key)>
-static bool AEADSeal(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool AEADSeal(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
Span<const uint8_t> tag_len_span = args[0];
Span<const uint8_t> key = args[1];
Span<const uint8_t> plaintext = args[2];
@@ -1260,7 +1305,8 @@
template <bool (*SetupFunc)(EVP_AEAD_CTX *ctx, Span<const uint8_t> tag_len_span,
Span<const uint8_t> key)>
-static bool AEADOpen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool AEADOpen(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
Span<const uint8_t> tag_len_span = args[0];
Span<const uint8_t> key = args[1];
Span<const uint8_t> ciphertext = args[2];
@@ -1315,7 +1361,8 @@
return true;
}
-static bool AESKeyWrapSeal(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool AESKeyWrapSeal(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
Span<const uint8_t> key = args[1];
Span<const uint8_t> plaintext = args[2];
@@ -1335,7 +1382,8 @@
return write_reply({Span<const uint8_t>(out)});
}
-static bool AESKeyWrapOpen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool AESKeyWrapOpen(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
Span<const uint8_t> key = args[1];
Span<const uint8_t> ciphertext = args[2];
@@ -1358,7 +1406,8 @@
{Span<const uint8_t>(success_flag), Span<const uint8_t>(out)});
}
-static bool AESPaddedKeyWrapSeal(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool AESPaddedKeyWrapSeal(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
Span<const uint8_t> key = args[1];
Span<const uint8_t> plaintext = args[2];
@@ -1380,7 +1429,8 @@
return write_reply({Span<const uint8_t>(out)});
}
-static bool AESPaddedKeyWrapOpen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool AESPaddedKeyWrapOpen(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
Span<const uint8_t> key = args[1];
Span<const uint8_t> ciphertext = args[2];
@@ -1452,7 +1502,8 @@
}
template <bool Encrypt>
-static bool TDES_CBC(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool TDES_CBC(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
const EVP_CIPHER *cipher = EVP_des_ede3_cbc();
if (args[0].size() != 24) {
@@ -1513,8 +1564,8 @@
}
return write_reply({Span<const uint8_t>(result),
- Span<const uint8_t>(prev_result),
- Span<const uint8_t>(prev_prev_result)});
+ Span<const uint8_t>(prev_result),
+ Span<const uint8_t>(prev_prev_result)});
}
template <const EVP_MD *HashFunc()>
@@ -1693,7 +1744,8 @@
return std::make_pair(std::move(x_bytes), std::move(y_bytes));
}
-static bool ECDSAKeyGen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool ECDSAKeyGen(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
bssl::UniquePtr<EC_KEY> key = ECKeyFromName(args[0]);
if (!key || !EC_KEY_generate_key_fips(key.get())) {
return false;
@@ -1714,7 +1766,8 @@
return bn;
}
-static bool ECDSAKeyVer(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool ECDSAKeyVer(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
bssl::UniquePtr<EC_KEY> key = ECKeyFromName(args[0]);
if (!key) {
return false;
@@ -1752,7 +1805,8 @@
}
}
-static bool ECDSASigGen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool ECDSASigGen(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
bssl::UniquePtr<EC_KEY> key = ECKeyFromName(args[0]);
bssl::UniquePtr<BIGNUM> d = BytesToBIGNUM(args[1]);
const EVP_MD *hash = HashFromName(args[2]);
@@ -1777,7 +1831,8 @@
{Span<const uint8_t>(r_bytes), Span<const uint8_t>(s_bytes)});
}
-static bool ECDSASigVer(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool ECDSASigVer(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
bssl::UniquePtr<EC_KEY> key = ECKeyFromName(args[0]);
const EVP_MD *hash = HashFromName(args[1]);
auto msg = args[2];
@@ -1809,7 +1864,8 @@
return write_reply({Span<const uint8_t>(reply)});
}
-static bool CMAC_AES(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool CMAC_AES(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
uint8_t mac[16];
if (!AES_CMAC(mac, args[1].data(), args[1].size(), args[2].data(),
args[2].size())) {
@@ -1828,7 +1884,8 @@
return write_reply({Span<const uint8_t>(mac, mac_len)});
}
-static bool CMAC_AESVerify(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool CMAC_AESVerify(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
// This function is just for testing since libcrypto doesn't do the
// verification itself. The regcap doesn't advertise "ver" support.
uint8_t mac[16];
@@ -1842,12 +1899,12 @@
return write_reply({Span<const uint8_t>(&ok, sizeof(ok))});
}
-static std::map<unsigned, bssl::UniquePtr<RSA>>& CachedRSAKeys() {
+static std::map<unsigned, bssl::UniquePtr<RSA>> &CachedRSAKeys() {
static std::map<unsigned, bssl::UniquePtr<RSA>> keys;
return keys;
}
-static RSA* GetRSAKey(unsigned bits) {
+static RSA *GetRSAKey(unsigned bits) {
auto it = CachedRSAKeys().find(bits);
if (it != CachedRSAKeys().end()) {
return it->second.get();
@@ -1864,7 +1921,8 @@
return ret;
}
-static bool RSAKeyGen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool RSAKeyGen(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
uint32_t bits;
if (args[0].size() != sizeof(bits)) {
return false;
@@ -1891,7 +1949,8 @@
}
template <const EVP_MD *(MDFunc)(), bool UsePSS>
-static bool RSASigGen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool RSASigGen(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
uint32_t bits;
if (args[0].size() != sizeof(bits)) {
return false;
@@ -1930,7 +1989,8 @@
}
template <const EVP_MD *(MDFunc)(), bool UsePSS>
-static bool RSASigVer(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool RSASigVer(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
const Span<const uint8_t> n_bytes = args[0];
const Span<const uint8_t> e_bytes = args[1];
const Span<const uint8_t> msg = args[2];
@@ -1966,7 +2026,8 @@
}
template <const EVP_MD *(MDFunc)()>
-static bool TLSKDF(const Span<const uint8_t> args[], ReplyCallback write_reply) {
+static bool TLSKDF(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
const Span<const uint8_t> out_len_bytes = args[0];
const Span<const uint8_t> secret = args[1];
const Span<const uint8_t> label = args[2];
@@ -2100,6 +2161,104 @@
return write_reply({BIGNUMBytes(DH_get0_pub_key(dh.get())), z});
}
+template <typename PrivateKey, size_t PublicKeyBytes,
+ bcm_status (*KeyGen)(uint8_t *, PrivateKey *, const uint8_t *),
+ bcm_status (*MarshalPrivateKey)(CBB *, const PrivateKey *)>
+static bool MLDSAKeyGen(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
+ const Span<const uint8_t> seed = args[0];
+ if (seed.size() != BCM_MLDSA_SEED_BYTES) {
+ LOG_ERROR("Bad seed size.\n");
+ return false;
+ }
+
+ auto priv = std::make_unique<PrivateKey>();
+ uint8_t pub_key_bytes[PublicKeyBytes];
+ if (KeyGen(pub_key_bytes, priv.get(), seed.data()) != bcm_status::approved) {
+ LOG_ERROR("ML-DSA key gen failed.\n");
+ return false;
+ }
+
+ ScopedCBB cbb;
+ if (!CBB_init(cbb.get(), 1024) ||
+ MarshalPrivateKey(cbb.get(), priv.get()) != bcm_status::approved ||
+ !CBB_flush(cbb.get())) {
+ LOG_ERROR("ML-DSA marshal failed.\n");
+ return false;
+ }
+
+ return write_reply(
+ {pub_key_bytes, MakeConstSpan(CBB_data(cbb.get()), CBB_len(cbb.get()))});
+}
+
+template <typename PrivateKey, size_t SignatureBytes,
+ bcm_status (*ParsePrivateKey)(PrivateKey *, CBS *),
+ bcm_status (*SignInternal)(uint8_t *, const PrivateKey *,
+ const uint8_t *, size_t, const uint8_t *,
+ size_t, const uint8_t *, size_t,
+ const uint8_t *)>
+static bool MLDSASigGen(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
+ CBS cbs = bssl::MakeConstSpan(args[0]);
+ auto priv = std::make_unique<PrivateKey>();
+ if (ParsePrivateKey(priv.get(), &cbs) != bcm_status::approved) {
+ LOG_ERROR("Failed to parse ML-DSA private key.\n");
+ return false;
+ }
+
+ const Span<const uint8_t> msg = args[1];
+ const Span<const uint8_t> randomizer = args[2];
+
+ if (randomizer.size() != BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES) {
+ LOG_ERROR("Bad randomizer size.\n");
+ return false;
+ }
+
+ uint8_t signature[SignatureBytes];
+ if (SignInternal(signature, priv.get(), msg.data(), msg.size(),
+ // It's not just an empty context, the context prefix
+ // is omitted too.
+ nullptr, 0, nullptr, 0,
+ randomizer.data()) != bcm_status::approved) {
+ LOG_ERROR("ML-DSA signing failed.\n");
+ return false;
+ }
+
+ return write_reply({signature});
+}
+
+template <typename PublicKey, size_t SignatureBytes,
+ bcm_status (*ParsePublicKey)(PublicKey *, CBS *),
+ bcm_status (*VerifyInternal)(const PublicKey *, const uint8_t *,
+ const uint8_t *, size_t, const uint8_t *,
+ size_t, const uint8_t *, size_t)>
+static bool MLDSASigVer(const Span<const uint8_t> args[],
+ ReplyCallback write_reply) {
+ const Span<const uint8_t> pub_key_bytes = args[0];
+ const Span<const uint8_t> msg = args[1];
+ const Span<const uint8_t> signature = args[2];
+
+ CBS cbs = bssl::MakeConstSpan(pub_key_bytes);
+ auto pub = std::make_unique<PublicKey>();
+ if (ParsePublicKey(pub.get(), &cbs) != bcm_status::approved) {
+ LOG_ERROR("Failed to parse ML-DSA public key.\n");
+ return false;
+ }
+
+ if (signature.size() != SignatureBytes) {
+ LOG_ERROR("Bad signature size.\n");
+ return false;
+ }
+
+ const uint8_t ok = bcm_success(
+ VerifyInternal(pub.get(), signature.data(), msg.data(), msg.size(),
+ // It's not just an empty context, the context
+ // prefix is omitted too.
+ nullptr, 0, nullptr, 0));
+
+ return write_reply({Span<const uint8_t>(&ok, sizeof(ok))});
+}
+
static constexpr struct {
char name[kMaxNameLength + 1];
uint8_t num_expected_args;
@@ -2193,6 +2352,26 @@
{"ECDH/P-384", 3, ECDH<NID_secp384r1>},
{"ECDH/P-521", 3, ECDH<NID_secp521r1>},
{"FFDH", 6, FFDH},
+ {"ML-DSA-65/keyGen", 1,
+ MLDSAKeyGen<BCM_mldsa65_private_key, BCM_MLDSA65_PUBLIC_KEY_BYTES,
+ BCM_mldsa65_generate_key_external_entropy,
+ BCM_mldsa65_marshal_private_key>},
+ {"ML-DSA-87/keyGen", 1,
+ MLDSAKeyGen<BCM_mldsa87_private_key, BCM_MLDSA87_PUBLIC_KEY_BYTES,
+ BCM_mldsa87_generate_key_external_entropy,
+ BCM_mldsa87_marshal_private_key>},
+ {"ML-DSA-65/sigGen", 3,
+ MLDSASigGen<BCM_mldsa65_private_key, BCM_MLDSA65_SIGNATURE_BYTES,
+ BCM_mldsa65_parse_private_key, BCM_mldsa65_sign_internal>},
+ {"ML-DSA-87/sigGen", 3,
+ MLDSASigGen<BCM_mldsa87_private_key, BCM_MLDSA87_SIGNATURE_BYTES,
+ BCM_mldsa87_parse_private_key, BCM_mldsa87_sign_internal>},
+ {"ML-DSA-65/sigVer", 3,
+ MLDSASigVer<BCM_mldsa65_public_key, BCM_MLDSA65_SIGNATURE_BYTES,
+ BCM_mldsa65_parse_public_key, BCM_mldsa65_verify_internal>},
+ {"ML-DSA-87/sigVer", 3,
+ MLDSASigVer<BCM_mldsa87_public_key, BCM_MLDSA87_SIGNATURE_BYTES,
+ BCM_mldsa87_parse_public_key, BCM_mldsa87_verify_internal>},
};
Handler FindHandler(Span<const Span<const uint8_t>> args) {