Add bn_mod_exp_mont_small and bn_mod_inverse_prime_mont_small.

These can be used to invert values in ECDSA. Unlike their BIGNUM
counterparts, the caller is responsible for taking values in and out of
Montgomery domain. This will save some work later on in the ECDSA
computation.

Change-Id: Ib7292900a0fdeedce6cb3e9a9123c94863659043
Reviewed-on: https://boringssl-review.googlesource.com/23071
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc
index d2a875b..8e156ad 100644
--- a/crypto/fipsmodule/bn/bn_test.cc
+++ b/crypto/fipsmodule/bn/bn_test.cc
@@ -680,6 +680,28 @@
                                           ctx, NULL));
     EXPECT_BIGNUMS_EQUAL("A ^ E (mod M) (constant-time)", mod_exp.get(),
                          ret.get());
+
+#if !defined(BORINGSSL_SHARED_LIBRARY)
+    if (m->top <= BN_SMALL_MAX_WORDS) {
+      bssl::UniquePtr<BN_MONT_CTX> mont(BN_MONT_CTX_new());
+      ASSERT_TRUE(mont.get());
+      ASSERT_TRUE(BN_MONT_CTX_set(mont.get(), m.get(), ctx));
+      ASSERT_TRUE(BN_nnmod(a.get(), a.get(), m.get(), ctx));
+      std::unique_ptr<BN_ULONG[]> r_words(new BN_ULONG[m->top]),
+          a_words(new BN_ULONG[m->top]);
+      OPENSSL_memset(a_words.get(), 0, m->top * sizeof(BN_ULONG));
+      OPENSSL_memcpy(a_words.get(), a->d, a->top * sizeof(BN_ULONG));
+      ASSERT_TRUE(bn_to_montgomery_small(a_words.get(), m->top, a_words.get(),
+                                         m->top, mont.get()));
+      ASSERT_TRUE(bn_mod_exp_mont_small(r_words.get(), m->top, a_words.get(),
+                                        m->top, e->d, e->top, mont.get()));
+      ASSERT_TRUE(bn_from_montgomery_small(r_words.get(), m->top, r_words.get(),
+                                           m->top, mont.get()));
+      ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m->top));
+      EXPECT_BIGNUMS_EQUAL("A ^ E (mod M) (Montgomery, words)", mod_exp.get(),
+                           ret.get());
+    }
+#endif
   }
 }
 
diff --git a/crypto/fipsmodule/bn/exponentiation.c b/crypto/fipsmodule/bn/exponentiation.c
index 2d40e8f..a5cb7da 100644
--- a/crypto/fipsmodule/bn/exponentiation.c
+++ b/crypto/fipsmodule/bn/exponentiation.c
@@ -434,6 +434,15 @@
 // value returned from |BN_window_bits_for_exponent_size|.
 #define TABLE_SIZE 32
 
+// TABLE_BITS_SMALL is the smallest value returned from
+// |BN_window_bits_for_exponent_size| when |b| is at most |BN_BITS2| *
+// |BN_SMALL_MAX_WORDS| words.
+#define TABLE_BITS_SMALL 5
+
+// TABLE_SIZE_SMALL is the same as |TABLE_SIZE|, but when |b| is at most
+// |BN_BITS2| * |BN_SMALL_MAX_WORDS|.
+#define TABLE_SIZE_SMALL (1 << (TABLE_BITS_SMALL - 1))
+
 static int mod_exp_recp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
                         const BIGNUM *m, BN_CTX *ctx) {
   int i, j, bits, ret = 0, wstart, window;
@@ -734,6 +743,152 @@
   return ret;
 }
 
