Generalize bn_from_montgomery_small.

Montgomery reduction works when the input is at most N*R (N^2 is a
tighter bound that's easier to describe and usually suffices). This is
useful when reducing product-sized values. In particular,
hash-to-curve's hash_to_field function requires a reduction. Generalize
this so we can implement it with Montgomery reduction.

Bug: chromium:1014199
Change-Id: I1a07f9b94823742384a98c0c6fecdedfe5240b7b
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/40588
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc
index d7db77e..9791437 100644
--- a/crypto/fipsmodule/bn/bn_test.cc
+++ b/crypto/fipsmodule/bn/bn_test.cc
@@ -659,11 +659,26 @@
       bn_mod_mul_montgomery_small(r_words.get(), a_words.get(), b_words.get(),
                                   m_width, mont.get());
       // Use the second half of |tmp| so ASan will catch out-of-bounds writes.
-      bn_from_montgomery_small(r_words.get(), r_words.get(), m_width,
+      bn_from_montgomery_small(r_words.get(), m_width, r_words.get(), m_width,
                                mont.get());
       ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m_width));
       EXPECT_BIGNUMS_EQUAL("A * B (mod M) (Montgomery, words)", mod_mul.get(),
                            ret.get());
+
+      // |bn_from_montgomery_small| must additionally work on double-width
+      // inputs. Test this by running |bn_from_montgomery_small| on the result
+      // of a product. Note |a_words| * |b_words| has an extra factor of R^2, so
+      // we must reduce twice.
+      std::unique_ptr<BN_ULONG[]> prod_words(new BN_ULONG[m_width * 2]);
+      bn_mul_small(prod_words.get(), m_width * 2, a_words.get(), m_width,
+                   b_words.get(), m_width);
+      bn_from_montgomery_small(r_words.get(), m_width, prod_words.get(),
+                               m_width * 2, mont.get());
+      bn_from_montgomery_small(r_words.get(), m_width, r_words.get(), m_width,
+                               mont.get());
+      ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m_width));
+      EXPECT_BIGNUMS_EQUAL("A * B (mod M) (Montgomery, words)",
+                           mod_mul.get(), ret.get());
     }
 #endif
   }
@@ -721,7 +736,8 @@
       bn_to_montgomery_small(a_words.get(), a_words.get(), m_width, mont.get());
       bn_mod_mul_montgomery_small(r_words.get(), a_words.get(), a_words.get(),
                                   m_width, mont.get());
-      bn_from_montgomery_small(r_words.get(), r_words.get(), m_width, mont.get());
+      bn_from_montgomery_small(r_words.get(), m_width, r_words.get(), m_width,
+                               mont.get());
       ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m_width));
       EXPECT_BIGNUMS_EQUAL("A * A (mod M) (Montgomery, words)",
                            mod_square.get(), ret.get());
@@ -732,7 +748,7 @@
       bn_mod_mul_montgomery_small(r_words.get(), a_words.get(),
                                   a_copy_words.get(), m_width, mont.get());
       // Use the second half of |tmp| so ASan will catch out-of-bounds writes.
-      bn_from_montgomery_small(r_words.get(), r_words.get(), m_width,
+      bn_from_montgomery_small(r_words.get(), m_width, r_words.get(), m_width,
                                mont.get());
       ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m_width));
       EXPECT_BIGNUMS_EQUAL("A * A_copy (mod M) (Montgomery, words)",
@@ -783,7 +799,7 @@
       bn_to_montgomery_small(a_words.get(), a_words.get(), m_width, mont.get());
       bn_mod_exp_mont_small(r_words.get(), a_words.get(), m_width, e->d,
                             e->width, mont.get());
-      bn_from_montgomery_small(r_words.get(), r_words.get(), m_width,
+      bn_from_montgomery_small(r_words.get(), m_width, r_words.get(), m_width,
                                mont.get());
       ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m_width));
       EXPECT_BIGNUMS_EQUAL("A ^ E (mod M) (Montgomery, words)", mod_exp.get(),
diff --git a/crypto/fipsmodule/bn/exponentiation.c b/crypto/fipsmodule/bn/exponentiation.c
index f2c3f68..a0f2549 100644
--- a/crypto/fipsmodule/bn/exponentiation.c
+++ b/crypto/fipsmodule/bn/exponentiation.c
@@ -732,7 +732,7 @@
     num_p--;
   }
   if (num_p == 0) {
-    bn_from_montgomery_small(r, mont->RR.d, num, mont);
+    bn_from_montgomery_small(r, num, mont->RR.d, num, mont);
     return;
   }
   unsigned bits = BN_num_bits_word(p[num_p - 1]) + (num_p - 1) * BN_BITS2;
