Add a function which folds BN_MONT_CTX_{new,set} together.

These empty states aren't any use to either caller or implementor.

Change-Id: If0b748afeeb79e4a1386182e61c5b5ecf838de62
Reviewed-on: https://boringssl-review.googlesource.com/25254
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/dsa/dsa.c b/crypto/dsa/dsa.c
index f3d4f85..532ffec 100644
--- a/crypto/dsa/dsa.c
+++ b/crypto/dsa/dsa.c
@@ -239,11 +239,6 @@
   }
   BN_CTX_start(ctx);
 
-  mont = BN_MONT_CTX_new();
-  if (mont == NULL) {
-    goto err;
-  }
-
   r0 = BN_CTX_get(ctx);
   g = BN_CTX_get(ctx);
   W = BN_CTX_get(ctx);
@@ -401,8 +396,9 @@
     goto err;
   }
 
-  if (!BN_set_word(test, h) ||
-      !BN_MONT_CTX_set(mont, p, ctx)) {
+  mont = BN_MONT_CTX_new_for_modulus(p, ctx);
+  if (mont == NULL ||
+      !BN_set_word(test, h)) {
     goto err;
   }
 
diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc
index 94b8973..f36656f 100644
--- a/crypto/fipsmodule/bn/bn_test.cc
+++ b/crypto/fipsmodule/bn/bn_test.cc
@@ -540,12 +540,12 @@
 
   if (BN_is_odd(m.get())) {
     // Reduce |a| and |b| and test the Montgomery version.
-    bssl::UniquePtr<BN_MONT_CTX> mont(BN_MONT_CTX_new());
+    bssl::UniquePtr<BN_MONT_CTX> mont(
+        BN_MONT_CTX_new_for_modulus(m.get(), ctx));
     bssl::UniquePtr<BIGNUM> a_tmp(BN_new()), b_tmp(BN_new());
     ASSERT_TRUE(mont);
     ASSERT_TRUE(a_tmp);
     ASSERT_TRUE(b_tmp);
-    ASSERT_TRUE(BN_MONT_CTX_set(mont.get(), m.get(), ctx));
     ASSERT_TRUE(BN_nnmod(a.get(), a.get(), m.get(), ctx));
     ASSERT_TRUE(BN_nnmod(b.get(), b.get(), m.get(), ctx));
     ASSERT_TRUE(BN_to_montgomery(a_tmp.get(), a.get(), mont.get(), ctx));
@@ -603,11 +603,11 @@
 
   if (BN_is_odd(m.get())) {
     // Reduce |a| and test the Montgomery version.
-    bssl::UniquePtr<BN_MONT_CTX> mont(BN_MONT_CTX_new());
+    bssl::UniquePtr<BN_MONT_CTX> mont(
+        BN_MONT_CTX_new_for_modulus(m.get(), ctx));
     bssl::UniquePtr<BIGNUM> a_tmp(BN_new());
     ASSERT_TRUE(mont);
     ASSERT_TRUE(a_tmp);
-    ASSERT_TRUE(BN_MONT_CTX_set(mont.get(), m.get(), ctx));
     ASSERT_TRUE(BN_nnmod(a.get(), a.get(), m.get(), ctx));
     ASSERT_TRUE(BN_to_montgomery(a_tmp.get(), a.get(), mont.get(), ctx));
     ASSERT_TRUE(BN_mod_mul_montgomery(ret.get(), a_tmp.get(), a_tmp.get(),
@@ -687,9 +687,9 @@
 #if !defined(BORINGSSL_SHARED_LIBRARY)
     size_t m_width = static_cast<size_t>(bn_minimal_width(m.get()));
     if (m_width <= BN_SMALL_MAX_WORDS) {
-      bssl::UniquePtr<BN_MONT_CTX> mont(BN_MONT_CTX_new());
+      bssl::UniquePtr<BN_MONT_CTX> mont(
+          BN_MONT_CTX_new_for_modulus(m.get(), ctx));
       ASSERT_TRUE(mont.get());
-      ASSERT_TRUE(BN_MONT_CTX_set(mont.get(), m.get(), ctx));
       ASSERT_TRUE(BN_nnmod(a.get(), a.get(), m.get(), ctx));
       std::unique_ptr<BN_ULONG[]> r_words(new BN_ULONG[m_width]),
           a_words(new BN_ULONG[m_width]);
@@ -1280,11 +1280,9 @@
   bssl::UniquePtr<BIGNUM> a(BN_new());
   bssl::UniquePtr<BIGNUM> b(BN_new());
   bssl::UniquePtr<BIGNUM> zero(BN_new());
-  bssl::UniquePtr<BN_MONT_CTX> mont(BN_MONT_CTX_new());
   ASSERT_TRUE(a);
   ASSERT_TRUE(b);
   ASSERT_TRUE(zero);
-  ASSERT_TRUE(mont);
 
   BN_zero(zero.get());
 
@@ -1307,13 +1305,16 @@
       a.get(), BN_value_one(), BN_value_one(), zero.get(), ctx(), nullptr));
   ERR_clear_error();
 
-  EXPECT_FALSE(BN_MONT_CTX_set(mont.get(), zero.get(), ctx()));
+  bssl::UniquePtr<BN_MONT_CTX> mont(
+      BN_MONT_CTX_new_for_modulus(zero.get(), ctx()));
+  EXPECT_FALSE(mont);
   ERR_clear_error();
 
   // Some operations also may not be used with an even modulus.
   ASSERT_TRUE(BN_set_word(b.get(), 16));
 
-  EXPECT_FALSE(BN_MONT_CTX_set(mont.get(), b.get(), ctx()));
+  mont.reset(BN_MONT_CTX_new_for_modulus(b.get(), ctx()));
+  EXPECT_FALSE(mont);
   ERR_clear_error();
 
   EXPECT_FALSE(BN_mod_exp_mont(a.get(), BN_value_one(), BN_value_one(), b.get(),
@@ -1979,14 +1980,14 @@
   bssl::UniquePtr<BIGNUM> p(BN_bin2bn(kP, sizeof(kP), nullptr));
   ASSERT_TRUE(p);
 
-  bssl::UniquePtr<BN_MONT_CTX> mont(BN_MONT_CTX_new());
+  bssl::UniquePtr<BN_MONT_CTX> mont(
+      BN_MONT_CTX_new_for_modulus(p.get(), ctx()));
   ASSERT_TRUE(mont);
-  ASSERT_TRUE(BN_MONT_CTX_set(mont.get(), p.get(), ctx()));
 
   ASSERT_TRUE(bn_resize_words(p.get(), 32));
-  bssl::UniquePtr<BN_MONT_CTX> mont2(BN_MONT_CTX_new());
+  bssl::UniquePtr<BN_MONT_CTX> mont2(
+      BN_MONT_CTX_new_for_modulus(p.get(), ctx()));
   ASSERT_TRUE(mont2);
-  ASSERT_TRUE(BN_MONT_CTX_set(mont2.get(), p.get(), ctx()));
 
   EXPECT_EQ(mont->N.top, mont2->N.top);
   EXPECT_EQ(0, BN_cmp(&mont->RR, &mont2->RR));
diff --git a/crypto/fipsmodule/bn/exponentiation.c b/crypto/fipsmodule/bn/exponentiation.c
index bdcc9d3..9e0ddfb 100644
--- a/crypto/fipsmodule/bn/exponentiation.c
+++ b/crypto/fipsmodule/bn/exponentiation.c
@@ -622,8 +622,8 @@
 
   // Allocate a montgomery context if it was not supplied by the caller.
   if (mont == NULL) {
-    new_mont = BN_MONT_CTX_new();
-    if (new_mont == NULL || !BN_MONT_CTX_set(new_mont, m, ctx)) {
+    new_mont = BN_MONT_CTX_new_for_modulus(m, ctx);
+    if (new_mont == NULL) {
       goto err;
     }
     mont = new_mont;
@@ -1014,8 +1014,8 @@
 
   // Allocate a montgomery context if it was not supplied by the caller.
   if (mont == NULL) {
-    new_mont = BN_MONT_CTX_new();
-    if (new_mont == NULL || !BN_MONT_CTX_set(new_mont, m, ctx)) {
+    new_mont = BN_MONT_CTX_new_for_modulus(m, ctx);
+    if (new_mont == NULL) {
       goto err;
     }
     mont = new_mont;
@@ -1331,8 +1331,8 @@
 
   // Allocate a montgomery context if it was not supplied by the caller.
   if (mont == NULL) {
-    new_mont = BN_MONT_CTX_new();
-    if (new_mont == NULL || !BN_MONT_CTX_set(new_mont, m, ctx)) {
+    new_mont = BN_MONT_CTX_new_for_modulus(m, ctx);
+    if (new_mont == NULL) {
       goto err;
     }
     mont = new_mont;
diff --git a/crypto/fipsmodule/bn/montgomery.c b/crypto/fipsmodule/bn/montgomery.c
index eaf2ba0..a51725c 100644
--- a/crypto/fipsmodule/bn/montgomery.c
+++ b/crypto/fipsmodule/bn/montgomery.c
@@ -223,6 +223,16 @@
   return 1;
 }
 
+BN_MONT_CTX *BN_MONT_CTX_new_for_modulus(const BIGNUM *mod, BN_CTX *ctx) {
+  BN_MONT_CTX *mont = BN_MONT_CTX_new();
+  if (mont == NULL ||
+      !BN_MONT_CTX_set(mont, mod, ctx)) {
+    BN_MONT_CTX_free(mont);
+    return NULL;
+  }
+  return mont;
+}
+
 int BN_MONT_CTX_set_locked(BN_MONT_CTX **pmont, CRYPTO_MUTEX *lock,
                            const BIGNUM *mod, BN_CTX *bn_ctx) {
   CRYPTO_MUTEX_lock_read(lock);
@@ -234,25 +244,12 @@
   }
 
   CRYPTO_MUTEX_lock_write(lock);
-  ctx = *pmont;
-  if (ctx) {
-    goto out;
+  if (*pmont == NULL) {
+    *pmont = BN_MONT_CTX_new_for_modulus(mod, bn_ctx);
   }
-
-  ctx = BN_MONT_CTX_new();
-  if (ctx == NULL) {
-    goto out;
-  }
-  if (!BN_MONT_CTX_set(ctx, mod, bn_ctx)) {
-    BN_MONT_CTX_free(ctx);
-    ctx = NULL;
-    goto out;
-  }
-  *pmont = ctx;
-
-out:
+  const int ok = *pmont != NULL;
   CRYPTO_MUTEX_unlock_write(lock);
-  return ctx != NULL;
+  return ok;
 }
 
 int BN_to_montgomery(BIGNUM *ret, const BIGNUM *a, const BN_MONT_CTX *mont,
diff --git a/crypto/fipsmodule/bn/prime.c b/crypto/fipsmodule/bn/prime.c
index 691d0cb..a291f7a 100644
--- a/crypto/fipsmodule/bn/prime.c
+++ b/crypto/fipsmodule/bn/prime.c
@@ -586,9 +586,8 @@
   }
 
   // Montgomery setup for computations mod A
-  mont = BN_MONT_CTX_new();
-  if (mont == NULL ||
-      !BN_MONT_CTX_set(mont, w, ctx)) {
+  mont = BN_MONT_CTX_new_for_modulus(w, ctx);
+  if (mont == NULL) {
     goto err;
   }
 
diff --git a/crypto/fipsmodule/ec/ec.c b/crypto/fipsmodule/ec/ec.c
index f002ccd..b55beb6 100644
--- a/crypto/fipsmodule/ec/ec.c
+++ b/crypto/fipsmodule/ec/ec.c
@@ -389,9 +389,8 @@
   }
 
   BN_MONT_CTX_free(group->order_mont);
-  group->order_mont = BN_MONT_CTX_new();
-  if (group->order_mont == NULL ||
-      !BN_MONT_CTX_set(group->order_mont, &group->order, NULL)) {
+  group->order_mont = BN_MONT_CTX_new_for_modulus(&group->order, NULL);
+  if (group->order_mont == NULL) {
     return 0;
   }
 
@@ -448,9 +447,8 @@
     goto err;
   }
 
-  group->order_mont = BN_MONT_CTX_new();
-  if (group->order_mont == NULL ||
-      !BN_MONT_CTX_set(group->order_mont, &group->order, ctx)) {
+  group->order_mont = BN_MONT_CTX_new_for_modulus(&group->order, ctx);
+  if (group->order_mont == NULL) {
     OPENSSL_PUT_ERROR(EC, ERR_R_BN_LIB);
     goto err;
   }
diff --git a/crypto/fipsmodule/ec/ec_montgomery.c b/crypto/fipsmodule/ec/ec_montgomery.c
index 898cf07..165c06f 100644
--- a/crypto/fipsmodule/ec/ec_montgomery.c
+++ b/crypto/fipsmodule/ec/ec_montgomery.c
@@ -93,7 +93,6 @@
 int ec_GFp_mont_group_set_curve(EC_GROUP *group, const BIGNUM *p,
                                 const BIGNUM *a, const BIGNUM *b, BN_CTX *ctx) {
   BN_CTX *new_ctx = NULL;
-  BN_MONT_CTX *mont = NULL;
   int ret = 0;
 
   BN_MONT_CTX_free(group->mont);
@@ -106,18 +105,12 @@
     }
   }
 
-  mont = BN_MONT_CTX_new();
-  if (mont == NULL) {
-    goto err;
-  }
-  if (!BN_MONT_CTX_set(mont, p, ctx)) {
+  group->mont = BN_MONT_CTX_new_for_modulus(p, ctx);
+  if (group->mont == NULL) {
     OPENSSL_PUT_ERROR(EC, ERR_R_BN_LIB);
     goto err;
   }
 
-  group->mont = mont;
-  mont = NULL;
-
   ret = ec_GFp_simple_group_set_curve(group, p, a, b, ctx);
 
   if (!ret) {
@@ -127,7 +120,6 @@
 
 err:
   BN_CTX_free(new_ctx);
-  BN_MONT_CTX_free(mont);
   return ret;
 }
 
diff --git a/crypto/fipsmodule/ec/p256-x86_64_test.cc b/crypto/fipsmodule/ec/p256-x86_64_test.cc
index 4f33de0..5cd701b 100644
--- a/crypto/fipsmodule/ec/p256-x86_64_test.cc
+++ b/crypto/fipsmodule/ec/p256-x86_64_test.cc
@@ -167,9 +167,9 @@
   }
 
   bssl::UniquePtr<BN_CTX> ctx(BN_CTX_new());
-  bssl::UniquePtr<BN_MONT_CTX> mont(BN_MONT_CTX_new());
+  bssl::UniquePtr<BN_MONT_CTX> mont(
+      BN_MONT_CTX_new_for_modulus(p.get(), ctx.get()));
   if (!ctx || !mont ||
-      !BN_MONT_CTX_set(mont.get(), p.get(), ctx.get()) ||
       // Invert Z.
       !BN_from_montgomery(z.get(), z.get(), mont.get(), ctx.get()) ||
       !BN_mod_inverse(z.get(), z.get(), p.get(), ctx.get()) ||
diff --git a/fuzz/bn_mod_exp.cc b/fuzz/bn_mod_exp.cc
index 46e3c88..997c3a6 100644
--- a/fuzz/bn_mod_exp.cc
+++ b/fuzz/bn_mod_exp.cc
@@ -93,11 +93,9 @@
   }
 
   bssl::UniquePtr<BN_CTX> ctx(BN_CTX_new());
-  bssl::UniquePtr<BN_MONT_CTX> mont(BN_MONT_CTX_new());
   bssl::UniquePtr<BIGNUM> result(BN_new());
   bssl::UniquePtr<BIGNUM> expected(BN_new());
   CHECK(ctx);
-  CHECK(mont);
   CHECK(result);
   CHECK(expected);
 
@@ -108,7 +106,9 @@
   CHECK(BN_cmp(result.get(), expected.get()) == 0);
 
   if (BN_is_odd(modulus.get())) {
-    CHECK(BN_MONT_CTX_set(mont.get(), modulus.get(), ctx.get()));
+    bssl::UniquePtr<BN_MONT_CTX> mont(
+        BN_MONT_CTX_new_for_modulus(modulus.get(), ctx.get()));
+    CHECK(mont);
     CHECK(BN_mod_exp_mont(result.get(), base.get(), power.get(), modulus.get(),
                           ctx.get(), mont.get()));
     CHECK(BN_cmp(result.get(), expected.get()) == 0);
diff --git a/include/openssl/bn.h b/include/openssl/bn.h
index d9a7f08..7f7d479 100644
--- a/include/openssl/bn.h
+++ b/include/openssl/bn.h
@@ -793,8 +793,10 @@
 // BN_MONT_CTX contains the precomputed values needed to work in a specific
 // Montgomery domain.
 
-// BN_MONT_CTX_new returns a fresh BN_MONT_CTX or NULL on allocation failure.
-OPENSSL_EXPORT BN_MONT_CTX *BN_MONT_CTX_new(void);
+// BN_MONT_CTX_new_for_modulus returns a fresh |BN_MONT_CTX| given the modulus,
+// |mod| or NULL on error.
+OPENSSL_EXPORT BN_MONT_CTX *BN_MONT_CTX_new_for_modulus(const BIGNUM *mod,
+                                                        BN_CTX *ctx);
 
 // BN_MONT_CTX_free frees memory associated with |mont|.
 OPENSSL_EXPORT void BN_MONT_CTX_free(BN_MONT_CTX *mont);
@@ -804,11 +806,6 @@
 OPENSSL_EXPORT BN_MONT_CTX *BN_MONT_CTX_copy(BN_MONT_CTX *to,
                                              const BN_MONT_CTX *from);
 
-// BN_MONT_CTX_set sets up a Montgomery context given the modulus, |mod|. It
-// returns one on success and zero on error.
-OPENSSL_EXPORT int BN_MONT_CTX_set(BN_MONT_CTX *mont, const BIGNUM *mod,
-                                   BN_CTX *ctx);
-
 // BN_MONT_CTX_set_locked takes |lock| and checks whether |*pmont| is NULL. If
 // so, it creates a new |BN_MONT_CTX| and sets the modulus for it to |mod|. It
 // then stores it as |*pmont|. It returns one on success and zero on error.
@@ -896,6 +893,16 @@
                                     const BIGNUM *p2, const BIGNUM *m,
                                     BN_CTX *ctx, const BN_MONT_CTX *mont);
 
+// BN_MONT_CTX_new returns a fresh |BN_MONT_CTX| or NULL on allocation failure.
+// Use |BN_MONT_CTX_new_for_modulus| instead.
+OPENSSL_EXPORT BN_MONT_CTX *BN_MONT_CTX_new(void);
+
+// BN_MONT_CTX_set sets up a Montgomery context given the modulus, |mod|. It
+// returns one on success and zero on error. Use |BN_MONT_CTX_new_for_modulus|
+// instead.
+OPENSSL_EXPORT int BN_MONT_CTX_set(BN_MONT_CTX *mont, const BIGNUM *mod,
+                                   BN_CTX *ctx);
+
 
 // Private functions