Make primality testing mostly constant-time.

The extra details in Enhanced Rabin-Miller are only used in
RSA_check_key_fips, on the public RSA modulus, which the static linker
will drop in most of our consumers anyway. Implement normal Rabin-Miller
for RSA keygen and use Montgomery reduction so it runs in constant-time.

Note that we only need to avoid leaking information about the input if
it's a large prime. If the number ends up composite, or we find it in
our table of small primes, we can return immediately.

The leaks not addressed by this CL are:

- The difficulty of selecting |b| leaks information about |w|.
- The distribution of whether step 4.4 runs leaks information about w.
- We leak |a| (the largest power of two which divides w) everywhere.
- BN_mod_word in the trial division is not constant-time.

These will be resolved in follow-up changes.

Median of 29 RSA keygens: 0m0.521 -> 0m0.621s
(Accuracy beyond 0.1s is questionable.)

Bug: 238
Change-Id: I0cf0ff22079732a0a3ababfe352bb4327e95b879
Reviewed-on: https://boringssl-review.googlesource.com/25886
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc
index f73054d..e4faae9 100644
--- a/crypto/fipsmodule/bn/bn_test.cc
+++ b/crypto/fipsmodule/bn/bn_test.cc
@@ -1856,6 +1856,7 @@
   bssl::UniquePtr<BIGNUM> p(BN_new());
   ASSERT_TRUE(p);
   int is_probably_prime_1 = 0, is_probably_prime_2 = 0;
+  enum bn_primality_result_t result_3;
 
   const int max_prime = kPrimes[OPENSSL_ARRAY_SIZE(kPrimes)-1];
   size_t next_prime_index = 0;
@@ -1878,6 +1879,11 @@
         &is_probably_prime_2, p.get(), BN_prime_checks, ctx(),
         true /* do_trial_division */, nullptr /* callback */));
     EXPECT_EQ(is_prime ? 1 : 0, is_probably_prime_2);
+    if (i > 3 && i % 2 == 1) {
+      ASSERT_TRUE(BN_enhanced_miller_rabin_primality_test(
+          &result_3, p.get(), BN_prime_checks, ctx(), nullptr /* callback */));
+      EXPECT_EQ(is_prime, result_3 == bn_probably_prime);
+    }
   }
 
   // Negative numbers are not prime.
@@ -1920,6 +1926,10 @@
         &is_probably_prime_2, p.get(), BN_prime_checks, ctx(),
         true /* do_trial_division */, nullptr /* callback */));
     EXPECT_EQ(0, is_probably_prime_2);
+
+    ASSERT_TRUE(BN_enhanced_miller_rabin_primality_test(
+        &result_3, p.get(), BN_prime_checks, ctx(), nullptr /* callback */));
+    EXPECT_EQ(bn_composite, result_3);
   }
 }
 
diff --git a/crypto/fipsmodule/bn/prime.c b/crypto/fipsmodule/bn/prime.c
index a291f7a..d203522 100644
--- a/crypto/fipsmodule/bn/prime.c
+++ b/crypto/fipsmodule/bn/prime.c
@@ -461,20 +461,144 @@
   return found;
 }
 
