Revert "Remove Karatsuba multiplication in BIGNUM"

This reverts commit 89eb6ddbf52887631ca45ad2242862515cd698ab. Sadly,
exactfloat, a decidedly non-cryptographic use case, uses BN_mul with
large enough inputs that Karatsuba is actually a load-bearing
optimization. (They're also why we need to support allocating giant
BIGNUMs.)

This CL just reverts the change for now. We should revise thresholds and
rearrange code so that this code is not reachable from any of the
cryptographic code. From there, we can revert the work to make it
constant-time, which will be better from exactfloat and also remove some
complexity.

Bug: 406497222
Change-Id: I08b6d12e19c2a6ae741ac490d81cc534ba260145
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/78047
Commit-Queue: Bob Beck <bbe@google.com>
Reviewed-by: Bob Beck <bbe@google.com>
diff --git a/crypto/fipsmodule/bn/mul.cc.inc b/crypto/fipsmodule/bn/mul.cc.inc
index f28e140..4d7a919 100644
--- a/crypto/fipsmodule/bn/mul.cc.inc
+++ b/crypto/fipsmodule/bn/mul.cc.inc
@@ -25,7 +25,16 @@
 #include "internal.h"
 
 
-#define BN_SQR_STACK_WORDS 16
+#define BN_MUL_RECURSIVE_SIZE_NORMAL 16
+#define BN_SQR_RECURSIVE_SIZE_NORMAL BN_MUL_RECURSIVE_SIZE_NORMAL
+
+
+static void bn_abs_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
+                             size_t num, BN_ULONG *tmp) {
+  BN_ULONG borrow = bn_sub_words(tmp, a, b, num);
+  bn_sub_words(r, b, a, num);
+  bn_select_words(r, 0 - borrow, r /* tmp < 0 */, tmp /* tmp >= 0 */, num);
+}
 
 static void bn_mul_normal(BN_ULONG *r, const BN_ULONG *a, size_t na,
                           const BN_ULONG *b, size_t nb) {
@@ -112,11 +121,6 @@
 //
 // TODO(davidben): Make this take |size_t|. The |cl| + |dl| calling convention
 // is confusing.
-//
-// TODO(davidben): This function used to be used as part of a general Karatsuba
-// multiplication implementation, which had to account for differently-sized
-// inputs. Now it is only used as part of RSA key generation, which does not
-// need all this.
 static BN_ULONG bn_abs_sub_part_words(BN_ULONG *r, const BN_ULONG *a,
                                       const BN_ULONG *b, int cl, int dl,
                                       BN_ULONG *tmp) {
@@ -144,6 +148,223 @@
   return ok;
 }
 
+// Karatsuba recursive multiplication algorithm
+// (cf. Knuth, The Art of Computer Programming, Vol. 2)
+
+// bn_mul_recursive sets |r| to |a| * |b|, using |t| as scratch space. |r| has
+// length 2*|n2|, |a| has length |n2| + |dna|, |b| has length |n2| + |dnb|, and
+// |t| has length 4*|n2|. |n2| must be a power of two. Finally, we must have
+// -|BN_MUL_RECURSIVE_SIZE_NORMAL|/2 <= |dna| <= 0 and
+// -|BN_MUL_RECURSIVE_SIZE_NORMAL|/2 <= |dnb| <= 0.
+//
+// TODO(davidben): Simplify and |size_t| the calling convention around lengths
+// here.
+static void bn_mul_recursive(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
+                             int n2, int dna, int dnb, BN_ULONG *t) {
+  // |n2| is a power of two.
+  assert(n2 != 0 && (n2 & (n2 - 1)) == 0);
+  // Check |dna| and |dnb| are in range.
+  assert(-BN_MUL_RECURSIVE_SIZE_NORMAL / 2 <= dna && dna <= 0);
+  assert(-BN_MUL_RECURSIVE_SIZE_NORMAL / 2 <= dnb && dnb <= 0);
+
+  // Only call bn_mul_comba 8 if n2 == 8 and the
+  // two arrays are complete [steve]
+  if (n2 == 8 && dna == 0 && dnb == 0) {
+    bn_mul_comba8(r, a, b);
+    return;
+  }
+
+  // Else do normal multiply
+  if (n2 < BN_MUL_RECURSIVE_SIZE_NORMAL) {
+    bn_mul_normal(r, a, n2 + dna, b, n2 + dnb);
+    if (dna + dnb < 0) {
+      OPENSSL_memset(&r[2 * n2 + dna + dnb], 0,
+                     sizeof(BN_ULONG) * -(dna + dnb));
+    }
+    return;
+  }
+
+  // Split |a| and |b| into a0,a1 and b0,b1, where a0 and b0 have size |n|.
+  // Split |t| into t0,t1,t2,t3, each of size |n|, with the remaining 4*|n| used
+  // for recursive calls.
+  // Split |r| into r0,r1,r2,r3. We must contribute a0*b0 to r0,r1, a0*a1+b0*b1
+  // to r1,r2, and a1*b1 to r2,r3. The middle term we will compute as:
+  //
+  //   a0*a1 + b0*b1 = (a0 - a1)*(b1 - b0) + a1*b1 + a0*b0
+  //
+  // Note that we know |n| >= |BN_MUL_RECURSIVE_SIZE_NORMAL|/2 above, so
+  // |tna| and |tnb| are non-negative.
+  int n = n2 / 2, tna = n + dna, tnb = n + dnb;
+
+  // t0 = a0 - a1 and t1 = b1 - b0. The result will be multiplied, so we XOR
+  // their sign masks, giving the sign of (a0 - a1)*(b1 - b0). t0 and t1
+  // themselves store the absolute value.
+  BN_ULONG neg = bn_abs_sub_part_words(t, a, &a[n], tna, n - tna, &t[n2]);
+  neg ^= bn_abs_sub_part_words(&t[n], &b[n], b, tnb, tnb - n, &t[n2]);
+
+  // Compute:
+  // t2,t3 = t0 * t1 = |(a0 - a1)*(b1 - b0)|
+  // r0,r1 = a0 * b0
+  // r2,r3 = a1 * b1
+  if (n == 4 && dna == 0 && dnb == 0) {
+    bn_mul_comba4(&t[n2], t, &t[n]);
+
+    bn_mul_comba4(r, a, b);
+    bn_mul_comba4(&r[n2], &a[n], &b[n]);
+  } else if (n == 8 && dna == 0 && dnb == 0) {
+    bn_mul_comba8(&t[n2], t, &t[n]);
+
+    bn_mul_comba8(r, a, b);
+    bn_mul_comba8(&r[n2], &a[n], &b[n]);
+  } else {
+    BN_ULONG *p = &t[n2 * 2];
+    bn_mul_recursive(&t[n2], t, &t[n], n, 0, 0, p);
+    bn_mul_recursive(r, a, b, n, 0, 0, p);
+    bn_mul_recursive(&r[n2], &a[n], &b[n], n, dna, dnb, p);
+  }
+
+  // t0,t1,c = r0,r1 + r2,r3 = a0*b0 + a1*b1
+  BN_ULONG c = bn_add_words(t, r, &r[n2], n2);
+
+  // t2,t3,c = t0,t1,c + neg*t2,t3 = (a0 - a1)*(b1 - b0) + a1*b1 + a0*b0.
+  // The second term is stored as the absolute value, so we do this with a
+  // constant-time select.
+  BN_ULONG c_neg = c - bn_sub_words(&t[n2 * 2], t, &t[n2], n2);
+  BN_ULONG c_pos = c + bn_add_words(&t[n2], t, &t[n2], n2);
+  bn_select_words(&t[n2], neg, &t[n2 * 2], &t[n2], n2);
+  static_assert(sizeof(BN_ULONG) <= sizeof(crypto_word_t),
+                "crypto_word_t is too small");
+  c = constant_time_select_w(neg, c_neg, c_pos);
+
+  // We now have our three components. Add them together.
+  // r1,r2,c = r1,r2 + t2,t3,c
+  c += bn_add_words(&r[n], &r[n], &t[n2], n2);
+
+  // Propagate the carry bit to the end.
+  for (int i = n + n2; i < n2 + n2; i++) {
+    BN_ULONG old = r[i];
+    r[i] = old + c;
+    c = r[i] < old;
+  }
+
+  // The product should fit without carries.
+  declassify_assert(c == 0);
+}
+
+// bn_mul_part_recursive sets |r| to |a| * |b|, using |t| as scratch space. |r|
+// has length 4*|n|, |a| has length |n| + |tna|, |b| has length |n| + |tnb|, and
+// |t| has length 8*|n|. |n| must be a power of two. Additionally, we must have
+// 0 <= tna < n and 0 <= tnb < n, and |tna| and |tnb| must differ by at most
+// one.
+//
+// TODO(davidben): Make this take |size_t| and perhaps the actual lengths of |a|
+// and |b|.
+static void bn_mul_part_recursive(BN_ULONG *r, const BN_ULONG *a,
+                                  const BN_ULONG *b, int n, int tna, int tnb,
+                                  BN_ULONG *t) {
+  // |n| is a power of two.
+  assert(n != 0 && (n & (n - 1)) == 0);
+  // Check |tna| and |tnb| are in range.
+  assert(0 <= tna && tna < n);
+  assert(0 <= tnb && tnb < n);
+  assert(-1 <= tna - tnb && tna - tnb <= 1);
+
+  int n2 = n * 2;
+  if (n < 8) {
+    bn_mul_normal(r, a, n + tna, b, n + tnb);
+    OPENSSL_memset(r + n2 + tna + tnb, 0, n2 - tna - tnb);
+    return;
+  }
+
+  // Split |a| and |b| into a0,a1 and b0,b1, where a0 and b0 have size |n|. |a1|
+  // and |b1| have size |tna| and |tnb|, respectively.
+  // Split |t| into t0,t1,t2,t3, each of size |n|, with the remaining 4*|n| used
+  // for recursive calls.
+  // Split |r| into r0,r1,r2,r3. We must contribute a0*b0 to r0,r1, a0*a1+b0*b1
+  // to r1,r2, and a1*b1 to r2,r3. The middle term we will compute as:
+  //
+  //   a0*a1 + b0*b1 = (a0 - a1)*(b1 - b0) + a1*b1 + a0*b0
+
+  // t0 = a0 - a1 and t1 = b1 - b0. The result will be multiplied, so we XOR
+  // their sign masks, giving the sign of (a0 - a1)*(b1 - b0). t0 and t1
+  // themselves store the absolute value.
+  BN_ULONG neg = bn_abs_sub_part_words(t, a, &a[n], tna, n - tna, &t[n2]);
+  neg ^= bn_abs_sub_part_words(&t[n], &b[n], b, tnb, tnb - n, &t[n2]);
+
+  // Compute:
+  // t2,t3 = t0 * t1 = |(a0 - a1)*(b1 - b0)|
+  // r0,r1 = a0 * b0
+  // r2,r3 = a1 * b1
+  if (n == 8) {
+    bn_mul_comba8(&t[n2], t, &t[n]);
+    bn_mul_comba8(r, a, b);
+
+    bn_mul_normal(&r[n2], &a[n], tna, &b[n], tnb);
+    // |bn_mul_normal| only writes |tna| + |tna| words. Zero the rest.
+    OPENSSL_memset(&r[n2 + tna + tnb], 0, sizeof(BN_ULONG) * (n2 - tna - tnb));
+  } else {
+    BN_ULONG *p = &t[n2 * 2];
+    bn_mul_recursive(&t[n2], t, &t[n], n, 0, 0, p);
+    bn_mul_recursive(r, a, b, n, 0, 0, p);
+
+    OPENSSL_memset(&r[n2], 0, sizeof(BN_ULONG) * n2);
+    if (tna < BN_MUL_RECURSIVE_SIZE_NORMAL &&
+        tnb < BN_MUL_RECURSIVE_SIZE_NORMAL) {
+      bn_mul_normal(&r[n2], &a[n], tna, &b[n], tnb);
+    } else {
+      int i = n;
+      for (;;) {
+        i /= 2;
+        if (i < tna || i < tnb) {
+          // E.g., n == 16, i == 8 and tna == 11. |tna| and |tnb| are within one
+          // of each other, so if |tna| is larger and tna > i, then we know
+          // tnb >= i, and this call is valid.
+          bn_mul_part_recursive(&r[n2], &a[n], &b[n], i, tna - i, tnb - i, p);
+          break;
+        }
+        if (i == tna || i == tnb) {
+          // If there is only a bottom half to the number, just do it. We know
+          // the larger of |tna - i| and |tnb - i| is zero. The other is zero or
+          // -1 by because of |tna| and |tnb| differ by at most one.
+          bn_mul_recursive(&r[n2], &a[n], &b[n], i, tna - i, tnb - i, p);
+          break;
+        }
+
+        // This loop will eventually terminate when |i| falls below
+        // |BN_MUL_RECURSIVE_SIZE_NORMAL| because we know one of |tna| and |tnb|
+        // exceeds that.
+      }
+    }
+  }
+
+  // t0,t1,c = r0,r1 + r2,r3 = a0*b0 + a1*b1
+  BN_ULONG c = bn_add_words(t, r, &r[n2], n2);
+
+  // t2,t3,c = t0,t1,c + neg*t2,t3 = (a0 - a1)*(b1 - b0) + a1*b1 + a0*b0.
+  // The second term is stored as the absolute value, so we do this with a
+  // constant-time select.
+  BN_ULONG c_neg = c - bn_sub_words(&t[n2 * 2], t, &t[n2], n2);
+  BN_ULONG c_pos = c + bn_add_words(&t[n2], t, &t[n2], n2);
+  bn_select_words(&t[n2], neg, &t[n2 * 2], &t[n2], n2);
+  static_assert(sizeof(BN_ULONG) <= sizeof(crypto_word_t),
+                "crypto_word_t is too small");
+  c = constant_time_select_w(neg, c_neg, c_pos);
+
+  // We now have our three components. Add them together.
+  // r1,r2,c = r1,r2 + t2,t3,c
+  c += bn_add_words(&r[n], &r[n], &t[n2], n2);
+
+  // Propagate the carry bit to the end.
+  for (int i = n + n2; i < n2 + n2; i++) {
+    BN_ULONG old = r[i];
+    r[i] = old + c;
+    c = r[i] < old;
+  }
+
+  // The product should fit without carries.
+  declassify_assert(c == 0);
+}
+
 // bn_mul_impl implements |BN_mul| and |bn_mul_consttime|. Note this function
 // breaks |BIGNUM| invariants and may return a negative zero. This is handled by
 // the callers.
@@ -182,6 +403,52 @@
   }
 
   top = al + bl;
