Remove optimized even modulus mod-exp implementation

There's a whole lot of logic here for some kind of windowed reciprocal
exponentation. While probably interesting, it is not relevant for any
cryptographic use cases. Replace it with a naive square-and-multiple
algorithm atop BN_mod_mul and BN_mod_sqr.

Change-Id: Ic2290fa1eccccd3bb21732d5171830f65b71670d
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/77427
Reviewed-by: Bob Beck <bbe@google.com>
Auto-Submit: David Benjamin <davidben@google.com>
Commit-Queue: Bob Beck <bbe@google.com>
diff --git a/crypto/fipsmodule/bn/exponentiation.cc.inc b/crypto/fipsmodule/bn/exponentiation.cc.inc
index 6c3d448..53081c4 100644
--- a/crypto/fipsmodule/bn/exponentiation.cc.inc
+++ b/crypto/fipsmodule/bn/exponentiation.cc.inc
@@ -122,209 +122,6 @@
   return ret;
 }
 
-namespace {
-typedef struct bn_recp_ctx_st {
-  BIGNUM N;   // the divisor
-  BIGNUM Nr;  // the reciprocal
-  int num_bits;
-  int shift;
-  int flags;
-} BN_RECP_CTX;
-}  // namespace
-
-static void BN_RECP_CTX_init(BN_RECP_CTX *recp) {
-  BN_init(&recp->N);
-  BN_init(&recp->Nr);
-  recp->num_bits = 0;
-  recp->shift = 0;
-  recp->flags = 0;
-}
-
-static void BN_RECP_CTX_free(BN_RECP_CTX *recp) {
-  if (recp == nullptr) {
-    return;
-  }
-  BN_free(&recp->N);
-  BN_free(&recp->Nr);
-}
-
-static int BN_RECP_CTX_set(BN_RECP_CTX *recp, const BIGNUM *d, BN_CTX *ctx) {
-  if (!BN_copy(&(recp->N), d)) {
-    return 0;
-  }
-  BN_zero(&recp->Nr);
-  recp->num_bits = BN_num_bits(d);
-  recp->shift = 0;
-
-  return 1;
-}
-
-// len is the expected size of the result We actually calculate with an extra
-// word of precision, so we can do faster division if the remainder is not
-// required.
-// r := 2^len / m
-static int BN_reciprocal(BIGNUM *r, const BIGNUM *m, int len, BN_CTX *ctx) {
-  int ret = -1;
-  BIGNUM *t;
-
-  BN_CTX_start(ctx);
-  t = BN_CTX_get(ctx);
-  if (t == NULL) {
-    goto err;
-  }
-
-  if (!BN_set_bit(t, len)) {
-    goto err;
-  }
-
-  if (!BN_div(r, NULL, t, m, ctx)) {
-    goto err;
-  }
-
-  ret = len;
-
-err:
-  BN_CTX_end(ctx);
-  return ret;
-}
-
-static int BN_div_recp(BIGNUM *dv, BIGNUM *rem, const BIGNUM *m,
-                       BN_RECP_CTX *recp, BN_CTX *ctx) {
-  int i, j, ret = 0;
-  BIGNUM *a, *b, *d, *r;
-
-  BN_CTX_start(ctx);
-  a = BN_CTX_get(ctx);
-  b = BN_CTX_get(ctx);
-  if (dv != NULL) {
-    d = dv;
-  } else {
-    d = BN_CTX_get(ctx);
-  }
-
-  if (rem != NULL) {
-    r = rem;
-  } else {
-    r = BN_CTX_get(ctx);
-  }
-
-  if (a == NULL || b == NULL || d == NULL || r == NULL) {
-    goto err;
-  }
-
-  if (BN_ucmp(m, &recp->N) < 0) {
-    BN_zero(d);
-    if (!BN_copy(r, m)) {
-      goto err;
-    }
-    BN_CTX_end(ctx);
-    return 1;
-  }
-
-  // We want the remainder
-  // Given input of ABCDEF / ab
-  // we need multiply ABCDEF by 3 digests of the reciprocal of ab
-
-  // i := max(BN_num_bits(m), 2*BN_num_bits(N))
-  i = BN_num_bits(m);
-  j = recp->num_bits << 1;
-  if (j > i) {
-    i = j;
-  }
-
-  // Nr := round(2^i / N)
-  if (i != recp->shift) {
-    recp->shift =
-        BN_reciprocal(&(recp->Nr), &(recp->N), i,
-                      ctx);  // BN_reciprocal returns i, or -1 for an error
-  }
-
-  if (recp->shift == -1) {
-    goto err;
-  }
-
-  // d := |round(round(m / 2^BN_num_bits(N)) * recp->Nr / 2^(i -
-  // BN_num_bits(N)))|
-  //    = |round(round(m / 2^BN_num_bits(N)) * round(2^i / N) / 2^(i -
-  // BN_num_bits(N)))|
-  //   <= |(m / 2^BN_num_bits(N)) * (2^i / N) * (2^BN_num_bits(N) / 2^i)|
-  //    = |m/N|
-  if (!BN_rshift(a, m, recp->num_bits)) {
-    goto err;
-  }
-  if (!BN_mul(b, a, &(recp->Nr), ctx)) {
-    goto err;
-  }
-  if (!BN_rshift(d, b, i - recp->num_bits)) {
-    goto err;
-  }
-  d->neg = 0;
-
-  if (!BN_mul(b, &(recp->N), d, ctx)) {
-    goto err;
-  }
-  if (!BN_usub(r, m, b)) {
-    goto err;
-  }
-  r->neg = 0;
-
-  j = 0;
-  while (BN_ucmp(r, &(recp->N)) >= 0) {
-    if (j++ > 2) {
-      OPENSSL_PUT_ERROR(BN, BN_R_BAD_RECIPROCAL);
-      goto err;
-    }
-    if (!BN_usub(r, r, &(recp->N))) {
-      goto err;
-    }
-    if (!BN_add_word(d, 1)) {
-      goto err;
-    }
-  }
-
-  r->neg = BN_is_zero(r) ? 0 : m->neg;
-  d->neg = m->neg ^ recp->N.neg;
-  ret = 1;
-
-err:
-  BN_CTX_end(ctx);
-  return ret;
-}
-
-static int BN_mod_mul_reciprocal(BIGNUM *r, const BIGNUM *x, const BIGNUM *y,
-                                 BN_RECP_CTX *recp, BN_CTX *ctx) {
-  int ret = 0;
-  BIGNUM *a;
-  const BIGNUM *ca;
-
-  BN_CTX_start(ctx);
-  a = BN_CTX_get(ctx);
-  if (a == NULL) {
-    goto err;
-  }
-
-  if (y != NULL) {
-    if (x == y) {
-      if (!BN_sqr(a, x, ctx)) {
-        goto err;
-      }
-    } else {
-      if (!BN_mul(a, x, y, ctx)) {
-        goto err;
-      }
-    }
-    ca = a;
-  } else {
-    ca = x;  // Just do the mod
-  }
-
-  ret = BN_div_recp(NULL, r, ca, recp, ctx);
-
-err:
-  BN_CTX_end(ctx);
-  return ret;
-}
-
 // BN_window_bits_for_exponent_size returns sliding window size for mod_exp with
 // a |b| bit exponent.
 //
