Fix a ** 0 mod 1 = 0 for real this time.

Commit 2b0180c37fa6ffc48ee40caa831ca398b828e680 attempted to do this but
only hit one of many BN_mod_exp codepaths. Fix remaining variants and
add a test for each method.

Thanks to Hanno Boeck for reporting this issue.

(Imported from upstream's 44e4f5b04b43054571e278381662cebd3f3555e6.)

Change-Id: Ic691b354101c3e9c3565300836fb6d55c6f253ba
Reviewed-on: https://boringssl-review.googlesource.com/6820
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/bn/bn_test.cc b/crypto/bn/bn_test.cc
index 7636f30..e7e04f1 100644
--- a/crypto/bn/bn_test.cc
+++ b/crypto/bn/bn_test.cc
@@ -1316,23 +1316,23 @@
 
 // test_exp_mod_zero tests that 1**0 mod 1 == 0.
 static bool test_exp_mod_zero(void) {
-  ScopedBIGNUM zero(BN_new());
-  if (!zero) {
+  ScopedBIGNUM zero(BN_new()), a(BN_new()), r(BN_new());
+  if (!zero || !a || !r || !BN_rand(a.get(), 1024, 0, 0)) {
     return false;
   }
   BN_zero(zero.get());
 
-  ScopedBN_CTX ctx(BN_CTX_new());
-  ScopedBIGNUM r(BN_new());
-  if (!ctx || !r ||
-      !BN_mod_exp(r.get(), BN_value_one(), zero.get(), BN_value_one(), ctx.get())) {
-    return false;
-  }
-
-  if (!BN_is_zero(r.get())) {
-    fprintf(stderr, "1**0 mod 1 = ");
-    BN_print_fp(stderr, r.get());
-    fprintf(stderr, ", should be 0\n");
+  if (!BN_mod_exp(r.get(), a.get(), zero.get(), BN_value_one(), nullptr) ||
+      !BN_is_zero(r.get()) ||
+      !BN_mod_exp_mont(r.get(), a.get(), zero.get(), BN_value_one(), nullptr,
+                       nullptr) ||
+      !BN_is_zero(r.get()) ||
+      !BN_mod_exp_mont_consttime(r.get(), a.get(), zero.get(), BN_value_one(),
+                                 nullptr, nullptr) ||
+      !BN_is_zero(r.get()) ||
+      !BN_mod_exp_mont_word(r.get(), 42, zero.get(), BN_value_one(), nullptr,
+                            nullptr) ||
+      !BN_is_zero(r.get())) {
     return false;
   }
 
diff --git a/crypto/bn/exponentiation.c b/crypto/bn/exponentiation.c
index c580248..72a8db4 100644
--- a/crypto/bn/exponentiation.c
+++ b/crypto/bn/exponentiation.c
@@ -445,8 +445,12 @@
   bits = BN_num_bits(p);
 
   if (bits == 0) {
-    ret = BN_one(r);
-    return ret;
+    /* x**0 mod 1 is still zero. */
+    if (BN_is_one(m)) {
+      BN_zero(r);
+      return 1;
+    }
+    return BN_one(r);
   }
 
   BN_CTX_start(ctx);
@@ -632,8 +636,12 @@
   }
   bits = BN_num_bits(p);
   if (bits == 0) {
-    ret = BN_one(rr);
-    return ret;
+    /* x**0 mod 1 is still zero. */
+    if (BN_is_one(m)) {
+      BN_zero(rr);
+      return 1;
+    }
+    return BN_one(rr);
   }
 
   BN_CTX_start(ctx);
@@ -875,8 +883,12 @@
 
   bits = BN_num_bits(p);
   if (bits == 0) {
-    ret = BN_one(rr);
-    return ret;
+    /* x**0 mod 1 is still zero. */
+    if (BN_is_one(m)) {
+      BN_zero(rr);
+      return 1;
+    }
+    return BN_one(rr);
   }
 
   BN_CTX_start(ctx);
@@ -1230,17 +1242,14 @@
   if (bits == 0) {
     /* x**0 mod 1 is still zero. */
     if (BN_is_one(m)) {
-      ret = 1;
       BN_zero(rr);
-    } else {
-      ret = BN_one(rr);
+      return 1;
     }
-    return ret;
+    return BN_one(rr);
   }
   if (a == 0) {
     BN_zero(rr);
-    ret = 1;
-    return ret;
+    return 1;
   }
 
   BN_CTX_start(ctx);