Use UniquePtr in DHImpl

Also a few BN_MONT_CTXs that are all tied together via
BN_MONT_CTX_set_locked.

Change-Id: I2ab5e243163280952c6889053e2e2722b7e6f1bd
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/90248
Reviewed-by: Rudolf Polzer <rpolzer@google.com>
Commit-Queue: Rudolf Polzer <rpolzer@google.com>
Auto-Submit: David Benjamin <davidben@google.com>
diff --git a/crypto/dh/dh_asn1.cc b/crypto/dh/dh_asn1.cc
index 16ec7ba..01241db 100644
--- a/crypto/dh/dh_asn1.cc
+++ b/crypto/dh/dh_asn1.cc
@@ -27,13 +27,13 @@
 
 using namespace bssl;
 
-static int parse_integer(CBS *cbs, BIGNUM **out) {
+static int parse_integer(CBS *cbs, UniquePtr<BIGNUM> *out) {
   assert(*out == nullptr);
-  *out = BN_new();
+  out->reset(BN_new());
   if (*out == nullptr) {
     return 0;
   }
-  return BN_parse_asn1_unsigned(cbs, *out);
+  return BN_parse_asn1_unsigned(cbs, out->get());
 }
 
 static int marshal_integer(CBB *cbb, BIGNUM *bn) {
@@ -54,7 +54,8 @@
   CBS child;
   auto *impl = FromOpaque(ret.get());
   if (!CBS_get_asn1(cbs, &child, CBS_ASN1_SEQUENCE) ||
-      !parse_integer(&child, &impl->p) || !parse_integer(&child, &impl->g)) {
+      !parse_integer(&child, &impl->p) ||  //
+      !parse_integer(&child, &impl->g)) {
     OPENSSL_PUT_ERROR(DH, DH_R_DECODE_ERROR);
     return nullptr;
   }
@@ -86,7 +87,8 @@
   CBB child;
   auto *impl = FromOpaque(dh);
   if (!CBB_add_asn1(cbb, &child, CBS_ASN1_SEQUENCE) ||
-      !marshal_integer(&child, impl->p) || !marshal_integer(&child, impl->g) ||
+      !marshal_integer(&child, impl->p.get()) ||
+      !marshal_integer(&child, impl->g.get()) ||
       (impl->priv_length != 0 &&
        !CBB_add_asn1_uint64(&child, impl->priv_length)) ||
       !CBB_flush(cbb)) {
diff --git a/crypto/dh/params.cc b/crypto/dh/params.cc
index 37481f7..406a069 100644
--- a/crypto/dh/params.cc
+++ b/crypto/dh/params.cc
@@ -310,14 +310,14 @@
   // Make sure |dh| has the necessary elements
   auto *impl = FromOpaque(dh);
   if (impl->p == nullptr) {
-    impl->p = BN_new();
+    impl->p.reset(BN_new());
     if (impl->p == nullptr) {
       OPENSSL_PUT_ERROR(DH, ERR_R_BN_LIB);
       return 0;
     }
   }
   if (impl->g == nullptr) {
-    impl->g = BN_new();
+    impl->g.reset(BN_new());
     if (impl->g == nullptr) {
       OPENSSL_PUT_ERROR(DH, ERR_R_BN_LIB);
       return 0;
@@ -350,10 +350,10 @@
   if (t1_bn == nullptr || t2_bn == nullptr ||
       !BN_set_word(t1_bn.get(), t1) ||  //
       !BN_set_word(t2_bn.get(), t2) ||  //
-      !BN_generate_prime_ex(impl->p, prime_bits, 1, t1_bn.get(), t2_bn.get(),
-                            cb) ||
+      !BN_generate_prime_ex(impl->p.get(), prime_bits, 1, t1_bn.get(),
+                            t2_bn.get(), cb) ||
       !BN_GENCB_call(cb, 3, 0) ||  //
-      !BN_set_word(impl->g, g)) {
+      !BN_set_word(impl->g.get(), g)) {
     OPENSSL_PUT_ERROR(DH, ERR_R_BN_LIB);
     return 0;
   }
@@ -361,19 +361,16 @@
   return 1;
 }
 
-static int int_dh_bn_cpy(BIGNUM **dst, const BIGNUM *src) {
-  BIGNUM *a = nullptr;
-
+static bool copy_bn(UniquePtr<BIGNUM> *dst, const BIGNUM *src) {
+  UniquePtr<BIGNUM> copy;
   if (src) {
-    a = BN_dup(src);
-    if (!a) {
-      return 0;
+    copy.reset(BN_dup(src));
+    if (!copy) {
+      return false;
     }
   }
-
-  BN_free(*dst);
-  *dst = a;
-  return 1;
+  *dst = std::move(copy);
+  return true;
 }
 
 static int int_dh_param_copy(DH *to, const DH *from, int is_x942) {
@@ -383,8 +380,8 @@
   if (is_x942 == -1) {
     is_x942 = !!from_impl->q;
   }
-  if (!int_dh_bn_cpy(&to_impl->p, from_impl->p) ||
-      !int_dh_bn_cpy(&to_impl->g, from_impl->g)) {
+  if (!copy_bn(&to_impl->p, from_impl->p.get()) ||
+      !copy_bn(&to_impl->g, from_impl->g.get())) {
     return 0;
   }
 
@@ -392,7 +389,7 @@
     return 1;
   }
 
-  if (!int_dh_bn_cpy(&to_impl->q, from_impl->q)) {
+  if (!copy_bn(&to_impl->q, from_impl->q.get())) {
     return 0;
   }
 
diff --git a/crypto/dsa/dsa.cc b/crypto/dsa/dsa.cc
index dd6c02a..cf48dfb 100644
--- a/crypto/dsa/dsa.cc
+++ b/crypto/dsa/dsa.cc
@@ -62,8 +62,6 @@
   BN_clear_free(g);
   BN_clear_free(pub_key);
   BN_clear_free(priv_key);
-  BN_MONT_CTX_free(method_mont_p);
-  BN_MONT_CTX_free(method_mont_q);
 }
 
 void DSA_free(DSA *dsa) {
@@ -176,9 +174,7 @@
     impl->g = g;
   }
 
-  BN_MONT_CTX_free(impl->method_mont_p);
   impl->method_mont_p = nullptr;
-  BN_MONT_CTX_free(impl->method_mont_q);
   impl->method_mont_q = nullptr;
   return 1;
 }
@@ -480,7 +476,7 @@
   if (!BN_MONT_CTX_set_locked(&impl->method_mont_p, &impl->method_mont_lock,
                               impl->p, ctx.get()) ||
       !BN_mod_exp_mont_consttime(pub_key, impl->g, priv_key, impl->p, ctx.get(),
-                                 impl->method_mont_p)) {
+                                 impl->method_mont_p.get())) {
     goto err;
   }
 
@@ -616,9 +612,10 @@
 
     // Compute s = inv(k) (m + xr) mod q. Note |impl->method_mont_q| is
     // initialized by |dsa_sign_setup|.
-    if (!mod_mul_consttime(&xr, impl->priv_key, r, impl->method_mont_q, ctx) ||
+    if (!mod_mul_consttime(&xr, impl->priv_key, r, impl->method_mont_q.get(),
+                           ctx) ||
         !bn_mod_add_consttime(s, &xr, &m, impl->q, ctx) ||
-        !mod_mul_consttime(s, s, kinv, impl->method_mont_q, ctx)) {
+        !mod_mul_consttime(s, s, kinv, impl->method_mont_q.get(), ctx)) {
       goto err;
     }
 
@@ -715,7 +712,7 @@
     // Calculate W = inv(S) mod Q, in the Montgomery domain. This is slightly
     // more efficiently computed as FromMont(s)^-1 = (s * R^-1)^-1 = s^-1 * R,
     // instead of ToMont(s^-1) = s^-1 * R.
-    if (!BN_from_montgomery(&u2, sig->s, impl->method_mont_q, ctx) ||
+    if (!BN_from_montgomery(&u2, sig->s, impl->method_mont_q.get(), ctx) ||
         !BN_mod_inverse(&u2, &u2, impl->q, ctx)) {
       goto err;
     }
@@ -735,18 +732,19 @@
 
     // u1 = M * w mod q. w was stored in the Montgomery domain while M was not,
     // so the result will already be out of the Montgomery domain.
-    if (!BN_mod_mul_montgomery(&u1, &u1, &u2, impl->method_mont_q, ctx)) {
+    if (!BN_mod_mul_montgomery(&u1, &u1, &u2, impl->method_mont_q.get(), ctx)) {
       goto err;
     }
 
     // u2 = r * w mod q. w was stored in the Montgomery domain while r was not,
     // so the result will already be out of the Montgomery domain.
-    if (!BN_mod_mul_montgomery(&u2, sig->r, &u2, impl->method_mont_q, ctx)) {
+    if (!BN_mod_mul_montgomery(&u2, sig->r, &u2, impl->method_mont_q.get(),
+                               ctx)) {
       goto err;
     }
 
     if (!BN_mod_exp2_mont(&t1, impl->g, &u1, impl->pub_key, &u2, impl->p, ctx,
-                          impl->method_mont_p)) {
+                          impl->method_mont_p.get())) {
       goto err;
     }
 
@@ -886,7 +884,7 @@
                               dsa->q, ctx) ||
       // Compute r = (g^k mod p) mod q
       !BN_mod_exp_mont_consttime(r, dsa->g, &k, dsa->p, ctx,
-                                 dsa->method_mont_p)) {
+                                 dsa->method_mont_p.get())) {
     OPENSSL_PUT_ERROR(DSA, ERR_R_BN_LIB);
     goto err;
   }
@@ -901,7 +899,7 @@
   if (!BN_mod(r, r, dsa->q, ctx) ||
       // Compute part of 's = inv(k) (m + xr) mod q' using Fermat's Little
       // Theorem.
-      !bn_mod_inverse_prime(kinv, &k, dsa->q, ctx, dsa->method_mont_q)) {
+      !bn_mod_inverse_prime(kinv, &k, dsa->q, ctx, dsa->method_mont_q.get())) {
     OPENSSL_PUT_ERROR(DSA, ERR_R_BN_LIB);
     goto err;
   }
@@ -938,6 +936,18 @@
   return CRYPTO_get_ex_data(&impl->ex_data, idx);
 }
 
+static bool copy_bn(UniquePtr<BIGNUM> *dst, const BIGNUM *src) {
+  UniquePtr<BIGNUM> copy;
+  if (src) {
+    copy.reset(BN_dup(src));
+    if (!copy) {
+      return false;
+    }
+  }
+  *dst = std::move(copy);
+  return true;
+}
+
 DH *DSA_dup_DH(const DSA *dsa) {
   auto *impl = FromOpaque(dsa);
 
@@ -952,16 +962,14 @@
   }
   if (impl->q != nullptr) {
     dh->priv_length = BN_num_bits(impl->q);
-    if ((dh->q = BN_dup(impl->q)) == nullptr) {
+    if (!copy_bn(&dh->q, impl->q)) {
       return nullptr;
     }
   }
-  if ((impl->p != nullptr && (dh->p = BN_dup(impl->p)) == nullptr) ||
-      (impl->g != nullptr && (dh->g = BN_dup(impl->g)) == nullptr) ||
-      (impl->pub_key != nullptr &&
-       (dh->pub_key = BN_dup(impl->pub_key)) == nullptr) ||
-      (impl->priv_key != nullptr &&
-       (dh->priv_key = BN_dup(impl->priv_key)) == nullptr)) {
+  if (!copy_bn(&dh->p, impl->p) ||              //
+      !copy_bn(&dh->g, impl->g) ||              //
+      !copy_bn(&dh->pub_key, impl->pub_key) ||  //
+      !copy_bn(&dh->priv_key, impl->priv_key)) {
     return nullptr;
   }
 
diff --git a/crypto/dsa/internal.h b/crypto/dsa/internal.h
index bc6da5d..c00d05d 100644
--- a/crypto/dsa/internal.h
+++ b/crypto/dsa/internal.h
@@ -38,8 +38,8 @@
 
   // Normally used to cache montgomery values
   mutable Mutex method_mont_lock;
-  mutable BN_MONT_CTX *method_mont_p = nullptr;
-  mutable BN_MONT_CTX *method_mont_q = nullptr;
+  mutable UniquePtr<BN_MONT_CTX> method_mont_p;
+  mutable UniquePtr<BN_MONT_CTX> method_mont_q;
   CRYPTO_EX_DATA ex_data;
 
  private:
diff --git a/crypto/fipsmodule/bn/internal.h b/crypto/fipsmodule/bn/internal.h
index 09d7b41..486fae1 100644
--- a/crypto/fipsmodule/bn/internal.h
+++ b/crypto/fipsmodule/bn/internal.h
@@ -700,8 +700,8 @@
 // this function assumes |mod| is public.
 //
 // If |*pmont| is already non-NULL then it does nothing and returns one.
-int BN_MONT_CTX_set_locked(BN_MONT_CTX **pmont, Mutex *lock, const BIGNUM *mod,
-                           BN_CTX *bn_ctx);
+int BN_MONT_CTX_set_locked(UniquePtr<BN_MONT_CTX> *pmont, Mutex *lock,
+                           const BIGNUM *mod, BN_CTX *bn_ctx);
 
 
 // Low-level operations for small numbers.
diff --git a/crypto/fipsmodule/bn/montgomery.cc.inc b/crypto/fipsmodule/bn/montgomery.cc.inc
index 9c13c5c..4bc6247 100644
--- a/crypto/fipsmodule/bn/montgomery.cc.inc
+++ b/crypto/fipsmodule/bn/montgomery.cc.inc
@@ -165,10 +165,10 @@
   return mont;
 }
 
-int bssl::BN_MONT_CTX_set_locked(BN_MONT_CTX **pmont, Mutex *lock,
+int bssl::BN_MONT_CTX_set_locked(UniquePtr<BN_MONT_CTX> *pmont, Mutex *lock,
                                  const BIGNUM *mod, BN_CTX *bn_ctx) {
   lock->LockRead();
-  BN_MONT_CTX *ctx = *pmont;
+  BN_MONT_CTX *ctx = pmont->get();
   lock->UnlockRead();
 
   if (ctx) {
@@ -177,7 +177,7 @@
 
   MutexWriteLock write_lock(lock);
   if (*pmont == nullptr) {
-    *pmont = BN_MONT_CTX_new_for_modulus(mod, bn_ctx);
+    pmont->reset(BN_MONT_CTX_new_for_modulus(mod, bn_ctx));
   }
   return *pmont != nullptr;
 }
diff --git a/crypto/fipsmodule/dh/check.cc.inc b/crypto/fipsmodule/dh/check.cc.inc
index 544d2eb..ec8e358 100644
--- a/crypto/fipsmodule/dh/check.cc.inc
+++ b/crypto/fipsmodule/dh/check.cc.inc
@@ -30,22 +30,22 @@
   auto *impl = FromOpaque(dh);
 
   // Most operations scale with p and q.
-  if (BN_is_negative(impl->p) || !BN_is_odd(impl->p) ||
-      BN_num_bits(impl->p) > OPENSSL_DH_MAX_MODULUS_BITS) {
+  if (BN_is_negative(impl->p.get()) || !BN_is_odd(impl->p.get()) ||
+      BN_num_bits(impl->p.get()) > OPENSSL_DH_MAX_MODULUS_BITS) {
     OPENSSL_PUT_ERROR(DH, DH_R_INVALID_PARAMETERS);
     return 0;
   }
 
   // q must be bounded by p.
-  if (impl->q != nullptr &&
-      (BN_is_negative(impl->q) || BN_ucmp(impl->q, impl->p) > 0)) {
+  if (impl->q != nullptr && (BN_is_negative(impl->q.get()) ||
+                             BN_ucmp(impl->q.get(), impl->p.get()) > 0)) {
     OPENSSL_PUT_ERROR(DH, DH_R_INVALID_PARAMETERS);
     return 0;
   }
 
   // g must be an element of p's multiplicative group.
-  if (BN_is_negative(impl->g) || BN_is_zero(impl->g) ||
-      BN_ucmp(impl->g, impl->p) >= 0) {
+  if (BN_is_negative(impl->g.get()) || BN_is_zero(impl->g.get()) ||
+      BN_ucmp(impl->g.get(), impl->p.get()) >= 0) {
     OPENSSL_PUT_ERROR(DH, DH_R_INVALID_PARAMETERS);
     return 0;
   }
@@ -74,7 +74,7 @@
 
   // Check |pub_key| is less than |impl->p| - 1.
   BIGNUM *tmp = BN_CTX_get(ctx.get());
-  if (tmp == nullptr || !BN_copy(tmp, impl->p) || !BN_sub_word(tmp, 1)) {
+  if (tmp == nullptr || !BN_copy(tmp, impl->p.get()) || !BN_sub_word(tmp, 1)) {
     return 0;
   }
   if (BN_cmp(pub_key, tmp) >= 0) {
@@ -85,7 +85,8 @@
     // Check |pub_key|^|impl->q| is 1 mod |impl->p|. This is necessary for RFC
     // 5114 groups which are not safe primes but pick a generator on a
     // prime-order subgroup of size |impl->q|.
-    if (!BN_mod_exp_mont(tmp, pub_key, impl->q, impl->p, ctx.get(), nullptr)) {
+    if (!BN_mod_exp_mont(tmp, pub_key, impl->q.get(), impl->p.get(), ctx.get(),
+                         nullptr)) {
       return 0;
     }
     if (!BN_is_one(tmp)) {
@@ -125,21 +126,22 @@
   }
 
   if (impl->q) {
-    if (BN_cmp(impl->g, BN_value_one()) <= 0) {
+    if (BN_cmp(impl->g.get(), BN_value_one()) <= 0) {
       *out_flags |= DH_CHECK_NOT_SUITABLE_GENERATOR;
-    } else if (BN_cmp(impl->g, impl->p) >= 0) {
+    } else if (BN_cmp(impl->g.get(), impl->p.get()) >= 0) {
       *out_flags |= DH_CHECK_NOT_SUITABLE_GENERATOR;
     } else {
       // Check g^q == 1 mod p
-      if (!BN_mod_exp_mont(t1, impl->g, impl->q, impl->p, ctx.get(), nullptr)) {
+      if (!BN_mod_exp_mont(t1, impl->g.get(), impl->q.get(), impl->p.get(),
+                           ctx.get(), nullptr)) {
         return 0;
       }
       if (!BN_is_one(t1)) {
         *out_flags |= DH_CHECK_NOT_SUITABLE_GENERATOR;
       }
     }
-    int r = BN_is_prime_ex(impl->q, BN_prime_checks_for_validation, ctx.get(),
-                           nullptr);
+    int r = BN_is_prime_ex(impl->q.get(), BN_prime_checks_for_validation,
+                           ctx.get(), nullptr);
     if (r < 0) {
       return 0;
     }
@@ -147,22 +149,22 @@
       *out_flags |= DH_CHECK_Q_NOT_PRIME;
     }
     // Check p == 1 mod q  i.e. q divides p - 1
-    if (!BN_div(t1, t2, impl->p, impl->q, ctx.get())) {
+    if (!BN_div(t1, t2, impl->p.get(), impl->q.get(), ctx.get())) {
       return 0;
     }
     if (!BN_is_one(t2)) {
       *out_flags |= DH_CHECK_INVALID_Q_VALUE;
     }
-  } else if (BN_is_word(impl->g, DH_GENERATOR_2)) {
-    BN_ULONG l = BN_mod_word(impl->p, 24);
+  } else if (BN_is_word(impl->g.get(), DH_GENERATOR_2)) {
+    BN_ULONG l = BN_mod_word(impl->p.get(), 24);
     if (l == (BN_ULONG)-1) {
       return 0;
     }
     if (l != 11) {
       *out_flags |= DH_CHECK_NOT_SUITABLE_GENERATOR;
     }
-  } else if (BN_is_word(impl->g, DH_GENERATOR_5)) {
-    BN_ULONG l = BN_mod_word(impl->p, 10);
+  } else if (BN_is_word(impl->g.get(), DH_GENERATOR_5)) {
+    BN_ULONG l = BN_mod_word(impl->p.get(), 10);
     if (l == (BN_ULONG)-1) {
       return 0;
     }
@@ -173,15 +175,15 @@
     *out_flags |= DH_CHECK_UNABLE_TO_CHECK_GENERATOR;
   }
 
-  int r = BN_is_prime_ex(impl->p, BN_prime_checks_for_validation, ctx.get(),
-                         nullptr);
+  int r = BN_is_prime_ex(impl->p.get(), BN_prime_checks_for_validation,
+                         ctx.get(), nullptr);
   if (r < 0) {
     return 0;
   }
   if (!r) {
     *out_flags |= DH_CHECK_P_NOT_PRIME;
   } else if (!impl->q) {
-    if (!BN_rshift1(t1, impl->p)) {
+    if (!BN_rshift1(t1, impl->p.get())) {
       return 0;
     }
     r = BN_is_prime_ex(t1, BN_prime_checks_for_validation, ctx.get(), nullptr);
diff --git a/crypto/fipsmodule/dh/dh.cc.inc b/crypto/fipsmodule/dh/dh.cc.inc
index dd75a35..9837c6c 100644
--- a/crypto/fipsmodule/dh/dh.cc.inc
+++ b/crypto/fipsmodule/dh/dh.cc.inc
@@ -32,78 +32,49 @@
 
 using namespace bssl;
 
-DHImpl::DHImpl() : RefCounted(CheckSubClass()) {}
-
 DH *DH_new() { return New<DHImpl>(); }
 
-DHImpl::~DHImpl() {
-  BN_MONT_CTX_free(method_mont_p);
-  BN_clear_free(p);
-  BN_clear_free(g);
-  BN_clear_free(q);
-  BN_clear_free(pub_key);
-  BN_clear_free(priv_key);
-}
-
 void DH_free(DH *dh) {
-  if (dh == nullptr) {
-    return;
+  if (dh != nullptr) {
+    FromOpaque(dh)->DecRefInternal();
   }
-  auto *impl = FromOpaque(dh);
-  impl->DecRefInternal();
 }
 
-unsigned DH_bits(const DH *dh) {
-  auto *impl = FromOpaque(dh);
-  return BN_num_bits(impl->p);
-}
+unsigned DH_bits(const DH *dh) { return BN_num_bits(FromOpaque(dh)->p.get()); }
 
 const BIGNUM *DH_get0_pub_key(const DH *dh) {
-  auto *impl = FromOpaque(dh);
-  return impl->pub_key;
+  return FromOpaque(dh)->pub_key.get();
 }
 
 const BIGNUM *DH_get0_priv_key(const DH *dh) {
-  auto *impl = FromOpaque(dh);
-  return impl->priv_key;
+  return FromOpaque(dh)->priv_key.get();
 }
 
-const BIGNUM *DH_get0_p(const DH *dh) {
-  auto *impl = FromOpaque(dh);
-  return impl->p;
-}
+const BIGNUM *DH_get0_p(const DH *dh) { return FromOpaque(dh)->p.get(); }
 
-const BIGNUM *DH_get0_q(const DH *dh) {
-  auto *impl = FromOpaque(dh);
-  return impl->q;
-}
+const BIGNUM *DH_get0_q(const DH *dh) { return FromOpaque(dh)->q.get(); }
 
-const BIGNUM *DH_get0_g(const DH *dh) {
-  auto *impl = FromOpaque(dh);
-  return impl->g;
-}
+const BIGNUM *DH_get0_g(const DH *dh) { return FromOpaque(dh)->g.get(); }
 
 void DH_get0_key(const DH *dh, const BIGNUM **out_pub_key,
                  const BIGNUM **out_priv_key) {
   auto *impl = FromOpaque(dh);
   if (out_pub_key != nullptr) {
-    *out_pub_key = impl->pub_key;
+    *out_pub_key = impl->pub_key.get();
   }
   if (out_priv_key != nullptr) {
-    *out_priv_key = impl->priv_key;
+    *out_priv_key = impl->priv_key.get();
   }
 }
 
 int DH_set0_key(DH *dh, BIGNUM *pub_key, BIGNUM *priv_key) {
   auto *impl = FromOpaque(dh);
   if (pub_key != nullptr) {
-    BN_free(impl->pub_key);
-    impl->pub_key = pub_key;
+    impl->pub_key.reset(pub_key);
   }
 
   if (priv_key != nullptr) {
-    BN_free(impl->priv_key);
-    impl->priv_key = priv_key;
+    impl->priv_key.reset(priv_key);
   }
 
   return 1;
@@ -113,13 +84,13 @@
                  const BIGNUM **out_g) {
   auto *impl = FromOpaque(dh);
   if (out_p != nullptr) {
-    *out_p = impl->p;
+    *out_p = impl->p.get();
   }
   if (out_q != nullptr) {
-    *out_q = impl->q;
+    *out_q = impl->q.get();
   }
   if (out_g != nullptr) {
-    *out_g = impl->g;
+    *out_g = impl->g.get();
   }
 }
 
@@ -131,22 +102,18 @@
   }
 
   if (p != nullptr) {
-    BN_free(impl->p);
-    impl->p = p;
+    impl->p.reset(p);
   }
 
   if (q != nullptr) {
-    BN_free(impl->q);
-    impl->q = q;
+    impl->q.reset(q);
   }
 
   if (g != nullptr) {
-    BN_free(impl->g);
-    impl->g = g;
+    impl->g.reset(g);
   }
 
   // Invalidate the cached Montgomery parameters.
-  BN_MONT_CTX_free(impl->method_mont_p);
   impl->method_mont_p = nullptr;
   return 1;
 }
@@ -164,41 +131,29 @@
     return 0;
   }
 
-  int ok = 0;
-  bool generate_new_key = false;
-  BIGNUM *pub_key = nullptr, *priv_key = nullptr;
   auto *impl = FromOpaque(dh);
-
   UniquePtr<BN_CTX> ctx(BN_CTX_new());
   if (ctx == nullptr) {
-    goto err;
-  }
-
-  if (impl->priv_key == nullptr) {
-    priv_key = BN_new();
-    if (priv_key == nullptr) {
-      goto err;
-    }
-    generate_new_key = true;
-  } else {
-    priv_key = impl->priv_key;
-  }
-
-  if (impl->pub_key == nullptr) {
-    pub_key = BN_new();
-    if (pub_key == nullptr) {
-      goto err;
-    }
-  } else {
-    pub_key = impl->pub_key;
+    OPENSSL_PUT_ERROR(DH, ERR_R_BN_LIB);
+    return 0;
   }
 
   if (!BN_MONT_CTX_set_locked(&impl->method_mont_p, &impl->method_mont_p_lock,
-                              impl->p, ctx.get())) {
-    goto err;
+                              impl->p.get(), ctx.get())) {
+    OPENSSL_PUT_ERROR(DH, ERR_R_BN_LIB);
+    return 0;
   }
 
-  if (generate_new_key) {
+  // Only generate a private key if there's already one. Otherwise,
+  // |DH_generate_key| recomputes the public key.
+  const BIGNUM *priv_key = impl->priv_key.get();
+  UniquePtr<BIGNUM> new_priv_key;
+  if (priv_key == nullptr) {
+    new_priv_key.reset(BN_new());
+    if (new_priv_key == nullptr) {
+      OPENSSL_PUT_ERROR(DH, ERR_R_BN_LIB);
+      return 0;
+    }
     if (impl->q) {
       // Section 5.6.1.1.4 of SP 800-56A Rev3 generates a private key uniformly
       // from [1, min(2^N-1, q-1)].
@@ -206,8 +161,9 @@
       // Although SP 800-56A Rev3 now permits a private key length N,
       // |impl->priv_length| historically was ignored when q is available. We
       // continue to ignore it and interpret such a configuration as N = len(q).
-      if (!BN_rand_range_ex(priv_key, 1, impl->q)) {
-        goto err;
+      if (!BN_rand_range_ex(new_priv_key.get(), 1, impl->q.get())) {
+        OPENSSL_PUT_ERROR(DH, ERR_R_BN_LIB);
+        return 0;
       }
     } else {
       // If q is unspecified, we expect p to be a safe prime, with g generating
@@ -222,49 +178,47 @@
       // Compute M = min(2^N, q).
       UniquePtr<BIGNUM> priv_key_limit(BN_new());
       if (priv_key_limit == nullptr) {
-        goto err;
+        OPENSSL_PUT_ERROR(DH, ERR_R_BN_LIB);
+        return 0;
       }
       if (impl->priv_length == 0 ||
-          impl->priv_length >= BN_num_bits(impl->p) - 1) {
+          impl->priv_length >= BN_num_bits(impl->p.get()) - 1) {
         // M = q = (p - 1) / 2.
-        if (!BN_rshift1(priv_key_limit.get(), impl->p)) {
-          goto err;
+        if (!BN_rshift1(priv_key_limit.get(), impl->p.get())) {
+          OPENSSL_PUT_ERROR(DH, ERR_R_BN_LIB);
+          return 0;
         }
       } else {
         // M = 2^N.
         if (!BN_set_bit(priv_key_limit.get(), impl->priv_length)) {
-          goto err;
+          OPENSSL_PUT_ERROR(DH, ERR_R_BN_LIB);
+          return 0;
         }
       }
 
       // Choose a private key uniformly from [1, M-1].
-      if (!BN_rand_range_ex(priv_key, 1, priv_key_limit.get())) {
-        goto err;
+      if (!BN_rand_range_ex(new_priv_key.get(), 1, priv_key_limit.get())) {
+        OPENSSL_PUT_ERROR(DH, ERR_R_BN_LIB);
+        return 0;
       }
     }
+    priv_key = new_priv_key.get();
   }
 
-  if (!BN_mod_exp_mont_consttime(pub_key, impl->g, priv_key, impl->p, ctx.get(),
-                                 impl->method_mont_p)) {
-    goto err;
-  }
-
-  impl->pub_key = pub_key;
-  impl->priv_key = priv_key;
-  ok = 1;
-
-err:
-  if (ok != 1) {
+  UniquePtr<BIGNUM> new_pub_key(BN_new());
+  if (new_pub_key == nullptr ||
+      !BN_mod_exp_mont_consttime(new_pub_key.get(), impl->g.get(), priv_key,
+                                 impl->p.get(), ctx.get(),
+                                 impl->method_mont_p.get())) {
     OPENSSL_PUT_ERROR(DH, ERR_R_BN_LIB);
+    return 0;
   }
 
-  if (impl->pub_key == nullptr) {
-    BN_free(pub_key);
+  impl->pub_key = std::move(new_pub_key);
+  if (new_priv_key != nullptr) {
+    impl->priv_key = std::move(new_priv_key);
   }
-  if (impl->priv_key == nullptr) {
-    BN_free(priv_key);
-  }
-  return ok;
+  return 1;
 }
 
 static int dh_compute_key(DH *dh, BIGNUM *out_shared_key,
@@ -290,13 +244,14 @@
   BIGNUM *p_minus_1 = BN_CTX_get(ctx);
   if (!p_minus_1 ||
       !BN_MONT_CTX_set_locked(&impl->method_mont_p, &impl->method_mont_p_lock,
-                              impl->p, ctx)) {
+                              impl->p.get(), ctx)) {
     return 0;
   }
 
-  if (!BN_mod_exp_mont_consttime(out_shared_key, peers_key, impl->priv_key,
-                                 impl->p, ctx, impl->method_mont_p) ||
-      !BN_copy(p_minus_1, impl->p) || !BN_sub_word(p_minus_1, 1)) {
+  if (!BN_mod_exp_mont_consttime(out_shared_key, peers_key,
+                                 impl->priv_key.get(), impl->p.get(), ctx,
+                                 impl->method_mont_p.get()) ||
+      !BN_copy(p_minus_1, impl->p.get()) || !BN_sub_word(p_minus_1, 1)) {
     OPENSSL_PUT_ERROR(DH, ERR_R_BN_LIB);
     return 0;
   }
@@ -390,10 +345,7 @@
   return ret;
 }
 
-int DH_size(const DH *dh) {
-  auto *impl = FromOpaque(dh);
-  return BN_num_bytes(impl->p);
-}
+int DH_size(const DH *dh) { return BN_num_bytes(FromOpaque(dh)->p.get()); }
 
 int DH_up_ref(DH *dh) {
   auto *impl = FromOpaque(dh);
diff --git a/crypto/fipsmodule/dh/internal.h b/crypto/fipsmodule/dh/internal.h
index 43ca555..337afa6 100644
--- a/crypto/fipsmodule/dh/internal.h
+++ b/crypto/fipsmodule/dh/internal.h
@@ -27,24 +27,24 @@
 
 class DHImpl : public dh_st, public RefCounted<DHImpl> {
  public:
-  DHImpl();
+  DHImpl() : RefCounted(CheckSubClass()) {}
 
-  BIGNUM *p = nullptr;
-  BIGNUM *g = nullptr;
-  BIGNUM *q = nullptr;
-  BIGNUM *pub_key = nullptr;   // g^x mod p
-  BIGNUM *priv_key = nullptr;  // x
+  UniquePtr<BIGNUM> p;
+  UniquePtr<BIGNUM> g;
+  UniquePtr<BIGNUM> q;
+  UniquePtr<BIGNUM> pub_key;   // g^x mod p
+  UniquePtr<BIGNUM> priv_key;  // x
 
   // priv_length contains the length, in bits, of the private value. If zero,
   // the private value will be the same length as |p|.
   unsigned priv_length = 0;
 
   mutable Mutex method_mont_p_lock;
-  mutable BN_MONT_CTX *method_mont_p = nullptr;
+  mutable UniquePtr<BN_MONT_CTX> method_mont_p;
 
  private:
   friend RefCounted;
-  ~DHImpl();
+  ~DHImpl() = default;
 };
 
 // dh_check_params_fast checks basic invariants on |dh|'s domain parameters. It
diff --git a/crypto/fipsmodule/rsa/internal.h b/crypto/fipsmodule/rsa/internal.h
index e9d8a36..c54e9bf 100644
--- a/crypto/fipsmodule/rsa/internal.h
+++ b/crypto/fipsmodule/rsa/internal.h
@@ -68,9 +68,9 @@
 
   // Used to cache montgomery values. The creation of these values is protected
   // by |lock|.
-  BN_MONT_CTX *mont_n = nullptr;
-  BN_MONT_CTX *mont_p = nullptr;
-  BN_MONT_CTX *mont_q = nullptr;
+  UniquePtr<BN_MONT_CTX> mont_n;
+  UniquePtr<BN_MONT_CTX> mont_p;
+  UniquePtr<BN_MONT_CTX> mont_q;
 
   // The following fields are copies of |d|, |dmp1|, and |dmq1|, respectively,
   // but with the correct widths to prevent side channels. These must use
diff --git a/crypto/fipsmodule/rsa/rsa_impl.cc.inc b/crypto/fipsmodule/rsa/rsa_impl.cc.inc
index 7bcb56f..e52ea47 100644
--- a/crypto/fipsmodule/rsa/rsa_impl.cc.inc
+++ b/crypto/fipsmodule/rsa/rsa_impl.cc.inc
@@ -18,7 +18,7 @@
 #include <limits.h>
 #include <string.h>
 
-#include <iterator>
+#include <utility>
 
 #include <openssl/bn.h>
 #include <openssl/err.h>
@@ -153,7 +153,7 @@
   // |p|, and |q| with the correct minimal widths.
 
   if (rsa->mont_n == nullptr) {
-    rsa->mont_n = BN_MONT_CTX_new_for_modulus(rsa->n, ctx);
+    rsa->mont_n.reset(BN_MONT_CTX_new_for_modulus(rsa->n, ctx));
     if (rsa->mont_n == nullptr) {
       return 0;
     }
@@ -176,14 +176,14 @@
     // code.
 
     if (rsa->mont_p == nullptr) {
-      rsa->mont_p = BN_MONT_CTX_new_consttime(rsa->p, ctx);
+      rsa->mont_p.reset(BN_MONT_CTX_new_consttime(rsa->p, ctx));
       if (rsa->mont_p == nullptr) {
         return 0;
       }
     }
 
     if (rsa->mont_q == nullptr) {
-      rsa->mont_q = BN_MONT_CTX_new_consttime(rsa->q, ctx);
+      rsa->mont_q.reset(BN_MONT_CTX_new_consttime(rsa->q, ctx));
       if (rsa->mont_q == nullptr) {
         return 0;
       }
@@ -204,7 +204,7 @@
       if (rsa->iqmp_mont == nullptr) {
         BIGNUM *iqmp_mont = BN_new();
         if (iqmp_mont == nullptr ||
-            !BN_to_montgomery(iqmp_mont, rsa->iqmp, rsa->mont_p, ctx)) {
+            !BN_to_montgomery(iqmp_mont, rsa->iqmp, rsa->mont_p.get(), ctx)) {
           BN_free(iqmp_mont);
           return 0;
         }
@@ -223,11 +223,8 @@
 
   impl->private_key_frozen = 0;
 
-  BN_MONT_CTX_free(impl->mont_n);
   impl->mont_n = nullptr;
-  BN_MONT_CTX_free(impl->mont_p);
   impl->mont_p = nullptr;
-  BN_MONT_CTX_free(impl->mont_q);
   impl->mont_q = nullptr;
 
   BN_free(impl->d_fixed);
@@ -351,7 +348,7 @@
 
   if (!BN_MONT_CTX_set_locked(&impl->mont_n, &impl->lock, impl->n, ctx.get()) ||
       !BN_mod_exp_mont(result, f, impl->e, &impl->mont_n->N, ctx.get(),
-                       impl->mont_n)) {
+                       impl->mont_n.get())) {
     goto err;
   }
 
@@ -447,13 +444,13 @@
       // time, which requires primes be the same size, rounded to the Montgomery
       // coefficient. (See |mod_montgomery|.) This is not required by RFC 8017,
       // but it is true for keys generated by us and all common implementations.
-      bn_less_than_montgomery_R(impl->q, impl->mont_p) &&
-      bn_less_than_montgomery_R(impl->p, impl->mont_q)) {
+      bn_less_than_montgomery_R(impl->q, impl->mont_p.get()) &&
+      bn_less_than_montgomery_R(impl->p, impl->mont_q.get())) {
     if (!rsa_mod_exp_crt(result, f, impl, ctx.get())) {
       return 0;
     }
   } else if (!BN_mod_exp_mont_consttime(result, f, impl->d_fixed, impl->n,
-                                        ctx.get(), impl->mont_n)) {
+                                        ctx.get(), impl->mont_n.get())) {
     return 0;
   }
 
@@ -471,7 +468,7 @@
     BIGNUM *vrfy = BN_CTX_get(ctx.get());
     if (vrfy == nullptr ||
         !BN_mod_exp_mont(vrfy, result, impl->e, impl->n, ctx.get(),
-                         impl->mont_n) ||
+                         impl->mont_n.get()) ||
         !constant_time_declassify_int(BN_equal_consttime(vrfy, f))) {
       OPENSSL_PUT_ERROR(RSA, ERR_R_INTERNAL_ERROR);
       return 0;
@@ -556,25 +553,25 @@
   declassify_assert(BN_ucmp(I, n) < 0);
 
   if (  // |m1| is the result modulo |q|.
-      !mod_montgomery(r1, I, q, rsa->mont_q, p, ctx) ||
+      !mod_montgomery(r1, I, q, rsa->mont_q.get(), p, ctx) ||
       !BN_mod_exp_mont_consttime(m1, r1, rsa->dmq1_fixed, q, ctx,
-                                 rsa->mont_q) ||
+                                 rsa->mont_q.get()) ||
       // |r0| is the result modulo |p|.
-      !mod_montgomery(r1, I, p, rsa->mont_p, q, ctx) ||
+      !mod_montgomery(r1, I, p, rsa->mont_p.get(), q, ctx) ||
       !BN_mod_exp_mont_consttime(r0, r1, rsa->dmp1_fixed, p, ctx,
-                                 rsa->mont_p) ||
+                                 rsa->mont_p.get()) ||
       // Compute r0 = r0 - m1 mod p. |m1| is reduced mod |q|, not |p|, so we
       // just run |mod_montgomery| again for srsaicity. This could be more
       // efficient with more cases: if |p > q|, |m1| is already reduced. If
       // |p < q| but they have the same bit width, |bn_reduce_once| suffices.
       // However, compared to over 2048 Montgomery multiplications above, this
       // difference is not measurable.
-      !mod_montgomery(r1, m1, p, rsa->mont_p, q, ctx) ||
+      !mod_montgomery(r1, m1, p, rsa->mont_p.get(), q, ctx) ||
       !bn_mod_sub_consttime(r0, r0, r1, p, ctx) ||
       // r0 = r0 * iqmp mod p. We use Montgomery multiplication to compute this
       // in constant time. |iqmp_mont| is in Montgomery form and r0 is not, so
       // the result is taken out of Montgomery form.
-      !BN_mod_mul_montgomery(r0, r0, rsa->iqmp_mont, rsa->mont_p, ctx) ||
+      !BN_mod_mul_montgomery(r0, r0, rsa->iqmp_mont, rsa->mont_p.get(), ctx) ||
       // r0 = r0 * q + m1 gives the final result. Reducing modulo q gives m1, so
       // it is correct mod p. Reducing modulo p gives (r0-m1)*iqmp*q + m1 = r0,
       // so it is correct mod q. Finally, the result is bounded by [m1, n + m1),
@@ -859,10 +856,10 @@
   bn_declassify(rsa->n);
 
   // Calculate q^-1 mod p.
-  rsa->mont_p = BN_MONT_CTX_new_consttime(rsa->p, ctx.get());
+  rsa->mont_p.reset(BN_MONT_CTX_new_consttime(rsa->p, ctx.get()));
   if (rsa->mont_p == nullptr ||  //
       !bn_mod_inverse_secret_prime(rsa->iqmp, rsa->q, rsa->p, ctx.get(),
-                                   rsa->mont_p)) {
+                                   rsa->mont_p.get())) {
     OPENSSL_PUT_ERROR(RSA, ERR_LIB_BN);
     return 0;
   }
@@ -891,12 +888,6 @@
   *in = nullptr;
 }
 
-static void replace_bn_mont_ctx(BN_MONT_CTX **out, BN_MONT_CTX **in) {
-  BN_MONT_CTX_free(*out);
-  *out = *in;
-  *in = nullptr;
-}
-
 static int RSA_generate_key_ex_maybe_fips(RSAImpl *rsa, int bits,
                                           const BIGNUM *e_value, BN_GENCB *cb,
                                           int check_fips) {
@@ -948,9 +939,9 @@
   replace_bignum(&rsa->dmp1, &tmp->dmp1);
   replace_bignum(&rsa->dmq1, &tmp->dmq1);
   replace_bignum(&rsa->iqmp, &tmp->iqmp);
-  replace_bn_mont_ctx(&rsa->mont_n, &tmp->mont_n);
-  replace_bn_mont_ctx(&rsa->mont_p, &tmp->mont_p);
-  replace_bn_mont_ctx(&rsa->mont_q, &tmp->mont_q);
+  rsa->mont_n = std::move(tmp->mont_n);
+  rsa->mont_p = std::move(tmp->mont_p);
+  rsa->mont_q = std::move(tmp->mont_q);
   replace_bignum(&rsa->d_fixed, &tmp->d_fixed);
   replace_bignum(&rsa->dmp1_fixed, &tmp->dmp1_fixed);
   replace_bignum(&rsa->dmq1_fixed, &tmp->dmq1_fixed);
diff --git a/crypto/rsa/rsa_crypt.cc b/crypto/rsa/rsa_crypt.cc
index 68fb133..18d49e2 100644
--- a/crypto/rsa/rsa_crypt.cc
+++ b/crypto/rsa/rsa_crypt.cc
@@ -406,7 +406,7 @@
 
   if (!BN_MONT_CTX_set_locked(&impl->mont_n, &impl->lock, impl->n, ctx.get()) ||
       !BN_mod_exp_mont(result, f, impl->e, &impl->mont_n->N, ctx.get(),
-                       impl->mont_n)) {
+                       impl->mont_n.get())) {
     goto err;
   }