Add an external mu variant of ML-DSA verification

Support verification of ML-DSA signatures without needing the signed
data to be in a contiguous buffer, as is already supported for signing.

Change-Id: Ic0ab981a189d652eb3ef6d9a07ec9ad03ea6ceae
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/82947
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bcm_interface.h b/crypto/fipsmodule/bcm_interface.h
index eb106f5..67c8f33 100644
--- a/crypto/fipsmodule/bcm_interface.h
+++ b/crypto/fipsmodule/bcm_interface.h
@@ -291,6 +291,11 @@
     const MLDSA65_private_key *private_key,
     const uint8_t msg_rep[MLDSA_MU_BYTES]);
 
+OPENSSL_EXPORT bcm_status BCM_mldsa65_verify_message_representative(
+    const MLDSA65_public_key *public_key,
+    const uint8_t signature[MLDSA65_SIGNATURE_BYTES],
+    const uint8_t msg_rep[MLDSA_MU_BYTES]);
+
 OPENSSL_EXPORT bcm_status
 BCM_mldsa65_marshal_public_key(CBB *out, const MLDSA65_public_key *public_key);
 
@@ -398,6 +403,11 @@
     const MLDSA87_private_key *private_key,
     const uint8_t msg_rep[MLDSA_MU_BYTES]);
 
+OPENSSL_EXPORT bcm_status BCM_mldsa87_verify_message_representative(
+    const MLDSA87_public_key *public_key,
+    const uint8_t signature[MLDSA87_SIGNATURE_BYTES],
+    const uint8_t msg_rep[MLDSA_MU_BYTES]);
+
 OPENSSL_EXPORT bcm_status
 BCM_mldsa87_marshal_public_key(CBB *out, const MLDSA87_public_key *public_key);
 
@@ -504,6 +514,11 @@
     const MLDSA44_private_key *private_key,
     const uint8_t msg_rep[MLDSA_MU_BYTES]);
 
+OPENSSL_EXPORT bcm_status BCM_mldsa44_verify_message_representative(
+    const MLDSA44_public_key *public_key,
+    const uint8_t signature[MLDSA44_SIGNATURE_BYTES],
+    const uint8_t msg_rep[MLDSA_MU_BYTES]);
+
 OPENSSL_EXPORT bcm_status
 BCM_mldsa44_marshal_public_key(CBB *out, const MLDSA44_public_key *public_key);
 
diff --git a/crypto/fipsmodule/mldsa/mldsa.cc.inc b/crypto/fipsmodule/mldsa/mldsa.cc.inc
index 2eb3247..abd9c97 100644
--- a/crypto/fipsmodule/mldsa/mldsa.cc.inc
+++ b/crypto/fipsmodule/mldsa/mldsa.cc.inc
@@ -1911,13 +1911,13 @@
                            kMuBytes);
 }
 
-// FIPS 204, Algorithm 8 (`ML-DSA.Verify_internal`).
+// FIPS 204, Algorithm 8 (`ML-DSA.Verify_internal`), using a pre-computed mu.
+// Returns 1 on success and 0 on failure.
 template <int K, int L>
-int mldsa_verify_internal_no_self_test(
+int mldsa_verify_mu_no_self_test(
     const public_key<K> *pub,
-    const uint8_t encoded_signature[signature_bytes<K>()], const uint8_t *msg,
-    size_t msg_len, const uint8_t *context_prefix, size_t context_prefix_len,
-    const uint8_t *context, size_t context_len) {
+    const uint8_t encoded_signature[signature_bytes<K>()],
+    const uint8_t mu[kMuBytes]) {
   // Intermediate values, allocated on the heap to allow use when there is a
   // limited amount of stack.
   struct Values {
@@ -1941,16 +1941,6 @@
 
   matrix_expand(&values->a_ntt, pub->rho);
 
-  uint8_t mu[kMuBytes];
-  BORINGSSL_keccak_st keccak_ctx;
-  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
-  BORINGSSL_keccak_absorb(&keccak_ctx, pub->public_key_hash,
-                          sizeof(pub->public_key_hash));
-  BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
-  BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
-  BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
-  BORINGSSL_keccak_squeeze(&keccak_ctx, mu, kMuBytes);
-
   scalar c_ntt;
   scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde,
                                 sizeof(values->sign.c_tilde), tau<K>());
@@ -1975,6 +1965,7 @@
   w1_encode(w1_encoded, w1);
 
   uint8_t c_tilde[2 * lambda_bytes<K>()];
+  BORINGSSL_keccak_st keccak_ctx;
   BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
   BORINGSSL_keccak_absorb(&keccak_ctx, mu, kMuBytes);
   BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, w1_bytes<K>());
@@ -1985,6 +1976,39 @@
          OPENSSL_memcmp(c_tilde, values->sign.c_tilde, 2 * lambda_bytes<K>()) ==
              0;
 }
