Stop using sqrt(2) in RSA key generation

This was always unnecessary, but a consequence of FIPS 186-4's
poorly-specified algorithm. FIPS 186-5 fixes this issue by allowing us
to set the top two bits arbitrarily. Use this allowance to set the top
two bits to 1s, upper bounding sqrt(2) by 1.5.

Change-Id: I4e3dbedbc211df6a35801c75cca82ef96fdc656e
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/86590
Commit-Queue: Adam Langley <agl@google.com>
Auto-Submit: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/rsa/internal.h b/crypto/fipsmodule/rsa/internal.h
index 449e8fa..a739c58 100644
--- a/crypto/fipsmodule/rsa/internal.h
+++ b/crypto/fipsmodule/rsa/internal.h
@@ -135,11 +135,6 @@
 void rsa_invalidate_key(RSA *rsa);
 
 
-// This constant is exported for test purposes.
-extern const BN_ULONG kBoringSSLRSASqrtTwo[];
-extern const size_t kBoringSSLRSASqrtTwoLen;
-
-
 // Functions that avoid self-tests.
 //
 // Self-tests need to call functions that don't try and ensure that the
diff --git a/crypto/fipsmodule/rsa/rsa_impl.cc.inc b/crypto/fipsmodule/rsa/rsa_impl.cc.inc
index 80b9115..e029bcc 100644
--- a/crypto/fipsmodule/rsa/rsa_impl.cc.inc
+++ b/crypto/fipsmodule/rsa/rsa_impl.cc.inc
@@ -598,84 +598,14 @@
   return *out != nullptr;
 }
 
-// kBoringSSLRSASqrtTwo is the BIGNUM representation of ⌊2²⁰⁴⁷×√2⌋. This is
-// chosen to give enough precision for 4096-bit RSA, the largest key size FIPS
-// specifies. Key sizes beyond this will round up.
-//
-// To calculate, use the following Haskell code:
-//
-// import Text.Printf (printf)
-// import Data.List (intercalate)
-//
-// pow2 = 4095
-// target = 2^pow2
-//
-// f x = x*x - (toRational target)
-//
-// fprime x = 2*x
-//
-// newtonIteration x = x - (f x) / (fprime x)
-//
-// converge x =
-//   let n = floor x in
-//   if n*n - target < 0 && (n+1)*(n+1) - target > 0
-//     then n
-//     else converge (newtonIteration x)
-//
-// divrem bits x = (x `div` (2^bits), x `rem` (2^bits))
-//
-// bnWords :: Integer -> [Integer]
-// bnWords x =
-//   if x == 0
-//     then []
-//     else let (high, low) = divrem 64 x in low : bnWords high
-//
-// showWord x = let (high, low) = divrem 32 x in printf "TOBN(0x%08x, 0x%08x)"
-// high low
-//
-// output :: String
-// output = intercalate ", " $ map showWord $ bnWords $ converge (2 ^ (pow2
-// `div` 2))
-//
-// To verify this number, check that n² < 2⁴⁰⁹⁵ < (n+1)², where n is value
-// represented here. Note the components are listed in little-endian order. Here
-// is some sample Python code to check:
-//
-//   >>> TOBN = lambda a, b: a << 32 | b
-//   >>> l = [ <paste the contents of kSqrtTwo> ]
-//   >>> n = sum(a * 2**(64*i) for i, a in enumerate(l))
-//   >>> n**2 < 2**4095 < (n+1)**2
-//   True
-const BN_ULONG kBoringSSLRSASqrtTwo[] = {
-    TOBN(0x4d7c60a5, 0xe633e3e1), TOBN(0x5fcf8f7b, 0xca3ea33b),
-    TOBN(0xc246785e, 0x92957023), TOBN(0xf9acce41, 0x797f2805),
-    TOBN(0xfdfe170f, 0xd3b1f780), TOBN(0xd24f4a76, 0x3facb882),
-    TOBN(0x18838a2e, 0xaff5f3b2), TOBN(0xc1fcbdde, 0xa2f7dc33),
-    TOBN(0xdea06241, 0xf7aa81c2), TOBN(0xf6a1be3f, 0xca221307),
-    TOBN(0x332a5e9f, 0x7bda1ebf), TOBN(0x0104dc01, 0xfe32352f),
-    TOBN(0xb8cf341b, 0x6f8236c7), TOBN(0x4264dabc, 0xd528b651),
-    TOBN(0xf4d3a02c, 0xebc93e0c), TOBN(0x81394ab6, 0xd8fd0efd),
-    TOBN(0xeaa4a089, 0x9040ca4a), TOBN(0xf52f120f, 0x836e582e),
-    TOBN(0xcb2a6343, 0x31f3c84d), TOBN(0xc6d5a8a3, 0x8bb7e9dc),
-    TOBN(0x460abc72, 0x2f7c4e33), TOBN(0xcab1bc91, 0x1688458a),
-    TOBN(0x53059c60, 0x11bc337b), TOBN(0xd2202e87, 0x42af1f4e),
-    TOBN(0x78048736, 0x3dfa2768), TOBN(0x0f74a85e, 0x439c7b4a),
-    TOBN(0xa8b1fe6f, 0xdc83db39), TOBN(0x4afc8304, 0x3ab8a2c3),
-    TOBN(0xed17ac85, 0x83339915), TOBN(0x1d6f60ba, 0x893ba84c),
-    TOBN(0x597d89b3, 0x754abe9f), TOBN(0xb504f333, 0xf9de6484),
-};
-const size_t kBoringSSLRSASqrtTwoLen = std::size(kBoringSSLRSASqrtTwo);
-
 // generate_prime sets |out| to a prime with length |bits| such that |out|-1 is
 // relatively prime to |e|. If |p| is non-NULL, |out| will also not be close to