+int bn_mod_exp_mont_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a,
+                          size_t num_a, const BN_ULONG *p, size_t num_p,
+                          const BN_MONT_CTX *mont) {
+  const BN_ULONG *n = mont->N.d;
+  size_t num_n = mont->N.top;
+  if (num_n != num_a || num_n != num_r || num_n > BN_SMALL_MAX_WORDS) {
+    OPENSSL_PUT_ERROR(BN, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+    return 0;
+  }
+  if (!BN_is_odd(&mont->N)) {
+    OPENSSL_PUT_ERROR(BN, BN_R_CALLED_WITH_EVEN_MODULUS);
+    return 0;
+  }
+  unsigned bits = 0;
+  if (num_p != 0) {
+    bits = BN_num_bits_word(p[num_p - 1]) + (num_p - 1) * BN_BITS2;
+  }
+  if (bits == 0) {
+    OPENSSL_memset(r, 0, num_r * sizeof(BN_ULONG));
+    if (!BN_is_one(&mont->N)) {
+      r[0] = 1;
+    }
+    return 1;
+  }
+
+  // We exponentiate by looking at sliding windows of the exponent and
+  // precomputing powers of |a|. Windows may be shifted so they always end on a
+  // set bit, so only precompute odd powers. We compute val[i] = a^(2*i + 1) for
+  // i = 0 to 2^(window-1), all in Montgomery form.
+  unsigned window = BN_window_bits_for_exponent_size(bits);
+  if (window > TABLE_BITS_SMALL) {
+    window = TABLE_BITS_SMALL;  // Tolerate excessively large |p|.
+  }
+  int ret = 0;
+  BN_ULONG val[TABLE_SIZE_SMALL][BN_SMALL_MAX_WORDS];
+  OPENSSL_memcpy(val[0], a, num_n * sizeof(BN_ULONG));
+  if (window > 1) {
+    BN_ULONG d[BN_SMALL_MAX_WORDS];
+    if (!bn_mod_mul_montgomery_small(d, num_n, val[0], num_n, val[0], num_n,
+                                     mont)) {
+      goto err;
+    }
+    for (unsigned i = 1; i < 1u << (window - 1); i++) {
+      if (!bn_mod_mul_montgomery_small(val[i], num_n, val[i - 1], num_n, d,
+                                       num_n, mont)) {
+        goto err;
+      }
+    }
+  }
+
+  // Set |r| to one in Montgomery form. If the high bit of |m| is set, |m| is
+  // close to R and we subtract rather than perform Montgomery reduction.
+  if (n[num_n - 1] & (((BN_ULONG)1) << (BN_BITS2 - 1))) {
+    // r = 2^(top*BN_BITS2) - m
+    r[0] = 0 - n[0];
+    for (size_t i = 1; i < num_n; i++) {
+      r[i] = ~n[i];
+    }
+  } else if (!bn_from_montgomery_small(r, num_r, mont->RR.d, mont->RR.top,
+                                       mont)) {
+    goto err;
+  }
+
+  int r_is_one = 1;
+  unsigned wstart = bits - 1;  // The top bit of the window.
+  for (;;) {
+    if (!bn_is_bit_set_words(p, num_p, wstart)) {
+      if (!r_is_one &&
+          !bn_mod_mul_montgomery_small(r, num_r, r, num_r, r, num_r, mont)) {
+        goto err;
+      }
+      if (wstart == 0) {
+        break;
+      }
+      wstart--;
+      continue;
+    }
+
+    // We now have wstart on a set bit. Find the largest window we can use.
+    unsigned wvalue = 1;
+    unsigned wsize = 0;
+    for (unsigned i = 1; i < window && i <= wstart; i++) {
+      if (bn_is_bit_set_words(p, num_p, wstart - i)) {
+        wvalue <<= (i - wsize);
+        wvalue |= 1;
+        wsize = i;
+      }
+    }
+
+    // Shift |r| to the end of the window.
+    if (!r_is_one) {
+      for (unsigned i = 0; i < wsize + 1; i++) {
+        if (!bn_mod_mul_montgomery_small(r, num_r, r, num_r, r, num_r, mont)) {
+          goto err;
+        }
+      }
+    }
+
+    assert(wvalue & 1);
+    assert(wvalue < (1u << window));
+    if (!bn_mod_mul_montgomery_small(r, num_r, r, num_r, val[wvalue >> 1],
+                                     num_n, mont)) {
+      goto err;
+    }
+
+    r_is_one = 0;
+    if (wstart == wsize) {
+      break;
+    }
+    wstart -= wsize + 1;
+  }
+
+  ret = 1;
+
+err:
+  OPENSSL_cleanse(val, sizeof(val));
+  return ret;
+}
+
+int bn_mod_inverse_prime_mont_small(BN_ULONG *r, size_t num_r,
+                                    const BN_ULONG *a, size_t num_a,
+                                    const BN_MONT_CTX *mont) {
+  const BN_ULONG *p = mont->N.d;
+  size_t num_p = mont->N.top;
+  if (num_p > BN_SMALL_MAX_WORDS || num_p == 0) {
+    OPENSSL_PUT_ERROR(BN, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED);
+    return 0;
+  }
+
+  // Per Fermat's Little Theorem, a^-1 = a^(p-2) (mod p) for p prime.
+  BN_ULONG p_minus_two[BN_SMALL_MAX_WORDS];
+  OPENSSL_memcpy(p_minus_two, p, num_p * sizeof(BN_ULONG));
+  if (p_minus_two[0] >= 2) {
+    p_minus_two[0] -= 2;
+  } else {
+    p_minus_two[0] -= 2;
+    for (size_t i = 1; i < num_p; i++) {
+      if (p_minus_two[i]-- != 0) {
+        break;
+      }
+    }
+  }
+
+  return bn_mod_exp_mont_small(r, num_r, a, num_a, p_minus_two, num_p, mont);
+}
+
 
 // |BN_mod_exp_mont_consttime| stores the precomputed powers in a specific
 // layout so that accessing any of these table values shows the same access
diff --git a/crypto/fipsmodule/bn/internal.h b/crypto/fipsmodule/bn/internal.h
index 2f5dbbb..8fa8ed2 100644
--- a/crypto/fipsmodule/bn/internal.h
+++ b/crypto/fipsmodule/bn/internal.h
@@ -307,6 +307,10 @@
 // -2 on error.
 int bn_jacobi(const BIGNUM *a, const BIGNUM *b, BN_CTX *ctx);
 
+// bn_is_bit_set_words returns one if bit |bit| is set in |a| and zero
+// otherwise.
+int bn_is_bit_set_words(const BN_ULONG *a, size_t num, unsigned bit);
+
 
 // Low-level operations for small numbers.
 //
@@ -367,6 +371,29 @@
                                 size_t num_a, const BN_ULONG *b, size_t num_b,
                                 const BN_MONT_CTX *mont);
 
