Tidy up BN_mod_exp_mont.

This was primarily for my own understanding, but this should hopefully
also be clearer and more amenable to using unsigned indices later.

Change-Id: I09cc3d55de0f7d9284d3b3168d8b0446274b2ab7
Reviewed-on: https://boringssl-review.googlesource.com/22889
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bn/exponentiation.c b/crypto/fipsmodule/bn/exponentiation.c
index e2dfb34..2d40e8f 100644
--- a/crypto/fipsmodule/bn/exponentiation.c
+++ b/crypto/fipsmodule/bn/exponentiation.c
@@ -188,9 +188,6 @@
   return ret;
 }
 
-// maximum precomputation table size for *variable* sliding windows
-#define TABLE_SIZE 32
-
 typedef struct bn_recp_ctx_st {
   BIGNUM N;   // the divisor
   BIGNUM Nr;  // the reciprocal
@@ -393,8 +390,8 @@
   return ret;
 }
 
-// BN_window_bits_for_exponent_size -- macro for sliding window mod_exp
-// functions
+// BN_window_bits_for_exponent_size returns sliding window size for mod_exp with
+// a |b| bit exponent.
 //
 // For window size 'w' (w >= 2) and a random 'b' bits exponent, the number of
 // multiplications is a constant plus on average
@@ -416,11 +413,26 @@
 //
 // (with draws in between).  Very small exponents are often selected
 // with low Hamming weight, so we use  w = 1  for b <= 23.