-// |p|. |sqrt2| must be ⌊2^(bits-1)×√2⌋ (or a slightly overestimate for large
-// sizes), and |pow2_bits_100| must be 2^(bits-100).
+// |p|. |pow2_bits_100| must be 2^(bits-100).
 //
 // This function fails with probability around 2^-21.
 static int generate_prime(BIGNUM *out, int bits, const BIGNUM *e,
-                          const BIGNUM *p, const BIGNUM *sqrt2,
-                          const BIGNUM *pow2_bits_100, BN_CTX *ctx,
-                          BN_GENCB *cb) {
+                          const BIGNUM *p, const BIGNUM *pow2_bits_100,
+                          BN_CTX *ctx, BN_GENCB *cb) {
   if (bits < 128 || (bits % BN_BITS2) != 0) {
     OPENSSL_PUT_ERROR(RSA, ERR_R_INTERNAL_ERROR);
     return 0;
@@ -727,28 +657,19 @@
   }
 
   for (;;) {
-    // Generate a random number of length |bits| where the bottom bit is set
-    // (steps 4.2, 4.3, 5.2 and 5.3) and the top bit is set (implied by the
-    // bound checked below in steps 4.4 and 5.5).
-    if (!BN_rand(out, bits, BN_RAND_TOP_ONE, BN_RAND_BOTTOM_ODD) ||
+    // Generate a random number of length |bits| where the bottom bit is set and
+    // top two bits are set (steps 4.2–4.4 and 5.2–5.4):
+    //
+    // - Setting the top two bits is permitted by steps 4.2.1 and 5.2.1. Doing
+    //   so implements steps 4.4 and 5.4 by making this case impossible because
+    //   √2 < 1.5.
+    //
+    // - Setting the bottom bit implements steps 4.3 and 5.3.
+    if (!BN_rand(out, bits, BN_RAND_TOP_TWO, BN_RAND_BOTTOM_ODD) ||
         !BN_GENCB_call(cb, BN_GENCB_GENERATED, rand_tries++)) {
       return 0;
     }
 
-    // If out < 2^(bits-1)×√2, try again (steps 4.4 and 5.4). This is equivalent
-    // to out <= ⌊2^(bits-1)×√2⌋, or out <= sqrt2 for FIPS key sizes.
-    //
-    // For larger keys, the comparison is approximate, leaning towards
-    // retrying. That is, we reject a negligible fraction of primes that are
-    // within the FIPS bound, but we will never accept a prime outside the
-    // bound, ensuring the resulting RSA key is the right size.
-    //
-    // Values over the threshold are discarded, so it is safe to leak this
-    // comparison.
-    if (constant_time_declassify_int(BN_cmp(out, sqrt2) <= 0)) {
-      continue;
-    }
-
     if (p != nullptr) {
       // If |p| and |out| are too close, try again (step 5.5).
       if (!bn_abs_sub_consttime(tmp, out, p, ctx)) {
@@ -783,8 +704,7 @@
       }
     }
 
-    // If we've tried too many times to find a prime, abort (steps 4.7 and
-    // 5.8).
+    // If we've tried too many times to find a prime, abort (steps 4.7 and 5.8).
     tries++;
     if (tries >= limit) {
       OPENSSL_PUT_ERROR(RSA, RSA_R_TOO_MANY_ITERATIONS);
@@ -826,7 +746,6 @@
   }
 
   bssl::UniquePtr<BN_CTX> ctx(BN_CTX_new());
-  int sqrt2_bits;
   if (ctx == nullptr) {
     OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
     return 0;
@@ -837,12 +756,10 @@
   BIGNUM *totient = BN_CTX_get(ctx.get());
   BIGNUM *pm1 = BN_CTX_get(ctx.get());
   BIGNUM *qm1 = BN_CTX_get(ctx.get());
-  BIGNUM *sqrt2 = BN_CTX_get(ctx.get());
   BIGNUM *pow2_prime_bits_100 = BN_CTX_get(ctx.get());
   BIGNUM *pow2_prime_bits = BN_CTX_get(ctx.get());
   if (totient == nullptr || pm1 == nullptr || qm1 == nullptr ||
-      sqrt2 == nullptr || pow2_prime_bits_100 == nullptr ||
-      pow2_prime_bits == nullptr ||
+      pow2_prime_bits_100 == nullptr || pow2_prime_bits == nullptr ||
       !BN_set_bit(pow2_prime_bits_100, prime_bits - 100) ||
       !BN_set_bit(pow2_prime_bits, prime_bits)) {
     OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
@@ -867,42 +784,17 @@
     return 0;
   }
 
-  // Compute sqrt2 >= ⌊2^(prime_bits-1)×√2⌋.
-  if (!bn_set_words(sqrt2, kBoringSSLRSASqrtTwo, kBoringSSLRSASqrtTwoLen)) {
-    OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
-    return 0;
-  }
-  sqrt2_bits = kBoringSSLRSASqrtTwoLen * BN_BITS2;
-  assert(sqrt2_bits == (int)BN_num_bits(sqrt2));
-  if (sqrt2_bits > prime_bits) {
-    // For key sizes up to 4096 (prime_bits = 2048), this is exactly
-    // ⌊2^(prime_bits-1)×√2⌋.
-    if (!BN_rshift(sqrt2, sqrt2, sqrt2_bits - prime_bits)) {
-      OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
-      return 0;
-    }
-  } else if (prime_bits > sqrt2_bits) {
-    // For key sizes beyond 4096, this is approximate. We err towards retrying
-    // to ensure our key is the right size and round up.
-    if (!BN_add_word(sqrt2, 1) ||
-        !BN_lshift(sqrt2, sqrt2, prime_bits - sqrt2_bits)) {
-      OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
-      return 0;
-    }
-  }
-  assert(prime_bits == (int)BN_num_bits(sqrt2));
-
   do {
     // Generate p and q, each of size |prime_bits|, using the steps outlined in
     // appendix FIPS 186-5 appendix C.3.3.
     //
     // Each call to |generate_prime| fails with probability p = 2^-21. The
     // probability that either call fails is 1 - (1-p)^2, which is around 2^-20.
-    if (!generate_prime(rsa->p, prime_bits, rsa->e, nullptr, sqrt2,
+    if (!generate_prime(rsa->p, prime_bits, rsa->e, nullptr,
                         pow2_prime_bits_100, ctx.get(), cb) ||
         !BN_GENCB_call(cb, 3, 0) ||
-        !generate_prime(rsa->q, prime_bits, rsa->e, rsa->p, sqrt2,
-                        pow2_prime_bits_100, ctx.get(), cb) ||
+        !generate_prime(rsa->q, prime_bits, rsa->e, rsa->p, pow2_prime_bits_100,
+                        ctx.get(), cb) ||
         !BN_GENCB_call(cb, 3, 1)) {
       OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
       return 0;
diff --git a/crypto/rsa/rsa_test.cc b/crypto/rsa/rsa_test.cc
index f235377..27506b4 100644
--- a/crypto/rsa/rsa_test.cc
+++ b/crypto/rsa/rsa_test.cc
@@ -1315,36 +1315,6 @@
   EXPECT_FALSE(read_public_key("crypto/rsa/test/rsa8193pub.pem"));
 }
 
-#if !defined(BORINGSSL_SHARED_LIBRARY)
-TEST(RSATest, SqrtTwo) {
-  bssl::UniquePtr<BIGNUM> sqrt(BN_new()), pow2(BN_new());
-  bssl::UniquePtr<BN_CTX> ctx(BN_CTX_new());
-  ASSERT_TRUE(sqrt);
-  ASSERT_TRUE(pow2);
-  ASSERT_TRUE(ctx);
-
-  size_t bits = kBoringSSLRSASqrtTwoLen * BN_BITS2;
-  ASSERT_TRUE(BN_one(pow2.get()));
-  ASSERT_TRUE(BN_lshift(pow2.get(), pow2.get(), 2 * bits - 1));
-
-  // Check that sqrt² < pow2.
-  ASSERT_TRUE(
-      bn_set_words(sqrt.get(), kBoringSSLRSASqrtTwo, kBoringSSLRSASqrtTwoLen));
-  ASSERT_TRUE(BN_sqr(sqrt.get(), sqrt.get(), ctx.get()));
-  EXPECT_LT(BN_cmp(sqrt.get(), pow2.get()), 0);
-
-  // Check that pow2 < (sqrt + 1)².
-  ASSERT_TRUE(
-      bn_set_words(sqrt.get(), kBoringSSLRSASqrtTwo, kBoringSSLRSASqrtTwoLen));
-  ASSERT_TRUE(BN_add_word(sqrt.get(), 1));
-  ASSERT_TRUE(BN_sqr(sqrt.get(), sqrt.get(), ctx.get()));
-  EXPECT_LT(BN_cmp(pow2.get(), sqrt.get()), 0);
-
-  // Check the kBoringSSLRSASqrtTwo is sized for a 4096-bit RSA key.
-  EXPECT_EQ(4096u / 2u, bits);
-}
-#endif  // !BORINGSSL_SHARED_LIBRARY
-
 #if defined(OPENSSL_THREADS)
 TEST(RSATest, Threads) {
   bssl::UniquePtr<RSA> rsa_template(