+// bn_mod_exp_mont_small sets |r| to |a|^|p| mod |mont->N|. It returns one on
+// success and zero on programmer or internal error. Both inputs and outputs are
+// in the Montgomery domain. |num_r| and |num_a| must be |mont->N.top|, which
+// must be at most |BN_SMALL_MAX_WORDS|. |a| must be fully-reduced. This
+// function runs in time independent of |a|, but |p| and |mont->N| are public
+// values.
+//
+// Note this function differs from |BN_mod_exp_mont| which uses Montgomery
+// reduction but takes input and output outside the Montgomery domain. Combine
+// this function with |bn_from_montgomery_small| and |bn_to_montgomery_small|
+// if necessary.
+int bn_mod_exp_mont_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a,
+                          size_t num_a, const BN_ULONG *p, size_t num_p,
+                          const BN_MONT_CTX *mont);
+
+// bn_mod_inverse_prime_mont_small sets |r| to |a|^-1 mod |mont->N|. |mont->N|
+// must be a prime. |num_r| and |num_a| must be |mont->N.top|, which must be at
+// most |BN_SMALL_MAX_WORDS|. |a| must be fully-reduced. This function runs in
+// time independent of |a|, but |mont->N| is a public value.
+int bn_mod_inverse_prime_mont_small(BN_ULONG *r, size_t num_r,
+                                    const BN_ULONG *a, size_t num_a,
+                                    const BN_MONT_CTX *mont);
+
 
 #if defined(__cplusplus)
 }  // extern C
diff --git a/crypto/fipsmodule/bn/shift.c b/crypto/fipsmodule/bn/shift.c
index d4528e6..d4ed79e 100644
--- a/crypto/fipsmodule/bn/shift.c
+++ b/crypto/fipsmodule/bn/shift.c
@@ -267,17 +267,20 @@
   return 1;
 }
 
+int bn_is_bit_set_words(const BN_ULONG *a, size_t num, unsigned bit) {
+  unsigned i = bit / BN_BITS2;
+  unsigned j = bit % BN_BITS2;
+  if (i >= num) {
+    return 0;
+  }
+  return (a[i] >> j) & 1;
+}
+
 int BN_is_bit_set(const BIGNUM *a, int n) {
   if (n < 0) {
     return 0;
   }
-  int i = n / BN_BITS2;
-  int j = n % BN_BITS2;
-  if (a->top <= i) {
-    return 0;
-  }
-
-  return (a->d[i]>>j)&1;
+  return bn_is_bit_set_words(a->d, a->top, n);
 }
 
 int BN_mask_bits(BIGNUM *a, int n) {