Add an external mu variant of the ML-DSA API (65 and 87 variants).

Change-Id: Ie637a0968cc008f8fd894113e21cb64d1ede1e97
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/76747
Reviewed-by: Adam Langley <agl@google.com>
Commit-Queue: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bcm_interface.h b/crypto/fipsmodule/bcm_interface.h
index a70636d..b11a558 100644
--- a/crypto/fipsmodule/bcm_interface.h
+++ b/crypto/fipsmodule/bcm_interface.h
@@ -257,6 +257,9 @@
 // BCM_MLDSA_SEED_BYTES is the number of bytes in an ML-DSA seed value.
 #define BCM_MLDSA_SEED_BYTES 32
 
+// BCM_MLDSA_MU_BYTES is the number of bytes in an ML-DSA mu value.
+#define BCM_MLDSA_MU_BYTES 64
+
 // BCM_MLDSA65_PRIVATE_KEY_BYTES is the number of bytes in an encoded ML-DSA-65
 // private key.
 #define BCM_MLDSA65_PRIVATE_KEY_BYTES 4032
@@ -283,6 +286,13 @@
   } opaque;
 };
 
+struct BCM_mldsa65_prehash {
+  union {
+    uint8_t bytes[200 + 4 + 4 + 4 * sizeof(size_t)];
+    uint64_t alignment;
+  } opaque;
+};
+
 OPENSSL_EXPORT bcm_status BCM_mldsa65_generate_key(
     uint8_t out_encoded_public_key[BCM_MLDSA65_PUBLIC_KEY_BYTES],
     uint8_t out_seed[BCM_MLDSA_SEED_BYTES],
@@ -318,6 +328,24 @@
     const uint8_t signature[BCM_MLDSA65_SIGNATURE_BYTES], const uint8_t *msg,
     size_t msg_len, const uint8_t *context, size_t context_len);
 
+OPENSSL_EXPORT void BCM_mldsa65_prehash_init(
+    struct BCM_mldsa65_prehash *out_prehash_ctx,
+    const struct BCM_mldsa65_public_key *public_key, const uint8_t *context,
+    size_t context_len);
+
+OPENSSL_EXPORT void BCM_mldsa65_prehash_update(
+    struct BCM_mldsa65_prehash *inout_prehash_ctx, const uint8_t *msg,
+    size_t msg_len);
+
+OPENSSL_EXPORT void BCM_mldsa65_prehash_finalize(
+    uint8_t out_msg_rep[BCM_MLDSA_MU_BYTES],
+    struct BCM_mldsa65_prehash *inout_prehash_ctx);
+
+OPENSSL_EXPORT bcm_status BCM_mldsa65_sign_message_representative(
+    uint8_t out_encoded_signature[BCM_MLDSA65_SIGNATURE_BYTES],
+    const struct BCM_mldsa65_private_key *private_key,
+    const uint8_t msg_rep[BCM_MLDSA_MU_BYTES]);
+
 OPENSSL_EXPORT bcm_status BCM_mldsa65_marshal_public_key(
     CBB *out, const struct BCM_mldsa65_public_key *public_key);
 
@@ -393,6 +421,13 @@
   } opaque;
 };
 
+struct BCM_mldsa87_prehash {
+  union {
+    uint8_t bytes[200 + 4 + 4 + 4 * sizeof(size_t)];
+    uint64_t alignment;
+  } opaque;
+};
+
 OPENSSL_EXPORT bcm_status BCM_mldsa87_generate_key(
     uint8_t out_encoded_public_key[BCM_MLDSA87_PUBLIC_KEY_BYTES],
     uint8_t out_seed[BCM_MLDSA_SEED_BYTES],
@@ -428,6 +463,24 @@
                    const uint8_t *signature, const uint8_t *msg, size_t msg_len,
                    const uint8_t *context, size_t context_len);
 
