| /* Copyright (c) 2024, Google LLC |
| * |
| * Permission to use, copy, modify, and/or distribute this software for any |
| * purpose with or without fee is hereby granted, provided that the above |
| * copyright notice and this permission notice appear in all copies. |
| * |
| * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES |
| * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF |
| * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY |
| * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES |
| * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION |
| * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN |
| * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ |
| |
| #include <openssl/mldsa.h> |
| |
| #include <memory> |
| #include <vector> |
| |
| #include <gtest/gtest.h> |
| |
| #include <openssl/bytestring.h> |
| #include <openssl/mem.h> |
| #include <openssl/span.h> |
| |
| #include "../test/file_test.h" |
| #include "../test/test_util.h" |
| #include "./internal.h" |
| |
| |
| namespace { |
| |
| template <typename T> |
| std::vector<uint8_t> Marshal(int (*marshal_func)(CBB *, const T *), |
| const T *t) { |
| bssl::ScopedCBB cbb; |
| uint8_t *encoded; |
| size_t encoded_len; |
| if (!CBB_init(cbb.get(), 1) || // |
| !marshal_func(cbb.get(), t) || // |
| !CBB_finish(cbb.get(), &encoded, &encoded_len)) { |
| abort(); |
| } |
| |
| std::vector<uint8_t> ret(encoded, encoded + encoded_len); |
| OPENSSL_free(encoded); |
| return ret; |
| } |
| |
| // This test is very slow, so it is disabled by default. |
| TEST(MLDSATest, DISABLED_BitFlips) { |
| std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES); |
| auto priv = std::make_unique<MLDSA65_private_key>(); |
| uint8_t seed[MLDSA_SEED_BYTES]; |
| EXPECT_TRUE( |
| MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get())); |
| |
| std::vector<uint8_t> encoded_signature(MLDSA65_SIGNATURE_BYTES); |
| static const uint8_t kMessage[] = {'H', 'e', 'l', 'l', 'o', ' ', |
| 'w', 'o', 'r', 'l', 'd'}; |
| EXPECT_TRUE(MLDSA65_sign(encoded_signature.data(), priv.get(), kMessage, |
| sizeof(kMessage), nullptr, 0)); |
| |
| auto pub = std::make_unique<MLDSA65_public_key>(); |
| CBS cbs = bssl::MakeConstSpan(encoded_public_key); |
| ASSERT_TRUE(MLDSA65_parse_public_key(pub.get(), &cbs)); |
| |
| EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature.data(), |
| encoded_signature.size(), kMessage, sizeof(kMessage), |
| nullptr, 0), |
| 1); |
| |
| for (size_t i = 0; i < MLDSA65_SIGNATURE_BYTES; i++) { |
| for (int j = 0; j < 8; j++) { |
| encoded_signature[i] ^= 1 << j; |
| EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature.data(), |
| encoded_signature.size(), kMessage, |
| sizeof(kMessage), nullptr, 0), |
| 0) |
| << "Bit flip in signature at byte " << i << " bit " << j |
| << " didn't cause a verification failure"; |
| encoded_signature[i] ^= 1 << j; |
| } |
| } |
| } |
| |
| TEST(MLDSATest, Basic) { |
| std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES); |
| auto priv = std::make_unique<MLDSA65_private_key>(); |
| uint8_t seed[MLDSA_SEED_BYTES]; |
| EXPECT_TRUE( |
| MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get())); |
| |
| std::vector<uint8_t> encoded_signature(MLDSA65_SIGNATURE_BYTES); |
| static const uint8_t kMessage[] = {'H', 'e', 'l', 'l', 'o', ' ', |
| 'w', 'o', 'r', 'l', 'd'}; |
| static const uint8_t kContext[] = {'c', 't', 'x'}; |
| EXPECT_TRUE(MLDSA65_sign(encoded_signature.data(), priv.get(), kMessage, |
| sizeof(kMessage), kContext, sizeof(kContext))); |
| |
| auto pub = std::make_unique<MLDSA65_public_key>(); |
| CBS cbs = bssl::MakeConstSpan(encoded_public_key); |
| ASSERT_TRUE(MLDSA65_parse_public_key(pub.get(), &cbs)); |
| |
| EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature.data(), |
| encoded_signature.size(), kMessage, sizeof(kMessage), |
| kContext, sizeof(kContext)), |
| 1); |
| |
| auto priv2 = std::make_unique<MLDSA65_private_key>(); |
| EXPECT_TRUE(MLDSA65_private_key_from_seed(priv2.get(), seed, sizeof(seed))); |
| |
| EXPECT_EQ(Bytes(Marshal(MLDSA65_marshal_private_key, priv.get())), |
| Bytes(Marshal(MLDSA65_marshal_private_key, priv2.get()))); |
| } |
| |
| TEST(MLDSATest, SignatureIsRandomized) { |
| std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES); |
| auto priv = std::make_unique<MLDSA65_private_key>(); |
| uint8_t seed[MLDSA_SEED_BYTES]; |
| EXPECT_TRUE( |
| MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get())); |
| |
| auto pub = std::make_unique<MLDSA65_public_key>(); |
| CBS cbs = bssl::MakeConstSpan(encoded_public_key); |
| ASSERT_TRUE(MLDSA65_parse_public_key(pub.get(), &cbs)); |
| |
| std::vector<uint8_t> encoded_signature1(MLDSA65_SIGNATURE_BYTES); |
| std::vector<uint8_t> encoded_signature2(MLDSA65_SIGNATURE_BYTES); |
| static const uint8_t kMessage[] = {'H', 'e', 'l', 'l', 'o', ' ', |
| 'w', 'o', 'r', 'l', 'd'}; |
| EXPECT_TRUE(MLDSA65_sign(encoded_signature1.data(), priv.get(), kMessage, |
| sizeof(kMessage), nullptr, 0)); |
| EXPECT_TRUE(MLDSA65_sign(encoded_signature2.data(), priv.get(), kMessage, |
| sizeof(kMessage), nullptr, 0)); |
| |
| EXPECT_NE(Bytes(encoded_signature1), Bytes(encoded_signature2)); |
| |
| // Even though the signatures are different, they both verify. |
| EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature1.data(), |
| encoded_signature1.size(), kMessage, |
| sizeof(kMessage), nullptr, 0), |
| 1); |
| EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature2.data(), |
| encoded_signature2.size(), kMessage, |
| sizeof(kMessage), nullptr, 0), |
| 1); |
| } |
| |
| TEST(MLDSATest, PublicFromPrivateIsConsistent) { |
| std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES); |
| auto priv = std::make_unique<MLDSA65_private_key>(); |
| uint8_t seed[MLDSA_SEED_BYTES]; |
| EXPECT_TRUE( |
| MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get())); |
| |
| auto pub = std::make_unique<MLDSA65_public_key>(); |
| EXPECT_TRUE(MLDSA65_public_from_private(pub.get(), priv.get())); |
| |
| std::vector<uint8_t> encoded_public_key2(MLDSA65_PUBLIC_KEY_BYTES); |
| |
| CBB cbb; |
| CBB_init_fixed(&cbb, encoded_public_key2.data(), encoded_public_key2.size()); |
| ASSERT_TRUE(MLDSA65_marshal_public_key(&cbb, pub.get())); |
| |
| EXPECT_EQ(Bytes(encoded_public_key2), Bytes(encoded_public_key)); |
| } |
| |
| TEST(MLDSATest, InvalidPublicKeyEncodingLength) { |
| // Encode a public key with a trailing 0 at the end. |
| std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES + 1); |
| auto priv = std::make_unique<MLDSA65_private_key>(); |
| uint8_t seed[MLDSA_SEED_BYTES]; |
| EXPECT_TRUE( |
| MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get())); |
| |
| // Public key is 1 byte too short. |
| CBS cbs = bssl::MakeConstSpan(encoded_public_key) |
| .first(MLDSA65_PUBLIC_KEY_BYTES - 1); |
| auto parsed_pub = std::make_unique<MLDSA65_public_key>(); |
| EXPECT_FALSE(MLDSA65_parse_public_key(parsed_pub.get(), &cbs)); |
| |
| // Public key has the correct length. |
| cbs = bssl::MakeConstSpan(encoded_public_key).first(MLDSA65_PUBLIC_KEY_BYTES); |
| EXPECT_TRUE(MLDSA65_parse_public_key(parsed_pub.get(), &cbs)); |
| |
| // Public key is 1 byte too long. |
| cbs = bssl::MakeConstSpan(encoded_public_key); |
| EXPECT_FALSE(MLDSA65_parse_public_key(parsed_pub.get(), &cbs)); |
| } |
| |
| TEST(MLDSATest, InvalidPrivateKeyEncodingLength) { |
| std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES); |
| auto priv = std::make_unique<MLDSA65_private_key>(); |
| uint8_t seed[MLDSA_SEED_BYTES]; |
| EXPECT_TRUE( |
| MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get())); |
| |
| CBB cbb; |
| std::vector<uint8_t> malformed_private_key(MLDSA65_PRIVATE_KEY_BYTES + 1, 0); |
| CBB_init_fixed(&cbb, malformed_private_key.data(), MLDSA65_PRIVATE_KEY_BYTES); |
| ASSERT_TRUE(MLDSA65_marshal_private_key(&cbb, priv.get())); |
| |
| CBS cbs; |
| auto parsed_priv = std::make_unique<MLDSA65_private_key>(); |
| |
| // Private key is 1 byte too short. |
| CBS_init(&cbs, malformed_private_key.data(), MLDSA65_PRIVATE_KEY_BYTES - 1); |
| EXPECT_FALSE(MLDSA65_parse_private_key(parsed_priv.get(), &cbs)); |
| |
| // Private key has the correct length. |
| CBS_init(&cbs, malformed_private_key.data(), MLDSA65_PRIVATE_KEY_BYTES); |
| EXPECT_TRUE(MLDSA65_parse_private_key(parsed_priv.get(), &cbs)); |
| |
| // Private key is 1 byte too long. |
| CBS_init(&cbs, malformed_private_key.data(), MLDSA65_PRIVATE_KEY_BYTES + 1); |
| EXPECT_FALSE(MLDSA65_parse_private_key(parsed_priv.get(), &cbs)); |
| } |
| |
| static void MLDSASigGenTest(FileTest *t) { |
| std::vector<uint8_t> private_key_bytes, msg, expected_signature; |
| ASSERT_TRUE(t->GetBytes(&private_key_bytes, "sk")); |
| ASSERT_TRUE(t->GetBytes(&msg, "message")); |
| ASSERT_TRUE(t->GetBytes(&expected_signature, "signature")); |
| |
| auto priv = std::make_unique<MLDSA65_private_key>(); |
| CBS cbs; |
| CBS_init(&cbs, private_key_bytes.data(), private_key_bytes.size()); |
| EXPECT_TRUE(MLDSA65_parse_private_key(priv.get(), &cbs)); |
| |
| const uint8_t zero_randomizer[MLDSA_SIGNATURE_RANDOMIZER_BYTES] = {0}; |
| std::vector<uint8_t> signature(MLDSA65_SIGNATURE_BYTES); |
| EXPECT_TRUE(MLDSA65_sign_internal(signature.data(), priv.get(), msg.data(), |
| msg.size(), nullptr, 0, nullptr, 0, |
| zero_randomizer)); |
| |
| EXPECT_EQ(Bytes(signature), Bytes(expected_signature)); |
| |
| auto pub = std::make_unique<MLDSA65_public_key>(); |
| ASSERT_TRUE(MLDSA65_public_from_private(pub.get(), priv.get())); |
| EXPECT_TRUE(MLDSA65_verify_internal(pub.get(), signature.data(), msg.data(), |
| msg.size(), nullptr, 0, nullptr, 0)); |
| } |
| |
| TEST(MLDSATest, SigGenTests) { |
| FileTestGTest("crypto/mldsa/mldsa_nist_siggen_tests.txt", MLDSASigGenTest); |
| } |
| |
| static void MLDSAKeyGenTest(FileTest *t) { |
| std::vector<uint8_t> seed, expected_public_key, expected_private_key; |
| ASSERT_TRUE(t->GetBytes(&seed, "seed")); |
| ASSERT_TRUE(t->GetBytes(&expected_public_key, "pub")); |
| ASSERT_TRUE(t->GetBytes(&expected_private_key, "priv")); |
| |
| std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES); |
| auto priv = std::make_unique<MLDSA65_private_key>(); |
| ASSERT_TRUE(MLDSA65_generate_key_external_entropy(encoded_public_key.data(), |
| priv.get(), seed.data())); |
| |
| EXPECT_EQ(Bytes(encoded_public_key), Bytes(expected_public_key)); |
| } |
| |
| TEST(MLDSATest, KeyGenTests) { |
| FileTestGTest("crypto/mldsa/mldsa_nist_keygen_tests.txt", MLDSAKeyGenTest); |
| } |
| |
| static void MLDSAWycheproofSignTest(FileTest *t) { |
| std::vector<uint8_t> private_key_bytes, msg, expected_signature, context; |
| ASSERT_TRUE(t->GetInstructionBytes(&private_key_bytes, "privateKey")); |
| ASSERT_TRUE(t->GetBytes(&msg, "msg")); |
| ASSERT_TRUE(t->GetBytes(&expected_signature, "sig")); |
| if (t->HasAttribute("ctx")) { |
| t->GetBytes(&context, "ctx"); |
| } |
| std::string result; |
| ASSERT_TRUE(t->GetAttribute(&result, "result")); |
| t->IgnoreAttribute("flags"); |
| |
| CBS cbs; |
| CBS_init(&cbs, private_key_bytes.data(), private_key_bytes.size()); |
| auto priv = std::make_unique<MLDSA65_private_key>(); |
| const int priv_ok = MLDSA65_parse_private_key(priv.get(), &cbs); |
| |
| if (!priv_ok) { |
| ASSERT_TRUE(result != "valid"); |
| return; |
| } |
| |
| // Unfortunately we need to reimplement the context length check here because |
| // we are using the internal function in order to pass in an all-zero |
| // randomizer. |
| if (context.size() > 255) { |
| ASSERT_TRUE(result != "valid"); |
| return; |
| } |
| |
| const uint8_t zero_randomizer[MLDSA_SIGNATURE_RANDOMIZER_BYTES] = {0}; |
| std::vector<uint8_t> signature(MLDSA65_SIGNATURE_BYTES); |
| const uint8_t context_prefix[2] = {0, static_cast<uint8_t>(context.size())}; |
| EXPECT_TRUE(MLDSA65_sign_internal( |
| signature.data(), priv.get(), msg.data(), msg.size(), context_prefix, |
| sizeof(context_prefix), context.data(), context.size(), zero_randomizer)); |
| |
| EXPECT_EQ(Bytes(signature), Bytes(expected_signature)); |
| } |
| |
| TEST(MLDSATest, WycheproofSignTests) { |
| FileTestGTest( |
| "third_party/wycheproof_testvectors/mldsa_65_standard_sign_test.txt", |
| MLDSAWycheproofSignTest); |
| } |
| |
| static void MLDSAWycheproofVerifyTest(FileTest *t) { |
| std::vector<uint8_t> public_key_bytes, msg, signature, context; |
| ASSERT_TRUE(t->GetInstructionBytes(&public_key_bytes, "publicKey")); |
| ASSERT_TRUE(t->GetBytes(&msg, "msg")); |
| ASSERT_TRUE(t->GetBytes(&signature, "sig")); |
| if (t->HasAttribute("ctx")) { |
| t->GetBytes(&context, "ctx"); |
| } |
| std::string result, flags; |
| ASSERT_TRUE(t->GetAttribute(&result, "result")); |
| ASSERT_TRUE(t->GetAttribute(&flags, "flags")); |
| |
| CBS cbs; |
| CBS_init(&cbs, public_key_bytes.data(), public_key_bytes.size()); |
| auto pub = std::make_unique<MLDSA65_public_key>(); |
| const int pub_ok = MLDSA65_parse_public_key(pub.get(), &cbs); |
| |
| if (!pub_ok) { |
| EXPECT_EQ(flags, "IncorrectPublicKeyLength"); |
| return; |
| } |
| |
| const int sig_ok = |
| MLDSA65_verify(pub.get(), signature.data(), signature.size(), msg.data(), |
| msg.size(), context.data(), context.size()); |
| if (!sig_ok) { |
| EXPECT_EQ(result, "invalid"); |
| } else { |
| EXPECT_EQ(result, "valid"); |
| } |
| } |
| |
| TEST(MLDSATest, WycheproofVerifyTests) { |
| FileTestGTest( |
| "third_party/wycheproof_testvectors/mldsa_65_standard_verify_test.txt", |
| MLDSAWycheproofVerifyTest); |
| } |
| |
| } // namespace |