ML-DSA: focus the API on saving private keys as seeds.

There are two ways to save an ML-DSA or ML-KEM private key:

NIST specifies a partial serialization of the contents of the keys and this takes up several kilobytes.

But one can also save the seed that the key was generated from and simply regenerate the private key as needed.

* The seed is approximately two orders of magnitude smaller.
* It is fast to expand a private key from a seed.
* The NIST format requires validating several aspects of the partially expanded private key.

Because of this, seeds seem clearly better and having two different
serializations in the API is a bit weird when currently neither of them
are used anywhere.

Thus this change emphasizes using seeds to save private keys and moves
the marshalling function for the NIST format into the internal API.
ML-KEM already follows this pattern, although saving the seed is still
optional there because ephemeral keys are a major use case for ML-KEM.

Change-Id: I439224e745ad8747d26f57288f1d503593e0e52c
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/70407
Commit-Queue: Adam Langley <agl@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/crypto/mldsa/internal.h b/crypto/mldsa/internal.h
index 1c75761..08ff329 100644
--- a/crypto/mldsa/internal.h
+++ b/crypto/mldsa/internal.h
@@ -59,6 +59,12 @@
     const uint8_t *msg, size_t msg_len, const uint8_t *context_prefix,
     size_t context_prefix_len, const uint8_t *context, size_t context_len);
 
+// MLDSA65_marshal_private_key serializes |private_key| to |out| in the
+// NIST format for ML-DSA-65 private keys. It returns 1 on success or 0
+// on allocation error.
+OPENSSL_EXPORT int MLDSA65_marshal_private_key(
+    CBB *out, const struct MLDSA65_private_key *private_key);
+
 
 #if defined(__cplusplus)
 }  // extern C
diff --git a/crypto/mldsa/mldsa.c b/crypto/mldsa/mldsa.c
index 08ef50c..0a75864 100644
--- a/crypto/mldsa/mldsa.c
+++ b/crypto/mldsa/mldsa.c
@@ -1263,15 +1263,11 @@
 // |RAND_bytes|. Returns 1 on success and 0 on failure.
 int MLDSA65_generate_key(
     uint8_t out_encoded_public_key[MLDSA65_PUBLIC_KEY_BYTES],
-    uint8_t optional_out_seed[MLDSA_SEED_BYTES],
+    uint8_t out_seed[MLDSA_SEED_BYTES],
     struct MLDSA65_private_key *out_private_key) {
-  uint8_t entropy[MLDSA_SEED_BYTES];
-  RAND_bytes(entropy, sizeof(entropy));
-  if (optional_out_seed) {
-    OPENSSL_memcpy(optional_out_seed, entropy, MLDSA_SEED_BYTES);
-  }
+  RAND_bytes(out_seed, MLDSA_SEED_BYTES);
   return MLDSA65_generate_key_external_entropy(out_encoded_public_key,
-                                               out_private_key, entropy);
+                                               out_private_key, out_seed);
 }
 
 int MLDSA65_private_key_from_seed(struct MLDSA65_private_key *out_private_key,
diff --git a/crypto/mldsa/mldsa_test.cc b/crypto/mldsa/mldsa_test.cc
index e17cf05..c9b0828 100644
--- a/crypto/mldsa/mldsa_test.cc
+++ b/crypto/mldsa/mldsa_test.cc
@@ -51,8 +51,9 @@
 TEST(MLDSATest, DISABLED_BitFlips) {
   std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
   auto priv = std::make_unique<MLDSA65_private_key>();
+  uint8_t seed[MLDSA_SEED_BYTES];
   EXPECT_TRUE(
-      MLDSA65_generate_key(encoded_public_key.data(), nullptr, priv.get()));
+      MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));
 
   std::vector<uint8_t> encoded_signature(MLDSA65_SIGNATURE_BYTES);
   static const uint8_t kMessage[] = {'H', 'e', 'l', 'l', 'o', ' ',
@@ -116,8 +117,9 @@
 TEST(MLDSATest, SignatureIsRandomized) {
   std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
   auto priv = std::make_unique<MLDSA65_private_key>();
+  uint8_t seed[MLDSA_SEED_BYTES];
   EXPECT_TRUE(
-      MLDSA65_generate_key(encoded_public_key.data(), nullptr, priv.get()));
+      MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));
 
   auto pub = std::make_unique<MLDSA65_public_key>();
   CBS cbs = bssl::MakeConstSpan(encoded_public_key);