+OPENSSL_EXPORT void BCM_mldsa87_prehash_init(
+    struct BCM_mldsa87_prehash *out_prehash_ctx,
+    const struct BCM_mldsa87_public_key *public_key, const uint8_t *context,
+    size_t context_len);
+
+OPENSSL_EXPORT void BCM_mldsa87_prehash_update(
+    struct BCM_mldsa87_prehash *inout_prehash_ctx, const uint8_t *msg,
+    size_t msg_len);
+
+OPENSSL_EXPORT void BCM_mldsa87_prehash_finalize(
+    uint8_t out_msg_rep[BCM_MLDSA_MU_BYTES],
+    struct BCM_mldsa87_prehash *inout_prehash_ctx);
+
+OPENSSL_EXPORT bcm_status BCM_mldsa87_sign_message_representative(
+    uint8_t out_encoded_signature[BCM_MLDSA87_SIGNATURE_BYTES],
+    const struct BCM_mldsa87_private_key *private_key,
+    const uint8_t msg_rep[BCM_MLDSA_MU_BYTES]);
+
 OPENSSL_EXPORT bcm_status BCM_mldsa87_marshal_public_key(
     CBB *out, const struct BCM_mldsa87_public_key *public_key);
 
diff --git a/crypto/fipsmodule/keccak/internal.h b/crypto/fipsmodule/keccak/internal.h
index ba76af9..f27f3a4 100644
--- a/crypto/fipsmodule/keccak/internal.h
+++ b/crypto/fipsmodule/keccak/internal.h
@@ -22,23 +22,26 @@
 #endif
 
 
-enum boringssl_keccak_config_t {
+enum boringssl_keccak_config_t : int32_t {
   boringssl_sha3_256,
   boringssl_sha3_512,
   boringssl_shake128,
   boringssl_shake256,
 };
 
-enum boringssl_keccak_phase_t {
+enum boringssl_keccak_phase_t : int32_t {
   boringssl_keccak_phase_absorb,
   boringssl_keccak_phase_squeeze,
 };
 
 struct BORINGSSL_keccak_st {
+  // Note: the state with 64-bit integers comes first so that the size of this
+  // struct is easy to compute on all architectures without padding surprises
+  // due to alignment.
+  uint64_t state[25];
   enum boringssl_keccak_config_t config;
   enum boringssl_keccak_phase_t phase;
   size_t required_out_len;
-  uint64_t state[25];
   size_t rate_bytes;
   size_t absorb_offset;
   size_t squeeze_offset;
diff --git a/crypto/fipsmodule/mldsa/mldsa.cc.inc b/crypto/fipsmodule/mldsa/mldsa.cc.inc
index e89d87c..2e59840 100644
--- a/crypto/fipsmodule/mldsa/mldsa.cc.inc
+++ b/crypto/fipsmodule/mldsa/mldsa.cc.inc
@@ -1512,26 +1512,15 @@
   return 1;
 }
 
-// FIPS 204, Algorithm 7 (`ML-DSA.Sign_internal`). Returns 1 on success and 0
-// on failure.
+// FIPS 204, Algorithm 7 (`ML-DSA.Sign_internal`), using a pre-computed mu.
+// Returns 1 on success and 0 on failure.
 template <int K, int L>
-int mldsa_sign_internal_no_self_test(
+int mldsa_sign_mu(
     uint8_t out_encoded_signature[signature_bytes<K>()],
-    const struct private_key<K, L> *priv, const uint8_t *msg, size_t msg_len,
-    const uint8_t *context_prefix, size_t context_prefix_len,
-    const uint8_t *context, size_t context_len,
+    const struct private_key<K, L> *priv, const uint8_t mu[kMuBytes],
     const uint8_t randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES]) {
-  uint8_t mu[kMuBytes];
-  struct BORINGSSL_keccak_st keccak_ctx;
-  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
-  BORINGSSL_keccak_absorb(&keccak_ctx, priv->public_key_hash,
-                          sizeof(priv->public_key_hash));
-  BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
-  BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
-  BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
-  BORINGSSL_keccak_squeeze(&keccak_ctx, mu, kMuBytes);
-
   uint8_t rho_prime[kRhoPrimeBytes];
+  struct BORINGSSL_keccak_st keccak_ctx;
   BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
   BORINGSSL_keccak_absorb(&keccak_ctx, priv->k, sizeof(priv->k));
   BORINGSSL_keccak_absorb(&keccak_ctx, randomizer,
@@ -1663,6 +1652,30 @@
   }
 }
 
+// FIPS 204, Algorithm 7 (`ML-DSA.Sign_internal`). Returns 1 on success and 0
+// on failure.
+template <int K, int L>
+int mldsa_sign_internal_no_self_test(
+    uint8_t out_encoded_signature[signature_bytes<K>()],
+    const struct private_key<K, L> *priv, const uint8_t *msg, size_t msg_len,
+    const uint8_t *context_prefix, size_t context_prefix_len,
+    const uint8_t *context, size_t context_len,
+    const uint8_t randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES]) {
+  uint8_t mu[kMuBytes];
+  struct BORINGSSL_keccak_st keccak_ctx;
+  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+  BORINGSSL_keccak_absorb(&keccak_ctx, priv->public_key_hash,
+                          sizeof(priv->public_key_hash));
+  BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
+  BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
+  BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
+  BORINGSSL_keccak_squeeze(&keccak_ctx, mu, kMuBytes);
+
+  return mldsa_sign_mu(out_encoded_signature, priv, mu, randomizer);
+}
+
+// FIPS 204, Algorithm 7 (`ML-DSA.Sign_internal`). Returns 1 on success and 0
+// on failure.
 template <int K, int L>
 int mldsa_sign_internal(
     uint8_t out_encoded_signature[signature_bytes<K>()],
@@ -1676,6 +1689,35 @@
       context_prefix_len, context, context_len, randomizer);
 }
 