+
+// FIPS 204, Algorithm 8 (`ML-DSA.Verify_internal`), using a pre-computed mu.
+// Returns 1 on success and 0 on failure.
+template <int K, int L>
+int mldsa_verify_mu(
+    const public_key<K> *pub,
+    const uint8_t encoded_signature[signature_bytes<K>()],
+    const uint8_t mu[kMuBytes]) {
+  fips::ensure_verify_self_test();
+  return mldsa_verify_mu_no_self_test<K, L>(pub, encoded_signature, mu);
+}
+
+// FIPS 204, Algorithm 8 (`ML-DSA.Verify_internal`).
+template <int K, int L>
+int mldsa_verify_internal_no_self_test(
+    const public_key<K> *pub,
+    const uint8_t encoded_signature[signature_bytes<K>()], const uint8_t *msg,
+    size_t msg_len, const uint8_t *context_prefix, size_t context_prefix_len,
+    const uint8_t *context, size_t context_len) {
+  uint8_t mu[kMuBytes];
+  BORINGSSL_keccak_st keccak_ctx;
+  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+  BORINGSSL_keccak_absorb(&keccak_ctx, pub->public_key_hash,
+                          sizeof(pub->public_key_hash));
+  BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
+  BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
+  BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
+  BORINGSSL_keccak_squeeze(&keccak_ctx, mu, kMuBytes);
+
+  return mldsa_verify_mu_no_self_test<K, L>(pub, encoded_signature, mu);
+}
+
+// FIPS 204, Algorithm 8 (`ML-DSA.Verify_internal`).
 template <int K, int L>
 int mldsa_verify_internal(const public_key<K> *pub,
                           const uint8_t encoded_signature[signature_bytes<K>()],
@@ -2435,6 +2459,15 @@
       msg_rep, randomizer));
 }
 
+// ML-DSA pre-hashed API: verifying a message representative.
+bcm_status BCM_mldsa65_verify_message_representative(
+    const MLDSA65_public_key *public_key,
+    const uint8_t signature[MLDSA65_SIGNATURE_BYTES],
+    const uint8_t msg_rep[MLDSA_MU_BYTES]) {
+  return bcm_as_approved_status(mldsa::mldsa_verify_mu<6, 5>(
+      mldsa::public_key_from_external_65(public_key), signature, msg_rep));
+}
+
 // FIPS 204, Algorithm 3 (`ML-DSA.Verify`).
 bcm_status BCM_mldsa65_verify(const MLDSA65_public_key *public_key,
                               const uint8_t signature[MLDSA65_SIGNATURE_BYTES],
@@ -2653,6 +2686,15 @@
       msg_rep, randomizer));
 }
 
+// ML-DSA pre-hashed API: verifying a message representative.
+bcm_status BCM_mldsa87_verify_message_representative(
+    const MLDSA87_public_key *public_key,
+    const uint8_t signature[MLDSA87_SIGNATURE_BYTES],
+    const uint8_t msg_rep[MLDSA_MU_BYTES]) {
+  return bcm_as_approved_status(mldsa::mldsa_verify_mu<8, 7>(
+      mldsa::public_key_from_external_87(public_key), signature, msg_rep));
+}
+
 // FIPS 204, Algorithm 3 (`ML-DSA.Verify`).
 bcm_status BCM_mldsa87_verify(const MLDSA87_public_key *public_key,
                               const uint8_t *signature, const uint8_t *msg,
@@ -2871,6 +2913,15 @@
       msg_rep, randomizer));
 }
 