@@ -148,8 +150,9 @@
 TEST(MLDSATest, PublicFromPrivateIsConsistent) {
   std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
   auto priv = std::make_unique<MLDSA65_private_key>();
+  uint8_t seed[MLDSA_SEED_BYTES];
   EXPECT_TRUE(
-      MLDSA65_generate_key(encoded_public_key.data(), nullptr, priv.get()));
+      MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));
 
   auto pub = std::make_unique<MLDSA65_public_key>();
   EXPECT_TRUE(MLDSA65_public_from_private(pub.get(), priv.get()));
@@ -167,8 +170,9 @@
   // Encode a public key with a trailing 0 at the end.
   std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES + 1);
   auto priv = std::make_unique<MLDSA65_private_key>();
+  uint8_t seed[MLDSA_SEED_BYTES];
   EXPECT_TRUE(
-      MLDSA65_generate_key(encoded_public_key.data(), nullptr, priv.get()));
+      MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));
 
   // Public key is 1 byte too short.
   CBS cbs = bssl::MakeConstSpan(encoded_public_key)
@@ -188,8 +192,9 @@
 TEST(MLDSATest, InvalidPrivateKeyEncodingLength) {
   std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
   auto priv = std::make_unique<MLDSA65_private_key>();
+  uint8_t seed[MLDSA_SEED_BYTES];
   EXPECT_TRUE(
-      MLDSA65_generate_key(encoded_public_key.data(), nullptr, priv.get()));
+      MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));
 
   CBB cbb;
   std::vector<uint8_t> malformed_private_key(MLDSA65_PRIVATE_KEY_BYTES + 1, 0);