+struct prehash_context {
+  struct BORINGSSL_keccak_st keccak_ctx;
+};
+
+template <int K>
+void mldsa_prehash_init(struct prehash_context *out_prehash_ctx,
+                        const struct public_key<K> *pub,
+                        const uint8_t *context_prefix,
+                        size_t context_prefix_len, const uint8_t *context,
+                        size_t context_len) {
+  BORINGSSL_keccak_init(&out_prehash_ctx->keccak_ctx, boringssl_shake256);
+  BORINGSSL_keccak_absorb(&out_prehash_ctx->keccak_ctx, pub->public_key_hash,
+                          sizeof(pub->public_key_hash));
+  BORINGSSL_keccak_absorb(&out_prehash_ctx->keccak_ctx, context_prefix,
+                          context_prefix_len);
+  BORINGSSL_keccak_absorb(&out_prehash_ctx->keccak_ctx, context, context_len);
+}
+
+void mldsa_prehash_update(struct prehash_context *inout_prehash_ctx,
+                          const uint8_t *msg, size_t msg_len) {
+  BORINGSSL_keccak_absorb(&inout_prehash_ctx->keccak_ctx, msg, msg_len);
+}
+
+void mldsa_prehash_finalize(uint8_t out_msg_rep[kMuBytes],
+                            struct prehash_context *inout_prehash_ctx) {
+  BORINGSSL_keccak_squeeze(&inout_prehash_ctx->keccak_ctx, out_msg_rep,
+                           kMuBytes);
+}
+
 // FIPS 204, Algorithm 8 (`ML-DSA.Verify_internal`).
 template <int K, int L>
 int mldsa_verify_internal_no_self_test(
@@ -1785,6 +1827,17 @@
   return (struct public_key<6> *)external;
 }
 
+struct prehash_context *prehash_context_from_external_65(
+    const struct BCM_mldsa65_prehash *external) {
+  static_assert(
+      sizeof(struct BCM_mldsa65_prehash) == sizeof(struct prehash_context),
+      "MLDSA pre-hash context size incorrect");
+  static_assert(
+      alignof(struct BCM_mldsa65_prehash) == alignof(struct prehash_context),
+      "MLDSA pre-hash context align incorrect");
+  return (struct prehash_context *)external;
+}
+
 struct private_key<8, 7> *private_key_from_external_87(
     const struct BCM_mldsa87_private_key *external) {
   static_assert(sizeof(struct BCM_mldsa87_private_key) ==
@@ -1807,6 +1860,17 @@
   return (struct public_key<8> *)external;
 }
 
+struct prehash_context *prehash_context_from_external_87(
+    const struct BCM_mldsa87_prehash *external) {
+  static_assert(
+      sizeof(struct BCM_mldsa87_prehash) == sizeof(struct prehash_context),
+      "MLDSA pre-hash context size incorrect");
+  static_assert(
+      alignof(struct BCM_mldsa87_prehash) == alignof(struct prehash_context),
+      "MLDSA pre-hash context align incorrect");
+  return (struct prehash_context *)external;
+}
+
 namespace fips {
 
 #include "fips_known_values.inc"
@@ -2105,6 +2169,48 @@
       sizeof(context_prefix), context, context_len, randomizer);
 }
 