-#define BN_window_bits_for_exponent_size(b) \
-		((b) > 671 ? 6 : \
-		 (b) > 239 ? 5 : \
-		 (b) >  79 ? 4 : \
-		 (b) >  23 ? 3 : 1)
+static int BN_window_bits_for_exponent_size(int b) {
+  if (b > 671) {
+    return 6;
+  }
+  if (b > 239) {
+    return 5;
+  }
+  if (b > 79) {
+    return 4;
+  }
+  if (b > 23) {
+    return 3;
+  }
+  return 1;
+}
+
+// TABLE_SIZE is the maximum precomputation table size for *variable* sliding
+// windows. This must be 2^(max_window - 1), where max_window is the largest
+// value returned from |BN_window_bits_for_exponent_size|.
+#define TABLE_SIZE 32
 
 static int mod_exp_recp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
                         const BIGNUM *m, BN_CTX *ctx) {
@@ -501,7 +513,7 @@
     int wvalue;  // The 'value' of the window
     int wend;  // The bottom bit of the window
 
-    if (BN_is_bit_set(p, wstart) == 0) {
+    if (!BN_is_bit_set(p, wstart)) {
       if (!start) {
         if (!BN_mod_mul_reciprocal(r, r, r, &recp, ctx)) {
           goto err;
@@ -573,19 +585,11 @@
 
 int BN_mod_exp_mont(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
                     const BIGNUM *m, BN_CTX *ctx, const BN_MONT_CTX *mont) {
-  int i, j, bits, ret = 0, wstart, window;
-  int start = 1;
-  BIGNUM *d, *r;
-  const BIGNUM *aa;
-  // Table of variables obtained from 'ctx'
-  BIGNUM *val[TABLE_SIZE];
-  BN_MONT_CTX *new_mont = NULL;
-
   if (!BN_is_odd(m)) {
     OPENSSL_PUT_ERROR(BN, BN_R_CALLED_WITH_EVEN_MODULUS);
     return 0;
   }
-  bits = BN_num_bits(p);
+  int bits = BN_num_bits(p);
   if (bits == 0) {
     // x**0 mod 1 is still zero.
     if (BN_is_one(m)) {
@@ -595,9 +599,13 @@
     return BN_one(rr);
   }
 
+  int ret = 0;
+  BIGNUM *val[TABLE_SIZE];
+  BN_MONT_CTX *new_mont = NULL;
+
   BN_CTX_start(ctx);
-  d = BN_CTX_get(ctx);
-  r = BN_CTX_get(ctx);
+  BIGNUM *d = BN_CTX_get(ctx);
+  BIGNUM *r = BN_CTX_get(ctx);
   val[0] = BN_CTX_get(ctx);
   if (!d || !r || !val[0]) {
     goto err;
@@ -612,6 +620,7 @@
     mont = new_mont;
   }
 
+  const BIGNUM *aa;
   if (a->neg || BN_ucmp(a, m) >= 0) {
     if (!BN_nnmod(val[0], a, m, ctx)) {
       goto err;
@@ -626,53 +635,52 @@
     ret = 1;
     goto err;
   }
-  if (!BN_to_montgomery(val[0], aa, mont, ctx)) {
-    goto err;  // 1
-  }
 
-  window = BN_window_bits_for_exponent_size(bits);
+  // We exponentiate by looking at sliding windows of the exponent and
+  // precomputing powers of |aa|. Windows may be shifted so they always end on a
+  // set bit, so only precompute odd powers. We compute val[i] = aa^(2*i + 1)
+  // for i = 0 to 2^(window-1), all in Montgomery form.
+  int window = BN_window_bits_for_exponent_size(bits);
+  if (!BN_to_montgomery(val[0], aa, mont, ctx)) {
+    goto err;
+  }
   if (window > 1) {
     if (!BN_mod_mul_montgomery(d, val[0], val[0], mont, ctx)) {
-      goto err;  // 2
+      goto err;
     }
-    j = 1 << (window - 1);
-    for (i = 1; i < j; i++) {
-      if (((val[i] = BN_CTX_get(ctx)) == NULL) ||
+    for (int i = 1; i < 1 << (window - 1); i++) {
+      val[i] = BN_CTX_get(ctx);
+      if (val[i] == NULL ||
           !BN_mod_mul_montgomery(val[i], val[i - 1], d, mont, ctx)) {
         goto err;
       }
     }
   }
 
-  start = 1;  // This is used to avoid multiplication etc
-              // when there is only the value '1' in the
-              // buffer.
-  wstart = bits - 1;  // The top bit of the window
-
-  j = m->top;  // borrow j
-  if (m->d[j - 1] & (((BN_ULONG)1) << (BN_BITS2 - 1))) {
-    if (!bn_wexpand(r, j)) {
+  // Set |r| to one in Montgomery form. If the high bit of |m| is set, |m| is
+  // close to R and we subtract rather than perform Montgomery reduction.
+  if (m->d[m->top - 1] & (((BN_ULONG)1) << (BN_BITS2 - 1))) {
+    if (!bn_wexpand(r, m->top)) {
       goto err;
     }
-    // 2^(top*BN_BITS2) - m
+    // r = 2^(top*BN_BITS2) - m
     r->d[0] = 0 - m->d[0];
-    for (i = 1; i < j; i++) {
+    for (int i = 1; i < m->top; i++) {
       r->d[i] = ~m->d[i];
     }
-    r->top = j;
-    // Upper words will be zero if the corresponding words of 'm'
-    // were 0xfff[...], so decrement r->top accordingly.
+    r->top = m->top;
+    // The upper words will be zero if the corresponding words of |m| were
+    // 0xfff[...], so call |bn_correct_top|.
     bn_correct_top(r);
   } else if (!BN_to_montgomery(r, BN_value_one(), mont, ctx)) {
     goto err;
   }
 
+  int r_is_one = 1;
+  int wstart = bits - 1;  // The top bit of the window.
   for (;;) {
-    int wvalue;  // The 'value' of the window
-    int wend;  // The bottom bit of the window
-
-    if (BN_is_bit_set(p, wstart) == 0) {
-      if (!start && !BN_mod_mul_montgomery(r, r, r, mont, ctx)) {
+    if (!BN_is_bit_set(p, wstart)) {
+      if (!r_is_one && !BN_mod_mul_montgomery(r, r, r, mont, ctx)) {
         goto err;
       }
       if (wstart == 0) {
@@ -682,44 +690,37 @@
       continue;
     }
 
-    // We now have wstart on a 'set' bit, we now need to work out how bit a
-    // window to do.  To do this we need to scan forward until the last set bit
-    // before the end of the window
-    wvalue = 1;
-    wend = 0;
-    for (i = 1; i < window; i++) {
-      if (wstart - i < 0) {
-        break;
-      }
+    // We now have wstart on a set bit. Find the largest window we can use.
+    int wvalue = 1;
+    int wsize = 0;
+    for (int i = 1; i < window && i <= wstart; i++) {
       if (BN_is_bit_set(p, wstart - i)) {
-        wvalue <<= (i - wend);
+        wvalue <<= (i - wsize);
         wvalue |= 1;
-        wend = i;
+        wsize = i;
       }
     }
 
-    // wend is the size of the current window
-    j = wend + 1;
-    // add the 'bytes above'
-    if (!start) {
-      for (i = 0; i < j; i++) {
+    // Shift |r| to the end of the window.
+    if (!r_is_one) {
+      for (int i = 0; i < wsize + 1; i++) {
         if (!BN_mod_mul_montgomery(r, r, r, mont, ctx)) {
           goto err;
         }
       }
     }
 
-    // wvalue will be an odd number < 2^window
+    assert(wvalue & 1);
+    assert(wvalue < (1 << window));
     if (!BN_mod_mul_montgomery(r, r, val[wvalue >> 1], mont, ctx)) {
       goto err;
     }
 
-    // move the 'window' down further
-    wstart -= wend + 1;
-    start = 0;
-    if (wstart < 0) {
+    r_is_one = 0;
+    if (wstart == wsize) {
       break;
     }
+    wstart -= wsize + 1;
   }
 
   if (!BN_from_montgomery(rr, r, mont, ctx)) {