PWCT failures should clear the generated key.

It's insufficient to signal an error when the PWCT fails. We
additionally need to ensure that the invalid key material is not
returned.

Change-Id: Ic5ff719a688985a61c52540ce6d1ed279a493d27
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/44306
Commit-Queue: Adam Langley <agl@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/crypto/fipsmodule/ec/ec_key.c b/crypto/fipsmodule/ec/ec_key.c
index cd48c60..bc09e0e 100644
--- a/crypto/fipsmodule/ec/ec_key.c
+++ b/crypto/fipsmodule/ec/ec_key.c
@@ -440,7 +440,15 @@
 }
 
 int EC_KEY_generate_key_fips(EC_KEY *eckey) {
-  return EC_KEY_generate_key(eckey) && EC_KEY_check_fips(eckey);
+  if (EC_KEY_generate_key(eckey) && EC_KEY_check_fips(eckey)) {
+    return 1;
+  }
+
+  EC_POINT_free(eckey->pub_key);
+  ec_wrapped_scalar_free(eckey->priv_key);
+  eckey->pub_key = NULL;
+  eckey->priv_key = NULL;
+  return 0;
 }
 
 int EC_KEY_get_ex_new_index(long argl, void *argp, CRYPTO_EX_unused *unused,
diff --git a/crypto/fipsmodule/rsa/rsa_impl.c b/crypto/fipsmodule/rsa/rsa_impl.c
index 7343b75..2f76e9e 100644
--- a/crypto/fipsmodule/rsa/rsa_impl.c
+++ b/crypto/fipsmodule/rsa/rsa_impl.c
@@ -1316,52 +1316,72 @@
   *in = NULL;
 }
 
-int RSA_generate_key_ex(RSA *rsa, int bits, const BIGNUM *e_value,
-                        BN_GENCB *cb) {
+static int RSA_generate_key_ex_maybe_fips(RSA *rsa, int bits,
+                                          const BIGNUM *e_value, BN_GENCB *cb,
+                                          int check_fips) {
+  RSA *tmp = NULL;
+  uint32_t err;
+  int ret = 0;
+
   // |rsa_generate_key_impl|'s 2^-20 failure probability is too high at scale,
   // so we run the FIPS algorithm four times, bringing it down to 2^-80. We
   // should just adjust the retry limit, but FIPS 186-4 prescribes that value
   // and thus results in unnecessary complexity.
-  for (int i = 0; i < 4; i++) {
+  int failures = 0;
+  do {
     ERR_clear_error();
     // Generate into scratch space, to avoid leaving partial work on failure.
-    RSA *tmp = RSA_new();
+    tmp = RSA_new();
     if (tmp == NULL) {
-      return 0;
+      goto out;
     }
+
     if (rsa_generate_key_impl(tmp, bits, e_value, cb)) {
-      replace_bignum(&rsa->n, &tmp->n);
-      replace_bignum(&rsa->e, &tmp->e);
-      replace_bignum(&rsa->d, &tmp->d);
-      replace_bignum(&rsa->p, &tmp->p);
-      replace_bignum(&rsa->q, &tmp->q);
-      replace_bignum(&rsa->dmp1, &tmp->dmp1);
-      replace_bignum(&rsa->dmq1, &tmp->dmq1);
-      replace_bignum(&rsa->iqmp, &tmp->iqmp);
-      replace_bn_mont_ctx(&rsa->mont_n, &tmp->mont_n);
-      replace_bn_mont_ctx(&rsa->mont_p, &tmp->mont_p);
-      replace_bn_mont_ctx(&rsa->mont_q, &tmp->mont_q);
-      replace_bignum(&rsa->d_fixed, &tmp->d_fixed);
-      replace_bignum(&rsa->dmp1_fixed, &tmp->dmp1_fixed);
-      replace_bignum(&rsa->dmq1_fixed, &tmp->dmq1_fixed);
-      replace_bignum(&rsa->inv_small_mod_large_mont,
-                     &tmp->inv_small_mod_large_mont);
-      rsa->private_key_frozen = tmp->private_key_frozen;
-      RSA_free(tmp);
-      return 1;
+      break;
     }
-    uint32_t err = ERR_peek_error();
+
+    err = ERR_peek_error();
     RSA_free(tmp);
     tmp = NULL;
+    failures++;
+
     // Only retry on |RSA_R_TOO_MANY_ITERATIONS|. This is so a caller-induced
     // failure in |BN_GENCB_call| is still fatal.
-    if (ERR_GET_LIB(err) != ERR_LIB_RSA ||
-        ERR_GET_REASON(err) != RSA_R_TOO_MANY_ITERATIONS) {
-      return 0;
-    }
+  } while (failures < 4 && ERR_GET_LIB(err) == ERR_LIB_RSA &&
+           ERR_GET_REASON(err) == RSA_R_TOO_MANY_ITERATIONS);
+
+  if (tmp == NULL || (check_fips && !RSA_check_fips(tmp))) {
+    goto out;
   }
 
-  return 0;
+  replace_bignum(&rsa->n, &tmp->n);
+  replace_bignum(&rsa->e, &tmp->e);
+  replace_bignum(&rsa->d, &tmp->d);
+  replace_bignum(&rsa->p, &tmp->p);
+  replace_bignum(&rsa->q, &tmp->q);
+  replace_bignum(&rsa->dmp1, &tmp->dmp1);
+  replace_bignum(&rsa->dmq1, &tmp->dmq1);
+  replace_bignum(&rsa->iqmp, &tmp->iqmp);
+  replace_bn_mont_ctx(&rsa->mont_n, &tmp->mont_n);
+  replace_bn_mont_ctx(&rsa->mont_p, &tmp->mont_p);
+  replace_bn_mont_ctx(&rsa->mont_q, &tmp->mont_q);
+  replace_bignum(&rsa->d_fixed, &tmp->d_fixed);
+  replace_bignum(&rsa->dmp1_fixed, &tmp->dmp1_fixed);
+  replace_bignum(&rsa->dmq1_fixed, &tmp->dmq1_fixed);
+  replace_bignum(&rsa->inv_small_mod_large_mont,
+                 &tmp->inv_small_mod_large_mont);
+  rsa->private_key_frozen = tmp->private_key_frozen;
+  ret = 1;
+
+out:
+  RSA_free(tmp);
+  return ret;
+}
+
+int RSA_generate_key_ex(RSA *rsa, int bits, const BIGNUM *e_value,
+                        BN_GENCB *cb) {
+  return RSA_generate_key_ex_maybe_fips(rsa, bits, e_value, cb,
+                                        /*check_fips=*/0);
 }
 
 int RSA_generate_key_fips(RSA *rsa, int bits, BN_GENCB *cb) {
@@ -1377,8 +1397,7 @@
   BIGNUM *e = BN_new();
   int ret = e != NULL &&
             BN_set_word(e, RSA_F4) &&
-            RSA_generate_key_ex(rsa, bits, e, cb) &&
-            RSA_check_fips(rsa);
+            RSA_generate_key_ex_maybe_fips(rsa, bits, e, cb, /*check_fips=*/1);
   BN_free(e);
   return ret;
 }