+  // TODO(crbug.com/406497222): The recursive implementation is actually worse
+  // for cryptographic use cases, but we need to retain it in |BN_mul| for the
+  // projects misusing BIGNUM as a general-purpose calculator library with
+  // giant integers. Disconnect this code from our cryptographic primitives.
+  static const int kMulNormalSize = 16;
+  if (al >= kMulNormalSize && bl >= kMulNormalSize) {
+    if (-1 <= i && i <= 1) {
+      // Find the largest power of two less than or equal to the larger length.
+      int j;
+      if (i >= 0) {
+        j = BN_num_bits_word((BN_ULONG)al);
+      } else {
+        j = BN_num_bits_word((BN_ULONG)bl);
+      }
+      j = 1 << (j - 1);
+      assert(j <= al || j <= bl);
+      BIGNUM *t = BN_CTX_get(ctx);
+      if (t == NULL) {
+        goto err;
+      }
+      if (al > j || bl > j) {
+        // We know |al| and |bl| are at most one from each other, so if al > j,
+        // bl >= j, and vice versa. Thus we can use |bn_mul_part_recursive|.
+        //
+        // TODO(davidben): This codepath is almost unused in standard
+        // algorithms. Is this optimization necessary? See notes in
+        // https://boringssl-review.googlesource.com/q/I0bd604e2cd6a75c266f64476c23a730ca1721ea6
+        assert(al >= j && bl >= j);
+        if (!bn_wexpand(t, j * 8) || !bn_wexpand(rr, j * 4)) {
+          goto err;
+        }
+        bn_mul_part_recursive(rr->d, a->d, b->d, j, al - j, bl - j, t->d);
+      } else {
+        // al <= j && bl <= j. Additionally, we know j <= al or j <= bl, so one
+        // of al - j or bl - j is zero. The other, by the bound on |i| above, is
+        // zero or -1. Thus, we can use |bn_mul_recursive|.
+        if (!bn_wexpand(t, j * 4) || !bn_wexpand(rr, j * 2)) {
+          goto err;
+        }
+        bn_mul_recursive(rr->d, a->d, b->d, j, al - j, bl - j, t->d);
+      }
+      rr->width = top;
+      goto end;
+    }
+  }
+
   if (!bn_wexpand(rr, top)) {
     goto err;
   }
