Rust wrappers for external mu variant of ML-DSA

Change-Id: I7ae00968199c8297fa43bf14c7642d7d28f68836
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/81847
Commit-Queue: Adam Langley <agl@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/rust/bssl-crypto/src/mldsa.rs b/rust/bssl-crypto/src/mldsa.rs
index 9ee28c5..655272f 100644
--- a/rust/bssl-crypto/src/mldsa.rs
+++ b/rust/bssl-crypto/src/mldsa.rs
@@ -36,7 +36,8 @@
 
 use crate::{
     as_cbs, cbb_to_vec, initialized_boxed_struct, initialized_boxed_struct_fallible,
-    with_output_vec, with_output_vec_fallible, FfiSlice, InvalidSignatureError,
+    initialized_struct, with_output_vec, with_output_vec_fallible, FfiMutSlice, FfiSlice,
+    InvalidSignatureError,
 };
 use alloc::{boxed::Box, vec::Vec};
 use core::mem::MaybeUninit;
@@ -56,6 +57,9 @@
 /// The number of bytes in an ML-DSA seed value.
 pub const SEED_BYTES: usize = bssl_sys::MLDSA_SEED_BYTES as usize;
 
+/// The number of bytes in an ML-DSA external mu.
+pub const MU_BYTES: usize = bssl_sys::MLDSA_MU_BYTES as usize;
+
 impl PrivateKey65 {
     /// Generates a random public/private key pair returning a serialized public
     /// key, a private key, and a private seed value that can be used to
@@ -164,6 +168,27 @@
             })
         }
     }
+
+    /// Sign pre-hashed data.
+    pub fn sign_prehashed(&self, prehash: Prehash65) -> Vec<u8> {
+        let representative = prehash.finalize();
+        unsafe {
+            // Safety: `signature` is the correct size via the type system and
+            // is always fully written; `representative` is an array of the
+            // correct size.
+            with_output_vec(SIGNATURE_BYTES_65, |signature| {
+                let ok = bssl_sys::MLDSA65_sign_message_representative(
+                    signature,
+                    &*self.0,
+                    representative.as_ffi_ptr(),
+                );
+                // This function can only fail if out of memory, which is not a
+                // case that this crate handles.
+                assert_eq!(ok, 1);
+                SIGNATURE_BYTES_65
+            })
+        }
+    }
 }
 
 impl PublicKey65 {
@@ -237,6 +262,61 @@
             }
         }
     }
+
+    /// Start a pre-hashing operation using this public key.
+    pub fn prehash(&self) -> Prehash65 {
+        unsafe {
+            // Safety: `self.0` is the correct size via the type system and
+            // is fully written if this function returns 1.
+            initialized_struct(|prehash: *mut Prehash65| {
+                let ok = bssl_sys::MLDSA65_prehash_init(
+                    &mut (*prehash).0,
+                    &*self.0,
+                    core::ptr::null(),
+                    0,
+                );
+                // This function can only fail if too much context is provided, but no context is
+                // used here.
+                assert_eq!(ok, 1);
+            })
+        }
+    }
+}
+
+/// An in-progress ML-DSA-65 pre-hashing operation.
+pub struct Prehash65(bssl_sys::MLDSA65_prehash);
+
+impl Prehash65 {
+    /// Add data to the pre-hashing operation.
+    pub fn update(&mut self, data: &[u8]) {
+        unsafe {
+            // Safety: `self.0` is the correct size via the type system and `data`
+            // is a valid Rust slice.
+            bssl_sys::MLDSA65_prehash_update(&mut self.0, data.as_ffi_ptr(), data.len());
+        }
+    }
+
+    /// Complete the pre-hashing operation.
+    fn finalize(mut self) -> [u8; MU_BYTES] {
+        let mut mu = [0u8; MU_BYTES];
+        unsafe {
+            // Safety: `self.0` is the correct size via the type system, as is `mu`.
+            bssl_sys::MLDSA65_prehash_finalize(mu.as_mut_ffi_ptr(), &mut self.0);
+        }
+        mu
+    }
+}
+
+#[cfg(feature = "std")]
+impl std::io::Write for Prehash65 {
+    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
+        self.update(buf);
+        Ok(buf.len())
+    }
+
+    fn flush(&mut self) -> std::io::Result<()> {
+        Ok(())
+    }
 }
 
 #[cfg(test)]
@@ -266,6 +346,43 @@
     }
 
     #[test]
+    fn prehashed() {
+        let (serialized_public_key, private_key, _private_seed) = PrivateKey65::generate();
+        let public_key = PublicKey65::parse(&serialized_public_key).unwrap();
+        let message = &[0u8, 1, 2, 3, 4, 5, 6];
+
+        let mut prehash = public_key.prehash();
+        prehash.update(&message[0..2]);
+        prehash.update(&message[2..4]);
+        prehash.update(&message[4..]);
+        let mut signature = private_key.sign_prehashed(prehash);
+
+        assert!(public_key.verify(message, &signature).is_ok());
+        signature[5] ^= 1;
+        assert!(public_key.verify(message, &signature).is_err());
+    }
+
+    #[cfg(feature = "std")]
+    #[test]
+    fn prehashed_write() {
+        use std::io::Write;
+        let (serialized_public_key, private_key, _private_seed) = PrivateKey65::generate();
+        let public_key = PublicKey65::parse(&serialized_public_key).unwrap();
+        let message = &[0u8, 1, 2, 3, 4, 5, 6];
+
+        let mut prehash = public_key.prehash();
+        prehash.write(&message[0..2]).unwrap();
+        prehash.write(&message[2..4]).unwrap();
+        prehash.write(&message[4..]).unwrap();
+        prehash.flush().unwrap();
+        let mut signature = private_key.sign_prehashed(prehash);
+
+        assert!(public_key.verify(message, &signature).is_ok());
+        signature[5] ^= 1;
+        assert!(public_key.verify(message, &signature).is_err());
+    }
+
+    #[test]
     fn marshal_public_key() {
         let (serialized_public_key, private_key, _) = PrivateKey65::generate();
         let public_key = PublicKey65::parse(&serialized_public_key).unwrap();