Add a hash_to_scalar variation of P-521's hash_to_field.

DLEQ proofs for PMBTokens need a random oracle over scalars as well as
field elements. (Interestingly, draft-irtf-cfrg-voprf-03 section 5.1
does not specify as strong of requirements, but then their reference
implementation does rejection sampling, so it's unclear.)

Reusing the hash_to_field operation so hash calls use the domain
separation tag consistently with other hash-to-curve operations seems
prudent, so implement a companion function until the actual construction
solidifies.

Change-Id: I92d807bfddcca26db690cce0a3da551143c25ff3
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/40646
Reviewed-by: Steven Valdez <svaldez@google.com>
Commit-Queue: Steven Valdez <svaldez@google.com>
diff --git a/crypto/ec_extra/hash_to_curve.c b/crypto/ec_extra/hash_to_curve.c
index 52dbd9c..5c0e2e5 100644
--- a/crypto/ec_extra/hash_to_curve.c
+++ b/crypto/ec_extra/hash_to_curve.c
@@ -133,26 +133,36 @@
   group->meth->felem_reduce(group, out, buf.words, num_words * 2);
 }
 
+// num_bytes_to_derive determines the number of bytes to derive when hashing to
+// a number modulo |modulus|. See the hash_to_field operation defined in
+// section 5.2 of draft-irtf-cfrg-hash-to-curve-06.
+static int num_bytes_to_derive(size_t *out, const BIGNUM *modulus, unsigned k) {
+  size_t bits = BN_num_bits(modulus);
+  size_t L = (bits + k + 7) / 8;
+  // We require 2^(8*L) < 2^(2*bits - 2) <= n^2 so to fit in bounds for
+  // |felem_reduce| and |ec_scalar_refuce|. All defined hash-to-curve suites
+  // define |k| to be well under this bound. (|k| is usually around half of
+  // |p_bits|.)
+  if (L * 8 >= 2 * bits - 2 ||
+      L > 2 * EC_MAX_BYTES) {
+    OPENSSL_PUT_ERROR(EC, ERR_R_INTERNAL_ERROR);
+    return 0;
+  }
+
+  *out = L;
+  return 1;
+}
+
 // hash_to_field implements the operation described in section 5.2
 // of draft-irtf-cfrg-hash-to-curve-06, with count = 2.
 static int hash_to_field2(const EC_GROUP *group, const EVP_MD *md,
                           EC_FELEM *out1, EC_FELEM *out2, const uint8_t *dst,
                           size_t dst_len, unsigned k, const uint8_t *msg,
                           size_t msg_len) {
-  // Determine L, the number of bytes to derive per output element.
-  size_t p_bits = BN_num_bits(&group->field);
-  size_t L = (p_bits + k + 7) / 8;
-
-  // We require 2^(8*L) < 2^(2*p_bits - 2) <= p^2 so to fit in bounds for
-  // |felem_reduce|. All defined hash-to-curve suites define |k| to be well
-  // under this bound. (|k| is usually around half of |p_bits|.)
-  if (L * 8 >= 2 * p_bits - 2) {
-    OPENSSL_PUT_ERROR(EC, ERR_R_INTERNAL_ERROR);
-    return 0;
-  }
-
+  size_t L;
   uint8_t buf[4 * EC_MAX_BYTES];
-  if (!expand_message_xmd(md, buf, 2 * L, msg, msg_len, dst, dst_len)) {
+  if (!num_bytes_to_derive(&L, &group->field, k) ||
+      !expand_message_xmd(md, buf, 2 * L, msg, msg_len, dst, dst_len)) {
     return 0;
   }
   reduce_to_felem(group, out1, buf, L);
@@ -160,6 +170,22 @@
   return 1;
 }
 
+// hash_to_scalar behaves like |hash_to_field2| but returns a value modulo the
+// group order rather than a field element.
+static int hash_to_scalar(const EC_GROUP *group, const EVP_MD *md,
+                          EC_SCALAR *out, const uint8_t *dst, size_t dst_len,
+                          unsigned k, const uint8_t *msg, size_t msg_len) {
+  size_t L;
+  BN_ULONG words[EC_MAX_WORDS * 2] = {0};
+  if (!num_bytes_to_derive(&L, &group->order, k) ||
+      !expand_message_xmd(md, (uint8_t *)words, L, msg, msg_len, dst,
+                          dst_len)) {
+    return 0;
+  }
+  ec_scalar_reduce(group, out, words, 2 * group->order.width);
+  return 1;
+}
+
 static inline void mul_A(const EC_GROUP *group, EC_FELEM *out,
                          const EC_FELEM *in) {
   assert(group->a_is_minus3);
@@ -354,3 +380,10 @@
   return hash_to_curve_p521_xmd_sswu(group, out, dst, dst_len, EVP_sha512(),
                                      /*k=*/240, msg, msg_len);
 }
+
+int ec_hash_to_scalar_p521_xmd_sha512(const EC_GROUP *group, EC_SCALAR *out,
+                                      const uint8_t *dst, size_t dst_len,
+                                      const uint8_t *msg, size_t msg_len) {
+  return hash_to_scalar(group, EVP_sha512(), out, dst, dst_len, /*k=*/256, msg,
+                        msg_len);
+}
diff --git a/crypto/ec_extra/internal.h b/crypto/ec_extra/internal.h
index 873726c..b05e9ea 100644
--- a/crypto/ec_extra/internal.h
+++ b/crypto/ec_extra/internal.h
@@ -48,6 +48,21 @@
     const EC_GROUP *group, EC_RAW_POINT *out, const uint8_t *dst,
     size_t dst_len, const uint8_t *msg, size_t msg_len);
 
