Add bn_copy_words.

This makes it easier going to and from non-minimal BIGNUMs and words
without worrying about the widths which are ultimately to become less
friendly.

Bug: 232
Change-Id: Ia57cb29164c560b600573c27b112ad9375a86aad
Reviewed-on: https://boringssl-review.googlesource.com/25245
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bn/bn.c b/crypto/fipsmodule/bn/bn.c
index 70f98c8..aa0c619 100644
--- a/crypto/fipsmodule/bn/bn.c
+++ b/crypto/fipsmodule/bn/bn.c
@@ -297,6 +297,35 @@
   return 1;
 }
 
+static int bn_fits_in_words(const BIGNUM *bn, size_t num) {
+  // All words beyond |num| must be zero.
+  BN_ULONG mask = 0;
+  for (size_t i = num; i < (size_t)bn->top; i++) {
+    mask |= bn->d[i];
+  }
+  return mask == 0;
+}
+
+int bn_copy_words(BN_ULONG *out, size_t num, const BIGNUM *bn) {
+  if (bn->neg) {
+    OPENSSL_PUT_ERROR(BN, BN_R_NEGATIVE_NUMBER);
+    return 0;
+  }
+
+  size_t width = (size_t)bn->top;
+  if (width > num) {
+    if (!bn_fits_in_words(bn, num)) {
+      OPENSSL_PUT_ERROR(BN, BN_R_BIGNUM_TOO_LONG);
+      return 0;
+    }
+    width = num;
+  }
+
+  OPENSSL_memset(out, 0, sizeof(BN_ULONG) * num);
+  OPENSSL_memcpy(out, bn->d, sizeof(BN_ULONG) * width);
+  return 1;
+}
+
 int BN_is_negative(const BIGNUM *bn) {
   return bn->neg != 0;
 }
@@ -360,11 +389,7 @@
   }
 
   // All words beyond the new width must be zero.
-  BN_ULONG mask = 0;
-  for (size_t i = words; i < (size_t)bn->top; i++) {
-    mask |= bn->d[i];
-  }
-  if (mask != 0) {
+  if (!bn_fits_in_words(bn, words)) {
     OPENSSL_PUT_ERROR(BN, BN_R_BIGNUM_TOO_LONG);
     return 0;
   }
diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc
index 8e6f4eb..a9442cd 100644
--- a/crypto/fipsmodule/bn/bn_test.cc
+++ b/crypto/fipsmodule/bn/bn_test.cc
@@ -387,15 +387,15 @@
   }
 
 #if !defined(BORINGSSL_SHARED_LIBRARY)