diff --git a/crypto/mlkem/mlkem_test.cc b/crypto/mlkem/mlkem_test.cc
index 9868dd9..581eea5 100644
--- a/crypto/mlkem/mlkem_test.cc
+++ b/crypto/mlkem/mlkem_test.cc
@@ -154,16 +154,16 @@
           int (*MARSHAL_PRIVATE)(CBB *, const PRIVATE_KEY *),
           void (*GENERATE)(uint8_t *, PRIVATE_KEY *, const uint8_t *)>
 void MLKEMKeyGenFileTest(FileTest *t) {
-  std::vector<uint8_t> expected_pub_key_bytes, entropy, expected_priv_key_bytes;
-  ASSERT_TRUE(t->GetBytes(&entropy, "seed"));
+  std::vector<uint8_t> expected_pub_key_bytes, seed, expected_priv_key_bytes;
+  ASSERT_TRUE(t->GetBytes(&seed, "seed"));
   ASSERT_TRUE(t->GetBytes(&expected_pub_key_bytes, "public_key"));
   ASSERT_TRUE(t->GetBytes(&expected_priv_key_bytes, "private_key"));
 
-  ASSERT_EQ(entropy.size(), size_t{MLKEM_SEED_BYTES});
+  ASSERT_EQ(seed.size(), size_t{MLKEM_SEED_BYTES});
 
   std::vector<uint8_t> pub_key_bytes(PUBLIC_KEY_BYTES);
   auto priv = std::make_unique<PRIVATE_KEY>();
-  GENERATE(pub_key_bytes.data(), priv.get(), entropy.data());
+  GENERATE(pub_key_bytes.data(), priv.get(), seed.data());
   const std::vector<uint8_t> priv_key_bytes(
       Marshal(MARSHAL_PRIVATE, priv.get()));
 
@@ -200,12 +200,12 @@
   ASSERT_EQ(z.size(), size_t{MLKEM_SEED_BYTES} / 2);
   ASSERT_EQ(d.size(), size_t{MLKEM_SEED_BYTES} / 2);
 
-  uint8_t entropy[MLKEM_SEED_BYTES];
-  OPENSSL_memcpy(&entropy[0], d.data(), d.size());
-  OPENSSL_memcpy(&entropy[MLKEM_SEED_BYTES / 2], z.data(), z.size());
+  uint8_t seed[MLKEM_SEED_BYTES];
+  OPENSSL_memcpy(&seed[0], d.data(), d.size());
+  OPENSSL_memcpy(&seed[MLKEM_SEED_BYTES / 2], z.data(), z.size());
   std::vector<uint8_t> pub_key_bytes(PUBLIC_KEY_BYTES);
   auto priv = std::make_unique<PRIVATE_KEY>();
-  GENERATE(pub_key_bytes.data(), priv.get(), entropy);
+  GENERATE(pub_key_bytes.data(), priv.get(), seed);
   const std::vector<uint8_t> priv_key_bytes(
       Marshal(MARSHAL_PRIVATE, priv.get()));
 
@@ -376,11 +376,10 @@
   auto priv = std::make_unique<PRIVATE_KEY>();
   auto pub = std::make_unique<PUBLIC_KEY>();
   for (int i = 0; i < 10000; i++) {
-    uint8_t generate_entropy[MLKEM_SEED_BYTES];
-    BORINGSSL_keccak_squeeze(&generate_st, generate_entropy,
-                             sizeof(generate_entropy));
+    uint8_t seed[MLKEM_SEED_BYTES];
+    BORINGSSL_keccak_squeeze(&generate_st, seed, sizeof(seed));
     uint8_t encoded_pub[PUBLIC_KEY_BYTES];
-    GENERATE(encoded_pub, priv.get(), generate_entropy);
+    GENERATE(encoded_pub, priv.get(), seed);
     TO_PUBLIC(pub.get(), priv.get());
 
     BORINGSSL_keccak_absorb(&results_st, encoded_pub, sizeof(encoded_pub));
diff --git a/include/openssl/mldsa.h b/include/openssl/mldsa.h
index 521a155..a0a7560 100644
--- a/include/openssl/mldsa.h
+++ b/include/openssl/mldsa.h
@@ -63,11 +63,11 @@
 
 // MLDSA65_generate_key generates a random public/private key pair, writes the
 // encoded public key to |out_encoded_public_key|, writes the seed to
-// |optional_out_seed| (if not NULL), and sets |out_private_key| to the private
-// key. Returns 1 on success and 0 on allocation failure.
+// |out_seed|, and sets |out_private_key| to the private key. Returns 1 on
+// success and 0 on allocation failure.
 OPENSSL_EXPORT int MLDSA65_generate_key(
     uint8_t out_encoded_public_key[MLDSA65_PUBLIC_KEY_BYTES],
-    uint8_t optional_out_seed[MLDSA_SEED_BYTES],
+    uint8_t out_seed[MLDSA_SEED_BYTES],
     struct MLDSA65_private_key *out_private_key);
 
 // MLDSA65_private_key_from_seed regenerates a private key from a seed value
@@ -122,16 +122,9 @@
 OPENSSL_EXPORT int MLDSA65_parse_public_key(
     struct MLDSA65_public_key *public_key, CBS *in);
 
-// MLDSA65_marshal_private_key serializes |private_key| to |out| in the
-// standard format for ML-DSA-65 private keys. It returns 1 on success or 0
-// on allocation error.
-OPENSSL_EXPORT int MLDSA65_marshal_private_key(
-    CBB *out, const struct MLDSA65_private_key *private_key);
-
-// MLDSA65_parse_private_key parses a private key, in the format generated by
-// |MLDSA65_marshal_private_key|, from |in| and writes the result to
-// |out_private_key|. It returns 1 on success or 0 on parse error or if
-// there are trailing bytes in |in|.
+// MLDSA65_parse_private_key parses a private key, in the NIST format, from |in|
+// and writes the result to |out_private_key|. It returns 1 on success or 0 on
+// parse error or if there are trailing bytes in |in|.
 OPENSSL_EXPORT int MLDSA65_parse_private_key(
     struct MLDSA65_private_key *private_key, CBS *in);
 
diff --git a/tool/speed.cc b/tool/speed.cc
index a1e74ad..1bba9fc 100644
--- a/tool/speed.cc
+++ b/tool/speed.cc
@@ -71,6 +71,7 @@
 #include "../crypto/ec_extra/internal.h"
 #include "../crypto/fipsmodule/ec/internal.h"
 #include "../crypto/internal.h"
+#include "../crypto/mldsa/internal.h"
 #include "../crypto/trust_token/internal.h"
 #include "internal.h"
 
@@ -1142,8 +1143,8 @@
       std::make_unique<uint8_t[]>(MLDSA65_PUBLIC_KEY_BYTES);
   auto priv = std::make_unique<MLDSA65_private_key>();
   if (!TimeFunctionParallel(&results, [&]() -> bool {
-        if (!MLDSA65_generate_key(encoded_public_key.get(), nullptr,
-                                  priv.get())) {
+        uint8_t seed[MLDSA_SEED_BYTES];
+        if (!MLDSA65_generate_key(encoded_public_key.get(), seed, priv.get())) {
           fprintf(stderr, "Failure in MLDSA65_generate_key.\n");
           return false;
         }