Include the public key in ML-DSA private keys

I went back and forth on this a bit, but I think this is the right
model. go/mldsa-mlkem-evp (internal) has some notes on the trade-offs
here.

Bug: 449751916
Change-Id: I8ba46107c5a3c0b0260538705bd2b20f28c1fb0c
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/82991
Reviewed-by: Lily Chen <chlily@google.com>
Reviewed-by: Adam Langley <agl@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/crypto/fipsmodule/mldsa/mldsa.cc.inc b/crypto/fipsmodule/mldsa/mldsa.cc.inc
index 9412b9c..2f14141 100644
--- a/crypto/fipsmodule/mldsa/mldsa.cc.inc
+++ b/crypto/fipsmodule/mldsa/mldsa.cc.inc
@@ -1436,9 +1436,8 @@
 
 template <int K, int L>
 struct private_key {
-  uint8_t rho[kRhoBytes];
+  public_key<K> pub;
   uint8_t k[kKBytes];
-  uint8_t public_key_hash[kTrBytes];
   vector<L> s1;
   vector<K> s2;
   vector<K> t0;
@@ -1492,10 +1491,10 @@
 // FIPS 204, Algorithm 24 (`skEncode`).
 template <int K, int L>
 int mldsa_marshal_private_key(CBB *out, const private_key<K, L> *priv) {
-  if (!CBB_add_bytes(out, priv->rho, sizeof(priv->rho)) ||
+  if (!CBB_add_bytes(out, priv->pub.rho, sizeof(priv->pub.rho)) ||
       !CBB_add_bytes(out, priv->k, sizeof(priv->k)) ||
-      !CBB_add_bytes(out, priv->public_key_hash,
-                     sizeof(priv->public_key_hash))) {
+      !CBB_add_bytes(out, priv->pub.public_key_hash,
+                     sizeof(priv->pub.public_key_hash))) {
     return 0;
   }
 
@@ -1524,18 +1523,16 @@
   return 1;
 }
 
-// FIPS 204, Algorithm 25 (`skDecode`).
+// FIPS 204, Algorithm 25 (`skDecode`). This is only used for testing. The
+// supported external way to construct ML-DSA keys is to use the input seed.
 template <int K, int L>
 int mldsa_parse_private_key(private_key<K, L> *priv, CBS *in) {
-  CBS s1_bytes;
-  CBS s2_bytes;
-  CBS t0_bytes;
+  CBS public_key_hash, s1_bytes, s2_bytes, t0_bytes;
   constexpr size_t scalar_bytes =
       (kDegree * plus_minus_eta_bitlen<K>() + 7) / 8;
-  if (!CBS_copy_bytes(in, priv->rho, sizeof(priv->rho)) ||
+  if (!CBS_copy_bytes(in, priv->pub.rho, sizeof(priv->pub.rho)) ||
       !CBS_copy_bytes(in, priv->k, sizeof(priv->k)) ||
-      !CBS_copy_bytes(in, priv->public_key_hash,
-                      sizeof(priv->public_key_hash)) ||
+      !CBS_get_bytes(in, &public_key_hash, kTrBytes) ||
       !CBS_get_bytes(in, &s1_bytes, scalar_bytes * L) ||
       !vector_decode_signed(&priv->s1, CBS_data(&s1_bytes),
                             plus_minus_eta_bitlen<K>(), eta<K>()) ||
@@ -1548,6 +1545,22 @@
     return 0;
   }
 
+  // Compute `t1`, which is not in the `skDecode` input.
+  uint8_t unused[public_key_bytes<K>()];
+  if (!mldsa_finish_keygen(unused, priv)) {
+    return 0;
+  }
+
+  // As a side effect of computing `t1`, we also compute `t0` and
+  // `public_key_hash`. Check they match the received bytes.
+  uint8_t t0_computed[416 * K];
+  vector_encode_signed(t0_computed, &priv->t0, 13, 1 << 12);
+  if (!CBS_mem_equal(&public_key_hash, priv->pub.public_key_hash,
+                     sizeof(priv->pub.public_key_hash)) ||
+      !CBS_mem_equal(&t0_bytes, t0_computed, sizeof(t0_computed))) {
+    return 0;
+  }
+
   return 1;
 }
 
@@ -1592,17 +1605,15 @@
   return 1;
 }
 
-// FIPS 204, Algorithm 6 (`ML-DSA.KeyGen_internal`). Returns 1 on success and 0
-// on failure.
+// FIPS 204, Algorithm 6 (`ML-DSA.KeyGen_internal`), steps 3 and 5–11.
+// Returns 1 on success and 0 on failure.
 template <int K, int L>
-int mldsa_generate_key_external_entropy_no_self_test(
-    uint8_t out_encoded_public_key[public_key_bytes<K>()],
-    private_key<K, L> *priv, const uint8_t entropy[MLDSA_SEED_BYTES]) {
+int mldsa_finish_keygen(uint8_t out_encoded_public_key[public_key_bytes<K>()],
+                        private_key<K, L> *priv) {
   // Intermediate values, allocated on the heap to allow use when there is a
   // limited amount of stack.
   struct Values {
     enum { kAllowUniquePtr = true };
-    public_key<K> pub;
     matrix<K, L> a_ntt;
     vector<L> s1_ntt;
     vector<K> t;
@@ -1612,6 +1623,46 @@
     return 0;
   }
 
+  // Step 3.
+  matrix_expand(&values->a_ntt, priv->pub.rho);
+
+  // Step 5.
+  OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
+  vector_ntt(&values->s1_ntt);
+
+  matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
+  vector_inverse_ntt(&values->t);
+  vector_add(&values->t, &values->t, &priv->s2);
+
+  // Step 6-7.
+  vector_power2_round(&priv->pub.t1, &priv->t0, &values->t);
+  // t1 is public.
+  CONSTTIME_DECLASSIFY(&priv->pub.t1, sizeof(priv->pub.t1));
+
+  // Step 8.
+  CBB cbb;
+  CBB_init_fixed(&cbb, out_encoded_public_key, public_key_bytes<K>());
+  if (!mldsa_marshal_public_key(&cbb, &priv->pub)) {
+    return 0;
+  }
+  assert(CBB_len(&cbb) == public_key_bytes<K>());
+
+  // Step 9-11.
+  BORINGSSL_keccak(priv->pub.public_key_hash, sizeof(priv->pub.public_key_hash),
+                   out_encoded_public_key, public_key_bytes<K>(),
+                   boringssl_shake256);
+
+  return 1;
+}
+
+// FIPS 204, Algorithm 6 (`ML-DSA.KeyGen_internal`). Returns 1 on success and 0
+// on failure.
+template <int K, int L>
+int mldsa_generate_key_external_entropy_no_self_test(
+    uint8_t out_encoded_public_key[public_key_bytes<K>()],
+    private_key<K, L> *priv,
+    const uint8_t entropy[MLDSA_SEED_BYTES]) {
+  // Step 1-2.
   uint8_t augmented_entropy[MLDSA_SEED_BYTES + 2];
   OPENSSL_memcpy(augmented_entropy, entropy, MLDSA_SEED_BYTES);
   // The k and l parameters are appended to the seed.
@@ -1625,36 +1676,12 @@
   const uint8_t *const k = expanded_seed + kRhoBytes + kSigmaBytes;
   // rho is public.
   CONSTTIME_DECLASSIFY(rho, kRhoBytes);
-  OPENSSL_memcpy(values->pub.rho, rho, sizeof(values->pub.rho));
-  OPENSSL_memcpy(priv->rho, rho, sizeof(priv->rho));
+  OPENSSL_memcpy(priv->pub.rho, rho, sizeof(priv->pub.rho));
   OPENSSL_memcpy(priv->k, k, sizeof(priv->k));
-
-  matrix_expand(&values->a_ntt, rho);
+  // Step 4. This is independent of A (step 3) and can be done first.
   vector_expand_short(&priv->s1, &priv->s2, sigma);
-
-  OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
-  vector_ntt(&values->s1_ntt);
-
-  matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
-  vector_inverse_ntt(&values->t);
-  vector_add(&values->t, &values->t, &priv->s2);
-
-  vector_power2_round(&values->pub.t1, &priv->t0, &values->t);
-  // t1 is public.
-  CONSTTIME_DECLASSIFY(&values->pub.t1, sizeof(values->pub.t1));
-
-  CBB cbb;
-  CBB_init_fixed(&cbb, out_encoded_public_key, public_key_bytes<K>());
-  if (!mldsa_marshal_public_key(&cbb, &values->pub)) {
-    return 0;
-  }
-  assert(CBB_len(&cbb) == public_key_bytes<K>());
-
-  BORINGSSL_keccak(priv->public_key_hash, sizeof(priv->public_key_hash),
-                   out_encoded_public_key, public_key_bytes<K>(),
-                   boringssl_shake256);
-
-  return 1;
+  // Steps 3 and 5-11.
+  return mldsa_finish_keygen(out_encoded_public_key, priv);
 }
 
 template <int K, int L>
@@ -1666,42 +1693,6 @@
       out_encoded_public_key, priv, entropy);
 }
 
-template <int K, int L>
-int mldsa_public_from_private(public_key<K> *pub,
-                              const private_key<K, L> *priv) {
-  // Intermediate values, allocated on the heap to allow use when there is a
-  // limited amount of stack.
-  struct Values {
-    enum { kAllowUniquePtr = true };
-    matrix<K, L> a_ntt;
-    vector<L> s1_ntt;
-    vector<K> t;
-    vector<K> t0;
-  };
-  auto values = bssl::MakeUnique<Values>();
-  if (values == nullptr) {
-    return 0;
-  }
-
-  OPENSSL_memcpy(pub->rho, priv->rho, sizeof(pub->rho));
-  OPENSSL_memcpy(pub->public_key_hash, priv->public_key_hash,
-                 sizeof(pub->public_key_hash));
-
-  matrix_expand(&values->a_ntt, priv->rho);
-
-  OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
-  vector_ntt(&values->s1_ntt);
-
-  matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
-  vector_inverse_ntt(&values->t);
-  vector_add(&values->t, &values->t, &priv->s2);
-
-  vector_power2_round(&pub->t1, &values->t0, &values->t);
-  // t1 is part of the public key and thus is public.
-  CONSTTIME_DECLASSIFY(&pub->t1, sizeof(pub->t1));
-  return 1;
-}
-
 // FIPS 204, Algorithm 7 (`ML-DSA.Sign_internal`), using a pre-computed mu.
 // Returns 1 on success and 0 on failure.
 template <int K, int L>
@@ -1746,7 +1737,7 @@
   OPENSSL_memcpy(&values->t0_ntt, &priv->t0, sizeof(values->t0_ntt));
   vector_ntt(&values->t0_ntt);
 
-  matrix_expand(&values->a_ntt, priv->rho);
+  matrix_expand(&values->a_ntt, priv->pub.rho);
 
   // kappa must not exceed 2**16/L = 13107. But the probability of it
   // exceeding even 1000 iterations is vanishingly small.
@@ -1854,8 +1845,8 @@
   uint8_t mu[kMuBytes];
   BORINGSSL_keccak_st keccak_ctx;
   BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
-  BORINGSSL_keccak_absorb(&keccak_ctx, priv->public_key_hash,
-                          sizeof(priv->public_key_hash));
+  BORINGSSL_keccak_absorb(&keccak_ctx, priv->pub.public_key_hash,
+                          sizeof(priv->pub.public_key_hash));
   BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
   BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
   BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
@@ -2082,25 +2073,32 @@
 #include "fips_known_values.inc"
 
 static int keygen_self_test() {
-  private_key<6, 5> priv;
-  uint8_t pub_bytes[MLDSA65_PUBLIC_KEY_BYTES];
-  if (!mldsa_generate_key_external_entropy_no_self_test(pub_bytes, &priv,
-                                                        kGenerateKeyEntropy)) {
+  struct Values {
+    enum { kAllowUniquePtr = true };
+    private_key<6, 5> priv;
+    uint8_t pub_bytes[MLDSA65_PUBLIC_KEY_BYTES];
+    uint8_t priv_bytes[BCM_MLDSA65_PRIVATE_KEY_BYTES];
+  };
+  auto values = bssl::MakeUnique<Values>();
+  if (values == nullptr ||
+      !mldsa_generate_key_external_entropy_no_self_test(
+          values->pub_bytes, &values->priv, kGenerateKeyEntropy)) {
     return 0;
   }
 
-  uint8_t priv_bytes[BCM_MLDSA65_PRIVATE_KEY_BYTES];
   CBB cbb;
-  CBB_init_fixed(&cbb, priv_bytes, sizeof(priv_bytes));
-  if (!mldsa_marshal_private_key(&cbb, &priv)) {
+  CBB_init_fixed(&cbb, values->priv_bytes, sizeof(values->priv_bytes));
+  if (!mldsa_marshal_private_key(&cbb, &values->priv)) {
     return 0;
   }
 
-  static_assert(sizeof(pub_bytes) == sizeof(kExpectedPublicKey));
-  static_assert(sizeof(priv_bytes) == sizeof(kExpectedPrivateKey));
-  if (!BORINGSSL_check_test(kExpectedPublicKey, pub_bytes, sizeof(pub_bytes),
+  static_assert(sizeof(values->pub_bytes) == sizeof(kExpectedPublicKey));
+  static_assert(sizeof(values->priv_bytes) == sizeof(kExpectedPrivateKey));
+  if (!BORINGSSL_check_test(kExpectedPublicKey, values->pub_bytes,
+                            sizeof(values->pub_bytes),
                             "ML-DSA keygen public key") ||
-      !BORINGSSL_check_test(kExpectedPrivateKey, priv_bytes, sizeof(priv_bytes),
+      !BORINGSSL_check_test(kExpectedPrivateKey, values->priv_bytes,
+                            sizeof(values->priv_bytes),
                             "ML-DSA keygen private key")) {
     return 0;
   }
@@ -2109,36 +2107,44 @@
 }
 
 static int sign_self_test() {
-  private_key<6, 5> priv;
-  uint8_t pub_bytes[MLDSA65_PUBLIC_KEY_BYTES];
-  if (!mldsa_generate_key_external_entropy(pub_bytes, &priv, kSignEntropy)) {
+  struct Values {
+    enum { kAllowUniquePtr = true };
+    private_key<6, 5> priv;
+    uint8_t pub_bytes[MLDSA65_PUBLIC_KEY_BYTES];
+    uint8_t sig[MLDSA65_SIGNATURE_BYTES];
+  };
+  auto values = bssl::MakeUnique<Values>();
+  if (values == nullptr ||
+      !mldsa_generate_key_external_entropy(values->pub_bytes, &values->priv,
+                                           kSignEntropy)) {
     return 0;
   }
 
   const uint8_t randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES] = {};
-  uint8_t sig[MLDSA65_SIGNATURE_BYTES];
 
   // This message triggers the first restart case for signing.
   uint8_t message[4] = {0};
-  if (!mldsa_sign_internal_no_self_test(sig, &priv, message, sizeof(message),
-                                        nullptr, 0, nullptr, 0, randomizer)) {
+  if (!mldsa_sign_internal_no_self_test(values->sig, &values->priv, message,
+                                        sizeof(message), nullptr, 0, nullptr, 0,
+                                        randomizer)) {
     return 0;
   }
-  static_assert(sizeof(kExpectedCase1Signature) == sizeof(sig));
-  if (!BORINGSSL_check_test(kExpectedCase1Signature, sig, sizeof(sig),
-                            "ML-DSA sign case 1")) {
+  static_assert(sizeof(kExpectedCase1Signature) == sizeof(values->sig));
+  if (!BORINGSSL_check_test(kExpectedCase1Signature, values->sig,
+                            sizeof(values->sig), "ML-DSA sign case 1")) {
     return 0;
   }
 
   // This message triggers the second restart case for signing.
   message[0] = 123;
-  if (!mldsa_sign_internal_no_self_test(sig, &priv, message, sizeof(message),
-                                        nullptr, 0, nullptr, 0, randomizer)) {
+  if (!mldsa_sign_internal_no_self_test(values->sig, &values->priv, message,
+                                        sizeof(message), nullptr, 0, nullptr, 0,
+                                        randomizer)) {
     return 0;
   }
-  static_assert(sizeof(kExpectedCase2Signature) == sizeof(sig));
-  if (!BORINGSSL_check_test(kExpectedCase2Signature, sig, sizeof(sig),
-                            "ML-DSA sign case 2")) {
+  static_assert(sizeof(kExpectedCase2Signature) == sizeof(values->sig));
+  if (!BORINGSSL_check_test(kExpectedCase2Signature, values->sig,
+                            sizeof(values->sig), "ML-DSA sign case 2")) {
     return 0;
   }
 
@@ -2149,7 +2155,6 @@
   struct Values {
     enum { kAllowUniquePtr = true };
     private_key<6, 5> priv;
-    public_key<6> pub;
     uint8_t pub_bytes[MLDSA65_PUBLIC_KEY_BYTES];
   };
   auto values = bssl::MakeUnique<Values>();
@@ -2163,9 +2168,8 @@
   }
 
   const uint8_t message[4] = {1, 0};
-  if (!mldsa_public_from_private(&values->pub, &values->priv) ||
-      !mldsa_verify_internal_no_self_test<6, 5>(
-          &values->pub, kExpectedVerifySignature, message, sizeof(message),
+  if (!mldsa_verify_internal_no_self_test<6, 5>(
+          &values->priv.pub, kExpectedVerifySignature, message, sizeof(message),
           nullptr, 0, nullptr, 0)) {
     return 0;
   }
@@ -2174,12 +2178,10 @@
 }
 
 template <int K, int L>
-int check_key(private_key<K, L> *priv) {
+int check_key(const private_key<K, L> *priv) {
   uint8_t sig[signature_bytes<K>()];
   uint8_t randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES] = {};
-  mldsa::public_key<K> pub;
-  if (!mldsa_public_from_private(&pub, priv) ||
-      !mldsa_sign_internal_no_self_test(sig, priv, nullptr, 0, nullptr, 0,
+  if (!mldsa_sign_internal_no_self_test(sig, priv, nullptr, 0, nullptr, 0,
                                         nullptr, 0, randomizer)) {
     return 0;
   }
@@ -2188,8 +2190,8 @@
     sig[0] ^= 1;
   }
 
-  if (!mldsa_verify_internal_no_self_test<K, L>(&pub, sig, nullptr, 0, nullptr,
-                                                0, nullptr, 0)) {
+  if (!mldsa_verify_internal_no_self_test<K, L>(&priv->pub, sig, nullptr, 0,
+                                                nullptr, 0, nullptr, 0)) {
     return 0;
   }
   return 1;
@@ -2338,9 +2340,10 @@
 bcm_status BCM_mldsa65_public_from_private(
     MLDSA65_public_key *out_public_key,
     const MLDSA65_private_key *private_key) {
-  return bcm_as_approved_status(mldsa_public_from_private(
-      mldsa::public_key_from_external_65(out_public_key),
-      mldsa::private_key_from_external_65(private_key)));
+  const auto *priv = mldsa::private_key_from_external_65(private_key);
+  auto *out_pub = mldsa::public_key_from_external_65(out_public_key);
+  *out_pub = priv->pub;
+  return bcm_status::approved;
 }
 
 bcm_status BCM_mldsa65_sign_internal(
@@ -2538,9 +2541,10 @@
 bcm_status BCM_mldsa87_public_from_private(
     MLDSA87_public_key *out_public_key,
     const MLDSA87_private_key *private_key) {
-  return bcm_as_approved_status(mldsa_public_from_private(
-      mldsa::public_key_from_external_87(out_public_key),
-      mldsa::private_key_from_external_87(private_key)));
+  const auto *priv = mldsa::private_key_from_external_87(private_key);
+  auto *out_pub = mldsa::public_key_from_external_87(out_public_key);
+  *out_pub = priv->pub;
+  return bcm_status::approved;
 }
 
 bcm_status BCM_mldsa87_sign_internal(
@@ -2737,9 +2741,10 @@
 bcm_status BCM_mldsa44_public_from_private(
     MLDSA44_public_key *out_public_key,
     const MLDSA44_private_key *private_key) {
-  return bcm_as_approved_status(mldsa_public_from_private(
-      mldsa::public_key_from_external_44(out_public_key),
-      mldsa::private_key_from_external_44(private_key)));
+  const auto *priv = mldsa::private_key_from_external_44(private_key);
+  auto *out_pub = mldsa::public_key_from_external_44(out_public_key);
+  *out_pub = priv->pub;
+  return bcm_status::approved;
 }
 
 bcm_status BCM_mldsa44_sign_internal(
diff --git a/include/openssl/mldsa.h b/include/openssl/mldsa.h
index decb5b3..d15ad3c 100644
--- a/include/openssl/mldsa.h
+++ b/include/openssl/mldsa.h
@@ -41,7 +41,7 @@
 // object should never leave the address space since the format is unstable.
 struct MLDSA65_private_key {
   union {
-    uint8_t bytes[32 + 32 + 64 + 256 * 4 * (5 + 6 + 6)];
+    uint8_t bytes[(32 + 64 + 256 * 4 * 6) + 32 + 256 * 4 * (5 + 6 + 6)];
     uint32_t alignment;
   } opaque;
 };
@@ -180,7 +180,7 @@
 // object should never leave the address space since the format is unstable.
 struct MLDSA87_private_key {
   union {
-    uint8_t bytes[32 + 32 + 64 + 256 * 4 * (7 + 8 + 8)];
+    uint8_t bytes[(32 + 64 + 256 * 4 * 8) + 32 + 256 * 4 * (7 + 8 + 8)];
     uint32_t alignment;
   } opaque;
 };
@@ -316,7 +316,7 @@
 // object should never leave the address space since the format is unstable.
 struct MLDSA44_private_key {
   union {
-    uint8_t bytes[32 + 32 + 64 + 256 * 4 * (4 + 4 + 4)];
+    uint8_t bytes[(32 + 64 + 256 * 4 * 4) + 32 + 256 * 4 * (4 + 4 + 4)];
     uint32_t alignment;
   } opaque;
 };