Make the rest of BIGNUM accept non-minimal values.

Test this by re-running bn_tests.txt tests a lot. For the most part,
this was done by scattering bn_minimal_width or bn_correct_top calls as
needed. We'll incrementally tease apart the functions that need to act
on non-minimal BIGNUMs in constant-time.

BN_sqr was switched to call bn_correct_top at the end, rather than
sample bn_minimal_width, in anticipation of later splitting it into
BN_sqr (for calculators) and BN_sqr_fixed (for BN_mod_mul_montgomery).

BN_div_word also uses bn_correct_top because it calls BN_lshift so
officially shouldn't rely on BN_lshift returning something
minimal-width, though I expect we'd want to split off a BN_lshift_fixed
than change that anyway?

The shifts sample bn_minimal_width rather than bn_correct_top because
they all seem to try to be very clever around the bit width. If we need
constant-time versions of them, we can adjust them later.

Bug: 232
Change-Id: Ie17b39034a713542dbe906cf8954c0c5483c7db7
Reviewed-on: https://boringssl-review.googlesource.com/25255
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bn/add.c b/crypto/fipsmodule/bn/add.c
index 645e647..0c430ae 100644
--- a/crypto/fipsmodule/bn/add.c
+++ b/crypto/fipsmodule/bn/add.c
@@ -251,8 +251,8 @@
   register BN_ULONG t1, t2, *ap, *bp, *rp;
   int i, carry;
 
-  max = a->top;
-  min = b->top;
+  max = bn_minimal_width(a);
+  min = bn_minimal_width(b);
   dif = max - min;
 
   if (dif < 0)  // hmm... should not be happening
@@ -337,7 +337,7 @@
     return i;
   }
 
