Reworking bssl_crypto: AES

Change-Id: I4dc295906da0f0c7132a944176774c3472752c51
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/65173
Reviewed-by: Bob Beck <bbe@google.com>
diff --git a/rust/bssl-crypto/src/aes.rs b/rust/bssl-crypto/src/aes.rs
index 0900420..2527e8d 100644
--- a/rust/bssl-crypto/src/aes.rs
+++ b/rust/bssl-crypto/src/aes.rs
@@ -13,214 +13,165 @@
  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  */
 
-/// Block size in bytes for AES.
+//! Advanced Encryption Standard.
+//!
+//! AES is a 128-bit block cipher that supports key sizes of 128, 192, or 256
+//! bits. (Although 192 isn't supported here.)
+//!
+//! Each key defines a permutation of the set of 128-bit blocks and AES can
+//! perform the forward and reverse permutation. (These directions are
+//! arbitrarily labeled "encryption" and "decryption".)
+//!
+//! AES requires relatively expensive preprocessing of keys and thus the
+//! processed form of the key is represented here using [`EncryptKey`] and
+//! [`DecryptKey`].
+//!
+//! ```
+//! use bssl_crypto::aes;
+//!
+//! let key_bytes = [0u8; 32];
+//! let enc_key = aes::EncryptKey::new_256(&key_bytes);
+//! let block = [0u8; aes::BLOCK_SIZE];
+//! let mut transformed_block = enc_key.encrypt(&block);
+//!
+//! let dec_key = aes::DecryptKey::new_256(&key_bytes);
+//! dec_key.decrypt_in_place(&mut transformed_block);
+//! assert_eq!(block, transformed_block);
+//! ```
+//!
+//! AES is a low-level primitive and must be used in a more complex construction
+//! in nearly every case. See the `aead` crate for usable encryption and
+//! decryption primitives.
+
+use crate::{initialized_struct_fallible, FfiMutSlice, FfiSlice};
+use core::ffi::c_uint;
+
+/// AES block size in bytes.
 pub const BLOCK_SIZE: usize = bssl_sys::AES_BLOCK_SIZE as usize;
 
 /// A single AES block.
-pub type AesBlock = [u8; BLOCK_SIZE];
-
-/// AES implementation used for encrypting/decrypting a single `AesBlock` at a time.
-pub struct Aes;
-
-impl Aes {
-    /// Encrypts `block` in place.
-    pub fn encrypt(key: &AesEncryptKey, block: &mut AesBlock) {
-        let input = *block;
-        // Safety:
-        // - AesBlock is always a valid size and key is guaranteed to already be initialized.
-        unsafe { bssl_sys::AES_encrypt(input.as_ptr(), block.as_mut_ptr(), &key.0) }
-    }
-
-    /// Decrypts `block` in place.
-    pub fn decrypt(key: &AesDecryptKey, block: &mut AesBlock) {
-        let input = *block;
-        // Safety:
-        // - AesBlock is always a valid size and key is guaranteed to already be initialized.
-        unsafe { bssl_sys::AES_decrypt(input.as_ptr(), block.as_mut_ptr(), &key.0) }
-    }
-}
+pub type Block = [u8; BLOCK_SIZE];
 
 /// An initialized key which can be used for encrypting.
-pub struct AesEncryptKey(bssl_sys::AES_KEY);
+pub struct EncryptKey(bssl_sys::AES_KEY);
 
