Improve the RSA key generation failure probability.

The FIPS 186-4 algorithm we use includes a limit which hits a 2^-20
failure probability, assuming my math is right. We've observed roughly
2^-23. This is a little large at scale. (See b/77854769.)

To avoid modifying the FIPS algorithm, retry the whole thing four times
to bring the failure rate down to 2^-80. Along the way, now that I have
the derivation on hand, adjust
https://boringssl-review.googlesource.com/22584 to target the same
failure probability.

Along the way, fix an issue with RSA_generate_key where, if callers
don't check for failure, there may be half a key in there.

Change-Id: I0e1da98413ebd4ffa65fb74c67a58a0e0cd570ff
Reviewed-on: https://boringssl-review.googlesource.com/27288
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/rsa/rsa_impl.c b/crypto/fipsmodule/rsa/rsa_impl.c
index 86249ec..6d1206b 100644
--- a/crypto/fipsmodule/rsa/rsa_impl.c
+++ b/crypto/fipsmodule/rsa/rsa_impl.c
@@ -934,6 +934,8 @@
 // relatively prime to |e|. If |p| is non-NULL, |out| will also not be close to
 // |p|. |sqrt2| must be ⌊2^(bits-1)×√2⌋ (or a slightly overestimate for large
 // sizes), and |pow2_bits_100| must be 2^(bits-100).
+//
+// This function fails with probability around 2^-21.
 static int generate_prime(BIGNUM *out, int bits, const BIGNUM *e,
                           const BIGNUM *p, const BIGNUM *sqrt2,
                           const BIGNUM *pow2_bits_100, BN_CTX *ctx,
@@ -950,11 +952,36 @@
   // Use the limit from steps 4.7 and 5.8 for most values of |e|. When |e| is 3,
   // the 186-4 limit is too low, so we use a higher one. Note this case is not
   // reachable from |RSA_generate_key_fips|.
+  //
+  // |limit| determines the failure probability. We must find a prime that is
+  // not 1 mod |e|. By the prime number theorem, we'll find one with probability
+  // p = (e-1)/e * 2/(ln(2)*bits). Note the second term is doubled because we
+  // discard even numbers.
+  //
+  // The failure probability is thus (1-p)^limit. To convert that to a power of
+  // two, we take logs. -log_2((1-p)^limit) = -limit * ln(1-p) / ln(2).
+  //
+  // >>> def f(bits, e, limit):
+  // ...   p = (e-1.0)/e * 2.0/(math.log(2)*bits)
+  // ...   return -limit * math.log(1 - p) / math.log(2)
+  // ...
+  // >>> f(1024, 65537, 5*1024)
+  // 20.842750558272634
+  // >>> f(1536, 65537, 5*1536)
+  // 20.83294549602474
+  // >>> f(2048, 65537, 5*2048)
+  // 20.828047576234948
+  // >>> f(1024, 3, 8*1024)
+  // 22.222147925962307
+  // >>> f(1536, 3, 8*1536)
+  // 22.21518251065506
+  // >>> f(2048, 3, 8*2048)
+  // 22.211701985875937
   if (bits >= INT_MAX/32) {
     OPENSSL_PUT_ERROR(RSA, RSA_R_MODULUS_TOO_LARGE);
     return 0;
   }
-  int limit = BN_is_word(e, 3) ? bits * 32 : bits * 5;
+  int limit = BN_is_word(e, 3) ? bits * 8 : bits * 5;
 
   int ret = 0, tries = 0, rand_tries = 0;
   BN_CTX_start(ctx);
@@ -1033,7 +1060,14 @@
   return ret;
 }
 