-  if (static_cast<size_t>(a->top) <= BN_SMALL_MAX_WORDS) {
-    for (size_t num_a = a->top; num_a <= BN_SMALL_MAX_WORDS; num_a++) {
+  int a_width = bn_minimal_width(a.get());
+  if (a_width <= BN_SMALL_MAX_WORDS) {
+    for (size_t num_a = a_width; num_a <= BN_SMALL_MAX_WORDS; num_a++) {
       SCOPED_TRACE(num_a);
       size_t num_r = 2 * num_a;
       // Use newly-allocated buffers so ASan will catch out-of-bounds writes.
       std::unique_ptr<BN_ULONG[]> a_words(new BN_ULONG[num_a]),
           r_words(new BN_ULONG[num_r]);
-      OPENSSL_memset(a_words.get(), 0, num_a * sizeof(BN_ULONG));
-      OPENSSL_memcpy(a_words.get(), a->d, a->top * sizeof(BN_ULONG));
+      ASSERT_TRUE(bn_copy_words(a_words.get(), num_a, a.get()));
 
       ASSERT_TRUE(bn_mul_small(r_words.get(), num_r, a_words.get(), num_a,
                                a_words.get(), num_a));
@@ -445,22 +445,25 @@
   }
 
 #if !defined(BORINGSSL_SHARED_LIBRARY)
-  if (!BN_is_negative(product.get()) &&
-      static_cast<size_t>(a->top) <= BN_SMALL_MAX_WORDS &&
-      static_cast<size_t>(b->top) <= BN_SMALL_MAX_WORDS) {
-    for (size_t num_a = a->top; num_a <= BN_SMALL_MAX_WORDS; num_a++) {
+  BN_set_negative(a.get(), 0);
+  BN_set_negative(b.get(), 0);
+  BN_set_negative(product.get(), 0);
+
+  int a_width = bn_minimal_width(a.get());
+  int b_width = bn_minimal_width(b.get());
+  if (a_width <= BN_SMALL_MAX_WORDS && b_width <= BN_SMALL_MAX_WORDS) {
+    for (size_t num_a = static_cast<size_t>(a_width);
+         num_a <= BN_SMALL_MAX_WORDS; num_a++) {
       SCOPED_TRACE(num_a);
-      for (size_t num_b = b->top; num_b <= BN_SMALL_MAX_WORDS; num_b++) {
+      for (size_t num_b = static_cast<size_t>(b_width);
+           num_b <= BN_SMALL_MAX_WORDS; num_b++) {
         SCOPED_TRACE(num_b);
         size_t num_r = num_a + num_b;
         // Use newly-allocated buffers so ASan will catch out-of-bounds writes.
         std::unique_ptr<BN_ULONG[]> a_words(new BN_ULONG[num_a]),
             b_words(new BN_ULONG[num_b]), r_words(new BN_ULONG[num_r]);
-        OPENSSL_memset(a_words.get(), 0, num_a * sizeof(BN_ULONG));
-        OPENSSL_memcpy(a_words.get(), a->d, a->top * sizeof(BN_ULONG));
-
-        OPENSSL_memset(b_words.get(), 0, num_b * sizeof(BN_ULONG));
-        OPENSSL_memcpy(b_words.get(), b->d, b->top * sizeof(BN_ULONG));
+        ASSERT_TRUE(bn_copy_words(a_words.get(), num_a, a.get()));
+        ASSERT_TRUE(bn_copy_words(b_words.get(), num_b, b.get()));
 
         ASSERT_TRUE(bn_mul_small(r_words.get(), num_r, a_words.get(), num_a,
                                  b_words.get(), num_b));
@@ -554,24 +557,23 @@
                          ret.get());
 
 #if !defined(BORINGSSL_SHARED_LIBRARY)
-    if (m->top <= BN_SMALL_MAX_WORDS) {
-      std::unique_ptr<BN_ULONG[]> a_words(new BN_ULONG[m->top]),
-          b_words(new BN_ULONG[m->top]), r_words(new BN_ULONG[m->top]);
-      OPENSSL_memset(a_words.get(), 0, m->top * sizeof(BN_ULONG));
-      OPENSSL_memcpy(a_words.get(), a->d, a->top * sizeof(BN_ULONG));
-      OPENSSL_memset(b_words.get(), 0, m->top * sizeof(BN_ULONG));
-      OPENSSL_memcpy(b_words.get(), b->d, b->top * sizeof(BN_ULONG));
-      ASSERT_TRUE(bn_to_montgomery_small(a_words.get(), m->top, a_words.get(),
-                                         m->top, mont.get()));
-      ASSERT_TRUE(bn_to_montgomery_small(b_words.get(), m->top, b_words.get(),
-                                         m->top, mont.get()));
+    size_t m_width = static_cast<size_t>(bn_minimal_width(m.get()));
+    if (m_width <= BN_SMALL_MAX_WORDS) {
+      std::unique_ptr<BN_ULONG[]> a_words(new BN_ULONG[m_width]),
+          b_words(new BN_ULONG[m_width]), r_words(new BN_ULONG[m_width]);
+      ASSERT_TRUE(bn_copy_words(a_words.get(), m_width, a.get()));
+      ASSERT_TRUE(bn_copy_words(b_words.get(), m_width, b.get()));
+      ASSERT_TRUE(bn_to_montgomery_small(a_words.get(), m_width, a_words.get(),
+                                         m_width, mont.get()));
+      ASSERT_TRUE(bn_to_montgomery_small(b_words.get(), m_width, b_words.get(),
+                                         m_width, mont.get()));
       ASSERT_TRUE(bn_mod_mul_montgomery_small(
-          r_words.get(), m->top, a_words.get(), m->top, b_words.get(), m->top,
+          r_words.get(), m_width, a_words.get(), m_width, b_words.get(), m_width,
           mont.get()));
       // Use the second half of |tmp| so ASan will catch out-of-bounds writes.
-      ASSERT_TRUE(bn_from_montgomery_small(r_words.get(), m->top, r_words.get(),
-                                           m->top, mont.get()));
-      ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m->top));
+      ASSERT_TRUE(bn_from_montgomery_small(r_words.get(), m_width, r_words.get(),
+                                           m_width, mont.get()));
+      ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m_width));
       EXPECT_BIGNUMS_EQUAL("A * B (mod M) (Montgomery, words)", mod_mul.get(),
                            ret.get());
     }
@@ -623,32 +625,32 @@
                          ret.get());
 
 #if !defined(BORINGSSL_SHARED_LIBRARY)
-    if (m->top <= BN_SMALL_MAX_WORDS) {
-      std::unique_ptr<BN_ULONG[]> a_words(new BN_ULONG[m->top]),
-          a_copy_words(new BN_ULONG[m->top]), r_words(new BN_ULONG[m->top]);
-      OPENSSL_memset(a_words.get(), 0, m->top * sizeof(BN_ULONG));
-      OPENSSL_memcpy(a_words.get(), a->d, a->top * sizeof(BN_ULONG));
-      ASSERT_TRUE(bn_to_montgomery_small(a_words.get(), m->top, a_words.get(),
-                                         m->top, mont.get()));
+    size_t m_width = static_cast<size_t>(bn_minimal_width(m.get()));
+    if (m_width <= BN_SMALL_MAX_WORDS) {
+      std::unique_ptr<BN_ULONG[]> a_words(new BN_ULONG[m_width]),
+          a_copy_words(new BN_ULONG[m_width]), r_words(new BN_ULONG[m_width]);
+      ASSERT_TRUE(bn_copy_words(a_words.get(), m_width, a.get()));
+      ASSERT_TRUE(bn_to_montgomery_small(a_words.get(), m_width, a_words.get(),
+                                         m_width, mont.get()));
       ASSERT_TRUE(bn_mod_mul_montgomery_small(
-          r_words.get(), m->top, a_words.get(), m->top, a_words.get(), m->top,
-          mont.get()));
-      ASSERT_TRUE(bn_from_montgomery_small(r_words.get(), m->top, r_words.get(),
-                                           m->top, mont.get()));
-      ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m->top));
+          r_words.get(), m_width, a_words.get(), m_width, a_words.get(),
+          m_width, mont.get()));
+      ASSERT_TRUE(bn_from_montgomery_small(r_words.get(), m_width,
+                                           r_words.get(), m_width, mont.get()));
+      ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m_width));
       EXPECT_BIGNUMS_EQUAL("A * A (mod M) (Montgomery, words)",
                            mod_square.get(), ret.get());
 
       // Repeat the operation with |a_copy_words|.
       OPENSSL_memcpy(a_copy_words.get(), a_words.get(),
-                     m->top * sizeof(BN_ULONG));
+                     m_width * sizeof(BN_ULONG));
       ASSERT_TRUE(bn_mod_mul_montgomery_small(
-          r_words.get(), m->top, a_words.get(), m->top, a_copy_words.get(),
-          m->top, mont.get()));
+          r_words.get(), m_width, a_words.get(), m_width, a_copy_words.get(),
+          m_width, mont.get()));
       // Use the second half of |tmp| so ASan will catch out-of-bounds writes.
