rust: bssl-tls: Introduce preshared key support

Bug: 479599893

Signed-off-by: Xiangfei Ding <xfding@google.com>
Change-Id: I4073a044b3393e416544c3a0fc673b266a6a6964
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/92407
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/rust/bssl-tls/src/credentials.rs b/rust/bssl-tls/src/credentials.rs
index 3e3f1aa..329a2de 100644
--- a/rust/bssl-tls/src/credentials.rs
+++ b/rust/bssl-tls/src/credentials.rs
@@ -182,6 +182,33 @@
     }
 }
 
+/// Supported hash algorithms for TLS 1.3 PSK
+///
+/// See [RFC 9258] § 5.1.
+///
+/// [RFC 9258]: <https://datatracker.ietf.org/doc/html/rfc9258#section-5.1>
+#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
+pub enum PskHash {
+    /// SHA-256
+    Sha256,
+    /// SHA-384
+    Sha384,
+}
+
+impl PskHash {
+    pub(crate) fn as_evp_md(&self) -> *const bssl_sys::EVP_MD {
+        match self {
+            PskHash::Sha256 => unsafe {
+                // Safety: `EVP_sha256` returns a valid pointer to a static `EVP_MD`.
+                bssl_sys::EVP_sha256()
+            },
+            PskHash::Sha384 => unsafe {
+                // Safety: `EVP_sha384` returns a valid pointer to a static `EVP_MD`.
+                bssl_sys::EVP_sha384()
+            },
+        }
+    }
+}
 /// A completely constructed TLS credential.
 pub struct TlsCredential(NonNull<bssl_sys::SSL_CREDENTIAL>);
 
@@ -200,6 +227,39 @@
         forget(self);
         ptr
     }
+
+    /// Create a new pre-shared key credential for TLS 1.3.
+    ///
+    /// See [RFC 9258](https://datatracker.ietf.org/doc/html/rfc9258) for details.
+    pub fn new_pre_shared_key(
+        key: &[u8],
+        identity: &[u8],
+        hash: PskHash,
+        context: &[u8],
+    ) -> Result<Self, Error> {
+        let (key_ptr, key_len) = slice_into_ffi_raw_parts(key);
+        let (id_ptr, id_len) = slice_into_ffi_raw_parts(identity);
+        let (ctx_ptr, ctx_len) = slice_into_ffi_raw_parts(context);
+        let cred = unsafe {
+            // Safety:
+            // - `key_ptr` and `key_len` are valid for the duration of the call.
+            // - `id_ptr` and `id_len` are valid for the duration of the call.
+            // - `hash.as_ptr()` returns a valid static `EVP_MD` pointer.
+            // - `ctx_ptr` and `ctx_len` are valid for the duration of the call.
+            // - The function returns a newly allocated `SSL_CREDENTIAL` or NULL.
+            bssl_sys::SSL_CREDENTIAL_new_pre_shared_key(
+                key_ptr,
+                key_len,
+                id_ptr,
+                id_len,
+                hash.as_evp_md(),
+                ctx_ptr,
+                ctx_len,
+            )
+        };
+        let cred = NonNull::new(cred).ok_or_else(|| Error::extract_lib_err())?;
+        Ok(TlsCredential(cred))
+    }
 }
 
 impl Clone for TlsCredential {
diff --git a/rust/bssl-tls/src/credentials/tests.rs b/rust/bssl-tls/src/credentials/tests.rs
index abe40eb..ab10393 100644
--- a/rust/bssl-tls/src/credentials/tests.rs
+++ b/rust/bssl-tls/src/credentials/tests.rs
@@ -134,3 +134,103 @@
         Some(PrivateKeyAlgorithm::Rsa)
     ));
 }
+
+#[cfg(feature = "tokio_net")]
+#[tokio::test]
+async fn psk_tls13_handshake() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
+    use crate::credentials::{PskHash, TlsCredential};
+    use crate::io::tokio::TokioIo;
+
+    let _ = tracing_subscriber::fmt()
+        .with_max_level(tracing::Level::DEBUG)
+        .try_init();
+
+    let key = b"test-key-test-key-test-key-test-key";
+    let identity = b"test-identity";
+    let context = b"test-context";
+
+    let cred = TlsCredential::new_pre_shared_key(key, identity, PskHash::Sha256, context)?;
+
+    let mut server_ctx = crate::context::TlsContextBuilder::new_tls();
+    server_ctx.with_credential(cred.clone())?;
+
+    let mut client_ctx = crate::context::TlsContextBuilder::new_tls();
+    client_ctx.with_credential(cred)?;
+
+    let server_ctx = server_ctx.build();
+    let client_ctx = client_ctx.build();
+
+    let (client_io, server_io) = tokio::io::duplex(1024);
+
+    let mut client_conn = client_ctx.new_client_connection(None)?.build();
+    let mut server_conn = server_ctx.new_server_connection(None)?.build();
+
+    client_conn.set_io(TokioIo(client_io))?;
+    server_conn.set_io(TokioIo(server_io))?;
+
+    let client_task = tokio::spawn(async move {
+        let mut in_handshake = client_conn.in_handshake().unwrap();
+        in_handshake.async_handshake().await?;
+        Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
+    });
+
+    let server_task = tokio::spawn(async move {
+        let mut in_handshake = server_conn.in_handshake().unwrap();
+        in_handshake.async_handshake().await?;
+        Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
+    });
+
+    let (client_task, server_task) = tokio::try_join!(client_task, server_task)?;
+    client_task?;
+    server_task?;
+    Ok(())
+}
+
+#[cfg(all(unix, feature = "std"))]
+#[test]
+fn psk_tls13_handshake_sync() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
+    use crate::credentials::{PskHash, TlsCredential};
+    use crate::io::sync_io::{NoAsync, StdIoWithReactor};
+    use std::io::pipe;
+
+    let (server_rx, client_tx) = pipe().unwrap();
+    let (client_rx, server_tx) = pipe().unwrap();
+
+    let client_reader = StdIoWithReactor::new(client_rx, NoAsync);
+    let client_writer = StdIoWithReactor::new(client_tx, NoAsync);
+    let server_reader = StdIoWithReactor::new(server_rx, NoAsync);
+    let server_writer = StdIoWithReactor::new(server_tx, NoAsync);
+
+    let key = b"test-key-test-key-test-key-test-key";
+    let identity = b"test-identity";
+    let context = b"test-context";
+
+    let cred = TlsCredential::new_pre_shared_key(key, identity, PskHash::Sha256, context)?;
+
+    let mut server_ctx = crate::context::TlsContextBuilder::new_tls();
+    server_ctx.with_credential(cred.clone())?;
+    let server_ctx = server_ctx.build();
+
+    let mut client_ctx = crate::context::TlsContextBuilder::new_tls();
+    client_ctx.with_credential(cred)?;
+    let client_ctx = client_ctx.build();
+
+    let mut client_conn = client_ctx.new_client_connection(None)?.build();
+    let mut server_conn = server_ctx.new_server_connection(None)?.build();
+
+    client_conn.set_split_io(client_reader, client_writer)?;
+    server_conn.set_split_io(server_reader, server_writer)?;
+
+    let server_thread = std::thread::spawn(move || {
+        let mut in_handshake = server_conn.in_handshake().unwrap();
+        in_handshake.do_handshake().unwrap();
+        server_conn
+    });
+
+    let mut in_handshake = client_conn.in_handshake().unwrap();
+    in_handshake.do_handshake().unwrap();
+
+    let _server_conn = server_thread.join().unwrap();
+
+    Ok(())
+}