-  if ((a->top == 1) && (a->d[0] < w)) {
+  if ((bn_minimal_width(a) == 1) && (a->d[0] < w)) {
     a->d[0] = w - a->d[0];
     a->neg = 1;
     return 1;
diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc
index f36656f..dcb55f8 100644
--- a/crypto/fipsmodule/bn/bn_test.cc
+++ b/crypto/fipsmodule/bn/bn_test.cc
@@ -107,34 +107,62 @@
   return ret;
 }
 
-static bssl::UniquePtr<BIGNUM> GetBIGNUM(FileTest *t, const char *attribute) {
-  std::string hex;
-  if (!t->GetAttribute(&hex, attribute)) {
-    return nullptr;
+// A BIGNUMFileTest wraps a FileTest to give |BIGNUM| values and also allows
+// injecting oversized |BIGNUM|s.
+class BIGNUMFileTest {
+ public:
+  BIGNUMFileTest(FileTest *t, unsigned large_mask)
+      : t_(t), large_mask_(large_mask), num_bignums_(0) {}
+
+  unsigned num_bignums() const { return num_bignums_; }
+
+  bssl::UniquePtr<BIGNUM> GetBIGNUM(const char *attribute) {
+    return GetBIGNUMImpl(attribute, true /* resize */);
   }
 
-  bssl::UniquePtr<BIGNUM> ret;
-  if (HexToBIGNUM(&ret, hex.c_str()) != static_cast<int>(hex.size())) {
-    t->PrintLine("Could not decode '%s'.", hex.c_str());
-    return nullptr;
-  }
-  return ret;
-}
+  bool GetInt(int *out, const char *attribute) {
+    bssl::UniquePtr<BIGNUM> ret =
+        GetBIGNUMImpl(attribute, false /* don't resize */);
+    if (!ret) {
+      return false;
+    }
 
-static bool GetInt(FileTest *t, int *out, const char *attribute) {
-  bssl::UniquePtr<BIGNUM> ret = GetBIGNUM(t, attribute);
-  if (!ret) {
-    return false;
+    BN_ULONG word = BN_get_word(ret.get());
+    if (word > INT_MAX) {
+      return false;
+    }
+
+    *out = static_cast<int>(word);
+    return true;
   }
 
-  BN_ULONG word = BN_get_word(ret.get());
-  if (word > INT_MAX) {
-    return false;
+ private:
+  bssl::UniquePtr<BIGNUM> GetBIGNUMImpl(const char *attribute, bool resize) {
+    std::string hex;
+    if (!t_->GetAttribute(&hex, attribute)) {
+      return nullptr;
+    }
+
+    bssl::UniquePtr<BIGNUM> ret;
+    if (HexToBIGNUM(&ret, hex.c_str()) != static_cast<int>(hex.size())) {
+      t_->PrintLine("Could not decode '%s'.", hex.c_str());
+      return nullptr;
+    }
+    if (resize) {
+      // Test with an oversized |BIGNUM| if necessary.
+      if ((large_mask_ & (1 << num_bignums_)) &&
+          !bn_resize_words(ret.get(), ret->top * 2 + 1)) {
+        return nullptr;
+      }
+      num_bignums_++;
+    }
+    return ret;
   }
 
-  *out = static_cast<int>(word);
-  return true;
-}
+  FileTest *t_;
+  unsigned large_mask_;
+  unsigned num_bignums_;
+};
 
 static testing::AssertionResult AssertBIGNUMSEqual(
     const char *operation_expr, const char *expected_expr,
@@ -160,10 +188,10 @@
 #define EXPECT_BIGNUMS_EQUAL(op, a, b) \
   EXPECT_PRED_FORMAT3(AssertBIGNUMSEqual, op, a, b)
 
-static void TestSum(FileTest *t, BN_CTX *ctx) {
-  bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
-  bssl::UniquePtr<BIGNUM> b = GetBIGNUM(t, "B");
-  bssl::UniquePtr<BIGNUM> sum = GetBIGNUM(t, "Sum");
+static void TestSum(BIGNUMFileTest *t, BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM> a = t->GetBIGNUM("A");
+  bssl::UniquePtr<BIGNUM> b = t->GetBIGNUM("B");
+  bssl::UniquePtr<BIGNUM> sum = t->GetBIGNUM("Sum");
   ASSERT_TRUE(a);
   ASSERT_TRUE(b);
   ASSERT_TRUE(sum);
@@ -262,9 +290,9 @@
   }
 }
 
-static void TestLShift1(FileTest *t, BN_CTX *ctx) {
-  bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
-  bssl::UniquePtr<BIGNUM> lshift1 = GetBIGNUM(t, "LShift1");
+static void TestLShift1(BIGNUMFileTest *t, BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM> a = t->GetBIGNUM("A");
+  bssl::UniquePtr<BIGNUM> lshift1 = t->GetBIGNUM("LShift1");
   bssl::UniquePtr<BIGNUM> zero(BN_new());
   ASSERT_TRUE(a);
   ASSERT_TRUE(lshift1);
@@ -308,13 +336,13 @@
   EXPECT_BIGNUMS_EQUAL("(LShift | 1) >> 1", a.get(), ret.get());
 }
 
-static void TestLShift(FileTest *t, BN_CTX *ctx) {
-  bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
-  bssl::UniquePtr<BIGNUM> lshift = GetBIGNUM(t, "LShift");
+static void TestLShift(BIGNUMFileTest *t, BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM> a = t->GetBIGNUM("A");
+  bssl::UniquePtr<BIGNUM> lshift = t->GetBIGNUM("LShift");
   ASSERT_TRUE(a);
   ASSERT_TRUE(lshift);
   int n = 0;
-  ASSERT_TRUE(GetInt(t, &n, "N"));
+  ASSERT_TRUE(t->GetInt(&n, "N"));
 
   bssl::UniquePtr<BIGNUM> ret(BN_new());
   ASSERT_TRUE(ret);
@@ -325,13 +353,13 @@
   EXPECT_BIGNUMS_EQUAL("A >> N", a.get(), ret.get());
 }
 
-static void TestRShift(FileTest *t, BN_CTX *ctx) {
-  bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
-  bssl::UniquePtr<BIGNUM> rshift = GetBIGNUM(t, "RShift");
+static void TestRShift(BIGNUMFileTest *t, BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM> a = t->GetBIGNUM("A");
+  bssl::UniquePtr<BIGNUM> rshift = t->GetBIGNUM("RShift");
   ASSERT_TRUE(a);
   ASSERT_TRUE(rshift);
   int n = 0;
-  ASSERT_TRUE(GetInt(t, &n, "N"));
+  ASSERT_TRUE(t->GetInt(&n, "N"));
 
   bssl::UniquePtr<BIGNUM> ret(BN_new());
   ASSERT_TRUE(ret);
@@ -339,9 +367,9 @@
   EXPECT_BIGNUMS_EQUAL("A >> N", rshift.get(), ret.get());
 }
 
-static void TestSquare(FileTest *t, BN_CTX *ctx) {
-  bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
-  bssl::UniquePtr<BIGNUM> square = GetBIGNUM(t, "Square");
+static void TestSquare(BIGNUMFileTest *t, BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM> a = t->GetBIGNUM("A");
+  bssl::UniquePtr<BIGNUM> square = t->GetBIGNUM("Square");
   bssl::UniquePtr<BIGNUM> zero(BN_new());
   ASSERT_TRUE(a);
   ASSERT_TRUE(square);
@@ -412,10 +440,10 @@
 #endif
 }
 
-static void TestProduct(FileTest *t, BN_CTX *ctx) {
-  bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
-  bssl::UniquePtr<BIGNUM> b = GetBIGNUM(t, "B");
-  bssl::UniquePtr<BIGNUM> product = GetBIGNUM(t, "Product");
+static void TestProduct(BIGNUMFileTest *t, BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM> a = t->GetBIGNUM("A");
+  bssl::UniquePtr<BIGNUM> b = t->GetBIGNUM("B");
+  bssl::UniquePtr<BIGNUM> product = t->GetBIGNUM("Product");
   bssl::UniquePtr<BIGNUM> zero(BN_new());
   ASSERT_TRUE(a);
   ASSERT_TRUE(b);
@@ -475,11 +503,11 @@
 #endif
 }
 
-static void TestQuotient(FileTest *t, BN_CTX *ctx) {
-  bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
-  bssl::UniquePtr<BIGNUM> b = GetBIGNUM(t, "B");
-  bssl::UniquePtr<BIGNUM> quotient = GetBIGNUM(t, "Quotient");
-  bssl::UniquePtr<BIGNUM> remainder = GetBIGNUM(t, "Remainder");
+static void TestQuotient(BIGNUMFileTest *t, BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM> a = t->GetBIGNUM("A");
+  bssl::UniquePtr<BIGNUM> b = t->GetBIGNUM("B");
+  bssl::UniquePtr<BIGNUM> quotient = t->GetBIGNUM("Quotient");
+  bssl::UniquePtr<BIGNUM> remainder = t->GetBIGNUM("Remainder");
   ASSERT_TRUE(a);
   ASSERT_TRUE(b);
   ASSERT_TRUE(quotient);
@@ -523,11 +551,11 @@
   }
 }
 
-static void TestModMul(FileTest *t, BN_CTX *ctx) {
-  bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
-  bssl::UniquePtr<BIGNUM> b = GetBIGNUM(t, "B");
-  bssl::UniquePtr<BIGNUM> m = GetBIGNUM(t, "M");
-  bssl::UniquePtr<BIGNUM> mod_mul = GetBIGNUM(t, "ModMul");
+static void TestModMul(BIGNUMFileTest *t, BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM> a = t->GetBIGNUM("A");
+  bssl::UniquePtr<BIGNUM> b = t->GetBIGNUM("B");
+  bssl::UniquePtr<BIGNUM> m = t->GetBIGNUM("M");
+  bssl::UniquePtr<BIGNUM> mod_mul = t->GetBIGNUM("ModMul");
   ASSERT_TRUE(a);
   ASSERT_TRUE(b);
   ASSERT_TRUE(m);
@@ -581,10 +609,10 @@
   }
 }
 
-static void TestModSquare(FileTest *t, BN_CTX *ctx) {
-  bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
-  bssl::UniquePtr<BIGNUM> m = GetBIGNUM(t, "M");
-  bssl::UniquePtr<BIGNUM> mod_square = GetBIGNUM(t, "ModSquare");
+static void TestModSquare(BIGNUMFileTest *t, BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM> a = t->GetBIGNUM("A");
+  bssl::UniquePtr<BIGNUM> m = t->GetBIGNUM("M");
+  bssl::UniquePtr<BIGNUM> mod_square = t->GetBIGNUM("ModSquare");
   ASSERT_TRUE(a);
   ASSERT_TRUE(m);
   ASSERT_TRUE(mod_square);
@@ -658,11 +686,11 @@
   }
 }
 
-static void TestModExp(FileTest *t, BN_CTX *ctx) {
-  bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
-  bssl::UniquePtr<BIGNUM> e = GetBIGNUM(t, "E");
-  bssl::UniquePtr<BIGNUM> m = GetBIGNUM(t, "M");
-  bssl::UniquePtr<BIGNUM> mod_exp = GetBIGNUM(t, "ModExp");
+static void TestModExp(BIGNUMFileTest *t, BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM> a = t->GetBIGNUM("A");
+  bssl::UniquePtr<BIGNUM> e = t->GetBIGNUM("E");
+  bssl::UniquePtr<BIGNUM> m = t->GetBIGNUM("M");
+  bssl::UniquePtr<BIGNUM> mod_exp = t->GetBIGNUM("ModExp");
   ASSERT_TRUE(a);
   ASSERT_TRUE(e);
   ASSERT_TRUE(m);
@@ -708,10 +736,10 @@
   }
 }
 
-static void TestExp(FileTest *t, BN_CTX *ctx) {
-  bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
-  bssl::UniquePtr<BIGNUM> e = GetBIGNUM(t, "E");
-  bssl::UniquePtr<BIGNUM> exp = GetBIGNUM(t, "Exp");
+static void TestExp(BIGNUMFileTest *t, BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM> a = t->GetBIGNUM("A");
+  bssl::UniquePtr<BIGNUM> e = t->GetBIGNUM("E");
+  bssl::UniquePtr<BIGNUM> exp = t->GetBIGNUM("Exp");
   ASSERT_TRUE(a);
   ASSERT_TRUE(e);
   ASSERT_TRUE(exp);
@@ -722,10 +750,10 @@
   EXPECT_BIGNUMS_EQUAL("A ^ E", exp.get(), ret.get());
 }
 
-static void TestModSqrt(FileTest *t, BN_CTX *ctx) {
-  bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
-  bssl::UniquePtr<BIGNUM> p = GetBIGNUM(t, "P");
-  bssl::UniquePtr<BIGNUM> mod_sqrt = GetBIGNUM(t, "ModSqrt");
+static void TestModSqrt(BIGNUMFileTest *t, BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM> a = t->GetBIGNUM("A");
+  bssl::UniquePtr<BIGNUM> p = t->GetBIGNUM("P");
+  bssl::UniquePtr<BIGNUM> mod_sqrt = t->GetBIGNUM("ModSqrt");
   bssl::UniquePtr<BIGNUM> mod_sqrt2(BN_new());
   ASSERT_TRUE(a);
   ASSERT_TRUE(p);
@@ -747,9 +775,9 @@
   }
 }
 
-static void TestNotModSquare(FileTest *t, BN_CTX *ctx) {
-  bssl::UniquePtr<BIGNUM> not_mod_square = GetBIGNUM(t, "NotModSquare");
-  bssl::UniquePtr<BIGNUM> p = GetBIGNUM(t, "P");
+static void TestNotModSquare(BIGNUMFileTest *t, BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM> not_mod_square = t->GetBIGNUM("NotModSquare");
+  bssl::UniquePtr<BIGNUM> p = t->GetBIGNUM("P");
   bssl::UniquePtr<BIGNUM> ret(BN_new());
   ASSERT_TRUE(not_mod_square);
   ASSERT_TRUE(p);
@@ -764,10 +792,10 @@
   ERR_clear_error();
 }
 
-static void TestModInv(FileTest *t, BN_CTX *ctx) {
-  bssl::UniquePtr<BIGNUM> a = GetBIGNUM(t, "A");
-  bssl::UniquePtr<BIGNUM> m = GetBIGNUM(t, "M");
-  bssl::UniquePtr<BIGNUM> mod_inv = GetBIGNUM(t, "ModInv");
+static void TestModInv(BIGNUMFileTest *t, BN_CTX *ctx) {
+  bssl::UniquePtr<BIGNUM> a = t->GetBIGNUM("A");
+  bssl::UniquePtr<BIGNUM> m = t->GetBIGNUM("M");
+  bssl::UniquePtr<BIGNUM> mod_inv = t->GetBIGNUM("ModInv");
   ASSERT_TRUE(a);
   ASSERT_TRUE(m);
   ASSERT_TRUE(mod_inv);
@@ -794,7 +822,7 @@
 TEST_F(BNTest, TestVectors) {
   static const struct {
     const char *name;
-    void (*func)(FileTest *t, BN_CTX *ctx);
+    void (*func)(BIGNUMFileTest *t, BN_CTX *ctx);
   } kTests[] = {
       {"Sum", TestSum},
       {"LShift1", TestLShift1},
@@ -813,13 +841,34 @@
   };
 
   FileTestGTest("crypto/fipsmodule/bn/bn_tests.txt", [&](FileTest *t) {
+    void (*func)(BIGNUMFileTest *t, BN_CTX *ctx) = nullptr;
     for (const auto &test : kTests) {
       if (t->GetType() == test.name) {
-        test.func(t, ctx());
-        return;
+        func = test.func;
+        break;
       }
     }
-    FAIL() << "Unknown test type: " << t->GetType();
+    if (!func) {
+      FAIL() << "Unknown test type: " << t->GetType();
+      return;
+    }
+
+    // Run the test with normalize-sized |BIGNUM|s.
+    BIGNUMFileTest bn_test(t, 0);
+    BN_CTX_start(ctx());
+    func(&bn_test, ctx());
+    BN_CTX_end(ctx());
+    unsigned num_bignums = bn_test.num_bignums();
+
+    // Repeat the test with all combinations of large and small |BIGNUM|s.
+    for (unsigned large_mask = 1; large_mask < (1u << num_bignums);
+         large_mask++) {
+      SCOPED_TRACE(large_mask);
+      BIGNUMFileTest bn_test2(t, large_mask);
+      BN_CTX_start(ctx());
+      func(&bn_test2, ctx());
+      BN_CTX_end(ctx());
+    }
   });
 }
 
@@ -865,7 +914,6 @@
               Bytes(out, sizeof(out) - bytes));
     EXPECT_EQ(Bytes(reference, bytes), Bytes(out + sizeof(out) - bytes, bytes));
 
-#if !defined(BORINGSSL_SHARED_LIBRARY)
     // Repeat some tests with a non-minimal |BIGNUM|.
     EXPECT_TRUE(bn_resize_words(n.get(), 32));
 
@@ -874,7 +922,6 @@
     ASSERT_TRUE(BN_bn2bin_padded(out, bytes + 1, n.get()));
     EXPECT_EQ(0u, out[0]);
     EXPECT_EQ(Bytes(reference, bytes), Bytes(out + 1, bytes));
-#endif
   }
 }
 
@@ -1897,6 +1944,7 @@
   EXPECT_EQ(0, bn_less_than_words(NULL, NULL, 0));
   EXPECT_EQ(0, bn_in_range_words(NULL, 0, NULL, 0));
 }
+#endif  // !BORINGSSL_SHARED_LIBRARY
 
 TEST_F(BNTest, NonMinimal) {
   bssl::UniquePtr<BIGNUM> ten(BN_new());
@@ -1992,5 +2040,3 @@
   EXPECT_EQ(mont->N.top, mont2->N.top);
   EXPECT_EQ(0, BN_cmp(&mont->RR, &mont2->RR));
 }
-
-#endif  // !BORINGSSL_SHARED_LIBRARY
diff --git a/crypto/fipsmodule/bn/div.c b/crypto/fipsmodule/bn/div.c
index 7f261f1..be8c582 100644
--- a/crypto/fipsmodule/bn/div.c
+++ b/crypto/fipsmodule/bn/div.c
@@ -202,10 +202,16 @@
   BN_ULONG d0, d1;
   int num_n, div_n;
 
-  // Invalid zero-padding would have particularly bad consequences
-  // so don't just rely on bn_check_top() here
-  if ((numerator->top > 0 && numerator->d[numerator->top - 1] == 0) ||
-      (divisor->top > 0 && divisor->d[divisor->top - 1] == 0)) {
+  // This function relies on the historical minimal-width |BIGNUM| invariant.
+  // It is already not constant-time (constant-time reductions should use
+  // Montgomery logic), so we shrink all inputs and intermediate values to
+  // retain the previous behavior.
+
+  // Invalid zero-padding would have particularly bad consequences.
+  int numerator_width = bn_minimal_width(numerator);
+  int divisor_width = bn_minimal_width(divisor);
+  if ((numerator_width > 0 && numerator->d[numerator_width - 1] == 0) ||
+      (divisor_width > 0 && divisor->d[divisor_width - 1] == 0)) {
     OPENSSL_PUT_ERROR(BN, BN_R_NOT_INITIALIZED);
     return 0;
   }
@@ -234,11 +240,13 @@
   if (!BN_lshift(sdiv, divisor, norm_shift)) {
     goto err;
   }
+  bn_correct_top(sdiv);
   sdiv->neg = 0;
   norm_shift += BN_BITS2;
   if (!BN_lshift(snum, numerator, norm_shift)) {
     goto err;
   }
+  bn_correct_top(snum);
   snum->neg = 0;
 
   // Since we don't want to have special-case logic for the case where snum is
@@ -604,14 +612,7 @@
     a->d[i] = d;
   }
 
-  if ((a->top > 0) && (a->d[a->top - 1] == 0)) {
-    a->top--;
-  }
-
-  if (a->top == 0) {
-    a->neg = 0;
-  }
-
+  bn_correct_top(a);
   ret >>= j;
   return ret;
 }
diff --git a/crypto/fipsmodule/bn/exponentiation.c b/crypto/fipsmodule/bn/exponentiation.c
index 9e0ddfb..b62292d 100644
--- a/crypto/fipsmodule/bn/exponentiation.c
+++ b/crypto/fipsmodule/bn/exponentiation.c
@@ -982,7 +982,6 @@
                               const BIGNUM *m, BN_CTX *ctx,
                               const BN_MONT_CTX *mont) {
   int i, ret = 0, window, wvalue;
-  int top;
   BN_MONT_CTX *new_mont = NULL;
 
   int numPowers;
@@ -997,8 +996,6 @@
     return 0;
   }
 
-  top = m->top;
-
   // Use all bits stored in |p|, rather than |BN_num_bits|, so we do not leak
   // whether the top bits are zero.
   int max_bits = p->top * BN_BITS2;
@@ -1021,6 +1018,10 @@
     mont = new_mont;
   }
 
+  // Use the width in |mont->N|, rather than the copy in |m|. The assembly
+  // implementation assumes it can use |top| to size R.
+  int top = mont->N.top;
+
   if (a->neg || BN_ucmp(a, m) >= 0) {
     new_a = BN_new();
     if (new_a == NULL ||
diff --git a/crypto/fipsmodule/bn/internal.h b/crypto/fipsmodule/bn/internal.h
index f3b8d8a..2fec60c 100644
--- a/crypto/fipsmodule/bn/internal.h
+++ b/crypto/fipsmodule/bn/internal.h
@@ -220,7 +220,7 @@
 // Do not call this function outside of unit tests. Most functions currently
 // require |BIGNUM|s be minimal. This function breaks that invariant. It is
 // introduced early so the invariant may be relaxed incrementally.
-int bn_resize_words(BIGNUM *bn, size_t words);
+OPENSSL_EXPORT int bn_resize_words(BIGNUM *bn, size_t words);
 
 // bn_set_words sets |bn| to the value encoded in the |num| words in |words|,
 // least significant word first.
diff --git a/crypto/fipsmodule/bn/mul.c b/crypto/fipsmodule/bn/mul.c
index b93f558..3303478 100644
--- a/crypto/fipsmodule/bn/mul.c
+++ b/crypto/fipsmodule/bn/mul.c
@@ -866,13 +866,8 @@
   }
 
   rr->neg = 0;
-  // If the most-significant half of the top word of 'a' is zero, then
-  // the square of 'a' will max-1 words.
-  if (a->d[al - 1] == (a->d[al - 1] & BN_MASK2l)) {
-    rr->top = max - 1;
-  } else {
-    rr->top = max;
-  }
+  rr->top = max;
+  bn_correct_top(rr);
 
   if (rr != r && !BN_copy(r, rr)) {
     goto err;
diff --git a/crypto/fipsmodule/bn/shift.c b/crypto/fipsmodule/bn/shift.c
index d4ed79e..bedab09 100644
--- a/crypto/fipsmodule/bn/shift.c
+++ b/crypto/fipsmodule/bn/shift.c
@@ -142,10 +142,11 @@
     return 0;
   }
 
+  int a_width = bn_minimal_width(a);
   nw = n / BN_BITS2;
   rb = n % BN_BITS2;
   lb = BN_BITS2 - rb;
-  if (nw >= a->top || a->top == 0) {
+  if (nw >= a_width || a_width == 0) {
     BN_zero(r);
     return 1;
   }
@@ -163,7 +164,7 @@
 
   f = &(a->d[nw]);
   t = r->d;
-  j = a->top - nw;
+  j = a_width - nw;
   r->top = i;
 
   if (rb == 0) {
@@ -198,7 +199,7 @@
     BN_zero(r);
     return 1;
   }
-  i = a->top;
+  i = bn_minimal_width(a);
   ap = a->d;
   j = i - (ap[i - 1] == 1);
   if (a != r) {