+// ec_hash_to_scalar_p521_xmd_sha512 hashes |msg| to a scalar on |group| and
+// writes the result to |out|, using the hash_to_field operation from the
+// P521_XMD:SHA-512_SSWU_RO_ suite, but generating a value modulo the group
+// order rather than a field element. |dst| is the domain separation
+// tag and must be unique for each protocol. See section 3.1 of
+// draft-irtf-cfrg-hash-to-curve-06 for additional guidance on this parameter.
+//
+// Note the requirement to use a different tag for each encoding used in a
+// protocol extends to this function. Protocols which use both this function and
+// |ec_hash_to_scalar_p521_xmd_sha512| must use distinct values of |dst| for
+// each use.
+OPENSSL_EXPORT int ec_hash_to_scalar_p521_xmd_sha512(
+    const EC_GROUP *group, EC_SCALAR *out, const uint8_t *dst, size_t dst_len,
+    const uint8_t *msg, size_t msg_len);
+
 
 #if defined(__cplusplus)
 }  // extern C
diff --git a/crypto/fipsmodule/ec/ec_test.cc b/crypto/fipsmodule/ec/ec_test.cc
index 7ab7f5f..eb53b2f 100644
--- a/crypto/fipsmodule/ec/ec_test.cc
+++ b/crypto/fipsmodule/ec/ec_test.cc
@@ -1203,3 +1203,53 @@
                               1 + field_len, field_len)));
   }
 }