+// ML-DSA pre-hashed API: initializing a pre-hashing context.
+void BCM_mldsa65_prehash_init(struct BCM_mldsa65_prehash *out_prehash_ctx,
+                              const struct BCM_mldsa65_public_key *public_key,
+                              const uint8_t *context, size_t context_len) {
+  BSSL_CHECK(context_len <= 255);
+
+  const uint8_t context_prefix[2] = {0, static_cast<uint8_t>(context_len)};
+  mldsa_prehash_init(mldsa::prehash_context_from_external_65(out_prehash_ctx),
+                     mldsa::public_key_from_external_65(public_key),
+                     context_prefix, sizeof(context_prefix), context,
+                     context_len);
+}
+
+// ML-DSA pre-hashed API: updating a pre-hashing context with a message chunk.
+void BCM_mldsa65_prehash_update(struct BCM_mldsa65_prehash *inout_prehash_ctx,
+                                const uint8_t *msg, size_t msg_len) {
+  mldsa_prehash_update(
+      mldsa::prehash_context_from_external_65(inout_prehash_ctx), msg, msg_len);
+}
+
+// ML-DSA pre-hashed API: obtaining a message representative to sign.
+void BCM_mldsa65_prehash_finalize(
+    uint8_t out_msg_rep[BCM_MLDSA_MU_BYTES],
+    struct BCM_mldsa65_prehash *inout_prehash_ctx) {
+  mldsa_prehash_finalize(
+      out_msg_rep, mldsa::prehash_context_from_external_65(inout_prehash_ctx));
+}
+
+// ML-DSA pre-hashed API: signing a message representative.
+bcm_status BCM_mldsa65_sign_message_representative(
+    uint8_t out_encoded_signature[BCM_MLDSA65_SIGNATURE_BYTES],
+    const struct BCM_mldsa65_private_key *private_key,
+    const uint8_t msg_rep[BCM_MLDSA_MU_BYTES]) {
+  uint8_t randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES];
+  BCM_rand_bytes(randomizer, sizeof(randomizer));
+  CONSTTIME_SECRET(randomizer, sizeof(randomizer));
+
+  return bcm_as_approved_status(mldsa_sign_mu(
+      out_encoded_signature, mldsa::private_key_from_external_65(private_key),
+      msg_rep, randomizer));
+}
+
 // FIPS 204, Algorithm 3 (`ML-DSA.Verify`).
 bcm_status BCM_mldsa65_verify(
     const struct BCM_mldsa65_public_key *public_key,
@@ -2266,6 +2372,48 @@
       sizeof(context_prefix), context, context_len, randomizer);
 }
 
