Final cleanup pass in BN_div

Tidy up the setup. Also we can simplify all the sign management. If snum
and sdiv just preserve the sign bits of numerator and denominator, the
remainder will have the correct sign from the start.

(The original code called BN_cmp and BN_add in places, which is
sensitive to the sign.)

Fixed: 358687140
Change-Id: I2d5f952814c9910552330b18462796ffc3fe5dab
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/70228
Reviewed-by: Bob Beck <bbe@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/crypto/fipsmodule/bn/div.c b/crypto/fipsmodule/bn/div.c
index aa2e2e5..b321ff5 100644
--- a/crypto/fipsmodule/bn/div.c
+++ b/crypto/fipsmodule/bn/div.c
@@ -185,13 +185,6 @@
   // Inputs to this function are assumed public and may be leaked by timing and
   // cache side channels. Division with secret inputs should use other
   // implementation strategies such as Montgomery reduction.
-  //
-  // Historically, this function diverged from Knuth's algorithm with some
-  // shortcuts in some cases. Those have been removed per "New Branch Prediction
-  // Vulnerabilities in OpenSSL and Necessary Software Countermeasures" by Onur
-  // Acıçmez, Shay Gueron, and Jean-Pierre Seifert. We continue to omit them for
-  // simplicity, but this function is no longer used with secret inputs. (We
-  // implement a variation on "Smooth CRT-RSA" as described in the paper.)
   if (BN_is_zero(divisor)) {
     OPENSSL_PUT_ERROR(BN, BN_R_DIV_BY_ZERO);
     return 0;
@@ -201,13 +194,8 @@
   BIGNUM *tmp = BN_CTX_get(ctx);
   BIGNUM *snum = BN_CTX_get(ctx);
   BIGNUM *sdiv = BN_CTX_get(ctx);
-  BIGNUM *res = NULL;
-  if (quotient == NULL) {
-    res = BN_CTX_get(ctx);
-  } else {
-    res = quotient;
-  }
-  if (sdiv == NULL || res == NULL) {
+  BIGNUM *res = quotient == NULL ? BN_CTX_get(ctx) : quotient;
+  if (tmp == NULL || snum == NULL || sdiv == NULL || res == NULL) {
     goto err;
   }
 
@@ -219,44 +207,37 @@
       !BN_lshift(snum, numerator, norm_shift)) {
     goto err;
   }
+
+  // This algorithm relies on |sdiv| being minimal width. We do not use this
+  // function on secret inputs, so leaking this is fine. Also minimize |snum| to
+  // avoid looping on leading zeros, as we're not trying to be leak-free.
   bn_set_minimal_width(sdiv);
   bn_set_minimal_width(snum);
-  sdiv->neg = 0;
-  snum->neg = 0;
+  int div_n = sdiv->width;
+  const BN_ULONG d0 = sdiv->d[div_n - 1];
+  const BN_ULONG d1 = (div_n == 1) ? 0 : sdiv->d[div_n - 2];
+  assert(d0 & (((BN_ULONG)1) << (BN_BITS2 - 1)));
 
   // Extend |snum| with zeros to satisfy the long division invariants:
   // - |snum| must have at least |div_n| + 1 words.
   // - |snum|'s most significant word must be zero to guarantee the first loop
   //   iteration works with a prefix greater than |sdiv|. (This is the extra u0
   //   digit in Knuth step D1.)
-  int div_n = sdiv->width;
   int num_n = snum->width <= div_n ? div_n + 1 : snum->width + 1;
   if (!bn_resize_words(snum, num_n)) {
     goto err;
   }
 
+  // Knuth step D2: The quotient's width is the difference between numerator and
+  // denominator. Also set up its sign and size a temporary for the loop.
   int loop = num_n - div_n;
-
-  // Get the top 2 words of sdiv.
-  const BN_ULONG d0 = sdiv->d[div_n - 1];
-  const BN_ULONG d1 = (div_n == 1) ? 0 : sdiv->d[div_n - 2];
-
-  // The normalization step ensures that |sdiv|'s MSB is one.
-  assert(d0 & (((BN_ULONG)1) << (BN_BITS2 - 1)));
-
-  // Setup |res|. |numerator| and |res| may alias, so we save |numerator->neg|
-  // for later.
-  const int numerator_neg = numerator->neg;
-  res->neg = (numerator_neg ^ divisor->neg);
-  if (!bn_wexpand(res, loop)) {
+  res->neg = snum->neg ^ sdiv->neg;
+  if (!bn_wexpand(res, loop) ||  //
+      !bn_wexpand(tmp, div_n + 1)) {
     goto err;
   }
   res->width = loop;
 
-  if (!bn_wexpand(tmp, div_n + 1)) {
-    goto err;
-  }
-
   // Knuth steps D2 through D7: Compute the quotient with a word-by-word long
   // division. Note that Knuth indexes words from most to least significant, so
   // our index is reversed. Each loop iteration computes res->d[i] of the
@@ -366,18 +347,15 @@
     res->d[i] = q;
   }
 
+  // Trim leading zeros and correct any negative zeros.
   bn_set_minimal_width(snum);
+  bn_set_minimal_width(res);
 
-  if (rem != NULL) {
-    if (!BN_rshift(rem, snum, norm_shift)) {
-      goto err;
-    }
-    if (!BN_is_zero(rem)) {
-      rem->neg = numerator_neg;
-    }
+  // Knuth step D8: Unnormalize. snum now contains the remainder.
+  if (rem != NULL && !BN_rshift(rem, snum, norm_shift)) {
+    goto err;
   }
 
-  bn_set_minimal_width(res);
   BN_CTX_end(ctx);
   return 1;