Make bn_mul_part_recursive constant-time.

This follows similar lines as the previous cleanups and fixes the
documentation of the preconditions.

And with that, RSA private key operations, provided p and q have the
same bit length, should be constant time, as far as I know. (Though I'm
sure I've missed something.)

bn_cmp_part_words and bn_cmp_words are no longer used and deleted.

Bug: 234
Change-Id: Iceefa39f57e466c214794c69b335c4d2c81f5577
Reviewed-on: https://boringssl-review.googlesource.com/25404
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bn/cmp.c b/crypto/fipsmodule/bn/cmp.c
index f5b4fdc..7790a8d 100644
--- a/crypto/fipsmodule/bn/cmp.c
+++ b/crypto/fipsmodule/bn/cmp.c
@@ -135,48 +135,6 @@
   return 0;
 }
 
-int bn_cmp_words(const BN_ULONG *a, const BN_ULONG *b, int n) {
-  int i;
-  BN_ULONG aa, bb;
-
-  aa = a[n - 1];
-  bb = b[n - 1];
-  if (aa != bb) {
-    return (aa > bb) ? 1 : -1;
-  }
-
-  for (i = n - 2; i >= 0; i--) {
-    aa = a[i];
-    bb = b[i];
-    if (aa != bb) {
-      return (aa > bb) ? 1 : -1;
-    }
-  }
-  return 0;
-}
-
-int bn_cmp_part_words(const BN_ULONG *a, const BN_ULONG *b, int cl, int dl) {
-  int n, i;
-  n = cl - 1;
-
-  if (dl < 0) {
-    for (i = dl; i < 0; i++) {
-      if (b[n - i] != 0) {
-        return -1;  // a < b
-      }
-    }
-  }
-  if (dl > 0) {
-    for (i = dl; i > 0; i--) {
-      if (a[n + i] != 0) {
-        return 1;  // a > b
-      }
-    }
-  }
-
-  return bn_cmp_words(a, b, cl);
-}
-
 static int bn_less_than_words_impl(const BN_ULONG *a, size_t a_len,
                                    const BN_ULONG *b, size_t b_len) {
   OPENSSL_COMPILE_ASSERT(sizeof(BN_ULONG) <= sizeof(crypto_word_t),
diff --git a/crypto/fipsmodule/bn/internal.h b/crypto/fipsmodule/bn/internal.h
index 56a0703..20945a9 100644
--- a/crypto/fipsmodule/bn/internal.h
+++ b/crypto/fipsmodule/bn/internal.h
@@ -280,16 +280,6 @@
 // bn_sqr_comba4 sets |r| to |a|^2.
 void bn_sqr_comba4(BN_ULONG r[8], const BN_ULONG a[4]);
 
-// bn_cmp_words returns a value less than, equal to or greater than zero if
-// the, length |n|, array |a| is less than, equal to or greater than |b|.
-int bn_cmp_words(const BN_ULONG *a, const BN_ULONG *b, int n);
-
-// bn_cmp_words returns a value less than, equal to or greater than zero if the
-// array |a| is less than, equal to or greater than |b|. The arrays can be of
-// different lengths: |cl| gives the minimum of the two lengths and |dl| gives
-// the length of |a| minus the length of |b|.
-int bn_cmp_part_words(const BN_ULONG *a, const BN_ULONG *b, int cl, int dl);
-
 // bn_less_than_words returns one if |a| < |b| and zero otherwise, where |a|
 // and |b| both are |len| words long. It runs in constant time.
 int bn_less_than_words(const BN_ULONG *a, const BN_ULONG *b, size_t len);
diff --git a/crypto/fipsmodule/bn/mul.c b/crypto/fipsmodule/bn/mul.c
index 38c70ca..92389fe 100644
--- a/crypto/fipsmodule/bn/mul.c
+++ b/crypto/fipsmodule/bn/mul.c
@@ -409,69 +409,63 @@
   assert(c == 0);
 }
 
-// n+tn is the word length
-// t needs to be n*4 is size, as does r
-// tnX may not be negative but less than n
+// bn_mul_part_recursive sets |r| to |a| * |b|, using |t| as scratch space. |r|
+// has length 4*|n|, |a| has length |n| + |tna|, |b| has length |n| + |tnb|, and
+// |t| has length 8*|n|. |n| must be a power of two. Additionally, we must have
+// 0 <= tna < n and 0 <= tnb < n, and |tna| and |tnb| must differ by at most
+// one.
+//
+// TODO(davidben): Make this take |size_t| and perhaps the actual lengths of |a|
+// and |b|.
 static void bn_mul_part_recursive(BN_ULONG *r, const BN_ULONG *a,
                                   const BN_ULONG *b, int n, int tna, int tnb,
                                   BN_ULONG *t) {
-  int i, j, n2 = n * 2;
-  int c1, c2, neg;
-  BN_ULONG ln, lo, *p;
+  // |n| is a power of two.
+  assert(n != 0 && (n & (n - 1)) == 0);
+  // Check |tna| and |tnb| are in range.
+  assert(0 <= tna && tna < n);
+  assert(0 <= tnb && tnb < n);
+  assert(-1 <= tna - tnb && tna - tnb <= 1);
 
+  int n2 = n * 2;
   if (n < 8) {
     bn_mul_normal(r, a, n + tna, b, n + tnb);
+    OPENSSL_memset(r + n2 + tna + tnb, 0, n2 - tna - tnb);
     return;
   }
 
-  // TODO(davidben): This function is not constant-time, but should be. See
-  // https://crbug.com/boringssl/234.
+  // Split |a| and |b| into a0,a1 and b0,b1, where a0 and b0 have size |n|. |a1|
+  // and |b1| have size |tna| and |tnb|, respectively.
+  // Split |t| into t0,t1,t2,t3, each of size |n|, with the remaining 4*|n| used
+  // for recursive calls.
+  // Split |r| into r0,r1,r2,r3. We must contribute a0*b0 to r0,r1, a0*a1+b0*b1
+  // to r1,r2, and a1*b1 to r2,r3. The middle term we will compute as:
+  //
+  //   a0*a1 + b0*b1 = (a0 - a1)*(b1 - b0) + a1*b1 + a0*b0
 
-  // r=(a[0]-a[1])*(b[1]-b[0])
-  c1 = bn_cmp_part_words(a, &(a[n]), tna, n - tna);
-  c2 = bn_cmp_part_words(&(b[n]), b, tnb, tnb - n);
-  neg = 0;
-  switch (c1 * 3 + c2) {
-    case -4:
-      bn_sub_part_words(t, &(a[n]), a, tna, tna - n);        // -
-      bn_sub_part_words(&(t[n]), b, &(b[n]), tnb, n - tnb);  // -
-      break;
-    case -3:
-      // break;
-    case -2:
-      bn_sub_part_words(t, &(a[n]), a, tna, tna - n);        // -
-      bn_sub_part_words(&(t[n]), &(b[n]), b, tnb, tnb - n);  // +
-      neg = 1;
-      break;
-    case -1:
-    case 0:
-    case 1:
-      // break;
-    case 2:
-      bn_sub_part_words(t, a, &(a[n]), tna, n - tna);        // +
-      bn_sub_part_words(&(t[n]), b, &(b[n]), tnb, n - tnb);  // -
-      neg = 1;
-      break;
-    case 3:
-      // break;
-    case 4:
-      bn_sub_part_words(t, a, &(a[n]), tna, n - tna);
-      bn_sub_part_words(&(t[n]), &(b[n]), b, tnb, tnb - n);
-      break;
-  }
+  // t0 = a0 - a1 and t1 = b1 - b0. The result will be multiplied, so we XOR
+  // their sign masks, giving the sign of (a0 - a1)*(b1 - b0). t0 and t1
+  // themselves store the absolute value.
+  BN_ULONG neg = bn_abs_sub_part_words(t, a, &a[n], tna, n - tna, &t[n2]);
+  neg ^= bn_abs_sub_part_words(&t[n], &b[n], b, tnb, tnb - n, &t[n2]);
 
+  // Compute:
+  // t2,t3 = t0 * t1 = |(a0 - a1)*(b1 - b0)|
+  // r0,r1 = a0 * b0
+  // r2,r3 = a1 * b1
   if (n == 8) {
-    bn_mul_comba8(&(t[n2]), t, &(t[n]));
+    bn_mul_comba8(&t[n2], t, &t[n]);
     bn_mul_comba8(r, a, b);
-    bn_mul_normal(&(r[n2]), &(a[n]), tna, &(b[n]), tnb);
-    OPENSSL_memset(&(r[n2 + tna + tnb]), 0, sizeof(BN_ULONG) * (n2 - tna - tnb));
+
+    bn_mul_normal(&r[n2], &a[n], tna, &b[n], tnb);
+    // |bn_mul_normal| only writes |tna| + |tna| words. Zero the rest.
+    OPENSSL_memset(&r[n2 + tna + tnb], 0, sizeof(BN_ULONG) * (n2 - tna - tnb));
   } else {
-    p = &(t[n2 * 2]);
-    bn_mul_recursive(&(t[n2]), t, &(t[n]), n, 0, 0, p);
+    BN_ULONG *p = &t[n2 * 2];
+    bn_mul_recursive(&t[n2], t, &t[n], n, 0, 0, p);
     bn_mul_recursive(r, a, b, n, 0, 0, p);
-    i = n / 2;
-    // If there is only a bottom half to the number,
-    // just do it
+
+    int i = n / 2, j;
     if (tna > tnb) {
       j = tna - i;
     } else {
@@ -479,32 +473,37 @@
     }
 
     if (j == 0) {
-      bn_mul_recursive(&(r[n2]), &(a[n]), &(b[n]), i, tna - i, tnb - i, p);
-      OPENSSL_memset(&(r[n2 + i * 2]), 0, sizeof(BN_ULONG) * (n2 - i * 2));
+      // If there is only a bottom half to the number, just do it. We know the
+      // larger of |tna - i| and |tnb - i| is zero. The other is zero or -1
+      // because |tna| and |tnb| differ by at most one.
+      bn_mul_recursive(&r[n2], &a[n], &b[n], i, tna - i, tnb - i, p);
+      // |bn_mul_recursive| only writes the bottom |i|*2 words.
+      OPENSSL_memset(&r[n2 + i * 2], 0, sizeof(BN_ULONG) * (n2 - i * 2));
     } else if (j > 0) {
-      // eg, n == 16, i == 8 and tn == 11
-      bn_mul_part_recursive(&(r[n2]), &(a[n]), &(b[n]), i, tna - i, tnb - i, p);
-      OPENSSL_memset(&(r[n2 + tna + tnb]), 0,
-                     sizeof(BN_ULONG) * (n2 - tna - tnb));
+      // E.g,, n == 16, i == 8 and tna == 11.
+      // |tna| and |tnb| are within one of each other, so if |tna| is larger and
+      // tna > i, then we know tnb >= i, and this call is valid.
+      bn_mul_part_recursive(&r[n2], &a[n], &b[n], i, tna - i, tnb - i, p);
     } else {
-      // (j < 0) eg, n == 16, i == 8 and tn == 5
-      OPENSSL_memset(&(r[n2]), 0, sizeof(BN_ULONG) * n2);
+      // (j < 0) E.g., n == 16, i == 8 and tn == 5
+      OPENSSL_memset(&r[n2], 0, sizeof(BN_ULONG) * n2);
       if (tna < BN_MUL_RECURSIVE_SIZE_NORMAL &&
           tnb < BN_MUL_RECURSIVE_SIZE_NORMAL) {
-        bn_mul_normal(&(r[n2]), &(a[n]), tna, &(b[n]), tnb);
+        bn_mul_normal(&r[n2], &a[n], tna, &b[n], tnb);
       } else {
         for (;;) {
           i /= 2;
-          // these simplified conditions work
-          // exclusively because difference
-          // between tna and tnb is 1 or 0
+          // These simplified conditions work exclusively because difference
+          // between |tna| and |tnb| is 1 or 0.
+          //
+          // TODO(davidben): This loop condition is exactly the same as the
+          // |j > 0| one but more complicated. Merge them.
           if (i < tna || i < tnb) {
-            bn_mul_part_recursive(&(r[n2]), &(a[n]), &(b[n]), i, tna - i,
-                                  tnb - i, p);
+            bn_mul_part_recursive(&r[n2], &a[n], &b[n], i, tna - i, tnb - i, p);
             break;
-          } else if (i == tna || i == tnb) {
-            bn_mul_recursive(&(r[n2]), &(a[n]), &(b[n]), i, tna - i, tnb - i,
-                             p);
+          }
+          if (i == tna || i == tnb) {
+            bn_mul_recursive(&r[n2], &a[n], &b[n], i, tna - i, tnb - i, p);
             break;
           }
         }
@@ -512,42 +511,32 @@
     }
   }
 
-  // t[32] holds (a[0]-a[1])*(b[1]-b[0]), c1 is the sign
-  // r[10] holds (a[0]*b[0])
-  // r[32] holds (b[1]*b[1])
+  // t0,t1,c = r0,r1 + r2,r3 = a0*b0 + a1*b1
+  BN_ULONG c = bn_add_words(t, r, &r[n2], n2);
 
-  c1 = (int)(bn_add_words(t, r, &(r[n2]), n2));
+  // t2,t3,c = t0,t1,c + neg*t2,t3 = (a0 - a1)*(b1 - b0) + a1*b1 + a0*b0.
+  // The second term is stored as the absolute value, so we do this with a
+  // constant-time select.
+  BN_ULONG c_neg = c - bn_sub_words(&t[n2 * 2], t, &t[n2], n2);
+  BN_ULONG c_pos = c + bn_add_words(&t[n2], t, &t[n2], n2);
+  bn_select_words(&t[n2], neg, &t[n2 * 2], &t[n2], n2);
+  OPENSSL_COMPILE_ASSERT(sizeof(BN_ULONG) <= sizeof(crypto_word_t),
+                         crypto_word_t_too_small);
+  c = constant_time_select_w(neg, c_neg, c_pos);
 
-  if (neg) {
-    // if t[32] is negative
-    c1 -= (int)(bn_sub_words(&(t[n2]), t, &(t[n2]), n2));
-  } else {
-    // Might have a carry
-    c1 += (int)(bn_add_words(&(t[n2]), &(t[n2]), t, n2));
+  // We now have our three components. Add them together.
+  // r1,r2,c = r1,r2 + t2,t3,c
+  c += bn_add_words(&r[n], &r[n], &t[n2], n2);
+
+  // Propagate the carry bit to the end.
+  for (int i = n + n2; i < n2 + n2; i++) {
+    BN_ULONG old = r[i];
+    r[i] = old + c;
+    c = r[i] < old;
   }
 
-  // t[32] holds (a[0]-a[1])*(b[1]-b[0])+(a[0]*b[0])+(a[1]*b[1])
-  // r[10] holds (a[0]*b[0])
-  // r[32] holds (b[1]*b[1])
-  // c1 holds the carry bits
-  c1 += (int)(bn_add_words(&(r[n]), &(r[n]), &(t[n2]), n2));
-  if (c1) {
-    p = &(r[n + n2]);
-    lo = *p;
-    ln = lo + c1;
-    *p = ln;
-
-    // The overflow will stop before we over write
-    // words we should not overwrite
-    if (ln < (BN_ULONG)c1) {
-      do {
-        p++;
-        lo = *p;
-        ln = lo + 1;
-        *p = ln;
-      } while (ln == 0);
-    }
-  }
+  // The product should fit without carries.
+  assert(c == 0);
 }
 
 // bn_mul_impl implements |BN_mul| and |bn_mul_fixed|. Note this function breaks
@@ -605,6 +594,9 @@
         goto err;
       }
       if (al > j || bl > j) {
+        // We know |al| and |bl| are at most one from each other, so if al > j,
+        // bl >= j, and vice versa. Thus we can use |bn_mul_part_recursive|.
+        assert(al >= j && bl >= j);
         // TODO(davidben): Check that these are correctly-sized, after rewriting
         // |bn_mul_part_recursive|.
         if (!bn_wexpand(t, j * 8) ||