@@ -378,141 +175,37 @@
 // |BN_BITS2| * |BN_SMALL_MAX_WORDS|.
 #define TABLE_SIZE_SMALL (1 << (TABLE_BITS_SMALL - 1))
 
-static int mod_exp_recp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
+static int mod_exp_even(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
                         const BIGNUM *m, BN_CTX *ctx) {
-  int i, j, ret = 0, wstart, window;
-  int start = 1;
-  BIGNUM *aa;
-  // Table of variables obtained from 'ctx'
-  BIGNUM *val[TABLE_SIZE];
-  BN_RECP_CTX recp;
-
-  // This function is only called on even moduli.
-  assert(!BN_is_odd(m));
-
+  // No cryptographic operations require modular exponentiation with an even
+  // modulus. We support it for backwards compatibility with any applications
+  // that may have relied on the operation, but optimize for simplicity over
+  // performance with straightforward square-and-multiply routine.
   int bits = BN_num_bits(p);
   if (bits == 0) {
     return BN_one(r);
   }
 
-  BN_RECP_CTX_init(&recp);
-  BN_CTX_start(ctx);
-  aa = BN_CTX_get(ctx);
-  val[0] = BN_CTX_get(ctx);
-  if (!aa || !val[0]) {
-    goto err;
+  // Make a copy of |a|, in case it aliases |r|.
+  bssl::BN_CTXScope scope(ctx);
+  BIGNUM *tmp = BN_CTX_get(ctx);
+  if (tmp == nullptr || !BN_copy(tmp, a)) {
+    return 0;
   }
 
-  if (m->neg) {
-    // ignore sign of 'm'
-    if (!BN_copy(aa, m)) {
-      goto err;
-    }
-    aa->neg = 0;
-    if (BN_RECP_CTX_set(&recp, aa, ctx) <= 0) {
-      goto err;
-    }
-  } else {
-    if (BN_RECP_CTX_set(&recp, m, ctx) <= 0) {
-      goto err;
+  assert(BN_is_bit_set(p, bits - 1));
+  if (!BN_copy(r, tmp)) {
+    return 0;
+  }
+
+  for (int i = bits - 2; i >= 0; i--) {
+    if (!BN_mod_sqr(r, r, m, ctx) ||
+        (BN_is_bit_set(p, i) && !BN_mod_mul(r, r, tmp, m, ctx))) {
+      return 0;
     }
   }
 
-  if (!BN_nnmod(val[0], a, m, ctx)) {
-    goto err;  // 1
-  }
-  if (BN_is_zero(val[0])) {
-    BN_zero(r);
-    ret = 1;
-    goto err;
-  }
-
-  window = BN_window_bits_for_exponent_size(bits);
-  if (window > 1) {
-    if (!BN_mod_mul_reciprocal(aa, val[0], val[0], &recp, ctx)) {
-      goto err;  // 2
-    }
-    j = 1 << (window - 1);
-    for (i = 1; i < j; i++) {
-      if (((val[i] = BN_CTX_get(ctx)) == NULL) ||
-          !BN_mod_mul_reciprocal(val[i], val[i - 1], aa, &recp, ctx)) {
-        goto err;
-      }
-    }
-  }
-
-  start = 1;          // This is used to avoid multiplication etc
-                      // when there is only the value '1' in the
-                      // buffer.
-  wstart = bits - 1;  // The top bit of the window
-
-  if (!BN_one(r)) {
-    goto err;
-  }
-
-  for (;;) {
-    int wvalue;  // The 'value' of the window
-    int wend;    // The bottom bit of the window
-
-    if (!BN_is_bit_set(p, wstart)) {
-      if (!start) {
-        if (!BN_mod_mul_reciprocal(r, r, r, &recp, ctx)) {
-          goto err;
-        }
-      }
-      if (wstart == 0) {
-        break;
-      }
-      wstart--;
-      continue;
-    }
-
-    // We now have wstart on a 'set' bit, we now need to work out
-    // how bit a window to do.  To do this we need to scan
-    // forward until the last set bit before the end of the
-    // window
-    wvalue = 1;
-    wend = 0;
-    for (i = 1; i < window; i++) {
-      if (wstart - i < 0) {
-        break;
-      }
-      if (BN_is_bit_set(p, wstart - i)) {
-        wvalue <<= (i - wend);
-        wvalue |= 1;
-        wend = i;
-      }
-    }
-
-    // wend is the size of the current window
-    j = wend + 1;
-    // add the 'bytes above'
-    if (!start) {
-      for (i = 0; i < j; i++) {
-        if (!BN_mod_mul_reciprocal(r, r, r, &recp, ctx)) {
-          goto err;
-        }
-      }
-    }
-
-    // wvalue will be an odd number < 2^window
-    if (!BN_mod_mul_reciprocal(r, r, val[wvalue >> 1], &recp, ctx)) {
-      goto err;
-    }
-
-    // move the 'window' down further
-    wstart -= wend + 1;
-    start = 0;
-    if (wstart < 0) {
-      break;
-    }
-  }
-  ret = 1;
-
-err:
-  BN_CTX_end(ctx);
-  BN_RECP_CTX_free(&recp);
-  return ret;
+  return 1;
 }
 
 int BN_mod_exp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p, const BIGNUM *m,
@@ -532,7 +225,7 @@
     return BN_mod_exp_mont(r, a, p, m, ctx, NULL);
   }
 
-  return mod_exp_recp(r, a, p, m, ctx);
+  return mod_exp_even(r, a, p, m, ctx);
 }
 
 int BN_mod_exp_mont(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,