Handle a modulus of -1 correctly.

Historically, OpenSSL's modular exponentiation functions tolerated negative
moduli by ignoring the sign bit. The special case for a modulus of 1 should do
the same. That said, this is ridiculous and the only reason I'm importing this
is BN_abs_is_word(1) is marginally more efficient than BN_is_one() and we
haven't gotten around to enforcing positive moduli yet.

Thanks to Guido Vranken and OSSFuzz for finding this issue and reporting to
OpenSSL.

(Imported from upstream's 235119f015e46a74040b78b10fd6e954f7f07774.)

Change-Id: I526889dfbe2356753aa1e6ecfd3aa3dc3a8cd2b8
Reviewed-on: https://boringssl-review.googlesource.com/31085
Commit-Queue: Adam Langley <agl@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc
index 27fd5c7..a932306 100644
--- a/crypto/fipsmodule/bn/bn_test.cc
+++ b/crypto/fipsmodule/bn/bn_test.cc
@@ -1557,25 +1557,56 @@
   ERR_clear_error();
 }
 
-// Test that 1**0 mod 1 == 0.
-TEST_F(BNTest, ExpModZero) {
-  bssl::UniquePtr<BIGNUM> zero(BN_new()), a(BN_new()), r(BN_new());
+// Test that a**0 mod 1 == 0.
+TEST_F(BNTest, ExpZeroModOne) {
+  bssl::UniquePtr<BIGNUM> zero(BN_new()), a(BN_new()), r(BN_new()),
+      minus_one(BN_new());
   ASSERT_TRUE(zero);
   ASSERT_TRUE(a);
   ASSERT_TRUE(r);
+  ASSERT_TRUE(minus_one);
+  ASSERT_TRUE(BN_set_word(minus_one.get(), 1));
+  BN_set_negative(minus_one.get(), 1);
   ASSERT_TRUE(BN_rand(a.get(), 1024, BN_RAND_TOP_ONE, BN_RAND_BOTTOM_ANY));
   BN_zero(zero.get());
 
   ASSERT_TRUE(BN_mod_exp(r.get(), a.get(), zero.get(), BN_value_one(), ctx()));
   EXPECT_TRUE(BN_is_zero(r.get()));
+  ASSERT_TRUE(
+      BN_mod_exp(r.get(), zero.get(), zero.get(), BN_value_one(), ctx()));
+  EXPECT_TRUE(BN_is_zero(r.get()));
 
   ASSERT_TRUE(BN_mod_exp_mont_word(r.get(), 42, zero.get(), BN_value_one(),
                                    ctx(), nullptr));
   EXPECT_TRUE(BN_is_zero(r.get()));
+  ASSERT_TRUE(BN_mod_exp_mont_word(r.get(), 0, zero.get(), BN_value_one(),
+                                   ctx(), nullptr));
+  EXPECT_TRUE(BN_is_zero(r.get()));
 
-  // The other modular exponentiation functions, |BN_mod_exp_mont| and
-  // |BN_mod_exp_mont_consttime|, require fully-reduced inputs, so 1**0 mod 1 is
-  // not a valid call.
+  // |BN_mod_exp_mont| and |BN_mod_exp_mont_consttime| require fully-reduced
+  // inputs, so a**0 mod 1 is not a valid call. 0**0 mod 1 is valid, however.
+  ASSERT_TRUE(BN_mod_exp_mont(r.get(), zero.get(), zero.get(), BN_value_one(),
+                              ctx(), nullptr));
+  EXPECT_TRUE(BN_is_zero(r.get()));
+
+  ASSERT_TRUE(BN_mod_exp_mont_consttime(r.get(), zero.get(), zero.get(),
+                                        BN_value_one(), ctx(), nullptr));
+  EXPECT_TRUE(BN_is_zero(r.get()));
+
+  // Historically, OpenSSL's modular exponentiation functions tolerated negative
+  // moduli by ignoring the sign bit. This logic should do the same.
+  ASSERT_TRUE(BN_mod_exp(r.get(), a.get(), zero.get(), minus_one.get(), ctx()));
+  EXPECT_TRUE(BN_is_zero(r.get()));
+  ASSERT_TRUE(BN_mod_exp_mont_word(r.get(), 0, zero.get(), minus_one.get(),
+                                   ctx(), nullptr));
+  EXPECT_TRUE(BN_is_zero(r.get()));
+  ASSERT_TRUE(BN_mod_exp_mont(r.get(), zero.get(), zero.get(), minus_one.get(),
+                              ctx(), nullptr));
+  EXPECT_TRUE(BN_is_zero(r.get()));
+
+  ASSERT_TRUE(BN_mod_exp_mont_consttime(r.get(), zero.get(), zero.get(),
+                                        minus_one.get(), ctx(), nullptr));
+  EXPECT_TRUE(BN_is_zero(r.get()));
 }
 
 TEST_F(BNTest, SmallPrime) {
diff --git a/crypto/fipsmodule/bn/exponentiation.c b/crypto/fipsmodule/bn/exponentiation.c
index 5187f4a..7035ea7 100644
--- a/crypto/fipsmodule/bn/exponentiation.c
+++ b/crypto/fipsmodule/bn/exponentiation.c
@@ -457,7 +457,7 @@
 
   if (bits == 0) {
     // x**0 mod 1 is still zero.
-    if (BN_is_one(m)) {
+    if (BN_abs_is_word(m, 1)) {
       BN_zero(r);
       return 1;
     }
@@ -614,7 +614,7 @@
   int bits = BN_num_bits(p);
   if (bits == 0) {
     // x**0 mod 1 is still zero.
-    if (BN_is_one(m)) {
+    if (BN_abs_is_word(m, 1)) {
       BN_zero(rr);
       return 1;
     }
@@ -981,7 +981,7 @@
   int bits = max_bits;
   if (bits == 0) {
     // x**0 mod 1 is still zero.
-    if (BN_is_one(m)) {
+    if (BN_abs_is_word(m, 1)) {
       BN_zero(rr);
       return 1;
     }