-int BN_primality_test(int *is_probably_prime, const BIGNUM *candidate,
-                      int checks, BN_CTX *ctx, int do_trial_division,
+int BN_primality_test(int *is_probably_prime, const BIGNUM *w,
+                      int iterations, BN_CTX *ctx, int do_trial_division,
                       BN_GENCB *cb) {
-  switch (BN_is_prime_fasttest_ex(candidate, checks, ctx, do_trial_division, cb)) {
-    case 1:
-      *is_probably_prime = 1;
-      return 1;
-    case 0:
-      *is_probably_prime = 0;
-      return 1;
-    default:
-      *is_probably_prime = 0;
-      return 0;
+  *is_probably_prime = 0;
+
+  // To support RSA key generation, this function should treat |w| as secret if
+  // it is a large prime. Composite numbers are discarded, so they may return
+  // early.
+  //
+  // TODO(davidben): This function is getting better, but is not constant-time.
+
+  if (BN_cmp(w, BN_value_one()) <= 0) {
+    return 1;
   }
+
+  if (!BN_is_odd(w)) {
+    // The only even prime is two.
+    *is_probably_prime = BN_is_word(w, 2);
+    return 1;
+  }
+
+  // Miller-Rabin does not work for three.
+  if (BN_is_word(w, 3)) {
+    *is_probably_prime = 1;
+    return 1;
+  }
+
+  if (do_trial_division) {
+    // Perform additional trial division checks to discard small primes.
+    for (int i = 1; i < NUMPRIMES; i++) {
+      BN_ULONG mod = BN_mod_word(w, primes[i]);
+      if (mod == (BN_ULONG)-1) {
+        return 0;
+      }
+      if (mod == 0) {
+        *is_probably_prime = BN_is_word(w, primes[i]);
+        return 1;
+      }
+    }
+    if (!BN_GENCB_call(cb, 1, -1)) {
+      return 0;
+    }
+  }
+
+  if (iterations == BN_prime_checks) {
+    iterations = BN_prime_checks_for_size(BN_num_bits(w));
+  }
+
+  // See C.3.1 from FIPS 186-4.
+  int ret = 0;
+  BN_MONT_CTX *mont = NULL;
+  BN_CTX_start(ctx);
+  BIGNUM *w1 = BN_CTX_get(ctx);
+  if (w1 == NULL ||
+      !bn_usub_fixed(w1, w, BN_value_one())) {
+    goto err;
+  }
+
+  // Write w1 as m * 2^a (Steps 1 and 2).
+  int a = 0;
+  while (!BN_is_bit_set(w1, a)) {
+    a++;
+  }
+  BIGNUM *m = BN_CTX_get(ctx);
+  if (m == NULL ||
+      !BN_rshift(m, w1, a)) {
+    goto err;
+  }
+
+  // Montgomery setup for computations mod w. Additionally, compute 1 and w - 1
+  // in the Montgomery domain for later comparisons.
+  BIGNUM *b = BN_CTX_get(ctx);
+  BIGNUM *z = BN_CTX_get(ctx);
+  BIGNUM *one_mont = BN_CTX_get(ctx);
+  BIGNUM *w1_mont = BN_CTX_get(ctx);
+  mont = BN_MONT_CTX_new_for_modulus(w, ctx);
+  if (b == NULL || z == NULL || one_mont == NULL || w1_mont == NULL ||
+      mont == NULL ||
+      !bn_one_to_montgomery(one_mont, mont, ctx) ||
+      // w - 1 is -1 mod w, so we can compute it in the Montgomery domain, -R,
+      // with a subtraction. (|one_mont| cannot be zero.)
+      !bn_usub_fixed(w1_mont, w, one_mont)) {
+    goto err;
+  }
+
+  // The following loop performs in inner iteration of the Miller-Rabin
+  // Primality test (Step 4).
+  for (int i = 1; i <= iterations; i++) {
+    if (// Step 4.1-4.2
+        !BN_rand_range_ex(b, 2, w1) ||
+        // Step 4.3
+        !BN_mod_exp_mont_consttime(z, b, m, w, ctx, mont)) {
+      goto err;
+    }
+
+    // Step 4.4
+    if (BN_equal_consttime(z, BN_value_one()) ||
+        BN_equal_consttime(z, w1)) {
+      goto loop;
+    }
+
+    // Step 4.5. We use Montgomery-encoding for better performance and to avoid
+    // timing leaks.
+    if (!BN_to_montgomery(z, z, mont, ctx)) {
+      goto err;
+    }
+
+    for (int j = 1; j < a; j++) {
+      if (!BN_mod_mul_montgomery(z, z, z, mont, ctx)) {
+        goto err;
+      }
+      if (BN_equal_consttime(z, w1_mont)) {
+        goto loop;
+      }
+      if (BN_equal_consttime(z, one_mont)) {
+        break;
+      }
+    }
+
+    // Step 4.6
+    *is_probably_prime = 0;
+    ret = 1;
+    goto err;
+
+  loop:
+    // Step 4.7
+    if (!BN_GENCB_call(cb, 1, i)) {
+      goto err;
+    }
+  }
+
+  *is_probably_prime = 1;
+  ret = 1;
+
+err:
+  BN_MONT_CTX_free(mont);
+  BN_CTX_end(ctx);
+  return ret;
 }
 
 int BN_is_prime_ex(const BIGNUM *candidate, int checks, BN_CTX *ctx, BN_GENCB *cb) {
@@ -483,57 +607,12 @@
 
 int BN_is_prime_fasttest_ex(const BIGNUM *a, int checks, BN_CTX *ctx,
                             int do_trial_division, BN_GENCB *cb) {
-  if (BN_cmp(a, BN_value_one()) <= 0) {
-    return 0;
+  int is_probably_prime;
+  if (!BN_primality_test(&is_probably_prime, a, checks, ctx, do_trial_division,
+                         cb)) {
+    return -1;
   }
-
-  // first look for small factors
-  if (!BN_is_odd(a)) {
-    // a is even => a is prime if and only if a == 2
-    return BN_is_word(a, 2);
-  }
-
-  // Enhanced Miller-Rabin does not work for three.
-  if (BN_is_word(a, 3)) {
-    return 1;
-  }
-
-  if (do_trial_division) {
-    for (int i = 1; i < NUMPRIMES; i++) {
-      BN_ULONG mod = BN_mod_word(a, primes[i]);
-      if (mod == (BN_ULONG)-1) {
-        return -1;
-      }
-      if (mod == 0) {
-        return BN_is_word(a, primes[i]);
-      }
-    }
-
-    if (!BN_GENCB_call(cb, 1, -1)) {
-      return -1;
-    }
-  }
-
-  int ret = -1;
-  BN_CTX *ctx_allocated = NULL;
-  if (ctx == NULL) {
-    ctx_allocated = BN_CTX_new();
-    if (ctx_allocated == NULL) {
-      return -1;
-    }
-    ctx = ctx_allocated;
-  }
-
-  enum bn_primality_result_t result;
-  if (!BN_enhanced_miller_rabin_primality_test(&result, a, checks, ctx, cb)) {
-    goto err;
-  }
-
-  ret = (result == bn_probably_prime);
-
-err:
-  BN_CTX_free(ctx_allocated);
-  return ret;
+  return is_probably_prime;
 }
 
 int BN_enhanced_miller_rabin_primality_test(
@@ -585,7 +664,7 @@
     goto err;
   }
 
-  // Montgomery setup for computations mod A
+  // Montgomery setup for computations mod w
   mont = BN_MONT_CTX_new_for_modulus(w, ctx);
   if (mont == NULL) {
     goto err;
diff --git a/include/openssl/bn.h b/include/openssl/bn.h
index 0a844ed..eeb25a3 100644
--- a/include/openssl/bn.h
+++ b/include/openssl/bn.h
@@ -706,7 +706,7 @@
 // than the number-field sieve security level of |w| is used. It returns one on
 // success and zero on failure. If |cb| is not NULL, then it is called during
 // each iteration of the primality test.
-int BN_enhanced_miller_rabin_primality_test(
+OPENSSL_EXPORT int BN_enhanced_miller_rabin_primality_test(
     enum bn_primality_result_t *out_result, const BIGNUM *w, int iterations,
     BN_CTX *ctx, BN_GENCB *cb);