Add "small" variants of Montgomery logic.

These use the square and multiply functions added earlier.

Change-Id: I723834f9a227a9983b752504a2d7ce0223c43d24
Reviewed-on: https://boringssl-review.googlesource.com/23070
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc
index 7376989..d2a875b 100644
--- a/crypto/fipsmodule/bn/bn_test.cc
+++ b/crypto/fipsmodule/bn/bn_test.cc
@@ -542,15 +542,39 @@
     ASSERT_TRUE(a_tmp);
     ASSERT_TRUE(b_tmp);
     ASSERT_TRUE(BN_MONT_CTX_set(mont.get(), m.get(), ctx));
-    ASSERT_TRUE(BN_nnmod(a_tmp.get(), a.get(), m.get(), ctx));
-    ASSERT_TRUE(BN_nnmod(b_tmp.get(), b.get(), m.get(), ctx));
-    ASSERT_TRUE(BN_to_montgomery(a_tmp.get(), a_tmp.get(), mont.get(), ctx));
-    ASSERT_TRUE(BN_to_montgomery(b_tmp.get(), b_tmp.get(), mont.get(), ctx));
+    ASSERT_TRUE(BN_nnmod(a.get(), a.get(), m.get(), ctx));
+    ASSERT_TRUE(BN_nnmod(b.get(), b.get(), m.get(), ctx));
+    ASSERT_TRUE(BN_to_montgomery(a_tmp.get(), a.get(), mont.get(), ctx));
+    ASSERT_TRUE(BN_to_montgomery(b_tmp.get(), b.get(), mont.get(), ctx));
     ASSERT_TRUE(BN_mod_mul_montgomery(ret.get(), a_tmp.get(), b_tmp.get(),
                                       mont.get(), ctx));
     ASSERT_TRUE(BN_from_montgomery(ret.get(), ret.get(), mont.get(), ctx));
     EXPECT_BIGNUMS_EQUAL("A * B (mod M) (Montgomery)", mod_mul.get(),
                          ret.get());
+
+#if !defined(BORINGSSL_SHARED_LIBRARY)
+    if (m->top <= BN_SMALL_MAX_WORDS) {
+      std::unique_ptr<BN_ULONG[]> a_words(new BN_ULONG[m->top]),
+          b_words(new BN_ULONG[m->top]), r_words(new BN_ULONG[m->top]);
+      OPENSSL_memset(a_words.get(), 0, m->top * sizeof(BN_ULONG));
+      OPENSSL_memcpy(a_words.get(), a->d, a->top * sizeof(BN_ULONG));
+      OPENSSL_memset(b_words.get(), 0, m->top * sizeof(BN_ULONG));
+      OPENSSL_memcpy(b_words.get(), b->d, b->top * sizeof(BN_ULONG));
+      ASSERT_TRUE(bn_to_montgomery_small(a_words.get(), m->top, a_words.get(),
+                                         m->top, mont.get()));
+      ASSERT_TRUE(bn_to_montgomery_small(b_words.get(), m->top, b_words.get(),
+                                         m->top, mont.get()));
+      ASSERT_TRUE(bn_mod_mul_montgomery_small(
+          r_words.get(), m->top, a_words.get(), m->top, b_words.get(), m->top,
+          mont.get()));
+      // Use the second half of |tmp| so ASan will catch out-of-bounds writes.
+      ASSERT_TRUE(bn_from_montgomery_small(r_words.get(), m->top, r_words.get(),
+                                           m->top, mont.get()));
+      ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m->top));
+      EXPECT_BIGNUMS_EQUAL("A * B (mod M) (Montgomery, words)", mod_mul.get(),
+                           ret.get());
+    }
+#endif
   }
 }
 
@@ -581,8 +605,8 @@
     ASSERT_TRUE(mont);
     ASSERT_TRUE(a_tmp);
     ASSERT_TRUE(BN_MONT_CTX_set(mont.get(), m.get(), ctx));