+// ML-DSA pre-hashed API: verifying a message representative.
+bcm_status BCM_mldsa44_verify_message_representative(
+    const MLDSA44_public_key *public_key,
+    const uint8_t signature[MLDSA44_SIGNATURE_BYTES],
+    const uint8_t msg_rep[MLDSA_MU_BYTES]) {
+  return bcm_as_approved_status(mldsa::mldsa_verify_mu<4, 4>(
+      mldsa::public_key_from_external_44(public_key), signature, msg_rep));
+}
+
 // FIPS 204, Algorithm 3 (`ML-DSA.Verify`).
 bcm_status BCM_mldsa44_verify(const MLDSA44_public_key *public_key,
                               const uint8_t *signature, const uint8_t *msg,
diff --git a/crypto/mldsa/mldsa.cc b/crypto/mldsa/mldsa.cc
index a1ad7ed..7b7c605 100644
--- a/crypto/mldsa/mldsa.cc
+++ b/crypto/mldsa/mldsa.cc
@@ -88,6 +88,17 @@
       out_encoded_signature, private_key, msg_rep));
 }
 
+int MLDSA65_verify_message_representative(
+    const struct MLDSA65_public_key *public_key,
+    const uint8_t *signature, size_t signature_len,
+    const uint8_t msg_rep[MLDSA_MU_BYTES]) {
+  if (signature_len != MLDSA65_SIGNATURE_BYTES) {
+    return 0;
+  }
+  return bcm_success(BCM_mldsa65_verify_message_representative(
+      public_key, signature, msg_rep));
+}
+
 int MLDSA65_marshal_public_key(CBB *out,
                                const struct MLDSA65_public_key *public_key) {
   return bcm_success(BCM_mldsa65_marshal_public_key(out, public_key));
@@ -169,6 +180,17 @@
       out_encoded_signature, private_key, msg_rep));
 }
 
+int MLDSA87_verify_message_representative(
+    const struct MLDSA87_public_key *public_key,
+    const uint8_t *signature, size_t signature_len,
+    const uint8_t msg_rep[MLDSA_MU_BYTES]) {
+  if (signature_len != MLDSA87_SIGNATURE_BYTES) {
+    return 0;
+  }
+  return bcm_success(BCM_mldsa87_verify_message_representative(
+      public_key, signature, msg_rep));
+}
+
 int MLDSA87_marshal_public_key(CBB *out,
                                const struct MLDSA87_public_key *public_key) {
   return bcm_success(BCM_mldsa87_marshal_public_key(out, public_key));
@@ -250,6 +272,17 @@
       out_encoded_signature, private_key, msg_rep));
 }
 
+int MLDSA44_verify_message_representative(
+    const struct MLDSA44_public_key *public_key,
+    const uint8_t *signature, size_t signature_len,
+    const uint8_t msg_rep[MLDSA_MU_BYTES]) {
+  if (signature_len != MLDSA44_SIGNATURE_BYTES) {
+    return 0;
+  }
+  return bcm_success(BCM_mldsa44_verify_message_representative(
+      public_key, signature, msg_rep));
+}
+
 int MLDSA44_marshal_public_key(CBB *out,
                                const struct MLDSA44_public_key *public_key) {
   return bcm_success(BCM_mldsa44_marshal_public_key(out, public_key));
diff --git a/crypto/mldsa/mldsa_test.cc b/crypto/mldsa/mldsa_test.cc
index 65e307a..d2af772 100644
--- a/crypto/mldsa/mldsa_test.cc
+++ b/crypto/mldsa/mldsa_test.cc
@@ -235,6 +235,53 @@
   }
 }
 
