Add BN_is_pow2, BN_mod_pow2, and BN_nnmod_pow2.

These are meant to make Android libcore's usage of BIGNUMs for java
BigIntegers faster and nicer (specifically, so that it doesn't need
to malloc a bunch of temporary BIGNUMs).

BUG=97
Change-Id: I5f30e14c6d8c66a9848d4935ce27d030829f6923
Reviewed-on: https://boringssl-review.googlesource.com/13387
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/crypto/bn/bn_test.cc b/crypto/bn/bn_test.cc
index a152cdf..c0af58d 100644
--- a/crypto/bn/bn_test.cc
+++ b/crypto/bn/bn_test.cc
@@ -1668,6 +1668,93 @@
   return true;
 }
 
+static bool TestBNPow2(BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM>
+      power_of_two(BN_new()),
+      random(BN_new()),
+      expected(BN_new()),
+      actual(BN_new());
+
+  if (!power_of_two.get() ||
+      !random.get() ||
+      !expected.get() ||
+      !actual.get()) {
+    return false;
+  }
+
+  // Choose an exponent.
+  for (size_t e = 3; e < 512; e += 11) {
+    // Choose a bit length for our randoms.
+    for (int len = 3; len < 512; len += 23) {
+      // Set power_of_two = 2^e.
+      if (!BN_lshift(power_of_two.get(), BN_value_one(), (int) e)) {
+        fprintf(stderr, "Failed to shiftl.\n");
+        return false;
+      }
+
+      // Test BN_is_pow2 on power_of_two.
+      if (!BN_is_pow2(power_of_two.get())) {
+        fprintf(stderr, "BN_is_pow2 returned false for a power of two.\n");
+        hexdump(stderr, "Arg: ", power_of_two->d,
+                power_of_two->top * sizeof(BN_ULONG));
+        return false;
+      }
+
+      // Pick a large random value, ensuring it isn't a power of two.
+      if (!BN_rand(random.get(), len, BN_RAND_TOP_TWO, BN_RAND_BOTTOM_ANY)) {
+        fprintf(stderr, "Failed to generate random in TestBNPow2.\n");
+        return false;
+      }
+
+      // Test BN_is_pow2 on |r|.
+      if (BN_is_pow2(random.get())) {
+        fprintf(stderr, "BN_is_pow2 returned true for a non-power of two.\n");
+        hexdump(stderr, "Arg: ", random->d, random->top * sizeof(BN_ULONG));
+        return false;
+      }
+
+      // Test BN_mod_pow2 on |r|.
+      if (!BN_mod(expected.get(), random.get(), power_of_two.get(), ctx) ||
+          !BN_mod_pow2(actual.get(), random.get(), e) ||
+          BN_cmp(actual.get(), expected.get())) {
+        fprintf(stderr, "BN_mod_pow2 returned the wrong value:\n");
+        hexdump(stderr, "Expected: ", expected->d,
+                expected->top * sizeof(BN_ULONG));
+        hexdump(stderr, "Got:      ", actual->d,
+                actual->top * sizeof(BN_ULONG));
+        return false;
+      }
+
+      // Test BN_nnmod_pow2 on |r|.
+      if (!BN_nnmod(expected.get(), random.get(), power_of_two.get(), ctx) ||
+          !BN_nnmod_pow2(actual.get(), random.get(), e) ||
+          BN_cmp(actual.get(), expected.get())) {
+        fprintf(stderr, "BN_nnmod_pow2 failed on positive input:\n");
+        hexdump(stderr, "Expected: ", expected->d,
+                expected->top * sizeof(BN_ULONG));
+        hexdump(stderr, "Got:      ", actual->d,
+                actual->top * sizeof(BN_ULONG));
+        return false;
+      }
+
+      // Test BN_nnmod_pow2 on -|r|.
+      BN_set_negative(random.get(), 1);
+      if (!BN_nnmod(expected.get(), random.get(), power_of_two.get(), ctx) ||
+          !BN_nnmod_pow2(actual.get(), random.get(), e) ||
+          BN_cmp(actual.get(), expected.get())) {
+        fprintf(stderr, "BN_nnmod_pow2 failed on negative input:\n");
+        hexdump(stderr, "Expected: ", expected->d,
+                expected->top * sizeof(BN_ULONG));
+        hexdump(stderr, "Got:      ", actual->d,
+                actual->top * sizeof(BN_ULONG));
+        return false;
+      }
+    }
+  }
+
+  return true;
+}
+
 int main(int argc, char *argv[]) {
   CRYPTO_library_init();
 
@@ -1695,7 +1782,8 @@
       !TestSmallPrime(ctx.get()) ||
       !TestCmpWord() ||
       !TestBN2Dec() ||
-      !TestBNSetGetU64()) {
+      !TestBNSetGetU64() ||
+      !TestBNPow2(ctx.get())) {
     return 1;
   }
 
diff --git a/crypto/bn/cmp.c b/crypto/bn/cmp.c
index 9cf33b4..71c0465 100644
--- a/crypto/bn/cmp.c
+++ b/crypto/bn/cmp.c
@@ -212,6 +212,20 @@
   return bn->top > 0 && (bn->d[0] & 1) == 1;
 }
 