diff --git a/crypto/fipsmodule/bn/internal.h b/crypto/fipsmodule/bn/internal.h
index f27f3e7..af32b3d 100644
--- a/crypto/fipsmodule/bn/internal.h
+++ b/crypto/fipsmodule/bn/internal.h
@@ -647,10 +647,13 @@
                             const BN_MONT_CTX *mont);
 
 // bn_from_montgomery_small sets |r| to |a| translated out of the Montgomery
-// domain. |r| and |a| are |num| words long, which must be |mont->N.width|. |a|
-// must be fully-reduced and may alias |r|.
-void bn_from_montgomery_small(BN_ULONG *r, const BN_ULONG *a, size_t num,
-                              const BN_MONT_CTX *mont);
+// domain. |r| and |a| are |num_r| and |num_a| words long, respectively. |num_r|
+// must be |mont->N.width|. |a| must be at most |mont->N|^2 and may alias |r|.
+//
+// Unlike most of these functions, only |num_r| is bounded by
+// |BN_SMALL_MAX_WORDS|. |num_a| may exceed it, but must be at most 2 * |num_r|.
+void bn_from_montgomery_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a,
+                              size_t num_a, const BN_MONT_CTX *mont);
 
 // bn_mod_mul_montgomery_small sets |r| to |a| * |b| mod |mont->N|. Both inputs
 // and outputs are in the Montgomery domain. Each array is |num| words long,
diff --git a/crypto/fipsmodule/bn/montgomery.c b/crypto/fipsmodule/bn/montgomery.c
index b6eaf6a..e9fa08f 100644
--- a/crypto/fipsmodule/bn/montgomery.c
+++ b/crypto/fipsmodule/bn/montgomery.c
@@ -455,18 +455,18 @@
   bn_mod_mul_montgomery_small(r, a, mont->RR.d, num, mont);
 }
 
-void bn_from_montgomery_small(BN_ULONG *r, const BN_ULONG *a, size_t num,
-                              const BN_MONT_CTX *mont) {
-  if (num != (size_t)mont->N.width || num > BN_SMALL_MAX_WORDS) {
+void bn_from_montgomery_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a,
+                              size_t num_a, const BN_MONT_CTX *mont) {
+  if (num_r != (size_t)mont->N.width || num_r > BN_SMALL_MAX_WORDS ||
+      num_a > 2 * num_r) {
     abort();
   }
-  BN_ULONG tmp[BN_SMALL_MAX_WORDS * 2];
-  OPENSSL_memcpy(tmp, a, num * sizeof(BN_ULONG));
-  OPENSSL_memset(tmp + num, 0, num * sizeof(BN_ULONG));
-  if (!bn_from_montgomery_in_place(r, num, tmp, 2 * num, mont)) {
+  BN_ULONG tmp[BN_SMALL_MAX_WORDS * 2] = {0};
+  OPENSSL_memcpy(tmp, a, num_a * sizeof(BN_ULONG));
+  if (!bn_from_montgomery_in_place(r, num_r, tmp, 2 * num_r, mont)) {
     abort();
   }
-  OPENSSL_cleanse(tmp, 2 * num * sizeof(BN_ULONG));
+  OPENSSL_cleanse(tmp, 2 * num_r * sizeof(BN_ULONG));
 }
 
 void bn_mod_mul_montgomery_small(BN_ULONG *r, const BN_ULONG *a,
diff --git a/crypto/fipsmodule/ec/ec_montgomery.c b/crypto/fipsmodule/ec/ec_montgomery.c
index 1e281a5..316e9d5 100644
--- a/crypto/fipsmodule/ec/ec_montgomery.c
+++ b/crypto/fipsmodule/ec/ec_montgomery.c
@@ -117,8 +117,8 @@
 static void ec_GFp_mont_felem_from_montgomery(const EC_GROUP *group,
                                               EC_FELEM *out,
                                               const EC_FELEM *in) {
-  bn_from_montgomery_small(out->words, in->words, group->field.width,
-                           group->mont);
+  bn_from_montgomery_small(out->words, group->field.width, in->words,
+                           group->field.width, group->mont);
 }
 
 static void ec_GFp_mont_felem_inv0(const EC_GROUP *group, EC_FELEM *out,
diff --git a/crypto/fipsmodule/ec/scalar.c b/crypto/fipsmodule/ec/scalar.c
index 595d3ff..aacefd2 100644
--- a/crypto/fipsmodule/ec/scalar.c
+++ b/crypto/fipsmodule/ec/scalar.c
@@ -98,7 +98,8 @@
 void ec_scalar_from_montgomery(const EC_GROUP *group, EC_SCALAR *r,
                                const EC_SCALAR *a) {
   const BIGNUM *order = &group->order;
-  bn_from_montgomery_small(r->words, a->words, order->width, group->order_mont);
+  bn_from_montgomery_small(r->words, order->width, a->words, order->width,
+                           group->order_mont);
 }
 
 void ec_scalar_mul_montgomery(const EC_GROUP *group, EC_SCALAR *r,