+// ML-DSA pre-hashed API: initializing a pre-hashing context.
+void BCM_mldsa87_prehash_init(struct BCM_mldsa87_prehash *out_prehash_ctx,
+                              const struct BCM_mldsa87_public_key *public_key,
+                              const uint8_t *context, size_t context_len) {
+  BSSL_CHECK(context_len <= 255);
+
+  const uint8_t context_prefix[2] = {0, static_cast<uint8_t>(context_len)};
+  mldsa_prehash_init(mldsa::prehash_context_from_external_87(out_prehash_ctx),
+                     mldsa::public_key_from_external_87(public_key),
+                     context_prefix, sizeof(context_prefix), context,
+                     context_len);
+}
+
+// ML-DSA pre-hashed API: updating a pre-hashing context with a message chunk.
+void BCM_mldsa87_prehash_update(struct BCM_mldsa87_prehash *inout_prehash_ctx,
+                                const uint8_t *msg, size_t msg_len) {
+  mldsa_prehash_update(
+      mldsa::prehash_context_from_external_87(inout_prehash_ctx), msg, msg_len);
+}
+
+// ML-DSA pre-hashed API: obtaining a message representative to sign.
+void BCM_mldsa87_prehash_finalize(
+    uint8_t out_msg_rep[BCM_MLDSA_MU_BYTES],
+    struct BCM_mldsa87_prehash *inout_prehash_ctx) {
+  mldsa_prehash_finalize(
+      out_msg_rep, mldsa::prehash_context_from_external_87(inout_prehash_ctx));
+}
+
+// ML-DSA pre-hashed API: signing a message representative.
+bcm_status BCM_mldsa87_sign_message_representative(
+    uint8_t out_encoded_signature[BCM_MLDSA87_SIGNATURE_BYTES],
+    const struct BCM_mldsa87_private_key *private_key,
+    const uint8_t msg_rep[BCM_MLDSA_MU_BYTES]) {
+  uint8_t randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES];
+  BCM_rand_bytes(randomizer, sizeof(randomizer));
+  CONSTTIME_SECRET(randomizer, sizeof(randomizer));
+
+  return bcm_as_approved_status(mldsa_sign_mu(
+      out_encoded_signature, mldsa::private_key_from_external_87(private_key),
+      msg_rep, randomizer));
+}
+
 // FIPS 204, Algorithm 3 (`ML-DSA.Verify`).
 bcm_status BCM_mldsa87_verify(const struct BCM_mldsa87_public_key *public_key,
                               const uint8_t *signature, const uint8_t *msg,
diff --git a/crypto/mldsa/mldsa.cc b/crypto/mldsa/mldsa.cc
index ab125e7..721d2db 100644
--- a/crypto/mldsa/mldsa.cc
+++ b/crypto/mldsa/mldsa.cc
@@ -20,11 +20,16 @@
 static_assert(alignof(BCM_mldsa65_private_key) == alignof(MLDSA65_private_key));
 static_assert(sizeof(BCM_mldsa65_public_key) == sizeof(MLDSA65_public_key));
 static_assert(alignof(BCM_mldsa65_public_key) == alignof(MLDSA65_public_key));
+static_assert(sizeof(BCM_mldsa65_prehash) == sizeof(MLDSA65_prehash));
+static_assert(alignof(BCM_mldsa65_prehash) == alignof(MLDSA65_prehash));
 static_assert(sizeof(BCM_mldsa87_private_key) == sizeof(MLDSA87_private_key));
 static_assert(alignof(BCM_mldsa87_private_key) == alignof(MLDSA87_private_key));
 static_assert(sizeof(BCM_mldsa87_public_key) == sizeof(MLDSA87_public_key));
 static_assert(alignof(BCM_mldsa87_public_key) == alignof(MLDSA87_public_key));
+static_assert(sizeof(BCM_mldsa87_prehash) == sizeof(MLDSA87_prehash));
+static_assert(alignof(BCM_mldsa87_prehash) == alignof(MLDSA87_prehash));
 static_assert(MLDSA_SEED_BYTES == BCM_MLDSA_SEED_BYTES);
+static_assert(MLDSA_MU_BYTES == BCM_MLDSA_MU_BYTES);
 static_assert(MLDSA65_PRIVATE_KEY_BYTES == BCM_MLDSA65_PRIVATE_KEY_BYTES);
 static_assert(MLDSA65_PUBLIC_KEY_BYTES == BCM_MLDSA65_PUBLIC_KEY_BYTES);
 static_assert(MLDSA65_SIGNATURE_BYTES == BCM_MLDSA65_SIGNATURE_BYTES);
@@ -82,6 +87,40 @@
       msg, msg_len, context, context_len));
 }
 
+int MLDSA65_prehash_init(struct MLDSA65_prehash *out_state,
+                         const struct MLDSA65_public_key *public_key,
+                         const uint8_t *context, size_t context_len) {
+  if (context_len > 255) {
+    return 0;
+  }
+  BCM_mldsa65_prehash_init(
+      reinterpret_cast<BCM_mldsa65_prehash *>(out_state),
+      reinterpret_cast<const BCM_mldsa65_public_key *>(public_key), context,
+      context_len);
+  return 1;
+}
+
+void MLDSA65_prehash_update(struct MLDSA65_prehash *inout_state,
+                            const uint8_t *msg, size_t msg_len) {
+  BCM_mldsa65_prehash_update(
+      reinterpret_cast<BCM_mldsa65_prehash *>(inout_state), msg, msg_len);
+}
+
+void MLDSA65_prehash_finalize(uint8_t out_msg_rep[MLDSA_MU_BYTES],
+                              struct MLDSA65_prehash *inout_state) {
+  BCM_mldsa65_prehash_finalize(
+      out_msg_rep, reinterpret_cast<BCM_mldsa65_prehash *>(inout_state));
+}
+
+int MLDSA65_sign_message_representative(
+    uint8_t out_encoded_signature[MLDSA65_SIGNATURE_BYTES],
+    const struct MLDSA65_private_key *private_key,
+    const uint8_t msg_rep[MLDSA_MU_BYTES]) {
+  return bcm_success(BCM_mldsa65_sign_message_representative(
+      out_encoded_signature,
+      reinterpret_cast<const BCM_mldsa65_private_key *>(private_key), msg_rep));
+}
+
 int MLDSA65_marshal_public_key(CBB *out,
                                const struct MLDSA65_public_key *public_key) {
   return bcm_success(BCM_mldsa65_marshal_public_key(
@@ -143,6 +182,40 @@
       msg, msg_len, context, context_len));
 }
 
