Align the hash-to-curve formulation with draft-16.

draft-07 to draft-16 is mostly editorial, but there were a few notable
changes:

- Empty DST values are forbidden.

- The sample implementation for map_to_curve_simple_swu has completely
  changed. The new formulation has the same performance (if not a smidge
  faster), and aligning with the spec seems generally useful.

- P-384 is now paired with SHA-384, not SHA-512. As this would be a
  breaking change for the trust tokens code, I've left that in. A
  follow-up CL will add implementations of draft-16, which is expected
  to match the final draft.

Before:
Did 77000 hash-to-curve P384_XMD:SHA-512_SSWU_RO_ operations in 4025677us (19127.2 ops/sec)
Did 7156000 hash-to-scalar P384_XMD:SHA-512 operations in 4000385us (1788827.8 ops/sec)

After:
Did 77000 hash-to-curve P384_XMD:SHA-512_SSWU_RO_ operations in 4009708us (19203.4 ops/sec) [+0.4%]
Did 7327000 hash-to-scalar P384_XMD:SHA-512 operations in 4000477us (1831531.6 ops/sec) [+2.4%]

Bug: 1414562
Change-Id: Ic3c37061e325250d5d8723fd9aa263930c6023cf
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/57146
Auto-Submit: David Benjamin <davidben@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
Commit-Queue: Steven Valdez <svaldez@google.com>
diff --git a/crypto/ec_extra/hash_to_curve.c b/crypto/ec_extra/hash_to_curve.c
index fa7ff59..dca4c24 100644
--- a/crypto/ec_extra/hash_to_curve.c
+++ b/crypto/ec_extra/hash_to_curve.c
@@ -27,7 +27,7 @@
 
 
 // This file implements hash-to-curve, as described in
-// draft-irtf-cfrg-hash-to-curve-07.
+// draft-irtf-cfrg-hash-to-curve-16.
 //
 // This hash-to-curve implementation is written generically with the
 // expectation that we will eventually wish to support other curves. If it
@@ -48,11 +48,17 @@
 //   templates to make specializing more convenient.
 
 // expand_message_xmd implements the operation described in section 5.3.1 of
