Add `hpke::Kem::get_public_key`

This method is useful if only the private key is known because
it came from a source other than `Kem::generate_keypair`, such
as another library or a KDF.

Change-Id: I463d494048c74cac26f5cfb8660a19b20f7d7e3a
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/78027
Reviewed-by: Adam Langley <agl@google.com>
Auto-Submit: Brett McLarnon <bmclarnon@google.com>
Commit-Queue: Adam Langley <agl@google.com>
Commit-Queue: Brett McLarnon <bmclarnon@google.com>
diff --git a/rust/bssl-crypto/src/hpke.rs b/rust/bssl-crypto/src/hpke.rs
index 12010ac..82bb1c1 100644
--- a/rust/bssl-crypto/src/hpke.rs
+++ b/rust/bssl-crypto/src/hpke.rs
@@ -114,44 +114,70 @@
         // handled by this crate.
         assert_eq!(ret, 1);
 
-        fn get_value_from_key(
-            key: &scoped::EvpHpkeKey,
-            accessor: unsafe extern "C" fn(
-                *const bssl_sys::EVP_HPKE_KEY,
-                // Output buffer.
-                *mut u8,
-                // Number of bytes written.
-                *mut usize,
-                // Maximum output size.
-                usize,
-            ) -> core::ffi::c_int,
-            max_len: usize,
-        ) -> Vec<u8> {
-            unsafe {
-                with_output_vec(max_len, |out| {
-                    let mut out_len = 0usize;
-                    let ret = accessor(key.as_ffi_ptr(), out, &mut out_len, max_len);
-                    // If `max_len` is correct then these functions never fail.
-                    assert_eq!(ret, 1);
-                    assert!(out_len <= max_len);
-                    // Safety: `out_len` bytes have been written, as required.
-                    out_len
-                })
-            }
-        }
-
-        let pub_key = get_value_from_key(
+        let pub_key = Self::get_value_from_key(
             &key,
             bssl_sys::EVP_HPKE_KEY_public_key,
             bssl_sys::EVP_HPKE_MAX_PUBLIC_KEY_LENGTH as usize,
         );
-        let priv_key = get_value_from_key(
+        let priv_key = Self::get_value_from_key(
             &key,
             bssl_sys::EVP_HPKE_KEY_private_key,
             bssl_sys::EVP_HPKE_MAX_PRIVATE_KEY_LENGTH as usize,
         );
         (pub_key, priv_key)
     }
+
+    /// Get a private key's corresponding public key, or `None` if the private
+    /// key is invalid.
+    pub fn public_from_private(&self, priv_key: &[u8]) -> Option<Vec<u8>> {
+        let mut key = scoped::EvpHpkeKey::new();
+        // Safety: `key`, `self`, and `priv_key` must be valid and this function
+        // doesn't take ownership of any of them.
+        let ret = unsafe {
+            bssl_sys::EVP_HPKE_KEY_init(
+                key.as_mut_ffi_ptr(),
+                self.as_ffi_ptr(),
+                priv_key.as_ptr(),
+                priv_key.len(),
+            )
+        };
+        if ret != 1 {
+            return None;
+        }
+
+        let pub_key = Self::get_value_from_key(
+            &key,
+            bssl_sys::EVP_HPKE_KEY_public_key,
+            bssl_sys::EVP_HPKE_MAX_PUBLIC_KEY_LENGTH as usize,
+        );
+        Some(pub_key)
+    }
+
+    fn get_value_from_key(
+        key: &scoped::EvpHpkeKey,
+        accessor: unsafe extern "C" fn(
+            *const bssl_sys::EVP_HPKE_KEY,
+            // Output buffer.
+            *mut u8,
+            // Number of bytes written.
+            *mut usize,
+            // Maximum output size.
+            usize,
+        ) -> core::ffi::c_int,
+        max_len: usize,
+    ) -> Vec<u8> {
+        unsafe {
+            with_output_vec(max_len, |out| {
+                let mut out_len = 0usize;
+                let ret = accessor(key.as_ffi_ptr(), out, &mut out_len, max_len);
+                // If `max_len` is correct then these functions never fail.
+                assert_eq!(ret, 1);
+                assert!(out_len <= max_len);
+                // Safety: `out_len` bytes have been written, as required.
+                out_len
+            })
+        }
+    }
 }
 
 /// Supported KDF algorithms with values detailed in RFC 9180.
@@ -517,10 +543,10 @@
             recipient_priv_key: decode_hex("f3ce7fdae57e1a310d87f1ebbde6f328be0a99cdbcadf4d6589cf29de4b8ffd2"),
             encapsulated_key: decode_hex_into_vec("04a92719c6195d5085104f469a8b9814d5838ff72b60501e2c4466e5e67b325ac98536d7b61a1af4b78e5b7f951c0900be863c403ce65c9bfcb9382657222d18c4"),
             plaintext: decode_hex("4265617574792069732074727574682c20747275746820626561757479"),
-            associated_data: decode_hex("436f756e742d30"), 
-            ciphertext: decode_hex("5ad590bb8baa577f8619db35a36311226a896e7342a6d836d8b7bcd2f20b6c7f9076ac232e3ab2523f39513434"), 
-            exporter_context: decode_hex("54657374436f6e74657874"), 
-            exported_value: decode_hex("d8f1ea7942adbba7412c6d431c62d01371ea476b823eb697e1f6e6cae1dab85a"), 
+            associated_data: decode_hex("436f756e742d30"),
+            ciphertext: decode_hex("5ad590bb8baa577f8619db35a36311226a896e7342a6d836d8b7bcd2f20b6c7f9076ac232e3ab2523f39513434"),
+            exporter_context: decode_hex("54657374436f6e74657874"),
+            exported_value: decode_hex("d8f1ea7942adbba7412c6d431c62d01371ea476b823eb697e1f6e6cae1dab85a"),
         }
     }
 
@@ -562,6 +588,17 @@
         }
     }
 
+    #[test]
+    fn kem_public_from_private() {
+        let kems = vec![Kem::X25519HkdfSha256, Kem::P256HkdfSha256];
+        for kem in &kems {
+            let (pub_key, priv_key) = kem.generate_keypair();
+            assert_eq!(kem.public_from_private(&priv_key), Some(pub_key));
+
+            assert_eq!(kem.public_from_private(b"invalid"), None);
+        }
+    }
+
     fn new_sender_context_for_testing(
         params: &Params,
         recipient_pub_key: &[u8],