+int MLDSA87_prehash_init(struct MLDSA87_prehash *out_state,
+                         const struct MLDSA87_public_key *public_key,
+                         const uint8_t *context, size_t context_len) {
+  if (context_len > 255) {
+    return 0;
+  }
+  BCM_mldsa87_prehash_init(
+      reinterpret_cast<BCM_mldsa87_prehash *>(out_state),
+      reinterpret_cast<const BCM_mldsa87_public_key *>(public_key), context,
+      context_len);
+  return 1;
+}
+
+void MLDSA87_prehash_update(struct MLDSA87_prehash *inout_state,
+                            const uint8_t *msg, size_t msg_len) {
+  BCM_mldsa87_prehash_update(
+      reinterpret_cast<BCM_mldsa87_prehash *>(inout_state), msg, msg_len);
+}
+
+void MLDSA87_prehash_finalize(uint8_t out_msg_rep[MLDSA_MU_BYTES],
+                              struct MLDSA87_prehash *inout_state) {
+  BCM_mldsa87_prehash_finalize(
+      out_msg_rep, reinterpret_cast<BCM_mldsa87_prehash *>(inout_state));
+}
+
+int MLDSA87_sign_message_representative(
+    uint8_t out_encoded_signature[MLDSA87_SIGNATURE_BYTES],
+    const struct MLDSA87_private_key *private_key,
+    const uint8_t msg_rep[MLDSA_MU_BYTES]) {
+  return bcm_success(BCM_mldsa87_sign_message_representative(
+      out_encoded_signature,
+      reinterpret_cast<const BCM_mldsa87_private_key *>(private_key), msg_rep));
+}
+
 int MLDSA87_marshal_public_key(CBB *out,
                                const struct MLDSA87_public_key *public_key) {
   return bcm_success(BCM_mldsa87_marshal_public_key(
diff --git a/crypto/mldsa/mldsa_test.cc b/crypto/mldsa/mldsa_test.cc
index 9f1da28..b4053de 100644
--- a/crypto/mldsa/mldsa_test.cc
+++ b/crypto/mldsa/mldsa_test.cc
@@ -185,6 +185,54 @@
             1);
 }
 
+TEST(MLDSATest, PrehashedSignatureVerifies) {
+  std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
+  auto priv = std::make_unique<MLDSA65_private_key>();
+  uint8_t seed[MLDSA_SEED_BYTES];
+  EXPECT_TRUE(
+      MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));
+
+  auto pub = std::make_unique<MLDSA65_public_key>();
+  CBS cbs = CBS(encoded_public_key);
+  ASSERT_TRUE(MLDSA65_parse_public_key(pub.get(), &cbs));
+
+  std::vector<uint8_t> encoded_signature(MLDSA65_SIGNATURE_BYTES);
+  static const uint8_t kMessage[] = {'H', 'e', 'l', 'l', 'o', ' ',
+                                     'w', 'o', 'r', 'l', 'd'};
+
+  MLDSA65_prehash prehash_state;
+  EXPECT_TRUE(MLDSA65_prehash_init(&prehash_state, pub.get(), nullptr, 0));
+  MLDSA65_prehash_update(&prehash_state, kMessage, sizeof(kMessage));
+  uint8_t representative[MLDSA_MU_BYTES];
+  MLDSA65_prehash_finalize(representative, &prehash_state);
+  EXPECT_TRUE(MLDSA65_sign_message_representative(encoded_signature.data(),
+                                                  priv.get(), representative));
+
+  EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature.data(),
+                           encoded_signature.size(), kMessage, sizeof(kMessage),
+                           nullptr, 0),
+            1);
+
+  // Updating in multiple chunks also works.
+  for (size_t i = 0; i <= sizeof(kMessage); ++i) {
+    for (size_t j = i; j <= sizeof(kMessage); ++j) {
+      EXPECT_TRUE(MLDSA65_prehash_init(&prehash_state, pub.get(), nullptr, 0));
+      MLDSA65_prehash_update(&prehash_state, kMessage, i);
+      MLDSA65_prehash_update(&prehash_state, kMessage + i, j - i);
+      MLDSA65_prehash_update(&prehash_state, kMessage + j,
+                             sizeof(kMessage) - j);
+      MLDSA65_prehash_finalize(representative, &prehash_state);
+      EXPECT_TRUE(MLDSA65_sign_message_representative(
+          encoded_signature.data(), priv.get(), representative));
+
+      EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature.data(),
+                               encoded_signature.size(), kMessage,
+                               sizeof(kMessage), nullptr, 0),
+                1);
+    }
+  }
+}
+
 TEST(MLDSATest, PublicFromPrivateIsConsistent) {
   std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
   auto priv = std::make_unique<MLDSA65_private_key>();
diff --git a/include/openssl/mldsa.h b/include/openssl/mldsa.h
index 569c89c..8bc829b 100644
--- a/include/openssl/mldsa.h
+++ b/include/openssl/mldsa.h
@@ -31,6 +31,9 @@
 // MLDSA_SEED_BYTES is the number of bytes in an ML-DSA seed value.
 #define MLDSA_SEED_BYTES 32
 
+// MLDSA_MU_BYTES is the number of bytes in an ML-DSA mu value.
+#define MLDSA_MU_BYTES 64
+
 
 // ML-DSA-65.
 
@@ -52,6 +55,16 @@
   } opaque;
 };
 
