Simplify Montgomery RR precomputation.

Use equivalent but simpler math, and explain the simpler math. Move the
discussion of multipying-vs-doubling to be after the discussion of
squaring-vs-doubling so that the discussion order follows the code
order, and so that we can combine the multipying-vs-doubling discussion
with the explanation of why no multiply/doubling is needed at all.
Expand the existing discussion to be a little more explicit.

Retain |threshold|, but change the type of |threshold| was changed to
|int| to avoid a signed/unsigned comparison in the added assertion
(|bn_mod_lshift_consttime| takes the shift count as an |int| anyway).

Change-Id: I24e4687e76944a34a8621b5f2fdee15a5201ac88
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/63906
Reviewed-by: Bob Beck <bbe@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: Bob Beck <bbe@google.com>
diff --git a/crypto/fipsmodule/bn/internal.h b/crypto/fipsmodule/bn/internal.h
index a2292b4..d556488 100644
--- a/crypto/fipsmodule/bn/internal.h
+++ b/crypto/fipsmodule/bn/internal.h
@@ -149,6 +149,7 @@
 #endif
 
 #define BN_BITS2 64
+#define BN_BITS2_LG 6
 #define BN_BYTES 8
 #define BN_BITS4 32
 #define BN_MASK2 (0xffffffffffffffffUL)
@@ -165,6 +166,7 @@
 #define BN_ULLONG uint64_t
 #define BN_CAN_DIVIDE_ULLONG
 #define BN_BITS2 32
+#define BN_BITS2_LG 5
 #define BN_BYTES 4
 #define BN_BITS4 16
 #define BN_MASK2 (0xffffffffUL)
diff --git a/crypto/fipsmodule/bn/montgomery_inv.c b/crypto/fipsmodule/bn/montgomery_inv.c
index f496fe9..068d00b 100644
--- a/crypto/fipsmodule/bn/montgomery_inv.c
+++ b/crypto/fipsmodule/bn/montgomery_inv.c
@@ -179,42 +179,43 @@
   // Montgomery domain, 2R or 2^(lgBigR+1), and then use Montgomery
   // square-and-multiply to exponentiate.
   //
-  // The multiply steps take 2^n R to 2^(n+1) R. It is faster to double
-  // the value instead. The square steps take 2^n R to 2^(2n) R. This is
-  // equivalent to doubling n times. When n is below some threshold, doubling is
-  // faster. When above, squaring is faster.
+  // The square steps take 2^n R to (2^n)*(2^n) R = 2^2n R. This is the same as
+  // doubling 2^n R, n times (doubling any x, n times, computes 2^n * x). When n
+  // is below some threshold, doubling is faster; when above, squaring is
+  // faster. From benchmarking various 32-bit and 64-bit architectures, the word
+  // count seems to work well as a threshold. (Doubling scales linearly and
+  // Montgomery reduction scales quadratically, so the threshold should scale
+  // roughly linearly.)
   //
-  // We double to this threshold, then switch to Montgomery squaring. From
-  // benchmarking various 32-bit and 64-bit architectures, the word count seems
-  // to work well as a threshold. (Doubling scales linearly and Montgomery
-  // reduction scales quadratically, so the threshold should scale roughly
-  // linearly.)
-  unsigned threshold = mont->N.width;
-  unsigned iters;
-  for (iters = 0; iters < sizeof(lgBigR) * 8; iters++) {
-    if ((lgBigR >> iters) <= threshold) {
-      break;
-    }
-  }
+  // The multiply steps take 2^n R to 2*2^n R = 2^(n+1) R. It is faster to
+  // double the value instead, so the square-and-multiply exponentiation would
+  // become square-and-double. However, when using the word count as the
+  // threshold, it turns out that no multiply/double steps will be needed at
+  // all, because squaring any x, i times, computes x^(2^i):
+  //
+  //   (2^threshold)^(2^BN_BITS2_LG) R
+  //   (2^mont->N.width)^BN_BITS2 R
+  // = 2^(mont->N.width*BN_BITS2) R
+  // = 2^lgBigR R
+  // = RR
+  int threshold = mont->N.width;
 
-  // Compute 2^(lgBigR >> iters) R, or 2^((lgBigR >> iters) + lgBigR), by
-  // doubling. The first n_bits - 1 doubles can be skipped because we don't need
-  // to reduce.
+  // Calculate 2^threshold R = 2^(threshold + lgBigR) by doubling. The
+  // first n_bits - 1 doubles can be skipped because we don't need to reduce.
   if (!BN_set_bit(&mont->RR, n_bits - 1) ||
       !bn_mod_lshift_consttime(&mont->RR, &mont->RR,
-                               (lgBigR >> iters) + lgBigR - (n_bits - 1),
+                               threshold + (lgBigR - (n_bits - 1)),
                                &mont->N, ctx)) {
     return 0;
   }
 
-  for (unsigned i = iters - 1; i < iters; i--) {
+  // The above steps are the same regardless of the threshold. The steps below
+  // need to be modified if the threshold changes.
+  assert(threshold == mont->N.width);
+  for (unsigned i = 0; i < BN_BITS2_LG; i++) {
     if (!BN_mod_mul_montgomery(&mont->RR, &mont->RR, &mont->RR, mont, ctx)) {
       return 0;
     }
-    if ((lgBigR & (1u << i)) != 0 &&
-        !bn_mod_lshift1_consttime(&mont->RR, &mont->RR, &mont->N, ctx)) {
-      return 0;
-    }
   }
 
   return bn_resize_words(&mont->RR, mont->N.width);