Require BN_mod_exp_mont* inputs be reduced.

If the caller asked for the base to be treated as secret, we should
provide that. Allowing unbounded inputs is not compatible with being
constant-time.

Additionally, this aligns with the guidance here:
https://github.com/HACS-workshop/spectre-mitigations/blob/master/crypto_guidelines.md#1-do-not-conditionally-choose-between-constant-and-non-constant-time

Update-Note: BN_mod_exp_mont_consttime and BN_mod_exp_mont now require
inputs be fully reduced. I believe current callers tolerate this.

Additionally, due to a quirk of how certain operations were ordered,
using (publicly) zero exponent tolerated a NULL BN_CTX while other
exponents required non-NULL BN_CTX. Non-NULL BN_CTX is now required
uniformly. This is unlikely to cause problems. Any call site where the
exponent is always zero should just be replaced with BN_value_one().

Change-Id: I7c941953ea05f36dc2754facb9f4cf83a6789c61
Reviewed-on: https://boringssl-review.googlesource.com/27665
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Steven Valdez <svaldez@google.com>
diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc
index 0c96f47..a25d487 100644
--- a/crypto/fipsmodule/bn/bn_test.cc
+++ b/crypto/fipsmodule/bn/bn_test.cc
@@ -763,6 +763,9 @@
   ASSERT_TRUE(BN_mod_exp(ret.get(), a.get(), e.get(), m.get(), ctx));
   EXPECT_BIGNUMS_EQUAL("A ^ E (mod M)", mod_exp.get(), ret.get());
 
+  // The other implementations require reduced inputs.
+  ASSERT_TRUE(BN_nnmod(a.get(), a.get(), m.get(), ctx));
+
   if (BN_is_odd(m.get())) {
     ASSERT_TRUE(
         BN_mod_exp_mont(ret.get(), a.get(), e.get(), m.get(), ctx, NULL));
@@ -780,7 +783,6 @@
       bssl::UniquePtr<BN_MONT_CTX> mont(
           BN_MONT_CTX_new_for_modulus(m.get(), ctx));
       ASSERT_TRUE(mont.get());
-      ASSERT_TRUE(BN_nnmod(a.get(), a.get(), m.get(), ctx));
       std::unique_ptr<BN_ULONG[]> r_words(new BN_ULONG[m_width]),
           a_words(new BN_ULONG[m_width]);
       ASSERT_TRUE(bn_copy_words(a_words.get(), m_width, a.get()));
@@ -1564,21 +1566,16 @@
   ASSERT_TRUE(BN_rand(a.get(), 1024, BN_RAND_TOP_ONE, BN_RAND_BOTTOM_ANY));
   BN_zero(zero.get());
 
-  ASSERT_TRUE(
-      BN_mod_exp(r.get(), a.get(), zero.get(), BN_value_one(), nullptr));
-  EXPECT_TRUE(BN_is_zero(r.get()));
-
-  ASSERT_TRUE(BN_mod_exp_mont(r.get(), a.get(), zero.get(), BN_value_one(),
-                              nullptr, nullptr));
-  EXPECT_TRUE(BN_is_zero(r.get()));
-
-  ASSERT_TRUE(BN_mod_exp_mont_consttime(r.get(), a.get(), zero.get(),
-                                        BN_value_one(), nullptr, nullptr));
+  ASSERT_TRUE(BN_mod_exp(r.get(), a.get(), zero.get(), BN_value_one(), ctx()));
   EXPECT_TRUE(BN_is_zero(r.get()));
 
   ASSERT_TRUE(BN_mod_exp_mont_word(r.get(), 42, zero.get(), BN_value_one(),
-                                   nullptr, nullptr));
+                                   ctx(), nullptr));
   EXPECT_TRUE(BN_is_zero(r.get()));
+
+  // The other modular exponentiation functions, |BN_mod_exp_mont| and
+  // |BN_mod_exp_mont_consttime|, require fully-reduced inputs, so 1**0 mod 1 is
+  // not a valid call.
 }
 
 TEST_F(BNTest, SmallPrime) {
diff --git a/crypto/fipsmodule/bn/exponentiation.c b/crypto/fipsmodule/bn/exponentiation.c
index 90cb8ce..b07111e 100644
--- a/crypto/fipsmodule/bn/exponentiation.c
+++ b/crypto/fipsmodule/bn/exponentiation.c
@@ -586,6 +586,13 @@
 
 int BN_mod_exp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p, const BIGNUM *m,
                BN_CTX *ctx) {
+  if (a->neg || BN_ucmp(a, m) >= 0) {
+    if (!BN_nnmod(r, a, m, ctx)) {
+      return 0;
+    }
+    a = r;
+  }
+
   if (BN_is_odd(m)) {
     return BN_mod_exp_mont(r, a, p, m, ctx, NULL);
   }
@@ -599,6 +606,11 @@
     OPENSSL_PUT_ERROR(BN, BN_R_CALLED_WITH_EVEN_MODULUS);
     return 0;
   }
