Reworking bssl_crypto: HKDF

Change-Id: I10052a0bf922ba6f68effdcebeca2c4da97345af
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/65170
Reviewed-by: Bob Beck <bbe@google.com>
Commit-Queue: Adam Langley <agl@google.com>
diff --git a/rust/bssl-crypto/src/hkdf.rs b/rust/bssl-crypto/src/hkdf.rs
index 83c1e65..973ed88 100644
--- a/rust/bssl-crypto/src/hkdf.rs
+++ b/rust/bssl-crypto/src/hkdf.rs
@@ -12,87 +12,225 @@
  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  */
-use crate::{
-    digest,
-    digest::{Sha256, Sha512},
-    sealed, CSlice, CSliceMut, ForeignTypeRef,
-};
-use alloc::vec::Vec;
+
+//! Implements the HMAC-based Key Derivation Function from
+//! <https://datatracker.ietf.org/doc/html/rfc5869>.
+//!
+//! One-shot operation:
+//!
+//! ```
+//! use bssl_crypto::{hkdf, hkdf::HkdfSha256};
+//!
+//! let key: [u8; 32] = HkdfSha256::derive(b"secret", hkdf::Salt::NonEmpty(b"salt"),
+//!                                        b"info");
+//! ```
+//!
+//! If deriving several keys that vary only in the `info` parameter, then part
+//! of the computation can be shared by calculating the "pseudo-random key".
+//! This is purely a performance optimisation.
+//!
+//! ```
+//! use bssl_crypto::{hkdf, hkdf::HkdfSha256};
+//!
+//! let prk = HkdfSha256::extract(b"secret", hkdf::Salt::NonEmpty(b"salt"));
+//! let key1 : [u8; 32] = prk.expand(b"info1");
+//! let key2 : [u8; 32] = prk.expand(b"info2");
+//!
+//! assert_eq!(key1, HkdfSha256::derive(b"secret", hkdf::Salt::NonEmpty(b"salt"),
+//!                                     b"info1"));
+//! assert_eq!(key2, HkdfSha256::derive(b"secret", hkdf::Salt::NonEmpty(b"salt"),
+//!                                     b"info2"));
+//! ```
+//!
+//! The above examples assume that the size of the outputs is known at compile
+//! time. (And only output lengths less than 256 bytes are supported.)
+//!
+//! ```compile_fail
+//! use bssl_crypto::{hkdf, hkdf::HkdfSha256};
+//!
+//! let key: [u8; 256] = HkdfSha256::derive(b"secret", hkdf::Salt::None, b"info");
+//! ```
+//!
+//! To use HKDF with longer, or run-time, lengths, use `derive_into` and
+//! `extract_into`:
+//!
+//! ```
+//! use bssl_crypto::{hkdf, hkdf::HkdfSha256};
+//!
+//! let mut out = [0u8; 50];
+//! HkdfSha256::derive_into(b"secret", hkdf::Salt::None, b"info", &mut out).expect(
+//!    "HKDF can't produce that much");
+//!
+//! assert_eq!(out, HkdfSha256::derive(b"secret", hkdf::Salt::None, b"info"));
+//! ```
+
+use crate::{digest, sealed, with_output_array, FfiMutSlice, FfiSlice, ForeignTypeRef};
 use core::marker::PhantomData;
 
 /// Implementation of HKDF-SHA-256
-pub type HkdfSha256 = Hkdf<Sha256>;
+pub type HkdfSha256 = Hkdf<digest::Sha256>;
 
 /// Implementation of HKDF-SHA-512
-pub type HkdfSha512 = Hkdf<Sha512>;
+pub type HkdfSha512 = Hkdf<digest::Sha512>;
 
-/// Error type returned from the HKDF-Expand operations when the output key material has
-/// an invalid length
+/// Error type returned when too much output is requested from an HKDF operation.
 #[derive(Debug)]
-pub struct InvalidLength;
+pub struct TooLong;
 