-    ASSERT_TRUE(BN_nnmod(a_tmp.get(), a.get(), m.get(), ctx));
-    ASSERT_TRUE(BN_to_montgomery(a_tmp.get(), a_tmp.get(), mont.get(), ctx));
+    ASSERT_TRUE(BN_nnmod(a.get(), a.get(), m.get(), ctx));
+    ASSERT_TRUE(BN_to_montgomery(a_tmp.get(), a.get(), mont.get(), ctx));
     ASSERT_TRUE(BN_mod_mul_montgomery(ret.get(), a_tmp.get(), a_tmp.get(),
                                       mont.get(), ctx));
     ASSERT_TRUE(BN_from_montgomery(ret.get(), ret.get(), mont.get(), ctx));
@@ -596,6 +620,38 @@
     ASSERT_TRUE(BN_from_montgomery(ret.get(), ret.get(), mont.get(), ctx));
     EXPECT_BIGNUMS_EQUAL("A * A_copy (mod M) (Montgomery)", mod_square.get(),
                          ret.get());
+
+#if !defined(BORINGSSL_SHARED_LIBRARY)
+    if (m->top <= BN_SMALL_MAX_WORDS) {
+      std::unique_ptr<BN_ULONG[]> a_words(new BN_ULONG[m->top]),
+          a_copy_words(new BN_ULONG[m->top]), r_words(new BN_ULONG[m->top]);
+      OPENSSL_memset(a_words.get(), 0, m->top * sizeof(BN_ULONG));
+      OPENSSL_memcpy(a_words.get(), a->d, a->top * sizeof(BN_ULONG));
+      ASSERT_TRUE(bn_to_montgomery_small(a_words.get(), m->top, a_words.get(),
+                                         m->top, mont.get()));
+      ASSERT_TRUE(bn_mod_mul_montgomery_small(
+          r_words.get(), m->top, a_words.get(), m->top, a_words.get(), m->top,
+          mont.get()));
+      ASSERT_TRUE(bn_from_montgomery_small(r_words.get(), m->top, r_words.get(),
+                                           m->top, mont.get()));
+      ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m->top));
+      EXPECT_BIGNUMS_EQUAL("A * A (mod M) (Montgomery, words)",
+                           mod_square.get(), ret.get());
+
+      // Repeat the operation with |a_copy_words|.
+      OPENSSL_memcpy(a_copy_words.get(), a_words.get(),
+                     m->top * sizeof(BN_ULONG));
+      ASSERT_TRUE(bn_mod_mul_montgomery_small(
+          r_words.get(), m->top, a_words.get(), m->top, a_copy_words.get(),
+          m->top, mont.get()));
+      // Use the second half of |tmp| so ASan will catch out-of-bounds writes.
+      ASSERT_TRUE(bn_from_montgomery_small(r_words.get(), m->top, r_words.get(),
+                                           m->top, mont.get()));
+      ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m->top));
+      EXPECT_BIGNUMS_EQUAL("A * A_copy (mod M) (Montgomery, words)",
+                           mod_square.get(), ret.get());
+    }
+#endif
   }
 }
 
diff --git a/crypto/fipsmodule/bn/internal.h b/crypto/fipsmodule/bn/internal.h
index 634435f..2f5dbbb 100644
--- a/crypto/fipsmodule/bn/internal.h
+++ b/crypto/fipsmodule/bn/internal.h
@@ -335,6 +335,38 @@
 // one on success and zero on programmer error.
 int bn_sqr_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a, size_t num_a);
 
+// In the following functions, the modulus must be at most |BN_SMALL_MAX_WORDS|
+// words long.
+
+// bn_to_montgomery_small sets |r| to |a| translated to the Montgomery domain.
+// |num_a| and |num_r| must be the length of the modulus, which is
+// |mont->N.top|. |a| must be fully reduced. This function returns one on
+// success and zero if lengths are inconsistent. |r| and |a| may alias.
+int bn_to_montgomery_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a,
+                           size_t num_a, const BN_MONT_CTX *mont);
+
+// bn_from_montgomery_small sets |r| to |a| translated out of the Montgomery
+// domain. |num_r| must be the length of the modulus, which is |mont->N.top|.
+// |a| must be at most |mont->N.top| * R and |num_a| must be at most 2 *
+// |mont->N.top|. This function returns one on success and zero if lengths are
+// inconsistent. |r| and |a| may alias.
+int bn_from_montgomery_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a,
+                             size_t num_a, const BN_MONT_CTX *mont);
+
+// bn_mod_mul_montgomery_small sets |r| to |a| * |b| mod |mont->N|. Both inputs
+// and outputs are in the Montgomery domain. |num_r| must be the length of the
+// modulus, which is |mont->N.top|. This function returns one on success and
+// zero on internal error or inconsistent lengths. Any two of |r|, |a|, and |b|
+// may alias.
+//
+// This function requires |a| * |b| < N * R, where N is the modulus and R is the
+// Montgomery divisor, 2^(N.top * BN_BITS2). This should generally be satisfied
+// by ensuring |a| and |b| are fully reduced, however ECDSA has one computation
+// which requires the more general bound.
+int bn_mod_mul_montgomery_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a,
+                                size_t num_a, const BN_ULONG *b, size_t num_b,
+                                const BN_MONT_CTX *mont);
+
 
 #if defined(__cplusplus)
 }  // extern C
