Make RSA_check_key more than 2x as fast.

The bulk of RSA_check_key is spent in bn_div_consttime, which is a naive
but constant-time long-division algorithm for the few places that divide
by a secret even divisor: RSA keygen and RSA import. RSA import is
somewhat performance-sensitive, so pick some low-hanging fruit:

The main observation is that, in all but one call site, the bit width of
the divisor is public. That means, for an N-bit divisor, we can skip the
first N-1 iterations of long division because an N-1-bit remainder
cannot exceed the N-bit divisor.

One minor nuisance is bn_lcm_consttime, used in RSA keygen has a case
that does *not* have a public bit width. Apply the optimization there
would leak information. I've implemented this as an optional public
lower bound on num_bits(divisor), which all but that call fills in.

Before:
Did 5060 RSA 2048 private key parse operations in 1058526us (4780.2 ops/sec)
Did 1551 RSA 4096 private key parse operations in 1082343us (1433.0 ops/sec)

After:
Did 11532 RSA 2048 private key parse operations in 1084145us (10637.0 ops/sec) [+122.5%]
Did 3542 RSA 4096 private key parse operations in 1036374us (3417.7 ops/sec) [+138.5%]

Bug: b/192484677
Change-Id: I893ebb8886aeb8200a1a365673b56c49774221a2
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/49106
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc
index 1b42d9c..72ec8c2 100644
--- a/crypto/fipsmodule/bn/bn_test.cc
+++ b/crypto/fipsmodule/bn/bn_test.cc
@@ -613,9 +613,17 @@
     }
   }
 
-  ASSERT_TRUE(bn_div_consttime(ret.get(), ret2.get(), a.get(), b.get(), ctx));
+  ASSERT_TRUE(bn_div_consttime(ret.get(), ret2.get(), a.get(), b.get(),
+                               /*divisor_min_bits=*/0, ctx));
   EXPECT_BIGNUMS_EQUAL("A / B (constant-time)", quotient.get(), ret.get());
   EXPECT_BIGNUMS_EQUAL("A % B (constant-time)", remainder.get(), ret2.get());
