Replace rsa_greater_than_pow2 with BN_cmp.

It costs us a malloc, but it's one less function to test and implement
in constant time, now that BN_cmp and BIGNUM are okay.

Median of 29 RSA keygens: 0m0.207s -> 0m0.210s
(Accuracy beyond 0.1s is questionable.)

Bug: 238
Change-Id: Ic56f92f0dcf04da1f542290a7e8cdab8036699ed
Reviewed-on: https://boringssl-review.googlesource.com/26367
Reviewed-by: Adam Langley <alangley@gmail.com>
diff --git a/crypto/fipsmodule/rsa/internal.h b/crypto/fipsmodule/rsa/internal.h
index 0f0c763..f913058 100644
--- a/crypto/fipsmodule/rsa/internal.h
+++ b/crypto/fipsmodule/rsa/internal.h
@@ -114,15 +114,10 @@
                           size_t len);
 
 
-// The following utility functions are exported for test purposes.
-
+// This constant is exported for test purposes.
 extern const BN_ULONG kBoringSSLRSASqrtTwo[];
 extern const size_t kBoringSSLRSASqrtTwoLen;
 
-// rsa_greater_than_pow2 returns one if |b| is greater than 2^|n| and zero
-// otherwise.
-int rsa_greater_than_pow2(const BIGNUM *b, int n);
-
 
 #if defined(__cplusplus)
 }  // extern C
diff --git a/crypto/fipsmodule/rsa/rsa_impl.c b/crypto/fipsmodule/rsa/rsa_impl.c
index 88bcb5f..66b59f0 100644
--- a/crypto/fipsmodule/rsa/rsa_impl.c
+++ b/crypto/fipsmodule/rsa/rsa_impl.c
@@ -924,25 +924,20 @@
 };
 const size_t kBoringSSLRSASqrtTwoLen = OPENSSL_ARRAY_SIZE(kBoringSSLRSASqrtTwo);
 
-int rsa_greater_than_pow2(const BIGNUM *b, int n) {
-  if (BN_is_negative(b) || n == INT_MAX) {
-    return 0;
-  }
-
-  int b_bits = BN_num_bits(b);
-  return b_bits > n + 1 || (b_bits == n + 1 && !BN_is_pow2(b));
-}
-
 // generate_prime sets |out| to a prime with length |bits| such that |out|-1 is
 // relatively prime to |e|. If |p| is non-NULL, |out| will also not be close to
-// |p|.
+// |p|. |sqrt2| must be ⌊2^(bits-1)×√2⌋ (or a slightly overestimate for large
+// sizes), and |pow2_bits_100| must be 2^(bits-100).
 static int generate_prime(BIGNUM *out, int bits, const BIGNUM *e,
-                          const BIGNUM *p, const BIGNUM *sqrt2, BN_CTX *ctx,
+                          const BIGNUM *p, const BIGNUM *sqrt2,
+                          const BIGNUM *pow2_bits_100, BN_CTX *ctx,
                           BN_GENCB *cb) {
   if (bits < 128 || (bits % BN_BITS2) != 0) {
     OPENSSL_PUT_ERROR(RSA, ERR_R_INTERNAL_ERROR);
     return 0;
   }
+  assert(BN_is_pow2(pow2_bits_100));
+  assert(BN_is_bit_set(pow2_bits_100, bits - 100));
 
   // See FIPS 186-4 appendix B.3.3, steps 4 and 5. Note |bits| here is nlen/2.
 
@@ -977,7 +972,7 @@
         goto err;
       }
       BN_set_negative(tmp, 0);
-      if (!rsa_greater_than_pow2(tmp, bits - 100)) {
+      if (BN_cmp(tmp, pow2_bits_100) <= 0) {
         continue;
       }
     }
@@ -1048,6 +1043,7 @@
   }
 
   int ret = 0;
+  int prime_bits = bits / 2;
   BN_CTX *ctx = BN_CTX_new();
   if (ctx == NULL) {
     goto bn_err;
@@ -1058,8 +1054,12 @@
   BIGNUM *qm1 = BN_CTX_get(ctx);
   BIGNUM *gcd = BN_CTX_get(ctx);
   BIGNUM *sqrt2 = BN_CTX_get(ctx);
+  BIGNUM *pow2_prime_bits_100 = BN_CTX_get(ctx);
+  BIGNUM *pow2_prime_bits = BN_CTX_get(ctx);
   if (totient == NULL || pm1 == NULL || qm1 == NULL || gcd == NULL ||
-      sqrt2 == NULL) {
+      sqrt2 == NULL || pow2_prime_bits_100 == NULL || pow2_prime_bits == NULL ||
+      !BN_set_bit(pow2_prime_bits_100, prime_bits - 100) ||
+      !BN_set_bit(pow2_prime_bits, prime_bits)) {
     goto bn_err;
   }
 
@@ -1078,8 +1078,6 @@
     goto bn_err;
   }
 