+  if (a->neg || BN_ucmp(a, m) >= 0) {
+    OPENSSL_PUT_ERROR(BN, BN_R_INPUT_NOT_REDUCED);
+    return 0;
+  }
+
   int bits = BN_num_bits(p);
   if (bits == 0) {
     // x**0 mod 1 is still zero.
@@ -630,22 +642,12 @@
     mont = new_mont;
   }
 
-  const BIGNUM *aa;
-  if (a->neg || BN_ucmp(a, m) >= 0) {
-    if (!BN_nnmod(val[0], a, m, ctx)) {
-      goto err;
-    }
-    aa = val[0];
-  } else {
-    aa = a;
-  }
-
   // We exponentiate by looking at sliding windows of the exponent and
-  // precomputing powers of |aa|. Windows may be shifted so they always end on a
-  // set bit, so only precompute odd powers. We compute val[i] = aa^(2*i + 1)
+  // precomputing powers of |a|. Windows may be shifted so they always end on a
+  // set bit, so only precompute odd powers. We compute val[i] = a^(2*i + 1)
   // for i = 0 to 2^(window-1), all in Montgomery form.
   int window = BN_window_bits_for_exponent_size(bits);
-  if (!BN_to_montgomery(val[0], aa, mont, ctx)) {
+  if (!BN_to_montgomery(val[0], a, mont, ctx)) {
     goto err;
   }
   if (window > 1) {
@@ -966,12 +968,15 @@
   int powerbufLen = 0;
   unsigned char *powerbuf = NULL;
   BIGNUM tmp, am;
-  BIGNUM *new_a = NULL;
 
   if (!BN_is_odd(m)) {
     OPENSSL_PUT_ERROR(BN, BN_R_CALLED_WITH_EVEN_MODULUS);
     return 0;
   }
+  if (a->neg || BN_ucmp(a, m) >= 0) {
+    OPENSSL_PUT_ERROR(BN, BN_R_INPUT_NOT_REDUCED);
+    return 0;
+  }
 
   // Use all bits stored in |p|, rather than |BN_num_bits|, so we do not leak
   // whether the top bits are zero.
@@ -999,15 +1004,6 @@
   // implementation assumes it can use |top| to size R.
   int top = mont->N.width;
 
-  if (a->neg || BN_ucmp(a, m) >= 0) {
-    new_a = BN_new();
-    if (new_a == NULL ||
-        !BN_nnmod(new_a, a, m, ctx)) {
-      goto err;
-    }
-    a = new_a;
-  }
-
 #ifdef RSAZ_ENABLED
   // If the size of the operands allow it, perform the optimized
   // RSAZ exponentiation. For further information see
@@ -1268,7 +1264,6 @@
 
 err:
   BN_MONT_CTX_free(new_mont);
-  BN_clear_free(new_a);
   OPENSSL_free(powerbufFree);
   return (ret);
 }
@@ -1281,6 +1276,11 @@
 
   int ret = 0;
 
+  // BN_mod_exp_mont requires reduced inputs.
+  if (bn_minimal_width(m) == 1) {
+    a %= m->d[0];
+  }
+
   if (!BN_set_word(&a_bignum, a)) {
     OPENSSL_PUT_ERROR(BN, ERR_R_INTERNAL_ERROR);
     goto err;
diff --git a/fuzz/bn_mod_exp.cc b/fuzz/bn_mod_exp.cc
index 997c3a6..0bfa5a8 100644
--- a/fuzz/bn_mod_exp.cc
+++ b/fuzz/bn_mod_exp.cc
@@ -109,6 +109,8 @@
     bssl::UniquePtr<BN_MONT_CTX> mont(
         BN_MONT_CTX_new_for_modulus(modulus.get(), ctx.get()));
     CHECK(mont);
+    // |BN_mod_exp_mont| and |BN_mod_exp_mont_consttime| require reduced inputs.
+    CHECK(BN_nnmod(base.get(), base.get(), modulus.get(), ctx.get()));
     CHECK(BN_mod_exp_mont(result.get(), base.get(), power.get(), modulus.get(),
                           ctx.get(), mont.get()));
     CHECK(BN_cmp(result.get(), expected.get()) == 0);
diff --git a/include/openssl/bn.h b/include/openssl/bn.h
index 90b4b36..e8cc70a 100644
--- a/include/openssl/bn.h
+++ b/include/openssl/bn.h
@@ -874,10 +874,14 @@
 OPENSSL_EXPORT int BN_mod_exp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
                               const BIGNUM *m, BN_CTX *ctx);
 
+// BN_mod_exp_mont behaves like |BN_mod_exp| but treats |a| as secret and
+// requires 0 <= |a| < |m|.
 OPENSSL_EXPORT int BN_mod_exp_mont(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
                                    const BIGNUM *m, BN_CTX *ctx,
                                    const BN_MONT_CTX *mont);
 
+// BN_mod_exp_mont_consttime behaves like |BN_mod_exp| but treats |a|, |p|, and
+// |m| as secret and requires 0 <= |a| < |m|.
 OPENSSL_EXPORT int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a,
                                              const BIGNUM *p, const BIGNUM *m,
                                              BN_CTX *ctx,