-int RSA_generate_key_ex(RSA *rsa, int bits, BIGNUM *e_value, BN_GENCB *cb) {
+// rsa_generate_key_impl generates an RSA key using a generalized version of
+// FIPS 186-4 appendix B.3. |RSA_generate_key_fips| performs additional checks
+// for FIPS-compliant key generation.
+//
+// This function returns one on success and zero on failure. It has a failure
+// probability of about 2^-20.
+static int rsa_generate_key_impl(RSA *rsa, int bits, BIGNUM *e_value,
+                                 BN_GENCB *cb) {
   // See FIPS 186-4 appendix B.3. This function implements a generalized version
   // of the FIPS algorithm. |RSA_generate_key_fips| performs additional checks
   // for FIPS-compliant key generation.
@@ -1119,6 +1153,9 @@
   do {
     // Generate p and q, each of size |prime_bits|, using the steps outlined in
     // appendix FIPS 186-4 appendix B.3.3.
+    //
+    // Each call to |generate_prime| fails with probability p = 2^-21. The
+    // probability that either call fails is 1 - (1-p)^2, which is around 2^-20.
     if (!generate_prime(rsa->p, prime_bits, rsa->e, NULL, sqrt2,
                         pow2_prime_bits_100, ctx, cb) ||
         !BN_GENCB_call(cb, 3, 0) ||
@@ -1198,6 +1235,65 @@
   return ret;
 }
 
+static void replace_bignum(BIGNUM **out, BIGNUM **in) {
+  BN_free(*out);
+  *out = *in;
+  *in = NULL;
+}
+
+static void replace_bn_mont_ctx(BN_MONT_CTX **out, BN_MONT_CTX **in) {
+  BN_MONT_CTX_free(*out);
+  *out = *in;
+  *in = NULL;
+}
+
+int RSA_generate_key_ex(RSA *rsa, int bits, BIGNUM *e_value, BN_GENCB *cb) {
+  // |rsa_generate_key_impl|'s 2^-20 failure probability is too high at scale,
+  // so we run the FIPS algorithm four times, bringing it down to 2^-80. We
+  // should just adjust the retry limit, but FIPS 186-4 prescribes that value
+  // and thus results in unnecessary complexity.
+  for (int i = 0; i < 4; i++) {
+    ERR_clear_error();
+    // Generate into scratch space, to avoid leaving partial work on failure.
+    RSA *tmp = RSA_new();
+    if (tmp == NULL) {
+      return 0;
+    }
+    if (rsa_generate_key_impl(tmp, bits, e_value, cb)) {
+      replace_bignum(&rsa->n, &tmp->n);
+      replace_bignum(&rsa->e, &tmp->e);
+      replace_bignum(&rsa->d, &tmp->d);
+      replace_bignum(&rsa->p, &tmp->p);
+      replace_bignum(&rsa->q, &tmp->q);
+      replace_bignum(&rsa->dmp1, &tmp->dmp1);
+      replace_bignum(&rsa->dmq1, &tmp->dmq1);
+      replace_bignum(&rsa->iqmp, &tmp->iqmp);
+      replace_bn_mont_ctx(&rsa->mont_n, &tmp->mont_n);
+      replace_bn_mont_ctx(&rsa->mont_p, &tmp->mont_p);
+      replace_bn_mont_ctx(&rsa->mont_q, &tmp->mont_q);
+      replace_bignum(&rsa->d_fixed, &tmp->d_fixed);
+      replace_bignum(&rsa->dmp1_fixed, &tmp->dmp1_fixed);
+      replace_bignum(&rsa->dmq1_fixed, &tmp->dmq1_fixed);
+      replace_bignum(&rsa->inv_small_mod_large_mont,
+                     &tmp->inv_small_mod_large_mont);
+      rsa->private_key_frozen = tmp->private_key_frozen;
+      RSA_free(tmp);
+      return 1;
+    }
+    uint32_t err = ERR_peek_error();
+    RSA_free(tmp);
+    tmp = NULL;
+    // Only retry on |RSA_R_TOO_MANY_ITERATIONS|. This is so a caller-induced
+    // failure in |BN_GENCB_call| is still fatal.
+    if (ERR_GET_LIB(err) != ERR_LIB_RSA ||
+        ERR_GET_REASON(err) != RSA_R_TOO_MANY_ITERATIONS) {
+      return 0;
+    }
+  }
+
+  return 0;
+}
+
 int RSA_generate_key_fips(RSA *rsa, int bits, BN_GENCB *cb) {
   // FIPS 186-4 allows 2048-bit and 3072-bit RSA keys (1024-bit and 1536-bit
   // primes, respectively) with the prime generation method we use.
diff --git a/crypto/rsa_extra/rsa_test.cc b/crypto/rsa_extra/rsa_test.cc
index a6bfb87..211b690 100644
--- a/crypto/rsa_extra/rsa_test.cc
+++ b/crypto/rsa_extra/rsa_test.cc
@@ -501,12 +501,12 @@
   ERR_clear_error();
 
   // Test that we can generate 2048-bit and 3072-bit RSA keys.
-  EXPECT_TRUE(RSA_generate_key_fips(rsa.get(), 2048, nullptr));
+  ASSERT_TRUE(RSA_generate_key_fips(rsa.get(), 2048, nullptr));
   EXPECT_EQ(2048u, BN_num_bits(rsa->n));
 
   rsa.reset(RSA_new());
   ASSERT_TRUE(rsa);
-  EXPECT_TRUE(RSA_generate_key_fips(rsa.get(), 3072, nullptr));
+  ASSERT_TRUE(RSA_generate_key_fips(rsa.get(), 3072, nullptr));
   EXPECT_EQ(3072u, BN_num_bits(rsa->n));
 }
 
@@ -653,22 +653,22 @@
 
   bssl::UniquePtr<RSA> rsa(RSA_new());
   ASSERT_TRUE(rsa);
-  EXPECT_TRUE(RSA_generate_key_ex(rsa.get(), 1025, e.get(), nullptr));
+  ASSERT_TRUE(RSA_generate_key_ex(rsa.get(), 1025, e.get(), nullptr));
   EXPECT_EQ(1024u, BN_num_bits(rsa->n));
 
   rsa.reset(RSA_new());
   ASSERT_TRUE(rsa);
-  EXPECT_TRUE(RSA_generate_key_ex(rsa.get(), 1027, e.get(), nullptr));
+  ASSERT_TRUE(RSA_generate_key_ex(rsa.get(), 1027, e.get(), nullptr));
   EXPECT_EQ(1024u, BN_num_bits(rsa->n));
 
   rsa.reset(RSA_new());
   ASSERT_TRUE(rsa);
-  EXPECT_TRUE(RSA_generate_key_ex(rsa.get(), 1151, e.get(), nullptr));
+  ASSERT_TRUE(RSA_generate_key_ex(rsa.get(), 1151, e.get(), nullptr));
   EXPECT_EQ(1024u, BN_num_bits(rsa->n));
 
   rsa.reset(RSA_new());
   ASSERT_TRUE(rsa);
-  EXPECT_TRUE(RSA_generate_key_ex(rsa.get(), 1152, e.get(), nullptr));
+  ASSERT_TRUE(RSA_generate_key_ex(rsa.get(), 1152, e.get(), nullptr));
   EXPECT_EQ(1152u, BN_num_bits(rsa->n));
 }
 
@@ -896,6 +896,123 @@
   ASSERT_TRUE(BN_sub(rsa->iqmp, rsa->iqmp, rsa->p));
 }
 
+TEST(RSATest, KeygenFail) {
+  bssl::UniquePtr<RSA> rsa(RSA_new());
+  ASSERT_TRUE(rsa);
+
+  // Cause RSA key generation after a prime has been generated, to test that
+  // |rsa| is left alone.
+  BN_GENCB cb;
+  BN_GENCB_set(&cb,
+               [](int event, int, BN_GENCB *) -> int { return event != 3; },
+               nullptr);
+
+  bssl::UniquePtr<BIGNUM> e(BN_new());
+  ASSERT_TRUE(e);
+  ASSERT_TRUE(BN_set_word(e.get(), RSA_F4));
+
+  // Key generation should fail.
+  EXPECT_FALSE(RSA_generate_key_ex(rsa.get(), 2048, e.get(), &cb));
+
+  // Failed key generations do not leave garbage in |rsa|.
+  EXPECT_FALSE(rsa->n);
+  EXPECT_FALSE(rsa->e);
+  EXPECT_FALSE(rsa->d);
+  EXPECT_FALSE(rsa->p);
+  EXPECT_FALSE(rsa->q);
+  EXPECT_FALSE(rsa->dmp1);
+  EXPECT_FALSE(rsa->dmq1);
+  EXPECT_FALSE(rsa->iqmp);
+  EXPECT_FALSE(rsa->mont_n);
+  EXPECT_FALSE(rsa->mont_p);
+  EXPECT_FALSE(rsa->mont_q);
+  EXPECT_FALSE(rsa->d_fixed);
+  EXPECT_FALSE(rsa->dmp1_fixed);
+  EXPECT_FALSE(rsa->dmq1_fixed);
+  EXPECT_FALSE(rsa->inv_small_mod_large_mont);
+  EXPECT_FALSE(rsa->private_key_frozen);
+
+  // Failed key generations leave the previous contents alone.
+  EXPECT_TRUE(RSA_generate_key_ex(rsa.get(), 2048, e.get(), nullptr));
+  uint8_t *der;
+  size_t der_len;
+  ASSERT_TRUE(RSA_private_key_to_bytes(&der, &der_len, rsa.get()));
+  bssl::UniquePtr<uint8_t> delete_der(der);
+
+  EXPECT_FALSE(RSA_generate_key_ex(rsa.get(), 2048, e.get(), &cb));
+
+  uint8_t *der2;
+  size_t der2_len;
+  ASSERT_TRUE(RSA_private_key_to_bytes(&der2, &der2_len, rsa.get()));
+  bssl::UniquePtr<uint8_t> delete_der2(der2);
+  EXPECT_EQ(Bytes(der, der_len), Bytes(der2, der2_len));
+
+  // Generating a key over an existing key works, despite any cached state.
+  EXPECT_TRUE(RSA_generate_key_ex(rsa.get(), 2048, e.get(), nullptr));
+  EXPECT_TRUE(RSA_check_key(rsa.get()));
+  uint8_t *der3;
+  size_t der3_len;
+  ASSERT_TRUE(RSA_private_key_to_bytes(&der3, &der3_len, rsa.get()));
+  bssl::UniquePtr<uint8_t> delete_der3(der3);
+  EXPECT_NE(Bytes(der, der_len), Bytes(der3, der3_len));
+}
+
+TEST(RSATest, KeygenFailOnce) {
+  bssl::UniquePtr<RSA> rsa(RSA_new());
+  ASSERT_TRUE(rsa);
+
+  // Cause only the first iteration of RSA key generation to fail.
+  bool failed = false;
+  BN_GENCB cb;
+  BN_GENCB_set(&cb,
+               [](int event, int n, BN_GENCB *cb_ptr) -> int {
+                 bool *failed_ptr = static_cast<bool *>(cb_ptr->arg);
+                 if (*failed_ptr) {
+                   ADD_FAILURE() << "Callback called multiple times.";
+                   return 1;
+                 }
+                 *failed_ptr = true;
+                 return 0;
+               },
+               &failed);
+
+  // Although key generation internally retries, the external behavior of
+  // |BN_GENCB| is preserved.
+  bssl::UniquePtr<BIGNUM> e(BN_new());
+  ASSERT_TRUE(e);
+  ASSERT_TRUE(BN_set_word(e.get(), RSA_F4));
+  EXPECT_FALSE(RSA_generate_key_ex(rsa.get(), 2048, e.get(), &cb));
+}
+
+TEST(RSATest, KeygenInternalRetry) {
+  bssl::UniquePtr<RSA> rsa(RSA_new());
+  ASSERT_TRUE(rsa);
+
+  // Simulate one internal attempt at key generation failing.
+  bool failed = false;
+  BN_GENCB cb;
+  BN_GENCB_set(&cb,
+               [](int event, int n, BN_GENCB *cb_ptr) -> int {
+                 bool *failed_ptr = static_cast<bool *>(cb_ptr->arg);
+                 if (*failed_ptr) {
+                   return 1;
+                 }
+                 *failed_ptr = true;
+                 // This test does not test any public API behavior. It is just
+                 // a hack to exercise the retry codepath and make sure it
+                 // works.
+                 OPENSSL_PUT_ERROR(RSA, RSA_R_TOO_MANY_ITERATIONS);
+                 return 0;
+               },
+               &failed);
+
+  // Key generation internally retries on RSA_R_TOO_MANY_ITERATIONS.
+  bssl::UniquePtr<BIGNUM> e(BN_new());
+  ASSERT_TRUE(e);
+  ASSERT_TRUE(BN_set_word(e.get(), RSA_F4));
+  EXPECT_TRUE(RSA_generate_key_ex(rsa.get(), 2048, e.get(), &cb));
+}
+
 #if !defined(BORINGSSL_SHARED_LIBRARY)
 TEST(RSATest, SqrtTwo) {
   bssl::UniquePtr<BIGNUM> sqrt(BN_new()), pow2(BN_new());
diff --git a/include/openssl/bn.h b/include/openssl/bn.h
index 1ed9864..b0f8a2d 100644
--- a/include/openssl/bn.h
+++ b/include/openssl/bn.h
@@ -659,8 +659,7 @@
 // BN_GENCB_set configures |callback| to call |f| and sets |callout->arg| to
 // |arg|.
 OPENSSL_EXPORT void BN_GENCB_set(BN_GENCB *callback,
-                                 int (*f)(int event, int n,
-                                          struct bn_gencb_st *),
+                                 int (*f)(int event, int n, BN_GENCB *),
                                  void *arg);
 
 // BN_GENCB_call calls |callback|, if not NULL, and returns the return value of