-      ASSERT_TRUE(bn_from_montgomery_small(r_words.get(), m->top, r_words.get(),
-                                           m->top, mont.get()));
-      ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m->top));
+      ASSERT_TRUE(bn_from_montgomery_small(r_words.get(), m_width,
+                                           r_words.get(), m_width, mont.get()));
+      ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m_width));
       EXPECT_BIGNUMS_EQUAL("A * A_copy (mod M) (Montgomery, words)",
                            mod_square.get(), ret.get());
     }
@@ -683,22 +685,22 @@
                          ret.get());
 
 #if !defined(BORINGSSL_SHARED_LIBRARY)
-    if (m->top <= BN_SMALL_MAX_WORDS) {
+    size_t m_width = static_cast<size_t>(bn_minimal_width(m.get()));
+    if (m_width <= BN_SMALL_MAX_WORDS) {
       bssl::UniquePtr<BN_MONT_CTX> mont(BN_MONT_CTX_new());
       ASSERT_TRUE(mont.get());
       ASSERT_TRUE(BN_MONT_CTX_set(mont.get(), m.get(), ctx));
       ASSERT_TRUE(BN_nnmod(a.get(), a.get(), m.get(), ctx));
-      std::unique_ptr<BN_ULONG[]> r_words(new BN_ULONG[m->top]),
-          a_words(new BN_ULONG[m->top]);
-      OPENSSL_memset(a_words.get(), 0, m->top * sizeof(BN_ULONG));
-      OPENSSL_memcpy(a_words.get(), a->d, a->top * sizeof(BN_ULONG));
-      ASSERT_TRUE(bn_to_montgomery_small(a_words.get(), m->top, a_words.get(),
-                                         m->top, mont.get()));
-      ASSERT_TRUE(bn_mod_exp_mont_small(r_words.get(), m->top, a_words.get(),
-                                        m->top, e->d, e->top, mont.get()));
-      ASSERT_TRUE(bn_from_montgomery_small(r_words.get(), m->top, r_words.get(),
-                                           m->top, mont.get()));
-      ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m->top));
+      std::unique_ptr<BN_ULONG[]> r_words(new BN_ULONG[m_width]),
+          a_words(new BN_ULONG[m_width]);
+      ASSERT_TRUE(bn_copy_words(a_words.get(), m_width, a.get()));
+      ASSERT_TRUE(bn_to_montgomery_small(a_words.get(), m_width, a_words.get(),
+                                         m_width, mont.get()));
+      ASSERT_TRUE(bn_mod_exp_mont_small(r_words.get(), m_width, a_words.get(),
+                                        m_width, e->d, e->top, mont.get()));
+      ASSERT_TRUE(bn_from_montgomery_small(r_words.get(), m_width,
+                                           r_words.get(), m_width, mont.get()));
+      ASSERT_TRUE(bn_set_words(ret.get(), r_words.get(), m_width));
       EXPECT_BIGNUMS_EQUAL("A ^ E (mod M) (Montgomery, words)", mod_exp.get(),
                            ret.get());
     }