+int BN_is_pow2(const BIGNUM *bn) {
+  if (bn->top == 0 || bn->neg) {
+    return 0;
+  }
+
+  for (int i = 0; i < bn->top - 1; i++) {
+    if (bn->d[i] != 0) {
+      return 0;
+    }
+  }
+
+  return 0 == (bn->d[bn->top-1] & (bn->d[bn->top-1] - 1));
+}
+
 int BN_equal_consttime(const BIGNUM *a, const BIGNUM *b) {
   if (a->top != b->top) {
     return 0;
diff --git a/crypto/bn/div.c b/crypto/bn/div.c
index 6e3df7d..de3fa1f 100644
--- a/crypto/bn/div.c
+++ b/crypto/bn/div.c
@@ -58,6 +58,7 @@
 
 #include <assert.h>
 #include <limits.h>
+
 #include <openssl/err.h>
 
 #include "internal.h"
@@ -646,3 +647,82 @@
   }
   return (BN_ULONG)ret;
 }
+
+int BN_mod_pow2(BIGNUM *r, const BIGNUM *a, size_t e) {
+  if (e == 0 || a->top == 0) {
+    BN_zero(r);
+    return 1;
+  }
+
+  size_t num_words = 1 + ((e - 1) / BN_BITS2);
+
+  /* If |a| definitely has less than |e| bits, just BN_copy. */
+  if ((size_t) a->top < num_words) {
+    return BN_copy(r, a) != NULL;
+  }
+
+  /* Otherwise, first make sure we have enough space in |r|.
+   * Note that this will fail if num_words > INT_MAX. */
+  if (bn_wexpand(r, num_words) == NULL) {
+    return 0;
+  }
+
+  /* Copy the content of |a| into |r|. */
+  OPENSSL_memcpy(r->d, a->d, num_words * sizeof(BN_ULONG));
+
+  /* If |e| isn't word-aligned, we have to mask off some of our bits. */
+  size_t top_word_exponent = e % (sizeof(BN_ULONG) * 8);
+  if (top_word_exponent != 0) {
+    r->d[num_words - 1] &= (((BN_ULONG) 1) << top_word_exponent) - 1;
+  }
+
+  /* Fill in the remaining fields of |r|. */
+  r->neg = a->neg;
+  r->top = (int) num_words;
+  bn_correct_top(r);
+  return 1;
+}
+
+int BN_nnmod_pow2(BIGNUM *r, const BIGNUM *a, size_t e) {
+  if (!BN_mod_pow2(r, a, e)) {
+    return 0;
+  }
+
+  /* If the returned value was non-negative, we're done. */
+  if (BN_is_zero(r) || !r->neg) {
+    return 1;
+  }
+
+  size_t num_words = 1 + (e - 1) / BN_BITS2;
+
+  /* Expand |r| to the size of our modulus. */
+  if (bn_wexpand(r, num_words) == NULL) {
+    return 0;
+  }
+
+  /* Clear the upper words of |r|. */
+  OPENSSL_memset(&r->d[r->top], 0, (num_words - r->top) * BN_BYTES);
+
+  /* Set parameters of |r|. */
+  r->neg = 0;
+  r->top = (int) num_words;
+
+  /* Now, invert every word. The idea here is that we want to compute 2^e-|x|,
+   * which is actually equivalent to the twos-complement representation of |x|
+   * in |e| bits, which is -x = ~x + 1. */
+  for (int i = 0; i < r->top; i++) {
+    r->d[i] = ~r->d[i];
+  }
+
+  /* If our exponent doesn't span the top word, we have to mask the rest. */
+  size_t top_word_exponent = e % BN_BITS2;
+  if (top_word_exponent != 0) {
+    r->d[r->top - 1] &= (((BN_ULONG) 1) << top_word_exponent) - 1;
+  }
+
+  /* Keep the correct_top invariant for BN_add. */
+  bn_correct_top(r);
+
+  /* Finally, add one, for the reason described above. */
+  return BN_add(r, r, BN_value_one());
+}
diff --git a/include/openssl/bn.h b/include/openssl/bn.h
index 77f6196..a57c23a9 100644
--- a/include/openssl/bn.h
+++ b/include/openssl/bn.h
@@ -476,6 +476,8 @@
 /* BN_is_odd returns one if |bn| is odd and zero otherwise. */
 OPENSSL_EXPORT int BN_is_odd(const BIGNUM *bn);
 
+/* BN_is_pow2 returns 1 if |a| is a power of two, and 0 otherwise. */
+OPENSSL_EXPORT int BN_is_pow2(const BIGNUM *a);
 
 /* Bitwise operations. */
 
@@ -519,6 +521,14 @@
 /* BN_mod_word returns |a| mod |w| or (BN_ULONG)-1 on error. */
 OPENSSL_EXPORT BN_ULONG BN_mod_word(const BIGNUM *a, BN_ULONG w);
 
+/* BN_mod_pow2 sets |r| = |a| mod 2^|e|. It returns 1 on success and
+ * 0 on error. */
+OPENSSL_EXPORT int BN_mod_pow2(BIGNUM *r, const BIGNUM *a, size_t e);
+
+/* BN_nnmod_pow2 sets |r| = |a| mod 2^|e| where |r| is always positive.
+ * It returns 1 on success and 0 on error. */
+OPENSSL_EXPORT int BN_nnmod_pow2(BIGNUM *r, const BIGNUM *a, size_t e);
+
 /* BN_mod is a helper macro that calls |BN_div| and discards the quotient. */
 #define BN_mod(rem, numerator, divisor, ctx) \
   BN_div(NULL, (rem), (numerator), (divisor), (ctx))