+// MLDSA65_prehash contains a pre-hash context for ML-DSA-65. The contents of
+// this object should never leave the address space since the format is
+// unstable.
+struct MLDSA65_prehash {
+  union {
+    uint8_t bytes[200 + 4 + 4 + 4 * sizeof(size_t)];
+    uint64_t alignment;
+  } opaque;
+};
+
 // MLDSA65_PRIVATE_KEY_BYTES is the number of bytes in an encoded ML-DSA-65
 // private key.
 #define MLDSA65_PRIVATE_KEY_BYTES 4032
@@ -109,6 +122,45 @@
                                   size_t msg_len, const uint8_t *context,
                                   size_t context_len);
 
+// MLDSA65_prehash_init initializes a pre-hashing state using |public_key|. The
+// |context| argument can be used to include implicit contextual information
+// that isn't included in the message. The same value of |context| must be
+// presented to |MLDSA65_verify| in order for the generated signature to be
+// considered valid. |context| and |context_len| may be |NULL| and 0 to use an
+// empty context (this is common). Returns 1 on success and 0 on failure (if the
+// context is too long).
+OPENSSL_EXPORT int MLDSA65_prehash_init(
+    struct MLDSA65_prehash *out_state,
+    const struct MLDSA65_public_key *public_key, const uint8_t *context,
+    size_t context_len);
+
+// MLDSA65_prehash_update incorporates the given |msg| of length |msg_len| into
+// the pre-hashing state. This can be called multiple times on successive chunks
+// of the message. This should be called after |MLDSA65_prehash_init| and before
+// |MLDSA65_prehash_finalize|.
+OPENSSL_EXPORT void MLDSA65_prehash_update(struct MLDSA65_prehash *inout_state,
+                                           const uint8_t *msg, size_t msg_len);
+
+// MLDSA65_prehash_finalize extracts a pre-hashed message representative from
+// the given pre-hashing state. This should be called after
+// |MLDSA65_prehash_init| and |MLDSA65_prehash_update|. The resulting
+// |out_msg_rep| should then be passed to |MLDSA65_sign_message_representative|
+// to obtain a signature.
+OPENSSL_EXPORT void MLDSA65_prehash_finalize(
+    uint8_t out_msg_rep[MLDSA_MU_BYTES], struct MLDSA65_prehash *inout_state);
+
+// MLDSA65_sign_message_representative generates a signature for the pre-hashed
+// message |msg_rep| using |private_key| (following the randomized algorithm),
+// and writes the encoded signature to |out_encoded_signature|. The |msg_rep|
+// should be obtained via calls to |MLDSA65_prehash_init|,
+// |MLDSA65_prehash_update| and |MLDSA65_prehash_finalize| using the public key
+// from the same key pair, otherwise the signature will not verify. Returns 1 on
+// success and 0 on failure.
+OPENSSL_EXPORT int MLDSA65_sign_message_representative(
+    uint8_t out_encoded_signature[MLDSA65_SIGNATURE_BYTES],
+    const struct MLDSA65_private_key *private_key,
+    const uint8_t msg_rep[MLDSA_MU_BYTES]);
+
 // MLDSA65_marshal_public_key serializes |public_key| to |out| in the standard
 // format for ML-DSA-65 public keys. It returns 1 on success or 0 on
 // allocation error.