+TEST(MLDSATest, SignatureVerifiesFromPrehash) {
+  std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
+  auto priv = std::make_unique<MLDSA65_private_key>();
+  uint8_t seed[MLDSA_SEED_BYTES];
+  EXPECT_TRUE(
+      MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));
+
+  auto pub = std::make_unique<MLDSA65_public_key>();
+  CBS cbs = CBS(encoded_public_key);
+  ASSERT_TRUE(MLDSA65_parse_public_key(pub.get(), &cbs));
+
+  std::vector<uint8_t> encoded_signature(MLDSA65_SIGNATURE_BYTES);
+  static const uint8_t kMessage[] = {'H', 'e', 'l', 'l', 'o', ' ',
+                                     'w', 'o', 'r', 'l', 'd'};
+
+  EXPECT_TRUE(MLDSA65_sign(encoded_signature.data(), priv.get(), kMessage,
+                           sizeof(kMessage), nullptr, 0));
+
+  MLDSA65_prehash prehash_state;
+  EXPECT_TRUE(MLDSA65_prehash_init(&prehash_state, pub.get(), nullptr, 0));
+  MLDSA65_prehash_update(&prehash_state, kMessage, sizeof(kMessage));
+  uint8_t representative[MLDSA_MU_BYTES];
+  MLDSA65_prehash_finalize(representative, &prehash_state);
+  EXPECT_EQ(MLDSA65_verify_message_representative(pub.get(),
+                                                  encoded_signature.data(),
+                                                  encoded_signature.size(),
+                                                  representative),
+            1);
+
+  // Updating in multiple chunks also works.
+  for (size_t i = 0; i <= sizeof(kMessage); ++i) {
+    for (size_t j = i; j <= sizeof(kMessage); ++j) {
+      EXPECT_TRUE(MLDSA65_prehash_init(&prehash_state, pub.get(), nullptr, 0));
+      MLDSA65_prehash_update(&prehash_state, kMessage, i);
+      MLDSA65_prehash_update(&prehash_state, kMessage + i, j - i);
+      MLDSA65_prehash_update(&prehash_state, kMessage + j,
+                             sizeof(kMessage) - j);
+      MLDSA65_prehash_finalize(representative, &prehash_state);
+      EXPECT_EQ(MLDSA65_verify_message_representative(pub.get(),
+                                                      encoded_signature.data(),
+                                                      encoded_signature.size(),
+                                                      representative),
+                1);
+    }
+  }
+}
+
 TEST(MLDSATest, PublicFromPrivateIsConsistent) {
   std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
   auto priv = std::make_unique<MLDSA65_private_key>();
diff --git a/include/openssl/mldsa.h b/include/openssl/mldsa.h
index d15ad3c..3d6b04c 100644
--- a/include/openssl/mldsa.h
+++ b/include/openssl/mldsa.h
@@ -157,6 +157,17 @@
     const struct MLDSA65_private_key *private_key,
     const uint8_t msg_rep[MLDSA_MU_BYTES]);
 
+// MLDSA65_verify_message_representative verifies that |signature| constitutes a
+// valid signature for the pre-hashed message |msg_rep| using |public_key|. The
+// |msg_rep| should be obtained via calls to |MLDSA65_prehash_init|,
+// |MLDSA65_prehash_update| and |MLDSA65_prehash_finalize| using |public key|
+// and the same context as when the signature was generated. Returns 1 on
+// success or 0 on error.
+OPENSSL_EXPORT int MLDSA65_verify_message_representative(
+    const struct MLDSA65_public_key *public_key,
+    const uint8_t *signature, size_t signature_len,
+    const uint8_t msg_rep[MLDSA_MU_BYTES]);
+
 // MLDSA65_marshal_public_key serializes |public_key| to |out| in the standard
 // format for ML-DSA-65 public keys. It returns 1 on success or 0 on
 // allocation error.
@@ -296,6 +307,17 @@
     const struct MLDSA87_private_key *private_key,
     const uint8_t msg_rep[MLDSA_MU_BYTES]);
 
+// MLDSA87_verify_message_representative verifies that |signature| constitutes a
+// valid signature for the pre-hashed message |msg_rep| using |public_key|. The
+// |msg_rep| should be obtained via calls to |MLDSA87_prehash_init|,
+// |MLDSA87_prehash_update| and |MLDSA87_prehash_finalize| using |public key|
+// and the same context as when the signature was generated. Returns 1 on
+// success or 0 on error.
+OPENSSL_EXPORT int MLDSA87_verify_message_representative(
+    const struct MLDSA87_public_key *public_key,
+    const uint8_t *signature, size_t signature_len,
+    const uint8_t msg_rep[MLDSA_MU_BYTES]);
+
 // MLDSA87_marshal_public_key serializes |public_key| to |out| in the standard
 // format for ML-DSA-87 public keys. It returns 1 on success or 0 on
 // allocation error.
@@ -432,6 +454,17 @@
     const struct MLDSA44_private_key *private_key,
     const uint8_t msg_rep[MLDSA_MU_BYTES]);
 
