Compute mont->RR in constant-time.

Use the now constant-time modular arithmetic functions.

Bug: 236
Change-Id: I4567d67bfe62ca82ec295f2233d1a6c9b131e5d2
Reviewed-on: https://boringssl-review.googlesource.com/25285
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bn/internal.h b/crypto/fipsmodule/bn/internal.h
index addc4bb..b09388a 100644
--- a/crypto/fipsmodule/bn/internal.h
+++ b/crypto/fipsmodule/bn/internal.h
@@ -312,7 +312,13 @@
                 const BN_ULONG *np, const BN_ULONG *n0, int num);
 
 uint64_t bn_mont_n0(const BIGNUM *n);
-int bn_mod_exp_base_2_vartime(BIGNUM *r, unsigned p, const BIGNUM *n);
+
+// bn_mod_exp_base_2_consttime calculates r = 2**p (mod n). |p| must be larger
+// than log_2(n); i.e. 2**p must be larger than |n|. |n| must be positive and
+// odd. |p| and the bit width of |n| are assumed public, but |n| is otherwise
+// treated as secret.
+int bn_mod_exp_base_2_consttime(BIGNUM *r, unsigned p, const BIGNUM *n,
+                                BN_CTX *ctx);
 
 #if defined(OPENSSL_X86_64) && defined(_MSC_VER)
 #define BN_UMULT_LOHI(low, high, a, b) ((low) = _umul128((a), (b), &(high)))
diff --git a/crypto/fipsmodule/bn/montgomery.c b/crypto/fipsmodule/bn/montgomery.c
index 720cef7..c21a030 100644
--- a/crypto/fipsmodule/bn/montgomery.c
+++ b/crypto/fipsmodule/bn/montgomery.c
@@ -208,19 +208,24 @@
   mont->n0[1] = 0;
 #endif
 
+  BN_CTX *new_ctx = NULL;
+  if (ctx == NULL) {
+    new_ctx = BN_CTX_new();
+    if (new_ctx == NULL) {
+      return 0;
+    }
+    ctx = new_ctx;
+  }
+
   // Save RR = R**2 (mod N). R is the smallest power of 2**BN_BITS2 such that R
   // > mod. Even though the assembly on some 32-bit platforms works with 64-bit
   // values, using |BN_BITS2| here, rather than |BN_MONT_CTX_N0_LIMBS *
   // BN_BITS2|, is correct because R**2 will still be a multiple of the latter
   // as |BN_MONT_CTX_N0_LIMBS| is either one or two.
-  //
-  // XXX: This is not constant time with respect to |mont->N|, but it should be.
   unsigned lgBigR = mont->N.width * BN_BITS2;
-  if (!bn_mod_exp_base_2_vartime(&mont->RR, lgBigR * 2, &mont->N)) {
-    return 0;
-  }
-
-  return 1;
+  int ok = bn_mod_exp_base_2_consttime(&mont->RR, lgBigR * 2, &mont->N, ctx);
+  BN_CTX_free(new_ctx);
+  return ok;
 }
 
 BN_MONT_CTX *BN_MONT_CTX_new_for_modulus(const BIGNUM *mod, BN_CTX *ctx) {
diff --git a/crypto/fipsmodule/bn/montgomery_inv.c b/crypto/fipsmodule/bn/montgomery_inv.c
index f21d045..15e62e4 100644
--- a/crypto/fipsmodule/bn/montgomery_inv.c
+++ b/crypto/fipsmodule/bn/montgomery_inv.c
@@ -159,10 +159,8 @@
   return v;
 }
 
-// bn_mod_exp_base_2_vartime calculates r = 2**p (mod n). |p| must be larger
-// than log_2(n); i.e. 2**p must be larger than |n|. |n| must be positive and
-// odd.
-int bn_mod_exp_base_2_vartime(BIGNUM *r, unsigned p, const BIGNUM *n) {
+int bn_mod_exp_base_2_consttime(BIGNUM *r, unsigned p, const BIGNUM *n,
+                                BN_CTX *ctx) {
   assert(!BN_is_zero(n));
   assert(!BN_is_negative(n));
   assert(BN_is_odd(n));
@@ -171,37 +169,17 @@
 
   unsigned n_bits = BN_num_bits(n);
   assert(n_bits != 0);
+  assert(p > n_bits);
   if (n_bits == 1) {
     return 1;
   }
 
-  // Set |r| to the smallest power of two larger than |n|.
-  assert(p > n_bits);
-  if (!BN_set_bit(r, n_bits)) {
+  // Set |r| to the larger power of two smaller than |n|, then shift with
+  // reductions the rest of the way.
+  if (!BN_set_bit(r, n_bits - 1) ||
+      !bn_mod_lshift_quick_ctx(r, r, p - (n_bits - 1), n, ctx)) {
     return 0;
   }
 
-  // Unconditionally reduce |r|.
-  assert(BN_cmp(r, n) > 0);
-  if (!BN_usub(r, r, n)) {
-    return 0;
-  }
-  assert(BN_cmp(r, n) < 0);
-
-  for (unsigned i = n_bits; i < p; ++i) {
-    // This is like |BN_mod_lshift1_quick| except using |BN_usub|.
-    //
-    // TODO: Replace this with the use of a constant-time variant of
-    // |BN_mod_lshift1_quick|.
-    if (!BN_lshift1(r, r)) {
-      return 0;
-    }
-    if (BN_cmp(r, n) >= 0) {
-      if (!BN_usub(r, r, n)) {
-        return 0;
-      }
-    }
-  }
-
   return 1;
 }