diff --git a/crypto/fipsmodule/bn/internal.h b/crypto/fipsmodule/bn/internal.h
index 1a5dc1b..140353f 100644
--- a/crypto/fipsmodule/bn/internal.h
+++ b/crypto/fipsmodule/bn/internal.h
@@ -226,6 +226,10 @@
 // least significant word first.
 int bn_set_words(BIGNUM *bn, const BN_ULONG *words, size_t num);
 
+// bn_copy_words copies the value of |bn| to |out| and returns one if the value
+// is representable in |num| words. Otherwise, it returns zero.
+int bn_copy_words(BN_ULONG *out, size_t num, const BIGNUM *bn);
+
 // bn_mul_add_words multiples |ap| by |w|, adds the result to |rp|, and places
 // the result in |rp|. |ap| and |rp| must both be |num| words long. It returns
 // the carry word of the operation. |ap| and |rp| may be equal but otherwise may
diff --git a/crypto/fipsmodule/ec/ec.c b/crypto/fipsmodule/ec/ec.c
index c9687a6..f002ccd 100644
--- a/crypto/fipsmodule/ec/ec.c
+++ b/crypto/fipsmodule/ec/ec.c
@@ -952,12 +952,10 @@
 
 int ec_bignum_to_scalar_unchecked(const EC_GROUP *group, EC_SCALAR *out,
                                   const BIGNUM *in) {
-  if (BN_is_negative(in) || in->top > group->order.top) {
+  if (!bn_copy_words(out->words, group->order.top, in)) {
     OPENSSL_PUT_ERROR(EC, EC_R_INVALID_SCALAR);
     return 0;
   }
-  OPENSSL_memset(out->words, 0, group->order.top * sizeof(BN_ULONG));
-  OPENSSL_memcpy(out->words, in->d, in->top * sizeof(BN_ULONG));
   return 1;
 }
 
diff --git a/crypto/fipsmodule/ec/ec_test.cc b/crypto/fipsmodule/ec/ec_test.cc
index e69f8d7..85dc8f2 100644
--- a/crypto/fipsmodule/ec/ec_test.cc
+++ b/crypto/fipsmodule/ec/ec_test.cc
@@ -28,6 +28,7 @@
 #include <openssl/nid.h>
 #include <openssl/obj.h>
 
+#include "../bn/internal.h"
 #include "../../test/test_util.h"
 
 
@@ -553,6 +554,32 @@
   EXPECT_EQ(0, EC_POINT_cmp(group.get(), result.get(), generator, nullptr));
 }
 