@@ -271,6 +538,66 @@
   bn_add_words(r, r, tmp, max);
 }
 
+// bn_sqr_recursive sets |r| to |a|^2, using |t| as scratch space. |r| has
+// length 2*|n2|, |a| has length |n2|, and |t| has length 4*|n2|. |n2| must be
+// a power of two.
+static void bn_sqr_recursive(BN_ULONG *r, const BN_ULONG *a, size_t n2,
+                             BN_ULONG *t) {
+  // |n2| is a power of two.
+  assert(n2 != 0 && (n2 & (n2 - 1)) == 0);
+
+  if (n2 == 4) {
+    bn_sqr_comba4(r, a);
+    return;
+  }
+  if (n2 == 8) {
+    bn_sqr_comba8(r, a);
+    return;
+  }
+  if (n2 < BN_SQR_RECURSIVE_SIZE_NORMAL) {
+    bn_sqr_normal(r, a, n2, t);
+    return;
+  }
+
+  // Split |a| into a0,a1, each of size |n|.
+  // Split |t| into t0,t1,t2,t3, each of size |n|, with the remaining 4*|n| used
+  // for recursive calls.
+  // Split |r| into r0,r1,r2,r3. We must contribute a0^2 to r0,r1, 2*a0*a1 to
+  // r1,r2, and a1^2 to r2,r3.
+  size_t n = n2 / 2;
+  BN_ULONG *t_recursive = &t[n2 * 2];
+
+  // t0 = |a0 - a1|.
+  bn_abs_sub_words(t, a, &a[n], n, &t[n]);
+  // t2,t3 = t0^2 = |a0 - a1|^2 = a0^2 - 2*a0*a1 + a1^2
+  bn_sqr_recursive(&t[n2], t, n, t_recursive);
+
+  // r0,r1 = a0^2
+  bn_sqr_recursive(r, a, n, t_recursive);
+
+  // r2,r3 = a1^2
+  bn_sqr_recursive(&r[n2], &a[n], n, t_recursive);
+
+  // t0,t1,c = r0,r1 + r2,r3 = a0^2 + a1^2
+  BN_ULONG c = bn_add_words(t, r, &r[n2], n2);
+  // t2,t3,c = t0,t1,c - t2,t3 = 2*a0*a1
+  c -= bn_sub_words(&t[n2], t, &t[n2], n2);
+
+  // We now have our three components. Add them together.
+  // r1,r2,c = r1,r2 + t2,t3,c
+  c += bn_add_words(&r[n], &r[n], &t[n2], n2);
+
+  // Propagate the carry bit to the end.
+  for (size_t i = n + n2; i < n2 + n2; i++) {
+    BN_ULONG old = r[i];
+    r[i] = old + c;
+    c = r[i] < old;
+  }
+
+  // The square should fit without carries.
+  assert(c == 0);
+}
+
 int BN_mul_word(BIGNUM *bn, BN_ULONG w) {
   if (!bn->width) {
     return 1;
@@ -318,14 +645,28 @@
   } else if (al == 8) {
     bn_sqr_comba8(rr->d, a->d);
   } else {
-    if (al < BN_SQR_STACK_WORDS) {
-      BN_ULONG t[BN_SQR_STACK_WORDS * 2];
+    if (al < BN_SQR_RECURSIVE_SIZE_NORMAL) {
+      BN_ULONG t[BN_SQR_RECURSIVE_SIZE_NORMAL * 2];
       bn_sqr_normal(rr->d, a->d, al, t);
     } else {
-      if (!bn_wexpand(tmp, max)) {
-        goto err;
+      // If |al| is a power of two, we can use |bn_sqr_recursive|.
+      //
+      // TODO(crbug.com/406497222): The recursive implementation is actually
+      // worse for cryptographic use cases, but we need to retain it in |BN_mul|
+      // for the projects misusing BIGNUM as a general-purpose calculator
+      // library with giant integers. Disconnect this code from our
+      // cryptographic primitives.
+      if (al != 0 && (al & (al - 1)) == 0) {
+        if (!bn_wexpand(tmp, al * 4)) {
+          goto err;
+        }
+        bn_sqr_recursive(rr->d, a->d, al, tmp->d);
+      } else {
+        if (!bn_wexpand(tmp, max)) {
+          goto err;
+        }
+        bn_sqr_normal(rr->d, a->d, al, tmp->d);
       }
-      bn_sqr_normal(rr->d, a->d, al, tmp->d);
     }
   }