rust: Properly dispatch the key types to the right curve after parsing

This patch partially revert 90487 because of the potential curve type
mismatch.

Update-Note: To parse `ECPrivateKey` with inferred curve type, please
use `ParsedPrivateKey` in the respective EC modules in bssl-crypto.

Signed-off-by: Xiangfei Ding <xfding@google.com>
Change-Id: Ia30048686a971f1bdba804fb579e8f576a6a6964
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/90507
Reviewed-by: Rudolf Polzer <rpolzer@google.com>
diff --git a/rust/bssl-crypto/src/ecdh.rs b/rust/bssl-crypto/src/ecdh.rs
index d8cc136..ea1ce93 100644
--- a/rust/bssl-crypto/src/ecdh.rs
+++ b/rust/bssl-crypto/src/ecdh.rs
@@ -36,7 +36,10 @@
 //! assert_eq!(shared_key1, shared_key2);
 //! ```
 
-use crate::{ec, with_output_vec, Buffer};
+use crate::{
+    ec::{self, Group},
+    with_output_vec, Buffer,
+};
 use alloc::vec::Vec;
 use core::marker::PhantomData;
 
@@ -63,6 +66,36 @@
     }
 }
 
+/// Parsed `ECPrivateKey` dispatched into the corresponding curve types.
+pub enum ParsedPrivateKey {
+    /// A P-256 private key.
+    P256(PrivateKey<ec::P256>),
+    /// A P-384 private key.
+    P384(PrivateKey<ec::P384>),
+}
+
+impl ParsedPrivateKey {
+    /// Parses an ECPrivateKey structure froma DER encoded structure per [RFC 5915],
+    /// whose curve is specified by the `ECParameters`.
+    ///
+    /// Unless the curve group is one of the variants of [`Group`], this method returns [`None`].
+    ///
+    /// [RFC 5915]: <https://datatracker.ietf.org/doc/html/rfc5915>
+    pub fn from_der(der: &[u8]) -> Option<Self> {
+        let key = ec::Key::from_der_ec_private_key_with_curve_names(der)?;
+        match key.get_group()? {
+            Group::P256 => Some(ParsedPrivateKey::P256(PrivateKey {
+                key,
+                marker: PhantomData,
+            })),
+            Group::P384 => Some(ParsedPrivateKey::P384(PrivateKey {
+                key,
+                marker: PhantomData,
+            })),
+        }
+    }
+}
+
 /// An ECDH private key over the given curve.
 pub struct PrivateKey<C: ec::Curve> {
     key: ec::Key,
@@ -111,19 +144,6 @@
         })
     }
 
-    /// Parses an ECPrivateKey structure from [RFC 5915], whose curve is specified by
-    /// the `ECParameters`.
-    ///
-    /// Unless the curve group is one of the variants of [`Group`], this method returns [`None`].
-    ///
-    /// [RFC 5915]: <https://datatracker.ietf.org/doc/html/rfc5915>
-    pub fn from_der_ec_private_key_with_curve_names(der: &[u8]) -> Option<Self> {
-        ec::Key::from_der_ec_private_key_with_curve_names(der).map(|key| Self {
-            key,
-            marker: PhantomData,
-        })
-    }
-
     /// Serialize this private key as an ECPrivateKey structure (from RFC 5915).
     pub fn to_der_ec_private_key(&self) -> Buffer {
         self.key.to_der_ec_private_key()
@@ -199,10 +219,9 @@
         let alice_public_key = alice_private_key.to_public_key();
         let alice_private_key =
             PrivateKey::<C>::from_big_endian(alice_private_key.to_big_endian().as_ref()).unwrap();
-        let alice_private_key = PrivateKey::<C>::from_der_ec_private_key(
-            alice_private_key.to_der_ec_private_key().as_ref(),
-        )
-        .unwrap();
+        let alice_private_key_der = alice_private_key.to_der_ec_private_key();
+        let alice_private_key =
+            PrivateKey::<C>::from_der_ec_private_key(alice_private_key_der.as_ref()).unwrap();
 
         let bob_private_key = PrivateKey::<C>::generate();
         let bob_public_key = bob_private_key.to_public_key();
@@ -211,6 +230,11 @@
         let shared_key2 = bob_private_key.compute_shared_key(&alice_public_key);
 
         assert_eq!(shared_key1, shared_key2);
+
+        match ParsedPrivateKey::from_der(alice_private_key_der.as_ref()).unwrap() {
+            ParsedPrivateKey::P256(_) => assert!(matches!(C::group(), Group::P256)),
+            ParsedPrivateKey::P384(_) => assert!(matches!(C::group(), Group::P384)),
+        }
     }
 
     #[test]