+#if !defined(BORINGSSL_SHARED_LIBRARY)
+TEST_P(ECCurveTest, MulNonMinimal) {
+  bssl::UniquePtr<EC_GROUP> group(EC_GROUP_new_by_curve_name(GetParam().nid));
+  ASSERT_TRUE(group);
+
+  bssl::UniquePtr<BIGNUM> forty_two(BN_new());
+  ASSERT_TRUE(forty_two);
+  ASSERT_TRUE(BN_set_word(forty_two.get(), 42));
+
+  // Compute g × 42.
+  bssl::UniquePtr<EC_POINT> point(EC_POINT_new(group.get()));
+  ASSERT_TRUE(point);
+  ASSERT_TRUE(EC_POINT_mul(group.get(), point.get(), forty_two.get(), nullptr,
+                           nullptr, nullptr));
+
+  // Compute it again with a non-minimal 42, much larger than the scalar.
+  ASSERT_TRUE(bn_resize_words(forty_two.get(), 64));
+
+  bssl::UniquePtr<EC_POINT> point2(EC_POINT_new(group.get()));
+  ASSERT_TRUE(point2);
+  ASSERT_TRUE(EC_POINT_mul(group.get(), point2.get(), forty_two.get(), nullptr,
+                           nullptr, nullptr));
+  EXPECT_EQ(0, EC_POINT_cmp(group.get(), point.get(), point2.get(), nullptr));
+}
+#endif  // BORINGSSL_SHARED_LIBRARY
+
 // Test that EC_KEY_set_private_key rejects invalid values.
 TEST_P(ECCurveTest, SetInvalidPrivateKey) {
   bssl::UniquePtr<EC_KEY> key(EC_KEY_new_by_curve_name(GetParam().nid));
diff --git a/crypto/fipsmodule/ec/p256-x86_64.c b/crypto/fipsmodule/ec/p256-x86_64.c
index 0e79b6d..ec371bf 100644
--- a/crypto/fipsmodule/ec/p256-x86_64.c
+++ b/crypto/fipsmodule/ec/p256-x86_64.c
@@ -205,13 +205,7 @@
 // returns one if it fits. Otherwise it returns zero.
 static int ecp_nistz256_bignum_to_field_elem(BN_ULONG out[P256_LIMBS],
                                              const BIGNUM *in) {
-  if (in->top > P256_LIMBS) {
-    return 0;
-  }
-
-  OPENSSL_memset(out, 0, sizeof(BN_ULONG) * P256_LIMBS);
-  OPENSSL_memcpy(out, in->d, sizeof(BN_ULONG) * in->top);
-  return 1;
+  return bn_copy_words(out, P256_LIMBS, in);
 }
 
 // r = p * p_scalar
diff --git a/crypto/fipsmodule/ec/p256-x86_64_test.cc b/crypto/fipsmodule/ec/p256-x86_64_test.cc
index a802bfb..4f33de0 100644
--- a/crypto/fipsmodule/ec/p256-x86_64_test.cc
+++ b/crypto/fipsmodule/ec/p256-x86_64_test.cc
@@ -160,10 +160,9 @@
     return false;
   }
 
-  OPENSSL_memset(out, 0, sizeof(P256_POINT_AFFINE));
-
   if (BN_is_zero(z.get())) {
     // The point at infinity is represented as (0, 0).
+    OPENSSL_memset(out, 0, sizeof(P256_POINT_AFFINE));
     return true;
   }
 
@@ -185,12 +184,11 @@
       !BN_mod_mul_montgomery(y.get(), y.get(), z.get(), mont.get(),
                              ctx.get()) ||
       !BN_mod_mul_montgomery(y.get(), y.get(), z.get(), mont.get(),
-                             ctx.get())) {
+                             ctx.get()) ||
+      !bn_copy_words(out->X, P256_LIMBS, x.get()) ||
+      !bn_copy_words(out->Y, P256_LIMBS, y.get())) {
     return false;
   }
-
-  OPENSSL_memcpy(out->X, x->d, sizeof(BN_ULONG) * x->top);
-  OPENSSL_memcpy(out->Y, y->d, sizeof(BN_ULONG) * y->top);
   return true;
 }