Add missing curve check to ec_hash_to_scalar_p521_xmd_sha512.

The bounds on k make this a little tricky to test, so stick an assert(0)
as that codepath should be impossible.

Change-Id: I03958ed36bff4f0b420a446c6d49eca944f45da2
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/40885
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
diff --git a/crypto/ec_extra/hash_to_curve.c b/crypto/ec_extra/hash_to_curve.c
index 5c454de..407fa5d 100644
--- a/crypto/ec_extra/hash_to_curve.c
+++ b/crypto/ec_extra/hash_to_curve.c
@@ -125,6 +125,7 @@
   // |p_bits|.)
   if (L * 8 >= 2 * bits - 2 ||
       L > 2 * EC_MAX_BYTES) {
+    assert(0);
     OPENSSL_PUT_ERROR(EC, ERR_R_INTERNAL_ERROR);
     return 0;
   }
@@ -385,6 +386,11 @@
 int ec_hash_to_scalar_p521_xmd_sha512(const EC_GROUP *group, EC_SCALAR *out,
                                       const uint8_t *dst, size_t dst_len,
                                       const uint8_t *msg, size_t msg_len) {
+  if (EC_GROUP_get_curve_name(group) != NID_secp521r1) {
+    OPENSSL_PUT_ERROR(EC, EC_R_GROUP_MISMATCH);
+    return 0;
+  }
+
   return hash_to_scalar(group, EVP_sha512(), out, dst, dst_len, /*k=*/256, msg,
                         msg_len);
 }
diff --git a/crypto/fipsmodule/ec/ec_test.cc b/crypto/fipsmodule/ec/ec_test.cc
index 586b0b3..8697555 100644
--- a/crypto/fipsmodule/ec/ec_test.cc
+++ b/crypto/fipsmodule/ec/ec_test.cc
@@ -1256,6 +1256,15 @@
     EXPECT_EQ(test.y_hex, EncodeHex(bssl::MakeConstSpan(buf).subspan(
                               1 + field_len, field_len)));
   }
+
+  // hash-to-curve functions should check for the wrong group.
+  bssl::UniquePtr<EC_GROUP> p224(EC_GROUP_new_by_curve_name(NID_secp224r1));
+  ASSERT_TRUE(p224);
+  EC_RAW_POINT p;
+  static const uint8_t kDST[] = {0, 1, 2, 3};
+  static const uint8_t kMessage[] = {4, 5, 6, 7};
+  EXPECT_FALSE(ec_hash_to_curve_p521_xmd_sha512_sswu(
+      p224.get(), &p, kDST, sizeof(kDST), kMessage, sizeof(kMessage)));
 }
 
 TEST(ECTest, HashToScalar) {
@@ -1306,4 +1315,13 @@
     ec_scalar_to_bytes(group.get(), buf, &len, &scalar);
     EXPECT_EQ(test.result_hex, EncodeHex(bssl::MakeConstSpan(buf, len)));
   }
+
+  // hash-to-scalar functions should check for the wrong group.
+  bssl::UniquePtr<EC_GROUP> p224(EC_GROUP_new_by_curve_name(NID_secp224r1));
+  ASSERT_TRUE(p224);
+  EC_SCALAR scalar;
+  static const uint8_t kDST[] = {0, 1, 2, 3};
+  static const uint8_t kMessage[] = {4, 5, 6, 7};
+  EXPECT_FALSE(ec_hash_to_scalar_p521_xmd_sha512(
+      p224.get(), &scalar, kDST, sizeof(kDST), kMessage, sizeof(kMessage)));
 }