-  int prime_bits = bits / 2;
-
   // Compute sqrt2 >= ⌊2^(prime_bits-1)×√2⌋.
   if (!bn_set_words(sqrt2, kBoringSSLRSASqrtTwo, kBoringSSLRSASqrtTwoLen)) {
     goto bn_err;
@@ -1105,9 +1103,11 @@
   do {
     // Generate p and q, each of size |prime_bits|, using the steps outlined in
     // appendix FIPS 186-4 appendix B.3.3.
-    if (!generate_prime(rsa->p, prime_bits, rsa->e, NULL, sqrt2, ctx, cb) ||
+    if (!generate_prime(rsa->p, prime_bits, rsa->e, NULL, sqrt2,
+                        pow2_prime_bits_100, ctx, cb) ||
         !BN_GENCB_call(cb, 3, 0) ||
-        !generate_prime(rsa->q, prime_bits, rsa->e, rsa->p, sqrt2, ctx, cb) ||
+        !generate_prime(rsa->q, prime_bits, rsa->e, rsa->p, sqrt2,
+                        pow2_prime_bits_100, ctx, cb) ||
         !BN_GENCB_call(cb, 3, 1)) {
       goto bn_err;
     }
@@ -1134,9 +1134,9 @@
       goto bn_err;
     }
 
-    // Check that |rsa->d| > 2^|prime_bits| and try again if it fails. See
-    // appendix B.3.1's guidance on values for d.
-  } while (!rsa_greater_than_pow2(rsa->d, prime_bits));
+    // Retry if |rsa->d| <= 2^|prime_bits|. See appendix B.3.1's guidance on
+    // values for d.
+  } while (BN_cmp(rsa->d, pow2_prime_bits) <= 0);
 
   if (// Calculate n.
       !BN_mul(rsa->n, rsa->p, rsa->q, ctx) ||
diff --git a/crypto/rsa_extra/rsa_test.cc b/crypto/rsa_extra/rsa_test.cc
index fdd5e49..87eabf8 100644
--- a/crypto/rsa_extra/rsa_test.cc
+++ b/crypto/rsa_extra/rsa_test.cc
@@ -909,54 +909,4 @@
   // Check the kBoringSSLRSASqrtTwo is sized for a 3072-bit RSA key.
   EXPECT_EQ(3072u / 2u, bits);
 }
-
-TEST(RSATest, GreaterThanPow2) {
-  bssl::UniquePtr<BIGNUM> b(BN_new());
-  BN_zero(b.get());
-  EXPECT_FALSE(rsa_greater_than_pow2(b.get(), 0));
-  EXPECT_FALSE(rsa_greater_than_pow2(b.get(), 1));
-  EXPECT_FALSE(rsa_greater_than_pow2(b.get(), 20));
-
-  ASSERT_TRUE(BN_set_word(b.get(), 1));
-  EXPECT_FALSE(rsa_greater_than_pow2(b.get(), 0));
-  EXPECT_FALSE(rsa_greater_than_pow2(b.get(), 1));
-  EXPECT_FALSE(rsa_greater_than_pow2(b.get(), 20));
-
-  ASSERT_TRUE(BN_set_word(b.get(), 2));
-  EXPECT_TRUE(rsa_greater_than_pow2(b.get(), 0));
-  EXPECT_FALSE(rsa_greater_than_pow2(b.get(), 1));
-  EXPECT_FALSE(rsa_greater_than_pow2(b.get(), 20));
-
-  ASSERT_TRUE(BN_set_word(b.get(), 3));
-  EXPECT_TRUE(rsa_greater_than_pow2(b.get(), 0));
-  EXPECT_TRUE(rsa_greater_than_pow2(b.get(), 1));
-  EXPECT_FALSE(rsa_greater_than_pow2(b.get(), 2));
-  EXPECT_FALSE(rsa_greater_than_pow2(b.get(), 20));
-
-  BN_set_negative(b.get(), 1);
-  EXPECT_FALSE(rsa_greater_than_pow2(b.get(), 0));
-  EXPECT_FALSE(rsa_greater_than_pow2(b.get(), 1));
-  EXPECT_FALSE(rsa_greater_than_pow2(b.get(), 2));
-  EXPECT_FALSE(rsa_greater_than_pow2(b.get(), 20));
-
-  // Check all bit lengths mod 64.
-  for (int n = 1024; n < 1024 + 64; n++) {
-    SCOPED_TRACE(n);
-    ASSERT_TRUE(BN_set_word(b.get(), 1));
-    ASSERT_TRUE(BN_lshift(b.get(), b.get(), n));
-    EXPECT_TRUE(rsa_greater_than_pow2(b.get(), n - 1));
-    EXPECT_FALSE(rsa_greater_than_pow2(b.get(), n));
-    EXPECT_FALSE(rsa_greater_than_pow2(b.get(), n + 1));
-
-    ASSERT_TRUE(BN_sub_word(b.get(), 1));
-    EXPECT_TRUE(rsa_greater_than_pow2(b.get(), n - 1));
-    EXPECT_FALSE(rsa_greater_than_pow2(b.get(), n));
-    EXPECT_FALSE(rsa_greater_than_pow2(b.get(), n + 1));
-
-    ASSERT_TRUE(BN_add_word(b.get(), 2));
-    EXPECT_TRUE(rsa_greater_than_pow2(b.get(), n - 1));
-    EXPECT_TRUE(rsa_greater_than_pow2(b.get(), n));
-    EXPECT_FALSE(rsa_greater_than_pow2(b.get(), n + 1));
-  }
-}
 #endif  // !BORINGSSL_SHARED_LIBRARY