-// draft-irtf-cfrg-hash-to-curve-07. It returns one on success and zero on
-// allocation failure or if |out_len| was too large.
+// draft-irtf-cfrg-hash-to-curve-16. It returns one on success and zero on
+// error.
 static int expand_message_xmd(const EVP_MD *md, uint8_t *out, size_t out_len,
                               const uint8_t *msg, size_t msg_len,
                               const uint8_t *dst, size_t dst_len) {
+  // See https://github.com/cfrg/draft-irtf-cfrg-hash-to-curve/issues/352
+  if (dst_len == 0) {
+    OPENSSL_PUT_ERROR(EC, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+    return 0;
+  }
+
   int ret = 0;
   const size_t block_size = EVP_MD_block_size(md);
   const size_t md_size = EVP_MD_size(md);
@@ -132,7 +138,7 @@
 
 // num_bytes_to_derive determines the number of bytes to derive when hashing to
 // a number modulo |modulus|. See the hash_to_field operation defined in
-// section 5.2 of draft-irtf-cfrg-hash-to-curve-07.
+// section 5.2 of draft-irtf-cfrg-hash-to-curve-16.
 static int num_bytes_to_derive(size_t *out, const BIGNUM *modulus, unsigned k) {
   size_t bits = BN_num_bits(modulus);
   size_t L = (bits + k + 7) / 8;
@@ -165,7 +171,7 @@
 }
 
 // hash_to_field implements the operation described in section 5.2
-// of draft-irtf-cfrg-hash-to-curve-07, with count = 2. |k| is the security
+// of draft-irtf-cfrg-hash-to-curve-16, with count = 2. |k| is the security
 // factor.
 static int hash_to_field2(const EC_GROUP *group, const EVP_MD *md,
                           EC_FELEM *out1, EC_FELEM *out2, const uint8_t *dst,
@@ -214,90 +220,126 @@
   ec_felem_sub(group, out, in, &tmp);     // out = -3*in
 }
 
-static inline void mul_minus_A(const EC_GROUP *group, EC_FELEM *out,
-                               const EC_FELEM *in) {
-  assert(group->a_is_minus3);
-  EC_FELEM tmp;
-  ec_felem_add(group, &tmp, in, in);   // tmp = 2*in
-  ec_felem_add(group, out, &tmp, in);  // out = 3*in
-}
-
-// sgn0_le implements the operation described in section 4.1.2 of
-// draft-irtf-cfrg-hash-to-curve-07.
-static BN_ULONG sgn0_le(const EC_GROUP *group, const EC_FELEM *a) {
+// sgn0 implements the operation described in section 4.1.2 of
+// draft-irtf-cfrg-hash-to-curve-16.
+static BN_ULONG sgn0(const EC_GROUP *group, const EC_FELEM *a) {
   uint8_t buf[EC_MAX_BYTES];
   size_t len;
   ec_felem_to_bytes(group, buf, &len, a);
   return buf[len - 1] & 1;
 }
 
-// map_to_curve_simple_swu implements the operation described in section 6.6.2
-// of draft-irtf-cfrg-hash-to-curve-07, using the optimization in appendix
-// D.2.1. It returns one on success and zero on error.
-static int map_to_curve_simple_swu(const EC_GROUP *group, const EC_FELEM *Z,
-                                   const BN_ULONG *c1, size_t num_c1,
-                                   const EC_FELEM *c2, EC_RAW_POINT *out,
-                                   const EC_FELEM *u) {
+OPENSSL_UNUSED static int is_3mod4(const EC_GROUP *group) {
+  return group->field.width > 0 && (group->field.d[0] & 3) == 3;
+}
+
+// sqrt_ratio_3mod4 implements the operation described in appendix F.2.1.2
+// of draft-irtf-cfrg-hash-to-curve-16.
+static BN_ULONG sqrt_ratio_3mod4(const EC_GROUP *group, const EC_FELEM *Z,
+                                 const BN_ULONG *c1, size_t num_c1,
+                                 const EC_FELEM *c2, EC_FELEM *out_y,
+                                 const EC_FELEM *u, const EC_FELEM *v) {
+  assert(is_3mod4(group));
+
   void (*const felem_mul)(const EC_GROUP *, EC_FELEM *r, const EC_FELEM *a,
                           const EC_FELEM *b) = group->meth->felem_mul;
   void (*const felem_sqr)(const EC_GROUP *, EC_FELEM *r, const EC_FELEM *a) =
       group->meth->felem_sqr;
 
+  EC_FELEM tv1, tv2, tv3, y1, y2;
+  felem_sqr(group, &tv1, v);                             // 1. tv1 = v^2
+  felem_mul(group, &tv2, u, v);                          // 2. tv2 = u * v
+  felem_mul(group, &tv1, &tv1, &tv2);                    // 3. tv1 = tv1 * tv2
+  group->meth->felem_exp(group, &y1, &tv1, c1, num_c1);  // 4. y1 = tv1^c1
+  felem_mul(group, &y1, &y1, &tv2);                      // 5. y1 = y1 * tv2
+  felem_mul(group, &y2, &y1, c2);                        // 6. y2 = y1 * c2
+  felem_sqr(group, &tv3, &y1);                           // 7. tv3 = y1^2
+  felem_mul(group, &tv3, &tv3, v);                       // 8. tv3 = tv3 * v
+
+  // 9. isQR = tv3 == u
+  // 10. y = CMOV(y2, y1, isQR)
+  // 11. return (isQR, y)
+  //
+  // Note the specification's CMOV function and our |ec_felem_select| have the
+  // opposite argument order.
+  ec_felem_sub(group, &tv1, &tv3, u);
+  const BN_ULONG isQR = ~ec_felem_non_zero_mask(group, &tv1);
+  ec_felem_select(group, out_y, isQR, &y1, &y2);
+  return isQR;
+}
+
+// map_to_curve_simple_swu implements the operation described in section 6.6.2
+// of draft-irtf-cfrg-hash-to-curve-16, using the straight-line implementation
+// in appendix F.2.
+static void map_to_curve_simple_swu(const EC_GROUP *group, const EC_FELEM *Z,
+                                    const BN_ULONG *c1, size_t num_c1,
+                                    const EC_FELEM *c2, EC_RAW_POINT *out,
+                                    const EC_FELEM *u) {
   // This function requires the prime be 3 mod 4, and that A = -3.
-  if (group->field.width == 0 || (group->field.d[0] & 3) != 3 ||
-      !group->a_is_minus3) {
-    OPENSSL_PUT_ERROR(EC, ERR_R_INTERNAL_ERROR);
-    return 0;
-  }
+  assert(is_3mod4(group));
+  assert(group->a_is_minus3);
 
-  EC_FELEM tv1, tv2, tv3, tv4, xd, x1n, x2n, tmp, gxd, gx1, y1, y2;
-  felem_sqr(group, &tv1, u);                         // tv1 = u^2
-  felem_mul(group, &tv3, Z, &tv1);                   // tv3 = Z * tv1
-  felem_sqr(group, &tv2, &tv3);                      // tv2 = tv3^2
-  ec_felem_add(group, &xd, &tv2, &tv3);              // xd = tv2 + tv3
-  ec_felem_add(group, &x1n, &xd, &group->one);       // x1n = xd + 1
-  felem_mul(group, &x1n, &x1n, &group->b);           // x1n = x1n * B
-  mul_minus_A(group, &xd, &xd);                      // xd = -A * xd
-  BN_ULONG e1 = ec_felem_non_zero_mask(group, &xd);  // e1 = xd == 0 [flipped]
-  mul_A(group, &tmp, Z);
-  ec_felem_select(group, &xd, e1, &xd, &tmp);  // xd = CMOV(xd, Z * A, e1)
-  felem_sqr(group, &tv2, &xd);                 // tv2 = xd^2
-  felem_mul(group, &gxd, &tv2, &xd);           // gxd = tv2 * xd = xd^3
-  mul_A(group, &tv2, &tv2);                    // tv2 = A * tv2
-  felem_sqr(group, &gx1, &x1n);                // gx1 = x1n^2
-  ec_felem_add(group, &gx1, &gx1, &tv2);       // gx1 = gx1 + tv2
-  felem_mul(group, &gx1, &gx1, &x1n);          // gx1 = gx1 * x1n
-  felem_mul(group, &tv2, &group->b, &gxd);     // tv2 = B * gxd
-  ec_felem_add(group, &gx1, &gx1, &tv2);       // gx1 = gx1 + tv2
-  felem_sqr(group, &tv4, &gxd);                // tv4 = gxd^2
-  felem_mul(group, &tv2, &gx1, &gxd);          // tv2 = gx1 * gxd
-  felem_mul(group, &tv4, &tv4, &tv2);          // tv4 = tv4 * tv2
-  group->meth->felem_exp(group, &y1, &tv4, c1, num_c1);  // y1 = tv4^c1
-  felem_mul(group, &y1, &y1, &tv2);                      // y1 = y1 * tv2
-  felem_mul(group, &x2n, &tv3, &x1n);                    // x2n = tv3 * x1n
-  felem_mul(group, &y2, &y1, c2);                        // y2 = y1 * c2
-  felem_mul(group, &y2, &y2, &tv1);                      // y2 = y2 * tv1
-  felem_mul(group, &y2, &y2, u);                         // y2 = y2 * u
-  felem_sqr(group, &tv2, &y1);                           // tv2 = y1^2
-  felem_mul(group, &tv2, &tv2, &gxd);                    // tv2 = tv2 * gxd
-  ec_felem_sub(group, &tv3, &tv2, &gx1);
-  BN_ULONG e2 =
-      ec_felem_non_zero_mask(group, &tv3);       // e2 = tv2 == gx1 [flipped]
-  ec_felem_select(group, &x1n, e2, &x2n, &x1n);  // xn = CMOV(x2n, x1n, e2)
-  ec_felem_select(group, &y1, e2, &y2, &y1);     // y = CMOV(y2, y1, e2)
-  BN_ULONG sgn0_u = sgn0_le(group, u);
-  BN_ULONG sgn0_y = sgn0_le(group, &y1);
-  BN_ULONG e3 = sgn0_u ^ sgn0_y;
-  e3 = ((BN_ULONG)0) - e3;  // e3 = sgn0(u) == sgn0(y) [flipped]
-  ec_felem_neg(group, &y2, &y1);
-  ec_felem_select(group, &y1, e3, &y2, &y1);  // y = CMOV(-y, y, e3)
+  void (*const felem_mul)(const EC_GROUP *, EC_FELEM *r, const EC_FELEM *a,
+                          const EC_FELEM *b) = group->meth->felem_mul;
+  void (*const felem_sqr)(const EC_GROUP *, EC_FELEM *r, const EC_FELEM *a) =
+      group->meth->felem_sqr;
 
-  // Appendix D.1 describes how to convert (x1n, xd, y1, 1) to Jacobian
-  // coordinates. Note yd = 1. Also note that gxd computed above is xd^3.
-  felem_mul(group, &out->X, &x1n, &xd);     // X = xn * xd
-  felem_mul(group, &out->Y, &y1, &gxd);     // Y = yn * gxd = yn * xd^3
-  out->Z = xd;                              // Z = xd
-  return 1;
+  EC_FELEM tv1, tv2, tv3, tv4, tv5, tv6, x, y, y1;
+  felem_sqr(group, &tv1, u);                     // 1. tv1 = u^2
+  felem_mul(group, &tv1, Z, &tv1);               // 2. tv1 = Z * tv1
+  felem_sqr(group, &tv2, &tv1);                  // 3. tv2 = tv1^2
+  ec_felem_add(group, &tv2, &tv2, &tv1);         // 4. tv2 = tv2 + tv1
+  ec_felem_add(group, &tv3, &tv2, &group->one);  // 5. tv3 = tv2 + 1
+  felem_mul(group, &tv3, &group->b, &tv3);       // 6. tv3 = B * tv3
+
+  // 7. tv4 = CMOV(Z, -tv2, tv2 != 0)
+  const BN_ULONG tv2_non_zero = ec_felem_non_zero_mask(group, &tv2);
+  ec_felem_neg(group, &tv4, &tv2);
+  ec_felem_select(group, &tv4, tv2_non_zero, &tv4, Z);
+
+  mul_A(group, &tv4, &tv4);                 // 8. tv4 = A * tv4
+  felem_sqr(group, &tv2, &tv3);             // 9. tv2 = tv3^2
+  felem_sqr(group, &tv6, &tv4);             // 10. tv6 = tv4^2
+  mul_A(group, &tv5, &tv6);                 // 11. tv5 = A * tv6
+  ec_felem_add(group, &tv2, &tv2, &tv5);    // 12. tv2 = tv2 + tv5
+  felem_mul(group, &tv2, &tv2, &tv3);       // 13. tv2 = tv2 * tv3
+  felem_mul(group, &tv6, &tv6, &tv4);       // 14. tv6 = tv6 * tv4
+  felem_mul(group, &tv5, &group->b, &tv6);  // 15. tv5 = B * tv6
+  ec_felem_add(group, &tv2, &tv2, &tv5);    // 16. tv2 = tv2 + tv5
+  felem_mul(group, &x, &tv1, &tv3);         // 17. x = tv1 * tv3
+
+  // 18. (is_gx1_square, y1) = sqrt_ratio(tv2, tv6)
+  const BN_ULONG is_gx1_square =
+      sqrt_ratio_3mod4(group, Z, c1, num_c1, c2, &y1, &tv2, &tv6);
+
+  felem_mul(group, &y, &tv1, u);  // 19. y = tv1 * u
+  felem_mul(group, &y, &y, &y1);  // 20. y = y * y1
+
+  // 21. x = CMOV(x, tv3, is_gx1_square)
+  ec_felem_select(group, &x, is_gx1_square, &tv3, &x);
+  // 22. y = CMOV(y, y1, is_gx1_square)
+  ec_felem_select(group, &y, is_gx1_square, &y1, &y);
+
+  // 23. e1 = sgn0(u) == sgn0(y)
+  BN_ULONG sgn0_u = sgn0(group, u);
+  BN_ULONG sgn0_y = sgn0(group, &y);
+  BN_ULONG not_e1 = sgn0_u ^ sgn0_y;
+  not_e1 = ((BN_ULONG)0) - not_e1;
+
+  // 24. y = CMOV(-y, y, e1)
+  ec_felem_neg(group, &tv1, &y);
+  ec_felem_select(group, &y, not_e1, &tv1, &y);
+
+  // 25. x = x / tv4
+  //
+  // Our output is in projective coordinates, so rather than inverting |tv4|
+  // now, represent (x / tv4, y) as (x * tv4, y * tv4^3, tv4). This is much more
+  // efficient if the caller will do further computation on the output. (If the
+  // caller will immediately convert to affine coordinates, it is slightly less
+  // efficient, but only by a few field multiplications.)
+  felem_mul(group, &out->X, &x, &tv4);
+  felem_mul(group, &out->Y, &y, &tv6);
+  out->Z = tv4;
 }
 
 static int hash_to_curve(const EC_GROUP *group, const EVP_MD *md,
@@ -318,10 +360,8 @@
   bn_rshift_words(c1, c1, /*shift=*/2, /*num=*/num_c1);
 
   EC_RAW_POINT Q0, Q1;
-  if (!map_to_curve_simple_swu(group, Z, c1, num_c1, c2, &Q0, &u0) ||
-      !map_to_curve_simple_swu(group, Z, c1, num_c1, c2, &Q1, &u1)) {
-    return 0;
-  }
+  map_to_curve_simple_swu(group, Z, c1, num_c1, c2, &Q0, &u0);
+  map_to_curve_simple_swu(group, Z, c1, num_c1, c2, &Q1, &u1);
 
   group->meth->add(group, out, &Q0, &Q1);  // R = Q0 + Q1
   // All our curves have cofactor one, so |clear_cofactor| is a no-op.
@@ -335,6 +375,19 @@
   return ec_felem_from_bytes(group, out, bytes, len);
 }
 
+// kP384Sqrt12 is sqrt(12) in P-384's field. It was computed as follows in
+// python3:
+//
+// p = 2**384 - 2**128 - 2**96 + 2**32 - 1
+// c2 = pow(12, (p+1)//4, p)
+// assert pow(c2, 2, p) == 12
+// ", ".join("0x%02x" % b for b in c2.to_bytes(384//8, 'big'))
+static const uint8_t kP384Sqrt12[] = {
+    0x2a, 0xcc, 0xb4, 0xa6, 0x56, 0xb0, 0x24, 0x9c, 0x71, 0xf0, 0x50, 0x0e,
+    0x83, 0xda, 0x2f, 0xdd, 0x7f, 0x98, 0xe3, 0x83, 0xd6, 0x8b, 0x53, 0x87,
+    0x1f, 0x87, 0x2f, 0xcb, 0x9c, 0xcb, 0x80, 0xc5, 0x3c, 0x0d, 0xe1, 0xf8,
+    0xa8, 0x0f, 0x7e, 0x19, 0x14, 0xe2, 0xec, 0x69, 0xf5, 0xa6, 0x26, 0xb3};
+
 int ec_hash_to_curve_p384_xmd_sha512_sswu_draft07(
     const EC_GROUP *group, EC_RAW_POINT *out, const uint8_t *dst,
     size_t dst_len, const uint8_t *msg, size_t msg_len) {
@@ -344,25 +397,10 @@
     return 0;
   }
 
-  // kSqrt1728 was computed as follows in python3:
-  //
-  // p = 2**384 - 2**128 - 2**96 + 2**32 - 1
-  // z3 = 12**3
-  // c2 = pow(z3, (p+1)//4, p)
-  // assert z3 == pow(c2, 2, p)
-  // ", ".join("0x%02x" % b for b in c2.to_bytes(384//8, 'big')
-
-  static const uint8_t kSqrt1728[] = {
-      0x01, 0x98, 0x77, 0xcc, 0x10, 0x41, 0xb7, 0x55, 0x57, 0x43, 0xc0, 0xae,
-      0x2e, 0x3a, 0x3e, 0x61, 0xfb, 0x2a, 0xaa, 0x2e, 0x0e, 0x87, 0xea, 0x55,
-      0x7a, 0x56, 0x3d, 0x8b, 0x59, 0x8a, 0x09, 0x40, 0xd0, 0xa6, 0x97, 0xa9,
-      0xe0, 0xb9, 0xe9, 0x2c, 0xfa, 0xa3, 0x14, 0xf5, 0x83, 0xc9, 0xd0, 0x66
-  };
-
-  // Z = -12, c2 = sqrt(1728)
+  // Z = -12, c2 = sqrt(12)
   EC_FELEM Z, c2;
   if (!felem_from_u8(group, &Z, 12) ||
-      !ec_felem_from_bytes(group, &c2, kSqrt1728, sizeof(kSqrt1728))) {
+      !ec_felem_from_bytes(group, &c2, kP384Sqrt12, sizeof(kP384Sqrt12))) {
     return 0;
   }
   ec_felem_neg(group, &Z, &Z);
diff --git a/crypto/ec_extra/internal.h b/crypto/ec_extra/internal.h
index 55314ac..ef93b56 100644
--- a/crypto/ec_extra/internal.h
+++ b/crypto/ec_extra/internal.h
@@ -36,6 +36,8 @@
 // |group| and writes the result to |out|, implementing the
 // P384_XMD:SHA-512_SSWU_RO_ suite from draft-irtf-cfrg-hash-to-curve-07. It
 // returns one on success and zero on error.
+//
+// TODO(https://crbug.com/1414562): Migrate this to the final version.
 OPENSSL_EXPORT int ec_hash_to_curve_p384_xmd_sha512_sswu_draft07(
     const EC_GROUP *group, EC_RAW_POINT *out, const uint8_t *dst,
     size_t dst_len, const uint8_t *msg, size_t msg_len);
@@ -44,6 +46,8 @@
 // and writes the result to |out|, using the hash_to_field operation from the
 // P384_XMD:SHA-512_SSWU_RO_ suite from draft-irtf-cfrg-hash-to-curve-07, but
 // generating a value modulo the group order rather than a field element.
+//
+// TODO(https://crbug.com/1414562): Migrate this to the final version.
 OPENSSL_EXPORT int ec_hash_to_scalar_p384_xmd_sha512_draft07(
     const EC_GROUP *group, EC_SCALAR *out, const uint8_t *dst, size_t dst_len,
     const uint8_t *msg, size_t msg_len);
diff --git a/crypto/fipsmodule/ec/ec_test.cc b/crypto/fipsmodule/ec/ec_test.cc
index 88665b2..bb93e55 100644
--- a/crypto/fipsmodule/ec/ec_test.cc
+++ b/crypto/fipsmodule/ec/ec_test.cc
@@ -1286,6 +1286,12 @@
   static const uint8_t kMessage[] = {4, 5, 6, 7};
   EXPECT_FALSE(ec_hash_to_curve_p384_xmd_sha512_sswu_draft07(
       p224.get(), &p, kDST, sizeof(kDST), kMessage, sizeof(kMessage)));
+
+  // Zero-length DSTs are not allowed.
+  bssl::UniquePtr<EC_GROUP> p384(EC_GROUP_new_by_curve_name(NID_secp384r1));
+  ASSERT_TRUE(p384);
+  EXPECT_FALSE(ec_hash_to_curve_p384_xmd_sha512_sswu_draft07(
+      p384.get(), &p, nullptr, 0, kMessage, sizeof(kMessage)));
 }
 
 TEST(ECTest, HashToScalar) {