diff --git a/rust/bssl-crypto/src/ecdsa.rs b/rust/bssl-crypto/src/ecdsa.rs
index b528e17..922abc1 100644
--- a/rust/bssl-crypto/src/ecdsa.rs
+++ b/rust/bssl-crypto/src/ecdsa.rs
@@ -33,7 +33,10 @@
 //! assert!(public_key.verify(signed_message, sig.as_slice()).is_ok());
 //! ```
 
-use crate::{ec, with_output_vec, Buffer, FfiSlice, InvalidSignatureError};
+use crate::{
+    ec::{self, Group},
+    with_output_vec, Buffer, FfiSlice, InvalidSignatureError,
+};
 use alloc::vec::Vec;
 use core::marker::PhantomData;
 
@@ -158,6 +161,36 @@
     }
 }
 
+/// Parsed `ECPrivateKey` dispatched into the corresponding curve types.
+pub enum ParsedPrivateKey {
+    /// A P-256 private key.
+    P256(PrivateKey<ec::P256>),
+    /// A P-384 private key.
+    P384(PrivateKey<ec::P384>),
+}
+
+impl ParsedPrivateKey {
+    /// Parses an ECPrivateKey structure froma DER encoded structure per [RFC 5915],
+    /// whose curve is specified by the `ECParameters`.
+    ///
+    /// Unless the curve group is one of the variants of [`Group`], this method returns [`None`].
+    ///
+    /// [RFC 5915]: <https://datatracker.ietf.org/doc/html/rfc5915>
+    pub fn from_der(der: &[u8]) -> Option<Self> {
+        let key = ec::Key::from_der_ec_private_key_with_curve_names(der)?;
+        match key.get_group()? {
+            Group::P256 => Some(ParsedPrivateKey::P256(PrivateKey {
+                key,
+                marker: PhantomData,
+            })),
+            Group::P384 => Some(ParsedPrivateKey::P384(PrivateKey {
+                key,
+                marker: PhantomData,
+            })),
+        }
+    }
+}
+
 impl<C: ec::Curve> PrivateKey<C> {
     /// Generate a random private key.
     pub fn generate() -> Self {
@@ -191,19 +224,6 @@
         })
     }
 
-    /// Parses an ECPrivateKey structure from [RFC 5915], whose curve is specified by
-    /// the `ECParameters`.
-    ///
-    /// Unless the curve group is one of the variants of [`Group`], this method returns [`None`].
-    ///
-    /// [RFC 5915]: <https://datatracker.ietf.org/doc/html/rfc5915>
-    pub fn from_der_ec_private_key_with_curve_names(der: &[u8]) -> Option<Self> {
-        ec::Key::from_der_ec_private_key_with_curve_names(der).map(|key| Self {
-            key,
-            marker: PhantomData,
-        })
-    }
-
     /// Serialize this private key as an ECPrivateKey structure (from RFC 5915).
     pub fn to_der_ec_private_key(&self) -> Buffer {
         self.key.to_der_ec_private_key()
@@ -386,15 +406,27 @@
             .is_err());
     }
 
+    fn check_parsing<C: ec::Curve>() {
+        let key = PrivateKey::<C>::generate();
+        let der = key.to_der_ec_private_key();
+        let parsed = ParsedPrivateKey::from_der(der.as_ref()).unwrap();
+        match parsed {
+            ParsedPrivateKey::P256(_) => assert!(matches!(C::group(), Group::P256)),
+            ParsedPrivateKey::P384(_) => assert!(matches!(C::group(), Group::P384)),
+        }
+    }
+
     #[test]
     fn p256() {
         check_curve::<P256>();
         check_compressed::<P256>();
+        check_parsing::<P256>();
     }
 
     #[test]
     fn p384() {
         check_curve::<P384>();
         check_compressed::<P384>();
+        check_parsing::<P384>();
     }
 }