@@ -146,6 +198,16 @@
   } opaque;
 };
 
+// MLDSA87_prehash contains a pre-hash context for ML-DSA-87. The contents of
+// this object should never leave the address space since the format is
+// unstable.
+struct MLDSA87_prehash {
+  union {
+    uint8_t bytes[200 + 4 + 4 + 4 * sizeof(size_t)];
+    uint64_t alignment;
+  } opaque;
+};
+
 // MLDSA87_PRIVATE_KEY_BYTES is the number of bytes in an encoded ML-DSA-87
 // private key.
 #define MLDSA87_PRIVATE_KEY_BYTES 4896
@@ -203,6 +265,45 @@
                                   size_t msg_len, const uint8_t *context,
                                   size_t context_len);
 
+// MLDSA87_prehash_init initializes a pre-hashing state using |public_key|. The
+// |context| argument can be used to include implicit contextual information
+// that isn't included in the message. The same value of |context| must be
+// presented to |MLDSA87_verify| in order for the generated signature to be
+// considered valid. |context| and |context_len| may be |NULL| and 0 to use an
+// empty context (this is common). Returns 1 on success and 0 on failure (if the
+// context is too long).
+OPENSSL_EXPORT int MLDSA87_prehash_init(
+    struct MLDSA87_prehash *out_state,
+    const struct MLDSA87_public_key *public_key, const uint8_t *context,
+    size_t context_len);
+
+// MLDSA87_prehash_update incorporates the given |msg| of length |msg_len| into
+// the pre-hashing state. This can be called multiple times on successive chunks
+// of the message. This should be called after |MLDSA87_prehash_init| and before
+// |MLDSA87_prehash_finalize|.
+OPENSSL_EXPORT void MLDSA87_prehash_update(struct MLDSA87_prehash *inout_state,
+                                           const uint8_t *msg, size_t msg_len);
+
+// MLDSA87_prehash_finalize extracts a pre-hashed message representative from
+// the given pre-hashing state. This should be called after
+// |MLDSA87_prehash_init| and |MLDSA87_prehash_update|. The resulting
+// |out_msg_rep| should then be passed to |MLDSA87_sign_message_representative|
+// to obtain a signature.
+OPENSSL_EXPORT void MLDSA87_prehash_finalize(
+    uint8_t out_msg_rep[MLDSA_MU_BYTES], struct MLDSA87_prehash *inout_state);
+
+// MLDSA87_sign_message_representative generates a signature for the pre-hashed
+// message |msg_rep| using |private_key| (following the randomized algorithm),
+// and writes the encoded signature to |out_encoded_signature|. The |msg_rep|
+// should be obtained via calls to |MLDSA87_prehash_init|,
+// |MLDSA87_prehash_update| and |MLDSA87_prehash_finalize| using the public key
+// from the same key pair, otherwise the signature will not verify. Returns 1 on
+// success and 0 on failure.
+OPENSSL_EXPORT int MLDSA87_sign_message_representative(
+    uint8_t out_encoded_signature[MLDSA87_SIGNATURE_BYTES],
+    const struct MLDSA87_private_key *private_key,
+    const uint8_t msg_rep[MLDSA_MU_BYTES]);
+
 // MLDSA87_marshal_public_key serializes |public_key| to |out| in the standard
 // format for ML-DSA-87 public keys. It returns 1 on success or 0 on
 // allocation error.