+
+  ASSERT_TRUE(bn_div_consttime(ret.get(), ret2.get(), a.get(), b.get(),
+                               /*divisor_min_bits=*/BN_num_bits(b.get()), ctx));
+  EXPECT_BIGNUMS_EQUAL("A / B (constant-time, public width)", quotient.get(),
+                       ret.get());
+  EXPECT_BIGNUMS_EQUAL("A % B (constant-time, public width)", remainder.get(),
+                       ret2.get());
 }
 
 static void TestModMul(BIGNUMFileTest *t, BN_CTX *ctx) {
diff --git a/crypto/fipsmodule/bn/div.c b/crypto/fipsmodule/bn/div.c
index 6ee6dbd..02b9931 100644
--- a/crypto/fipsmodule/bn/div.c
+++ b/crypto/fipsmodule/bn/div.c
@@ -456,7 +456,7 @@
 
 int bn_div_consttime(BIGNUM *quotient, BIGNUM *remainder,
                      const BIGNUM *numerator, const BIGNUM *divisor,
-                     BN_CTX *ctx) {
+                     unsigned divisor_min_bits, BN_CTX *ctx) {
   if (BN_is_negative(numerator) || BN_is_negative(divisor)) {
     OPENSSL_PUT_ERROR(BN, BN_R_NEGATIVE_NUMBER);
     return 0;
@@ -496,8 +496,26 @@
   r->neg = 0;
 
   // Incorporate |numerator| into |r|, one bit at a time, reducing after each
-  // step. At the start of each loop iteration, |r| < |divisor|
-  for (int i = numerator->width - 1; i >= 0; i--) {
+  // step. We maintain the invariant that |0 <= r < divisor| and
+  // |q * divisor + r = n| where |n| is the portion of |numerator| incorporated
+  // so far.
+  //
+  // First, we short-circuit the loop: if we know |divisor| has at least
+  // |divisor_min_bits| bits, the top |divisor_min_bits - 1| can be incorporated
+  // without reductions. This significantly speeds up |RSA_check_key|. For
+  // simplicity, we round down to a whole number of words.
+  assert(divisor_min_bits <= BN_num_bits(divisor));
+  int initial_words = 0;
+  if (divisor_min_bits > 0) {
+    initial_words = (divisor_min_bits - 1) / BN_BITS2;
+    if (initial_words > numerator->width) {
+      initial_words = numerator->width;
+    }
+    OPENSSL_memcpy(r->d, numerator->d + numerator->width - initial_words,
+                   initial_words * sizeof(BN_ULONG));
+  }
+
+  for (int i = numerator->width - initial_words - 1; i >= 0; i--) {
     for (int bit = BN_BITS2 - 1; bit >= 0; bit--) {
       // Incorporate the next bit of the numerator, by computing
       // r = 2*r or 2*r + 1. Note the result fits in one more word. We store the
diff --git a/crypto/fipsmodule/bn/gcd_extra.c b/crypto/fipsmodule/bn/gcd_extra.c
index 30540e3..53ab170 100644
--- a/crypto/fipsmodule/bn/gcd_extra.c
+++ b/crypto/fipsmodule/bn/gcd_extra.c
@@ -157,10 +157,11 @@
   BN_CTX_start(ctx);
   unsigned shift;
   BIGNUM *gcd = BN_CTX_get(ctx);
-  int ret = gcd != NULL &&
+  int ret = gcd != NULL &&  //
             bn_mul_consttime(r, a, b, ctx) &&
             bn_gcd_consttime(gcd, &shift, a, b, ctx) &&
-            bn_div_consttime(r, NULL, r, gcd, ctx) &&
+            // |gcd| has a secret bit width.
+            bn_div_consttime(r, NULL, r, gcd, /*divisor_min_bits=*/0, ctx) &&
             bn_rshift_secret_shift(r, r, shift, ctx);
   BN_CTX_end(ctx);
   return ret;
diff --git a/crypto/fipsmodule/bn/internal.h b/crypto/fipsmodule/bn/internal.h
index 3d368db..cab9a81 100644
--- a/crypto/fipsmodule/bn/internal.h
+++ b/crypto/fipsmodule/bn/internal.h
@@ -552,12 +552,15 @@
 // bn_div_consttime behaves like |BN_div|, but it rejects negative inputs and
 // treats both inputs, including their magnitudes, as secret. It is, as a
 // result, much slower than |BN_div| and should only be used for rare operations
-// where Montgomery reduction is not available.
+// where Montgomery reduction is not available. |divisor_min_bits| is a
+// public lower bound for |BN_num_bits(divisor)|. When |divisor|'s bit width is
+// public, this can speed up the operation.
 //
 // Note that |quotient->width| will be set pessimally to |numerator->width|.
 OPENSSL_EXPORT int bn_div_consttime(BIGNUM *quotient, BIGNUM *remainder,
                                     const BIGNUM *numerator,
-                                    const BIGNUM *divisor, BN_CTX *ctx);
+                                    const BIGNUM *divisor,
+                                    unsigned divisor_min_bits, BN_CTX *ctx);
 
 // bn_is_relatively_prime checks whether GCD(|x|, |y|) is one. On success, it
 // returns one and sets |*out_relatively_prime| to one if the GCD was one and
diff --git a/crypto/fipsmodule/rsa/rsa.c b/crypto/fipsmodule/rsa/rsa.c
index f6d3640..fd84cba 100644
--- a/crypto/fipsmodule/rsa/rsa.c
+++ b/crypto/fipsmodule/rsa/rsa.c
@@ -657,7 +657,8 @@
 }
 
 static int check_mod_inverse(int *out_ok, const BIGNUM *a, const BIGNUM *ainv,
-                             const BIGNUM *m, BN_CTX *ctx) {
+                             const BIGNUM *m, unsigned m_min_bits,
+                             BN_CTX *ctx) {
   if (BN_is_negative(ainv) || BN_cmp(ainv, m) >= 0) {
     *out_ok = 0;
     return 1;
@@ -670,7 +671,7 @@
   BIGNUM *tmp = BN_CTX_get(ctx);
   int ret = tmp != NULL &&
             bn_mul_consttime(tmp, a, ainv, ctx) &&
-            bn_div_consttime(NULL, tmp, tmp, m, ctx);
+            bn_div_consttime(NULL, tmp, tmp, m, m_min_bits, ctx);
   if (ret) {
     *out_ok = BN_is_one(tmp);
   }
@@ -750,10 +751,15 @@
   // simply check that d * e is one mod p-1 and mod q-1. Note d and e were bound
   // by earlier checks in this function.
   if (!bn_usub_consttime(&pm1, key->p, BN_value_one()) ||
-      !bn_usub_consttime(&qm1, key->q, BN_value_one()) ||
-      !bn_mul_consttime(&de, key->d, key->e, ctx) ||
-      !bn_div_consttime(NULL, &tmp, &de, &pm1, ctx) ||
-      !bn_div_consttime(NULL, &de, &de, &qm1, ctx)) {
+      !bn_usub_consttime(&qm1, key->q, BN_value_one())) {
+    OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
+    goto out;
+  }
+  const unsigned pm1_bits = BN_num_bits(&pm1);
+  const unsigned qm1_bits = BN_num_bits(&qm1);
+  if (!bn_mul_consttime(&de, key->d, key->e, ctx) ||
+      !bn_div_consttime(NULL, &tmp, &de, &pm1, pm1_bits, ctx) ||
+      !bn_div_consttime(NULL, &de, &de, &qm1, qm1_bits, ctx)) {
     OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
     goto out;
   }
@@ -772,9 +778,12 @@
 
   if (has_crt_values) {
     int dmp1_ok, dmq1_ok, iqmp_ok;
-    if (!check_mod_inverse(&dmp1_ok, key->e, key->dmp1, &pm1, ctx) ||
-        !check_mod_inverse(&dmq1_ok, key->e, key->dmq1, &qm1, ctx) ||
-        !check_mod_inverse(&iqmp_ok, key->q, key->iqmp, key->p, ctx)) {
+    if (!check_mod_inverse(&dmp1_ok, key->e, key->dmp1, &pm1, pm1_bits, ctx) ||
+        !check_mod_inverse(&dmq1_ok, key->e, key->dmq1, &qm1, qm1_bits, ctx) ||
+        // |p| is odd, so |pm1| and |p| have the same bit width. If they didn't,
+        // we only need a lower bound anyway.
+        !check_mod_inverse(&iqmp_ok, key->q, key->iqmp, key->p, pm1_bits,
+                           ctx)) {
       OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
       goto out;
     }
diff --git a/crypto/fipsmodule/rsa/rsa_impl.c b/crypto/fipsmodule/rsa/rsa_impl.c
index 6dd89b9..a6865c0 100644
--- a/crypto/fipsmodule/rsa/rsa_impl.c
+++ b/crypto/fipsmodule/rsa/rsa_impl.c
@@ -1262,12 +1262,14 @@
     // values for d.
   } while (BN_cmp(rsa->d, pow2_prime_bits) <= 0);
 
+  assert(BN_num_bits(pm1) == (unsigned)prime_bits);
+  assert(BN_num_bits(qm1) == (unsigned)prime_bits);
   if (// Calculate n.
       !bn_mul_consttime(rsa->n, rsa->p, rsa->q, ctx) ||
       // Calculate d mod (p-1).
-      !bn_div_consttime(NULL, rsa->dmp1, rsa->d, pm1, ctx) ||
+      !bn_div_consttime(NULL, rsa->dmp1, rsa->d, pm1, prime_bits, ctx) ||
       // Calculate d mod (q-1)
-      !bn_div_consttime(NULL, rsa->dmq1, rsa->d, qm1, ctx)) {
+      !bn_div_consttime(NULL, rsa->dmq1, rsa->d, qm1, prime_bits, ctx)) {
     goto bn_err;
   }
   bn_set_minimal_width(rsa->n);