diff --git a/crypto/fipsmodule/bn/montgomery.c b/crypto/fipsmodule/bn/montgomery.c
index f09ada8..e8505da 100644
--- a/crypto/fipsmodule/bn/montgomery.c
+++ b/crypto/fipsmodule/bn/montgomery.c
@@ -416,3 +416,68 @@
   BN_CTX_end(ctx);
   return ret;
 }
+
+int bn_to_montgomery_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a,
+                           size_t num_a, const BN_MONT_CTX *mont) {
+  return bn_mod_mul_montgomery_small(r, num_r, a, num_a, mont->RR.d,
+                                     mont->RR.top, mont);
+}
+
+int bn_from_montgomery_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a,
+                             size_t num_a, const BN_MONT_CTX *mont) {
+  size_t num_n = mont->N.top;
+  if (num_a > 2 * num_n || num_r != num_n || num_n > BN_SMALL_MAX_WORDS) {
+    OPENSSL_PUT_ERROR(BN, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+    return 0;
+  }
+  BN_ULONG tmp[BN_SMALL_MAX_WORDS * 2];
+  size_t num_tmp = 2 * num_n;
+  OPENSSL_memcpy(tmp, a, num_a * sizeof(BN_ULONG));
+  OPENSSL_memset(tmp + num_a, 0, (num_tmp - num_a) * sizeof(BN_ULONG));
+  int ret = bn_from_montgomery_in_place(r, num_r, tmp, num_tmp, mont);
+  OPENSSL_cleanse(tmp, num_tmp * sizeof(BN_ULONG));
+  return ret;
+}
+
+int bn_mod_mul_montgomery_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a,
+                                size_t num_a, const BN_ULONG *b, size_t num_b,
+                                const BN_MONT_CTX *mont) {
+  size_t num_n = mont->N.top;
+  if (num_r != num_n || num_a + num_b > 2 * num_n ||
+      num_n > BN_SMALL_MAX_WORDS) {
+    OPENSSL_PUT_ERROR(BN, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+    return 0;
+  }
+
+#if defined(OPENSSL_BN_ASM_MONT)
+  // |bn_mul_mont| requires at least 128 bits of limbs, at least for x86.
+  if (num_n >= (128 / BN_BITS2) &&
+      num_a == num_n &&
+      num_b == num_n) {
+    if (!bn_mul_mont(r, a, b, mont->N.d, mont->n0, num_n)) {
+      assert(0);  // The check above ensures this won't happen.
+      OPENSSL_PUT_ERROR(BN, ERR_R_INTERNAL_ERROR);
+      return 0;
+    }
+    return 1;
+  }
+#endif
+
+  // Compute the product.
+  BN_ULONG tmp[2 * BN_SMALL_MAX_WORDS];
+  size_t num_tmp = 2 * num_n;
+  size_t num_ab = num_a + num_b;
+  if (a == b && num_a == num_b) {
+    if (!bn_sqr_small(tmp, num_ab, a, num_a)) {
+      return 0;
+    }
+  } else if (!bn_mul_small(tmp, num_ab, a, num_a, b, num_b)) {
+    return 0;
+  }
+
+  // Zero-extend to full width and reduce.
+  OPENSSL_memset(tmp + num_ab, 0, (num_tmp - num_ab) * sizeof(BN_ULONG));
+  int ret = bn_from_montgomery_in_place(r, num_r, tmp, num_tmp, mont);
+  OPENSSL_cleanse(tmp, num_tmp * sizeof(BN_ULONG));
+  return ret;
+}