Make RSA_check_key more than 2x as fast.
The bulk of RSA_check_key is spent in bn_div_consttime, which is a naive
but constant-time long-division algorithm for the few places that divide
by a secret even divisor: RSA keygen and RSA import. RSA import is
somewhat performance-sensitive, so pick some low-hanging fruit:
The main observation is that, in all but one call site, the bit width of
the divisor is public. That means, for an N-bit divisor, we can skip the
first N-1 iterations of long division because an N-1-bit remainder
cannot exceed the N-bit divisor.
One minor nuisance is bn_lcm_consttime, used in RSA keygen has a case
that does *not* have a public bit width. Apply the optimization there
would leak information. I've implemented this as an optional public
lower bound on num_bits(divisor), which all but that call fills in.
Before:
Did 5060 RSA 2048 private key parse operations in 1058526us (4780.2 ops/sec)
Did 1551 RSA 4096 private key parse operations in 1082343us (1433.0 ops/sec)
After:
Did 11532 RSA 2048 private key parse operations in 1084145us (10637.0 ops/sec) [+122.5%]
Did 3542 RSA 4096 private key parse operations in 1036374us (3417.7 ops/sec) [+138.5%]
Bug: b/192484677
Change-Id: I893ebb8886aeb8200a1a365673b56c49774221a2
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/49106
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc
index 1b42d9c..72ec8c2 100644
--- a/crypto/fipsmodule/bn/bn_test.cc
+++ b/crypto/fipsmodule/bn/bn_test.cc
@@ -613,9 +613,17 @@
}
}
- ASSERT_TRUE(bn_div_consttime(ret.get(), ret2.get(), a.get(), b.get(), ctx));
+ ASSERT_TRUE(bn_div_consttime(ret.get(), ret2.get(), a.get(), b.get(),
+ /*divisor_min_bits=*/0, ctx));
EXPECT_BIGNUMS_EQUAL("A / B (constant-time)", quotient.get(), ret.get());
EXPECT_BIGNUMS_EQUAL("A % B (constant-time)", remainder.get(), ret2.get());
+
+ ASSERT_TRUE(bn_div_consttime(ret.get(), ret2.get(), a.get(), b.get(),
+ /*divisor_min_bits=*/BN_num_bits(b.get()), ctx));
+ EXPECT_BIGNUMS_EQUAL("A / B (constant-time, public width)", quotient.get(),
+ ret.get());
+ EXPECT_BIGNUMS_EQUAL("A % B (constant-time, public width)", remainder.get(),
+ ret2.get());
}
static void TestModMul(BIGNUMFileTest *t, BN_CTX *ctx) {
diff --git a/crypto/fipsmodule/bn/div.c b/crypto/fipsmodule/bn/div.c
index 6ee6dbd..02b9931 100644
--- a/crypto/fipsmodule/bn/div.c
+++ b/crypto/fipsmodule/bn/div.c
@@ -456,7 +456,7 @@
int bn_div_consttime(BIGNUM *quotient, BIGNUM *remainder,
const BIGNUM *numerator, const BIGNUM *divisor,
- BN_CTX *ctx) {
+ unsigned divisor_min_bits, BN_CTX *ctx) {
if (BN_is_negative(numerator) || BN_is_negative(divisor)) {
OPENSSL_PUT_ERROR(BN, BN_R_NEGATIVE_NUMBER);
return 0;
@@ -496,8 +496,26 @@
r->neg = 0;
// Incorporate |numerator| into |r|, one bit at a time, reducing after each
- // step. At the start of each loop iteration, |r| < |divisor|
- for (int i = numerator->width - 1; i >= 0; i--) {
+ // step. We maintain the invariant that |0 <= r < divisor| and
+ // |q * divisor + r = n| where |n| is the portion of |numerator| incorporated
+ // so far.
+ //
+ // First, we short-circuit the loop: if we know |divisor| has at least
+ // |divisor_min_bits| bits, the top |divisor_min_bits - 1| can be incorporated
+ // without reductions. This significantly speeds up |RSA_check_key|. For
+ // simplicity, we round down to a whole number of words.
+ assert(divisor_min_bits <= BN_num_bits(divisor));
+ int initial_words = 0;
+ if (divisor_min_bits > 0) {
+ initial_words = (divisor_min_bits - 1) / BN_BITS2;
+ if (initial_words > numerator->width) {
+ initial_words = numerator->width;
+ }
+ OPENSSL_memcpy(r->d, numerator->d + numerator->width - initial_words,
+ initial_words * sizeof(BN_ULONG));
+ }
+
+ for (int i = numerator->width - initial_words - 1; i >= 0; i--) {
for (int bit = BN_BITS2 - 1; bit >= 0; bit--) {
// Incorporate the next bit of the numerator, by computing
// r = 2*r or 2*r + 1. Note the result fits in one more word. We store the
diff --git a/crypto/fipsmodule/bn/gcd_extra.c b/crypto/fipsmodule/bn/gcd_extra.c
index 30540e3..53ab170 100644
--- a/crypto/fipsmodule/bn/gcd_extra.c
+++ b/crypto/fipsmodule/bn/gcd_extra.c
@@ -157,10 +157,11 @@
BN_CTX_start(ctx);
unsigned shift;
BIGNUM *gcd = BN_CTX_get(ctx);
- int ret = gcd != NULL &&
+ int ret = gcd != NULL && //
bn_mul_consttime(r, a, b, ctx) &&
bn_gcd_consttime(gcd, &shift, a, b, ctx) &&
- bn_div_consttime(r, NULL, r, gcd, ctx) &&
+ // |gcd| has a secret bit width.
+ bn_div_consttime(r, NULL, r, gcd, /*divisor_min_bits=*/0, ctx) &&
bn_rshift_secret_shift(r, r, shift, ctx);
BN_CTX_end(ctx);
return ret;
diff --git a/crypto/fipsmodule/bn/internal.h b/crypto/fipsmodule/bn/internal.h
index 3d368db..cab9a81 100644
--- a/crypto/fipsmodule/bn/internal.h
+++ b/crypto/fipsmodule/bn/internal.h
@@ -552,12 +552,15 @@
// bn_div_consttime behaves like |BN_div|, but it rejects negative inputs and
// treats both inputs, including their magnitudes, as secret. It is, as a
// result, much slower than |BN_div| and should only be used for rare operations
-// where Montgomery reduction is not available.
+// where Montgomery reduction is not available. |divisor_min_bits| is a
+// public lower bound for |BN_num_bits(divisor)|. When |divisor|'s bit width is
+// public, this can speed up the operation.
//
// Note that |quotient->width| will be set pessimally to |numerator->width|.
OPENSSL_EXPORT int bn_div_consttime(BIGNUM *quotient, BIGNUM *remainder,
const BIGNUM *numerator,
- const BIGNUM *divisor, BN_CTX *ctx);
+ const BIGNUM *divisor,
+ unsigned divisor_min_bits, BN_CTX *ctx);
// bn_is_relatively_prime checks whether GCD(|x|, |y|) is one. On success, it
// returns one and sets |*out_relatively_prime| to one if the GCD was one and
diff --git a/crypto/fipsmodule/rsa/rsa.c b/crypto/fipsmodule/rsa/rsa.c
index f6d3640..fd84cba 100644
--- a/crypto/fipsmodule/rsa/rsa.c
+++ b/crypto/fipsmodule/rsa/rsa.c
@@ -657,7 +657,8 @@
}
static int check_mod_inverse(int *out_ok, const BIGNUM *a, const BIGNUM *ainv,
- const BIGNUM *m, BN_CTX *ctx) {
+ const BIGNUM *m, unsigned m_min_bits,
+ BN_CTX *ctx) {
if (BN_is_negative(ainv) || BN_cmp(ainv, m) >= 0) {
*out_ok = 0;
return 1;
@@ -670,7 +671,7 @@
BIGNUM *tmp = BN_CTX_get(ctx);
int ret = tmp != NULL &&
bn_mul_consttime(tmp, a, ainv, ctx) &&
- bn_div_consttime(NULL, tmp, tmp, m, ctx);
+ bn_div_consttime(NULL, tmp, tmp, m, m_min_bits, ctx);
if (ret) {
*out_ok = BN_is_one(tmp);
}
@@ -750,10 +751,15 @@
// simply check that d * e is one mod p-1 and mod q-1. Note d and e were bound
// by earlier checks in this function.
if (!bn_usub_consttime(&pm1, key->p, BN_value_one()) ||
- !bn_usub_consttime(&qm1, key->q, BN_value_one()) ||
- !bn_mul_consttime(&de, key->d, key->e, ctx) ||
- !bn_div_consttime(NULL, &tmp, &de, &pm1, ctx) ||
- !bn_div_consttime(NULL, &de, &de, &qm1, ctx)) {
+ !bn_usub_consttime(&qm1, key->q, BN_value_one())) {
+ OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
+ goto out;
+ }
+ const unsigned pm1_bits = BN_num_bits(&pm1);
+ const unsigned qm1_bits = BN_num_bits(&qm1);
+ if (!bn_mul_consttime(&de, key->d, key->e, ctx) ||
+ !bn_div_consttime(NULL, &tmp, &de, &pm1, pm1_bits, ctx) ||
+ !bn_div_consttime(NULL, &de, &de, &qm1, qm1_bits, ctx)) {
OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
goto out;
}
@@ -772,9 +778,12 @@
if (has_crt_values) {
int dmp1_ok, dmq1_ok, iqmp_ok;
- if (!check_mod_inverse(&dmp1_ok, key->e, key->dmp1, &pm1, ctx) ||
- !check_mod_inverse(&dmq1_ok, key->e, key->dmq1, &qm1, ctx) ||
- !check_mod_inverse(&iqmp_ok, key->q, key->iqmp, key->p, ctx)) {
+ if (!check_mod_inverse(&dmp1_ok, key->e, key->dmp1, &pm1, pm1_bits, ctx) ||
+ !check_mod_inverse(&dmq1_ok, key->e, key->dmq1, &qm1, qm1_bits, ctx) ||
+ // |p| is odd, so |pm1| and |p| have the same bit width. If they didn't,
+ // we only need a lower bound anyway.
+ !check_mod_inverse(&iqmp_ok, key->q, key->iqmp, key->p, pm1_bits,
+ ctx)) {
OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
goto out;
}
diff --git a/crypto/fipsmodule/rsa/rsa_impl.c b/crypto/fipsmodule/rsa/rsa_impl.c
index 6dd89b9..a6865c0 100644
--- a/crypto/fipsmodule/rsa/rsa_impl.c
+++ b/crypto/fipsmodule/rsa/rsa_impl.c
@@ -1262,12 +1262,14 @@
// values for d.
} while (BN_cmp(rsa->d, pow2_prime_bits) <= 0);
+ assert(BN_num_bits(pm1) == (unsigned)prime_bits);
+ assert(BN_num_bits(qm1) == (unsigned)prime_bits);
if (// Calculate n.
!bn_mul_consttime(rsa->n, rsa->p, rsa->q, ctx) ||
// Calculate d mod (p-1).
- !bn_div_consttime(NULL, rsa->dmp1, rsa->d, pm1, ctx) ||
+ !bn_div_consttime(NULL, rsa->dmp1, rsa->d, pm1, prime_bits, ctx) ||
// Calculate d mod (q-1)
- !bn_div_consttime(NULL, rsa->dmq1, rsa->d, qm1, ctx)) {
+ !bn_div_consttime(NULL, rsa->dmq1, rsa->d, qm1, prime_bits, ctx)) {
goto bn_err;
}
bn_set_minimal_width(rsa->n);