-impl AesEncryptKey {
-    /// Initializes an encryption key from an appropriately sized array of bytes for AES-128 operations.
-    pub fn new_aes_128(key: [u8; 16]) -> AesEncryptKey {
-        new_encrypt_key(key)
+impl EncryptKey {
+    /// Initializes an encryption key from an appropriately sized array of bytes
+    // for AES-128 operations.
+    pub fn new_128(key: &[u8; 16]) -> Self {
+        new_encrypt_key(key.as_slice())
     }
 
-    /// Initializes an encryption key from an appropriately sized array of bytes for AES-256 operations.
-    pub fn new_aes_256(key: [u8; 32]) -> AesEncryptKey {
-        new_encrypt_key(key)
+    /// Initializes an encryption key from an appropriately sized array of bytes
+    // for AES-256 operations.
+    pub fn new_256(key: &[u8; 32]) -> Self {
+        new_encrypt_key(key.as_slice())
+    }
+
+    /// Return the encrypted version of the given block.
+    pub fn encrypt(&self, block: &Block) -> Block {
+        let mut ret = *block;
+        self.encrypt_in_place(&mut ret);
+        ret
+    }
+
+    /// Replace `block` with its encrypted version.
+    pub fn encrypt_in_place(&self, block: &mut Block) {
+        // Safety:
+        // - block is always a valid size and key is guaranteed to already be initialized.
+        unsafe { bssl_sys::AES_encrypt(block.as_ffi_ptr(), block.as_mut_ffi_ptr(), &self.0) }
     }
 }
 
 /// An initialized key which can be used for decrypting
-pub struct AesDecryptKey(bssl_sys::AES_KEY);
+pub struct DecryptKey(bssl_sys::AES_KEY);
 
-impl AesDecryptKey {
+impl DecryptKey {
     /// Initializes a decryption key from an appropriately sized array of bytes for AES-128 operations.
-    pub fn new_aes_128(key: [u8; 16]) -> AesDecryptKey {
-        new_decrypt_key(key)
+    pub fn new_128(key: &[u8; 16]) -> DecryptKey {
+        new_decrypt_key(key.as_slice())
     }
 
     /// Initializes a decryption key from an appropriately sized array of bytes for AES-256 operations.
-    pub fn new_aes_256(key: [u8; 32]) -> AesDecryptKey {
-        new_decrypt_key(key)
+    pub fn new_256(key: &[u8; 32]) -> DecryptKey {
+        new_decrypt_key(key.as_slice())
+    }
+
+    /// Return the decrypted version of the given block.
+    pub fn decrypt(&self, block: &Block) -> Block {
+        let mut ret = *block;
+        self.decrypt_in_place(&mut ret);
+        ret
+    }
+
+    /// Replace `block` with its decrypted version.
+    pub fn decrypt_in_place(&self, block: &mut Block) {
+        // Safety:
+        // - block is always a valid size and key is guaranteed to already be initialized.
+        unsafe { bssl_sys::AES_decrypt(block.as_ffi_ptr(), block.as_mut_ffi_ptr(), &self.0) }
     }
 }
 
-/// Private generically implemented function for creating a new `AesEncryptKey` from an array of bytes.
 /// This should only be publicly exposed by wrapper types with the correct key lengths
-fn new_encrypt_key<const N: usize>(key: [u8; N]) -> AesEncryptKey {
-    let mut enc_key_uninit = core::mem::MaybeUninit::uninit();
-
-    // Safety:
-    // - key is guaranteed to point to bits/8 bytes determined by the len() * 8 used below.
-    // - bits is always a valid AES key size, as defined by the new_aes_* fns defined on the public
-    //   key structs.
-    let result = unsafe {
-        bssl_sys::AES_set_encrypt_key(
-            key.as_ptr(),
-            key.len() as core::ffi::c_uint * 8,
-            enc_key_uninit.as_mut_ptr(),
-        )
-    };
-    assert_eq!(result, 0, "Error occurred in bssl_sys::AES_set_encrypt_key");
-
-    // Safety:
-    // - since we have checked above that initialization succeeded, this will never be UB
-    let enc_key = unsafe { enc_key_uninit.assume_init() };
-
-    AesEncryptKey(enc_key)
+#[allow(clippy::unwrap_used)]
+fn new_encrypt_key(key: &[u8]) -> EncryptKey {
+    EncryptKey(
+        unsafe {
+            initialized_struct_fallible(|aes_key| {
+                // The return value of this function differs from the usual BoringSSL
+                // convention.
+                bssl_sys::AES_set_encrypt_key(key.as_ffi_ptr(), key.len() as c_uint * 8, aes_key)
+                    == 0
+            })
+        }
+        // unwrap: this function only fails if `key` is the wrong length, which
+        // must be prevented by the pub functions that call this.
+        .unwrap(),
+    )
 }
 
-/// Private generically implemented function for creating a new `AesDecryptKey` from an array of bytes.
 /// This should only be publicly exposed by wrapper types with the correct key lengths.
-fn new_decrypt_key<const N: usize>(key: [u8; N]) -> AesDecryptKey {
-    let mut dec_key_uninit = core::mem::MaybeUninit::uninit();
-
-    // Safety:
-    // - key is guaranteed to point to bits/8 bytes determined by the len() * 8 used below.
-    // - bits is always a valid AES key size, as defined by the new_aes_* fns defined on the public
-    //   key structs.
-    let result = unsafe {
-        bssl_sys::AES_set_decrypt_key(
-            key.as_ptr(),
-            key.len() as core::ffi::c_uint * 8,
-            dec_key_uninit.as_mut_ptr(),
-        )
-    };
-    assert_eq!(result, 0, "Error occurred in bssl_sys::AES_set_decrypt_key");
-
-    // Safety:
-    // - Since we have checked above that initialization succeeded, this will never be UB.
-    let dec_key = unsafe { dec_key_uninit.assume_init() };
-
-    AesDecryptKey(dec_key)
+#[allow(clippy::unwrap_used)]
+fn new_decrypt_key(key: &[u8]) -> DecryptKey {
+    DecryptKey(
+        unsafe {
+            initialized_struct_fallible(|aes_key| {
+                // The return value of this function differs from the usual BoringSSL
+                // convention.
+                bssl_sys::AES_set_decrypt_key(key.as_ffi_ptr(), key.len() as c_uint * 8, aes_key)
+                    == 0
+            })
+        }
+        // unwrap: this function only fails if `key` is the wrong length, which
+        // must be prevented by the pub functions that call this.
+        .unwrap(),
+    )
 }
 
 #[cfg(test)]
 mod tests {
     use crate::{
-        aes::{Aes, AesDecryptKey, AesEncryptKey},
+        aes::{DecryptKey, EncryptKey},
         test_helpers::decode_hex,
     };
 
-    // test data from https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38a.pdf F.1.1
     #[test]
-    fn aes_128_test_encrypt() {
-        let key = AesEncryptKey::new_aes_128(decode_hex("2b7e151628aed2a6abf7158809cf4f3c"));
-        let mut block = [0_u8; 16];
-
-        block.copy_from_slice(&decode_hex::<16>("6bc1bee22e409f96e93d7e117393172a"));
-        Aes::encrypt(&key, &mut block);
-        assert_eq!(decode_hex("3ad77bb40d7a3660a89ecaf32466ef97"), block);
-
-        block.copy_from_slice(&decode_hex::<16>("ae2d8a571e03ac9c9eb76fac45af8e51"));
-        Aes::encrypt(&key, &mut block);
-        assert_eq!(decode_hex("f5d3d58503b9699de785895a96fdbaaf"), block);
-
-        block.copy_from_slice(&decode_hex::<16>("30c81c46a35ce411e5fbc1191a0a52ef"));
-        Aes::encrypt(&key, &mut block);
-        assert_eq!(decode_hex("43b1cd7f598ece23881b00e3ed030688"), block);
-
-        block.copy_from_slice(&decode_hex::<16>("f69f2445df4f9b17ad2b417be66c3710"));
-        Aes::encrypt(&key, &mut block);
-        assert_eq!(decode_hex("7b0c785e27e8ad3f8223207104725dd4"), block);
+    fn aes_128() {
+        // test data from https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38a.pdf F.1.1
+        let key = decode_hex("2b7e151628aed2a6abf7158809cf4f3c");
+        let plaintext = decode_hex("6bc1bee22e409f96e93d7e117393172a");
+        let ciphertext = decode_hex("3ad77bb40d7a3660a89ecaf32466ef97");
+        assert_eq!(ciphertext, EncryptKey::new_128(&key).encrypt(&plaintext));
+        assert_eq!(plaintext, DecryptKey::new_128(&key).decrypt(&ciphertext));
     }
 
-    // test data from https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38a.pdf F.1.2
     #[test]
-    fn aes_128_test_decrypt() {
-        let key = AesDecryptKey::new_aes_128(decode_hex("2b7e151628aed2a6abf7158809cf4f3c"));
-        let mut block = [0_u8; 16];
-
-        block.copy_from_slice(&decode_hex::<16>("3ad77bb40d7a3660a89ecaf32466ef97"));
-        Aes::decrypt(&key, &mut block);
-        assert_eq!(decode_hex::<16>("6bc1bee22e409f96e93d7e117393172a"), block);
-
-        block.copy_from_slice(&decode_hex::<16>("f5d3d58503b9699de785895a96fdbaaf"));
-        Aes::decrypt(&key, &mut block);
-        assert_eq!(decode_hex::<16>("ae2d8a571e03ac9c9eb76fac45af8e51"), block);
-
-        block.copy_from_slice(&decode_hex::<16>("43b1cd7f598ece23881b00e3ed030688"));
-        Aes::decrypt(&key, &mut block);
-        assert_eq!(decode_hex::<16>("30c81c46a35ce411e5fbc1191a0a52ef"), block);
-
-        block.copy_from_slice(&decode_hex::<16>("7b0c785e27e8ad3f8223207104725dd4").as_slice());
-        Aes::decrypt(&key, &mut block);
-        assert_eq!(decode_hex::<16>("f69f2445df4f9b17ad2b417be66c3710"), block);
-    }
-
-    // test data from https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38a.pdf F.1.5
-    #[test]
-    pub fn aes_256_test_encrypt() {
-        let key = AesEncryptKey::new_aes_256(decode_hex(
-            "603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4",
-        ));
-        let mut block: [u8; 16];
-
-        block = decode_hex("6bc1bee22e409f96e93d7e117393172a");
-        Aes::encrypt(&key, &mut block);
-        assert_eq!(decode_hex("f3eed1bdb5d2a03c064b5a7e3db181f8"), block);
-
-        block = decode_hex("ae2d8a571e03ac9c9eb76fac45af8e51");
-        Aes::encrypt(&key, &mut block);
-        assert_eq!(decode_hex("591ccb10d410ed26dc5ba74a31362870"), block);
-
-        block = decode_hex("30c81c46a35ce411e5fbc1191a0a52ef");
-        Aes::encrypt(&key, &mut block);
-        assert_eq!(decode_hex("b6ed21b99ca6f4f9f153e7b1beafed1d"), block);
-
-        block = decode_hex("f69f2445df4f9b17ad2b417be66c3710");
-        Aes::encrypt(&key, &mut block);
-        assert_eq!(decode_hex("23304b7a39f9f3ff067d8d8f9e24ecc7"), block);
-    }
-
-    // test data from https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38a.pdf F.1.6
-    #[test]
-    fn aes_256_test_decrypt() {
-        let key = AesDecryptKey::new_aes_256(decode_hex(
-            "603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4",
-        ));
-
-        let mut block: [u8; 16];
-
-        block = decode_hex("f3eed1bdb5d2a03c064b5a7e3db181f8");
-        Aes::decrypt(&key, &mut block);
-        assert_eq!(decode_hex("6bc1bee22e409f96e93d7e117393172a"), block);
-
-        block = decode_hex("591ccb10d410ed26dc5ba74a31362870");
-        Aes::decrypt(&key, &mut block);
-        assert_eq!(decode_hex("ae2d8a571e03ac9c9eb76fac45af8e51"), block);
-
-        block = decode_hex("b6ed21b99ca6f4f9f153e7b1beafed1d");
-        Aes::decrypt(&key, &mut block);
-        assert_eq!(decode_hex("30c81c46a35ce411e5fbc1191a0a52ef"), block);
-
-        block = decode_hex("23304b7a39f9f3ff067d8d8f9e24ecc7");
-        Aes::decrypt(&key, &mut block);
-        assert_eq!(decode_hex("f69f2445df4f9b17ad2b417be66c3710"), block);
+    fn aes_256() {
+        // test data from https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38a.pdf F.1.5
+        let key = decode_hex("603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4");
+        let plaintext = decode_hex("6bc1bee22e409f96e93d7e117393172a");
+        let ciphertext = decode_hex("f3eed1bdb5d2a03c064b5a7e3db181f8");
+        assert_eq!(ciphertext, EncryptKey::new_256(&key).encrypt(&plaintext));
+        assert_eq!(plaintext, DecryptKey::new_256(&key).decrypt(&ciphertext));
     }
 }
diff --git a/rust/bssl-crypto/src/lib.rs b/rust/bssl-crypto/src/lib.rs
index 993b049..21a2725 100644
--- a/rust/bssl-crypto/src/lib.rs
+++ b/rust/bssl-crypto/src/lib.rs
@@ -36,7 +36,6 @@
 
 pub mod aead;
 
-/// AES block operations.
 pub mod aes;
 
 /// Ciphers.
@@ -268,16 +267,16 @@
 
 /// Returns a BoringSSL structure that is initialized by some function.
 /// Requires that the given function completely initializes the value or else
-/// returns a value other than one.
+/// returns false.
 ///
 /// (Tagged `unsafe` because a no-op argument would otherwise expose
 /// uninitialized memory.)
 unsafe fn initialized_struct_fallible<T, F>(init: F) -> Option<T>
 where
-    F: FnOnce(*mut T) -> core::ffi::c_int,
+    F: FnOnce(*mut T) -> bool,
 {
     let mut out_uninit = core::mem::MaybeUninit::<T>::uninit();
-    if init(out_uninit.as_mut_ptr()) == 1 {
+    if init(out_uninit.as_mut_ptr()) {
         Some(unsafe { out_uninit.assume_init() })
     } else {
         None