+
+TEST(ECTest, HashToScalar) {
+  bssl::UniquePtr<EC_GROUP> group(EC_GROUP_new_by_curve_name(NID_secp521r1));
+  ASSERT_TRUE(group);
+
+  struct HashToScalarTest {
+    const char *dst;
+    const char *msg;
+    const char *result_hex;
+  };
+  static const HashToScalarTest kTests[] = {
+      {"P521_XMD:SHA-512_SCALAR_TEST", "",
+       "01407998b20d948d6ef4e68c981d24f44ed3e65a49849a16296770"
+       "14b48d4664e150074ccf9afcdf791c6afc648e69b94989881f1f0b"
+       "4e2b86ce40b1dc2ce4bb20f0"},
+      {"P521_XMD:SHA-512_SCALAR_TEST", "abcdef0123456789",
+       "019fab7021eeae5476d7ae7352793025a9aed0193831a42cbcd183"
+       "e377a83d33ee178e11f34f9b6cffeffdee40c9260e5aff50ebf276"
+       "c992b78d086dd4475d7b098e"},
+      {"P521_XMD:SHA-512_SCALAR_TEST",
+       "a512_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+       "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+       "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+       "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+       "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+       "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+       "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+       "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+       "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+       "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
+       "00ab2c0feabe9bbd93d4570fe627fa88667bb8f2f117e76b77d41a"
+       "15bb5dd995f61c64cd70a96dc9cda1f70b426dfd7a1c11a2865272"
+       "f4698f501e57f8c4c2ed0008"},
+  };
+
+  for (const auto &test : kTests) {
+    SCOPED_TRACE(test.dst);
+    SCOPED_TRACE(test.msg);
+
+    EC_SCALAR scalar;
+    ASSERT_TRUE(ec_hash_to_scalar_p521_xmd_sha512(
+        group.get(), &scalar, reinterpret_cast<const uint8_t *>(test.dst),
+        strlen(test.dst), reinterpret_cast<const uint8_t *>(test.msg),
+        strlen(test.msg)));
+    uint8_t buf[EC_MAX_BYTES];
+    size_t len;
+    ec_scalar_to_bytes(group.get(), buf, &len, &scalar);
+    EXPECT_EQ(test.result_hex, EncodeHex(bssl::MakeConstSpan(buf, len)));
+  }
+}
diff --git a/crypto/fipsmodule/ec/internal.h b/crypto/fipsmodule/ec/internal.h
index 5a40d9d..e9a8623 100644
--- a/crypto/fipsmodule/ec/internal.h
+++ b/crypto/fipsmodule/ec/internal.h
@@ -113,8 +113,8 @@
 // ec_scalar_to_bytes serializes |in| as a big-endian bytestring to |out| and
 // sets |*out_len| to the number of bytes written. The number of bytes written
 // is |BN_num_bytes(&group->order)|, which is at most |EC_MAX_BYTES|.
-void ec_scalar_to_bytes(const EC_GROUP *group, uint8_t *out, size_t *out_len,
-                        const EC_SCALAR *in);
+OPENSSL_EXPORT void ec_scalar_to_bytes(const EC_GROUP *group, uint8_t *out,
+                                       size_t *out_len, const EC_SCALAR *in);
 
 // ec_scalar_from_bytes deserializes |in| and stores the resulting scalar over
 // group |group| to |out|. It returns one on success and zero if |in| is
@@ -122,6 +122,12 @@
 int ec_scalar_from_bytes(const EC_GROUP *group, EC_SCALAR *out,
                          const uint8_t *in, size_t len);
 
+// ec_scalar_reduce sets |out| to |words|, reduced modulo the group order.
+// |words| must be less than order^2. |num| must be at most twice the width of
+// group order. This function treats |words| as secret.
+void ec_scalar_reduce(const EC_GROUP *group, EC_SCALAR *out,
+                      const BN_ULONG *words, size_t num);
+
 // ec_random_nonzero_scalar sets |out| to a uniformly selected random value from
 // 1 to |group->order| - 1. It returns one on success and zero on error.
 int ec_random_nonzero_scalar(const EC_GROUP *group, EC_SCALAR *out,
diff --git a/crypto/fipsmodule/ec/scalar.c b/crypto/fipsmodule/ec/scalar.c
index aacefd2..2d49682 100644
--- a/crypto/fipsmodule/ec/scalar.c
+++ b/crypto/fipsmodule/ec/scalar.c
@@ -81,6 +81,15 @@
   return 1;
 }
 
+void ec_scalar_reduce(const EC_GROUP *group, EC_SCALAR *out,
+                      const BN_ULONG *words, size_t num) {
+  // Convert "from" Montgomery form so the value is reduced modulo the order.
+  bn_from_montgomery_small(out->words, group->order.width, words, num,
+                           group->order_mont);
+  // Convert "to" Montgomery form to remove the R^-1 factor added.
+  ec_scalar_to_montgomery(group, out, out);
+}
+
 void ec_scalar_add(const EC_GROUP *group, EC_SCALAR *r, const EC_SCALAR *a,
                    const EC_SCALAR *b) {
   const BIGNUM *order = &group->order;