-/// Implementation of HKDF operations which are generic over a provided hashing functions. Type
-/// aliases are provided above for convenience of commonly used hashes
-pub struct Hkdf<MD: digest::Algorithm> {
-    salt: Option<Vec<u8>>,
-    ikm: Vec<u8>,
-    _marker: PhantomData<MD>,
+/// HKDF's optional salt values. See <https://datatracker.ietf.org/doc/html/rfc5869#section-3.1>
+pub enum Salt<'a> {
+    /// No salt.
+    None,
+    /// An explicit salt. Note that an empty value here is interpreted the same
+    /// as if passing `None`.
+    NonEmpty(&'a [u8]),
 }
 
-impl<MD: digest::Algorithm> Hkdf<MD> {
-    /// The max length of the output key material used for expanding
-    pub const MAX_OUTPUT_LENGTH: usize = MD::OUTPUT_LEN * 255;
-
-    /// Creates a new instance of HKDF from a salt and key material
-    pub fn new(salt: Option<&[u8]>, ikm: &[u8]) -> Self {
-        Self {
-            salt: salt.map(Vec::from),
-            ikm: Vec::from(ikm),
-            _marker: PhantomData,
+impl Salt<'_> {
+    fn as_ffi_ptr(&self) -> *const u8 {
+        match self {
+            Salt::None => core::ptr::null(),
+            Salt::NonEmpty(salt) => salt.as_ffi_ptr(),
         }
     }
 
-    /// Computes HKDF-Expand operation from RFC 5869. The info argument for the expand is set to
-    /// the concatenation of all the elements of info_components. Returns InvalidLength if the
-    /// output is too large.
-    pub fn expand_multi_info(
-        &self,
-        info_components: &[&[u8]],
-        okm: &mut [u8],
-    ) -> Result<(), InvalidLength> {
-        self.expand(&info_components.concat(), okm)
+    fn len(&self) -> usize {
+        match self {
+            Salt::None => 0,
+            Salt::NonEmpty(salt) => salt.len(),
+        }
+    }
+}
+
+/// HKDF for any of the implemented hash functions. The aliases [`HkdfSha256`]
+/// and [`HkdfSha512`] are provided for the most common cases.
+pub struct Hkdf<MD: digest::Algorithm>(PhantomData<MD>);
+
+impl<MD: digest::Algorithm> Hkdf<MD> {
+    /// The maximum number of bytes of key material that can be produced.
+    pub const MAX_OUTPUT_LEN: usize = MD::OUTPUT_LEN * 255;
+
+    /// Derive key material from the given secret, salt, and info. Attempting
+    /// to derive more than 255 bytes is a compile-time error, see `derive_into`
+    /// for longer outputs.
+    ///
+    /// The semantics of the arguments are complex. See
+    /// <https://datatracker.ietf.org/doc/html/rfc5869#section-3>.
+    pub fn derive<const N: usize>(secret: &[u8], salt: Salt, info: &[u8]) -> [u8; N] {
+        Self::extract(secret, salt).expand(info)
     }
 
-    /// Computes HKDF-Expand operation from RFC 5869. Returns InvalidLength if the output is too large.
-    pub fn expand(&self, info: &[u8], okm: &mut [u8]) -> Result<(), InvalidLength> {
-        // extract the salt bytes from the option, or empty slice if option is None
-        let salt = self.salt.as_deref().unwrap_or_default();
+    /// Derive key material from the given secret, salt, and info. Attempting
+    /// to derive more than `MAX_OUTPUT_LEN` bytes is a run-time error.
+    ///
+    /// The semantics of the arguments are complex. See
+    /// <https://datatracker.ietf.org/doc/html/rfc5869#section-3>.
+    pub fn derive_into(
+        secret: &[u8],
+        salt: Salt,
+        info: &[u8],
+        out: &mut [u8],
+    ) -> Result<(), TooLong> {
+        Self::extract(secret, salt).expand_into(info, out)
+    }
 
-        //validate the output size
-        (okm.len() <= Self::MAX_OUTPUT_LENGTH && !okm.is_empty())
-            .then(|| {
-                let mut okm_cslice = CSliceMut::from(okm);
+    /// Extract a pseudo-random key from the given secret and salt. This can
+    /// be used to avoid redoing computation when computing several keys that
+    /// vary only in the `info` parameter.
+    pub fn extract(secret: &[u8], salt: Salt) -> Prk {
+        let mut prk = [0u8; bssl_sys::EVP_MAX_MD_SIZE as usize];
+        let mut prk_len = 0usize;
+        let evp_md = MD::get_md(sealed::Sealed).as_ptr();
+        unsafe {
+            // Safety: `EVP_MAX_MD_SIZE` is the maximum output size of
+            // `HKDF_extract` so it'll never overrun the buffer.
+            bssl_sys::HKDF_extract(
+                prk.as_mut_ffi_ptr(),
+                &mut prk_len,
+                evp_md,
+                secret.as_ffi_ptr(),
+                secret.len(),
+                salt.as_ffi_ptr(),
+                salt.len(),
+            );
+        }
+        // This is documented to be always be true.
+        assert!(prk_len <= prk.len());
+        Prk {
+            prk,
+            len: prk_len,
+            evp_md,
+        }
+    }
+}
 
-                // Safety:
-                // - We validate the output length above, so invalid length errors will never be hit
-                // which leaves allocation failures as the only possible error case, in which case
-                // we panic immediately
-                let result = unsafe {
-                    bssl_sys::HKDF(
-                        okm_cslice.as_mut_ptr(),
-                        okm_cslice.len(),
-                        MD::get_md(sealed::Sealed).as_ptr(),
-                        CSlice::from(self.ikm.as_slice()).as_ptr(),
-                        self.ikm.as_slice().len(),
-                        CSlice::from(salt).as_ptr(),
-                        salt.len(),
-                        CSlice::from(info).as_ptr(),
-                        info.len(),
-                    )
-                };
-                assert!(result > 0, "Allocation failure in bssl_sys::HKDF");
+/// A pseudo-random key, an intermediate value in the HKDF computation.
+pub struct Prk {
+    prk: [u8; bssl_sys::EVP_MAX_MD_SIZE as usize],
+    len: usize,
+    evp_md: *const bssl_sys::EVP_MD,
+}
+
+#[allow(clippy::let_unit_value)]
+impl Prk {
+    /// Derive key material for the given info parameter. Attempting
+    /// to derive more than 255 bytes is a compile-time error, see `expand_into`
+    /// for longer outputs.
+    pub fn expand<const N: usize>(&self, info: &[u8]) -> [u8; N] {
+        // This is the odd way to write a static assertion that uses a const
+        // parameter in Rust. Even then, Rust cannot reference `MAX_OUTPUT_LEN`.
+        // But if we safely assume that all hash functions output at least a
+        // byte then 255 is a safe lower bound on `MAX_OUTPUT_LEN`.
+        // A doctest at the top of the module checks that this assert is effective.
+        struct StaticAssert<const N: usize, const BOUND: usize>;
+        impl<const N: usize, const BOUND: usize> StaticAssert<N, BOUND> {
+            const BOUNDS_CHECK: () = assert!(N < BOUND, "Large outputs not supported");
+        }
+        let _ = StaticAssert::<N, 256>::BOUNDS_CHECK;
+
+        unsafe {
+            with_output_array(|out, out_len| {
+                // Safety: `HKDF_expand` writes exactly `out_len` bytes or else
+                // returns zero. `evp_md` is valid by construction.
+                let result = bssl_sys::HKDF_expand(
+                    out,
+                    out_len,
+                    self.evp_md,
+                    self.prk.as_ffi_ptr(),
+                    self.len,
+                    info.as_ffi_ptr(),
+                    info.len(),
+                );
+                // The output length is known to be within bounds so the only other
+                // possibily is an allocation failure, which we don't attempt to
+                // handle.
+                assert_eq!(result, 1);
             })
-            .ok_or(InvalidLength)
+        }
+    }
+
+    /// Derive key material from the given info parameter. Attempting
+    /// to derive more than the HKDF's `MAX_OUTPUT_LEN` bytes is a run-time
+    /// error.
+    pub fn expand_into(&self, info: &[u8], out: &mut [u8]) -> Result<(), TooLong> {
+        // Safety: writes at most `out.len()` bytes into `out`.
+        // `evp_md` is valid by construction.
+        let result = unsafe {
+            bssl_sys::HKDF_expand(
+                out.as_mut_ffi_ptr(),
+                out.len(),
+                self.evp_md,
+                self.prk.as_ffi_ptr(),
+                self.len,
+                info.as_ffi_ptr(),
+                info.len(),
+            )
+        };
+        if result == 1 {
+            Ok(())
+        } else {
+            Err(TooLong)
+        }
     }
 }
 
@@ -105,29 +243,17 @@
 )]
 mod tests {
     use crate::{
-        hkdf::{HkdfSha256, HkdfSha512},
+        hkdf::{HkdfSha256, HkdfSha512, Salt},
         test_helpers::{decode_hex, decode_hex_into_vec},
     };
-    use core::iter;
-
-    struct Test {
-        ikm: Vec<u8>,
-        salt: Vec<u8>,
-        info: Vec<u8>,
-        okm: Vec<u8>,
-    }
 
     #[test]
-    fn hkdf_sha_256_test() {
+    fn sha256() {
         let ikm = decode_hex_into_vec("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b");
-        let salt = decode_hex_into_vec("000102030405060708090a0b0c");
+        let salt_vec = decode_hex_into_vec("000102030405060708090a0b0c");
+        let salt = Salt::NonEmpty(&salt_vec);
         let info = decode_hex_into_vec("f0f1f2f3f4f5f6f7f8f9");
-
-        let hk = HkdfSha256::new(Some(salt.as_slice()), ikm.as_slice());
-        let mut okm = [0u8; 42];
-        hk.expand(&info, &mut okm)
-            .expect("42 is a valid length for Sha256 to output");
-
+        let okm: [u8; 42] = HkdfSha256::derive(ikm.as_slice(), salt, info.as_slice());
         let expected = decode_hex(
             "3cb25f25faacd57a90434f64d0362f2a2d2d0a90cf1a5a4c5db02d56ecc4c5bf34007208d5b887185865",
         );
@@ -135,15 +261,12 @@
     }
 
     #[test]
-    fn hkdf_sha512_test() {
+    fn sha512() {
         let ikm = decode_hex_into_vec("5d3db20e8238a90b62a600fa57fdb318");
-        let salt = decode_hex_into_vec("1d6f3b38a1e607b5e6bcd4af1800a9d3");
+        let salt_vec = decode_hex_into_vec("1d6f3b38a1e607b5e6bcd4af1800a9d3");
+        let salt = Salt::NonEmpty(&salt_vec);
         let info = decode_hex_into_vec("2bc5f39032b6fc87da69ba8711ce735b169646fd");
-
-        let hk = HkdfSha512::new(Some(salt.as_slice()), ikm.as_slice());
-        let mut okm = [0u8; 42];
-        hk.expand(&info, &mut okm).expect("Should succeed");
-
+        let okm: [u8; 42] = HkdfSha512::derive(ikm.as_slice(), salt, info.as_slice());
         let expected = decode_hex(
             "8c3cf7122dcb5eb7efaf02718f1faf70bca20dcb75070e9d0871a413a6c05fc195a75aa9ffc349d70aae",
         );
@@ -152,7 +275,13 @@
 
     // Test Vectors from https://tools.ietf.org/html/rfc5869.
     #[test]
-    fn test_rfc5869_sha256() {
+    fn rfc5869_sha256() {
+        struct Test {
+            ikm: Vec<u8>,
+            salt: Vec<u8>,
+            info: Vec<u8>,
+            okm: Vec<u8>,
+        }
         let tests = [
             Test {
                 // Test Case 1
@@ -202,6 +331,7 @@
                     "8da4e775a563c18f715f802a063c5a31b8a11f5c5ee1879ec3454e5f3c738d2d9d201395faa4b61a96c8"),
             },
         ];
+
         for Test {
             ikm,
             salt,
@@ -210,90 +340,25 @@
         } in tests.iter()
         {
             let salt = if salt.is_empty() {
-                None
+                Salt::None
             } else {
-                Some(salt.as_slice())
+                Salt::NonEmpty(&salt)
             };
-            let hkdf = HkdfSha256::new(salt, ikm.as_slice());
             let mut okm2 = vec![0u8; okm.len()];
-            assert!(hkdf.expand(info.as_slice(), &mut okm2).is_ok());
+            assert!(
+                HkdfSha256::derive_into(ikm.as_slice(), salt, info.as_slice(), &mut okm2).is_ok()
+            );
             assert_eq!(okm2.as_slice(), okm.as_slice());
         }
     }
 
     #[test]
-    fn test_lengths() {
-        let hkdf = HkdfSha256::new(None, &[]);
-        let mut longest = vec![0u8; HkdfSha256::MAX_OUTPUT_LENGTH];
-        assert!(hkdf.expand(&[], &mut longest).is_ok());
-        // start at 1 since 0 is an invalid length
-        let lengths = 1..HkdfSha256::MAX_OUTPUT_LENGTH + 1;
+    fn max_output() {
+        let hkdf = HkdfSha256::extract(b"", Salt::None);
+        let mut longest = vec![0u8; HkdfSha256::MAX_OUTPUT_LEN];
+        assert!(hkdf.expand_into(b"", &mut longest).is_ok());
 
-        for length in lengths {
-            let mut okm = vec![0u8; length];
-
-            assert!(hkdf.expand(&[], &mut okm).is_ok());
-            assert_eq!(okm.len(), length);
-            assert_eq!(okm[..], longest[..length]);
-        }
-    }
-
-    #[test]
-    fn test_max_length() {
-        let hkdf = HkdfSha256::new(Some(&[]), &[]);
-        let mut okm = vec![0u8; HkdfSha256::MAX_OUTPUT_LENGTH];
-        assert!(hkdf.expand(&[], &mut okm).is_ok());
-    }
-
-    #[test]
-    fn test_max_length_exceeded() {
-        let hkdf = HkdfSha256::new(Some(&[]), &[]);
-        let mut okm = vec![0u8; HkdfSha256::MAX_OUTPUT_LENGTH + 1];
-        assert!(hkdf.expand(&[], &mut okm).is_err());
-    }
-
-    #[test]
-    fn test_unsupported_length() {
-        let hkdf = HkdfSha256::new(Some(&[]), &[]);
-        let mut okm = vec![0u8; 90000];
-        assert!(hkdf.expand(&[], &mut okm).is_err());
-    }
-
-    #[test]
-    fn test_expand_multi_info() {
-        let info_components = &[
-            &b"09090909090909090909090909090909090909090909"[..],
-            &b"8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a"[..],
-            &b"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0"[..],
-            &b"4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4"[..],
-            &b"1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d"[..],
-        ];
-
-        let hkdf = HkdfSha256::new(None, b"some ikm here");
-
-        // Compute HKDF-Expand on the concatenation of all the info components
-        let mut oneshot_res = [0u8; 16];
-        hkdf.expand(&info_components.concat(), &mut oneshot_res)
-            .unwrap();
-
-        // Now iteratively join the components of info_components until it's all 1 component. The value
-        // of HKDF-Expand should be the same throughout
-        let mut num_concatted = 0;
-        let mut info_head = Vec::new();
-
-        while num_concatted < info_components.len() {
-            info_head.extend(info_components[num_concatted]);
-
-            // Build the new input to be the info head followed by the remaining components
-            let input: Vec<&[u8]> = iter::once(info_head.as_slice())
-                .chain(info_components.iter().cloned().skip(num_concatted + 1))
-                .collect();
-
-            // Compute and compare to the one-shot answer
-            let mut multipart_res = [0u8; 16];
-            hkdf.expand_multi_info(&input, &mut multipart_res).unwrap();
-            assert_eq!(multipart_res, oneshot_res);
-            num_concatted += 1;
-        }
+        let mut too_long = vec![0u8; HkdfSha256::MAX_OUTPUT_LEN + 1];
+        assert!(hkdf.expand_into(b"", &mut too_long).is_err());
     }
 }
diff --git a/rust/bssl-crypto/src/lib.rs b/rust/bssl-crypto/src/lib.rs
index dea522f..c16a68b 100644
--- a/rust/bssl-crypto/src/lib.rs
+++ b/rust/bssl-crypto/src/lib.rs
@@ -47,7 +47,6 @@
 /// Ed25519, a signature scheme.
 pub mod ed25519;
 
-/// HKDF, a hash-based key derivation function.
 pub mod hkdf;
 
 /// HMAC, a hash-based message authentication code.