+// MLDSA44_verify_message_representative verifies that |signature| constitutes a
+// valid signature for the pre-hashed message |msg_rep| using |public_key|. The
+// |msg_rep| should be obtained via calls to |MLDSA44_prehash_init|,
+// |MLDSA44_prehash_update| and |MLDSA44_prehash_finalize| using |public key|
+// and the same context as when the signature was generated. Returns 1 on
+// success or 0 on error.
+OPENSSL_EXPORT int MLDSA44_verify_message_representative(
+    const struct MLDSA44_public_key *public_key,
+    const uint8_t *signature, size_t signature_len,
+    const uint8_t msg_rep[MLDSA_MU_BYTES]);
+
 // MLDSA44_marshal_public_key serializes |public_key| to |out| in the standard
 // format for ML-DSA-44 public keys. It returns 1 on success or 0 on
 // allocation error.
diff --git a/rust/bssl-crypto/src/mldsa.rs b/rust/bssl-crypto/src/mldsa.rs
index 655272f..309c3da 100644
--- a/rust/bssl-crypto/src/mldsa.rs
+++ b/rust/bssl-crypto/src/mldsa.rs
@@ -263,6 +263,28 @@
         }
     }
 
+    /// Verify pre-hashed data.
+    pub fn verify_prehashed(
+        &self,
+        prehash: Prehash65,
+        signature: &[u8],
+    ) -> Result<(), InvalidSignatureError> {
+        let representative = prehash.finalize();
+        unsafe {
+            let ok = bssl_sys::MLDSA65_verify_message_representative(
+                &*self.0,
+                signature.as_ffi_ptr(),
+                signature.len(),
+                representative.as_ffi_ptr(),
+            );
+            if ok == 1 {
+                Ok(())
+            } else {
+                Err(InvalidSignatureError)
+            }
+        }
+    }
+
     /// Start a pre-hashing operation using this public key.
     pub fn prehash(&self) -> Prehash65 {
         unsafe {
@@ -346,7 +368,7 @@
     }
 
     #[test]
-    fn prehashed() {
+    fn sign_prehashed() {
         let (serialized_public_key, private_key, _private_seed) = PrivateKey65::generate();
         let public_key = PublicKey65::parse(&serialized_public_key).unwrap();
         let message = &[0u8, 1, 2, 3, 4, 5, 6];
@@ -364,7 +386,7 @@
 
     #[cfg(feature = "std")]
     #[test]
-    fn prehashed_write() {
+    fn sign_prehashed_write() {
         use std::io::Write;
         let (serialized_public_key, private_key, _private_seed) = PrivateKey65::generate();
         let public_key = PublicKey65::parse(&serialized_public_key).unwrap();
@@ -383,6 +405,50 @@
     }
 
     #[test]
+    fn verify_prehashed() {
+        let (serialized_public_key, private_key, _private_seed) = PrivateKey65::generate();
+        let public_key = PublicKey65::parse(&serialized_public_key).unwrap();
+        let message = &[0u8, 1, 2, 3, 4, 5, 6];
+
+        let signature = private_key.sign(message);
+
+        let mut prehash = public_key.prehash();
+        prehash.update(&message[0..2]);
+        prehash.update(&message[2..4]);
+        assert!(public_key.verify_prehashed(prehash, &signature).is_err());
+
+        let mut prehash = public_key.prehash();
+        prehash.update(&message[0..2]);
+        prehash.update(&message[2..4]);
+        prehash.update(&message[4..]);
+        assert!(public_key.verify_prehashed(prehash, &signature).is_ok());
+    }
+
+    #[cfg(feature = "std")]
+    #[test]
+    fn verify_prehashed_write() {
+        use std::io::Write;
+        let (serialized_public_key, private_key, _private_seed) = PrivateKey65::generate();
+        let public_key = PublicKey65::parse(&serialized_public_key).unwrap();
+        let message = &[0u8, 1, 2, 3, 4, 5, 6];
+
+        let signature = private_key.sign(message);
+
+        let mut prehash = public_key.prehash();
+        prehash.write(&message[0..2]).unwrap();
+        prehash.write(&message[2..4]).unwrap();
+        prehash.flush().unwrap();
+        assert!(public_key.verify_prehashed(prehash, &signature).is_err());
+
+        let mut prehash = public_key.prehash();
+        prehash.write(&message[0..2]).unwrap();
+        prehash.write(&message[2..4]).unwrap();
+        prehash.write(&message[4..]).unwrap();
+        prehash.flush().unwrap();
+        assert!(public_key.verify_prehashed(prehash, &signature).is_ok());
+    }
+
+    #[test]
     fn marshal_public_key() {
         let (serialized_public_key, private_key, _) = PrivateKey65::generate();
         let public_key = PublicKey65::parse(&serialized_public_key).unwrap();