Make BN_cmp constant-time.

This is a bit easier to read than BN_less_than_consttime when we must do
>= or <=, about as much work to compute, and lots of code calls BN_cmp
on secret data. This also, by extension, makes BN_cmp_word
constant-time.

BN_equal_consttime is probably a little more efficient and is perfectly
readable, so leave that one around.

Change-Id: Id2e07fe312f01cb6fd10a1306dcbf6397990cf13
Reviewed-on: https://boringssl-review.googlesource.com/25444
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc
index a11d8c8..3d4867d 100644
--- a/crypto/fipsmodule/bn/bn_test.cc
+++ b/crypto/fipsmodule/bn/bn_test.cc
@@ -1988,24 +1988,20 @@
 
     EXPECT_TRUE(BN_equal_consttime(ten.get(), ten_copy.get()));
     EXPECT_TRUE(BN_equal_consttime(ten_copy.get(), ten.get()));
-    EXPECT_FALSE(BN_less_than_consttime(ten.get(), ten_copy.get()));
-    EXPECT_FALSE(BN_less_than_consttime(ten_copy.get(), ten.get()));
     EXPECT_EQ(BN_cmp(ten.get(), ten_copy.get()), 0);
+    EXPECT_EQ(BN_cmp(ten_copy.get(), ten.get()), 0);
 
     EXPECT_FALSE(BN_equal_consttime(ten.get(), eight.get()));
-    EXPECT_FALSE(BN_less_than_consttime(ten.get(), eight.get()));
-    EXPECT_TRUE(BN_less_than_consttime(eight.get(), ten.get()));
     EXPECT_LT(BN_cmp(eight.get(), ten.get()), 0);
+    EXPECT_GT(BN_cmp(ten.get(), eight.get()), 0);
 
     EXPECT_FALSE(BN_equal_consttime(ten.get(), forty_two.get()));
-    EXPECT_TRUE(BN_less_than_consttime(ten.get(), forty_two.get()));
-    EXPECT_FALSE(BN_less_than_consttime(forty_two.get(), ten.get()));
     EXPECT_GT(BN_cmp(forty_two.get(), ten.get()), 0);
+    EXPECT_LT(BN_cmp(ten.get(), forty_two.get()), 0);
 
     EXPECT_FALSE(BN_equal_consttime(ten.get(), two_exp_256.get()));
-    EXPECT_TRUE(BN_less_than_consttime(ten.get(), two_exp_256.get()));
-    EXPECT_FALSE(BN_less_than_consttime(two_exp_256.get(), ten.get()));
     EXPECT_GT(BN_cmp(two_exp_256.get(), ten.get()), 0);
+    EXPECT_LT(BN_cmp(ten.get(), two_exp_256.get()), 0);
 
     EXPECT_EQ(4u, BN_num_bits(ten.get()));
     EXPECT_EQ(1u, BN_num_bytes(ten.get()));
diff --git a/crypto/fipsmodule/bn/cmp.c b/crypto/fipsmodule/bn/cmp.c
index 7790a8d..89775c0 100644
--- a/crypto/fipsmodule/bn/cmp.c
+++ b/crypto/fipsmodule/bn/cmp.c
@@ -63,32 +63,43 @@
 #include "../../internal.h"
 
 
-int BN_ucmp(const BIGNUM *a, const BIGNUM *b) {
-  int a_width = bn_minimal_width(a);
-  int b_width = bn_minimal_width(b);
-  int i = a_width - b_width;
-  if (i != 0) {
-    return i;
+static int bn_cmp_words_consttime(const BN_ULONG *a, size_t a_len,
+                                  const BN_ULONG *b, size_t b_len) {
+  OPENSSL_COMPILE_ASSERT(sizeof(BN_ULONG) <= sizeof(crypto_word_t),
+                         crypto_word_t_too_small);
+  int ret = 0;
+  // Process the common words in little-endian order.
+  size_t min = a_len < b_len ? a_len : b_len;
+  for (size_t i = 0; i < min; i++) {
+    crypto_word_t eq = constant_time_eq_w(a[i], b[i]);
+    crypto_word_t lt = constant_time_lt_w(a[i], b[i]);
+    ret =
+        constant_time_select_int(eq, ret, constant_time_select_int(lt, -1, 1));
   }
 
-  const BN_ULONG *ap = a->d;
-  const BN_ULONG *bp = b->d;
-  for (i = a_width - 1; i >= 0; i--) {
-    BN_ULONG t1 = ap[i];
-    BN_ULONG t2 = bp[i];
-    if (t1 != t2) {
-      return (t1 > t2) ? 1 : -1;
+  // If |a| or |b| has non-zero words beyond |min|, they take precedence.
+  if (a_len < b_len) {
+    crypto_word_t mask = 0;
+    for (size_t i = a_len; i < b_len; i++) {
+      mask |= b[i];
     }
+    ret = constant_time_select_int(constant_time_is_zero_w(mask), ret, -1);
+  } else if (b_len < a_len) {
+    crypto_word_t mask = 0;
+    for (size_t i = b_len; i < a_len; i++) {
+      mask |= a[i];
+    }
+    ret = constant_time_select_int(constant_time_is_zero_w(mask), ret, 1);
   }
 
-  return 0;
+  return ret;
+}
+
+int BN_ucmp(const BIGNUM *a, const BIGNUM *b) {
+  return bn_cmp_words_consttime(a->d, a->width, b->d, b->width);
 }
 
 int BN_cmp(const BIGNUM *a, const BIGNUM *b) {
-  int i;
-  int gt, lt;
-  BN_ULONG t1, t2;
-
   if ((a == NULL) || (b == NULL)) {
     if (a != NULL) {
       return -1;
@@ -99,75 +110,21 @@
     }
   }
 
+  // We do not attempt to process the sign bit in constant time. Negative
+  // |BIGNUM|s should never occur in crypto, only calculators.
   if (a->neg != b->neg) {
     if (a->neg) {
       return -1;
     }
     return 1;
   }
-  if (a->neg == 0) {
-    gt = 1;
-    lt = -1;
-  } else {
-    gt = -1;
-    lt = 1;
-  }
 
-  int a_width = bn_minimal_width(a);
-  int b_width = bn_minimal_width(b);
-  if (a_width > b_width) {
-    return gt;
-  }
-  if (a_width < b_width) {
-    return lt;
-  }
-
-  for (i = a_width - 1; i >= 0; i--) {
-    t1 = a->d[i];
-    t2 = b->d[i];
-    if (t1 > t2) {
-      return gt;
-    } if (t1 < t2) {
-      return lt;
-    }
-  }
-
-  return 0;
-}
-
-static int bn_less_than_words_impl(const BN_ULONG *a, size_t a_len,
-                                   const BN_ULONG *b, size_t b_len) {
-  OPENSSL_COMPILE_ASSERT(sizeof(BN_ULONG) <= sizeof(crypto_word_t),
-                         crypto_word_t_too_small);
-  int ret = 0;
-  // Process the common words in little-endian order.
-  size_t min = a_len < b_len ? a_len : b_len;
-  for (size_t i = 0; i < min; i++) {
-    crypto_word_t eq = constant_time_eq_w(a[i], b[i]);
-    crypto_word_t lt = constant_time_lt_w(a[i], b[i]);
-    ret = constant_time_select_int(eq, ret, constant_time_select_int(lt, 1, 0));
-  }
-
-  // If |a| or |b| has non-zero words beyond |min|, they take precedence.
-  if (a_len < b_len) {
-    crypto_word_t mask = 0;
-    for (size_t i = a_len; i < b_len; i++) {
-      mask |= b[i];
-    }
-    ret = constant_time_select_int(constant_time_is_zero_w(mask), ret, 1);
-  } else if (b_len < a_len) {
-    crypto_word_t mask = 0;
-    for (size_t i = b_len; i < a_len; i++) {
-      mask |= a[i];
-    }
-    ret = constant_time_select_int(constant_time_is_zero_w(mask), ret, 0);
-  }
-
-  return ret;
+  int ret = BN_ucmp(a, b);
+  return a->neg ? -ret : ret;
 }
 
 int bn_less_than_words(const BN_ULONG *a, const BN_ULONG *b, size_t len) {
-  return bn_less_than_words_impl(a, len, b, len);
+  return bn_cmp_words_consttime(a, len, b, len) < 0;
 }
 
 int BN_abs_is_word(const BIGNUM *bn, BN_ULONG w) {
@@ -241,20 +198,3 @@
   mask |= (a->neg ^ b->neg);
   return mask == 0;
 }
-
-int BN_less_than_consttime(const BIGNUM *a, const BIGNUM *b) {
-  // We do not attempt to process the sign bit in constant time. Negative
-  // |BIGNUM|s should never occur in crypto, only calculators.
-  if (a->neg && !b->neg) {
-    return 1;
-  }
-  if (b->neg && !a->neg) {
-    return 0;
-  }
-  if (a->neg && b->neg) {
-    const BIGNUM *tmp = a;
-    a = b;
-    b = tmp;
-  }
-  return bn_less_than_words_impl(a->d, a->width, b->d, b->width);
-}
diff --git a/crypto/fipsmodule/ec/ec_key.c b/crypto/fipsmodule/ec/ec_key.c
index 084d33b..3fcc04f 100644
--- a/crypto/fipsmodule/ec/ec_key.c
+++ b/crypto/fipsmodule/ec/ec_key.c
@@ -260,7 +260,6 @@
     return 0;
   }
 
-  // XXX: |BN_cmp| is not constant time.
   if (BN_is_negative(priv_key) ||
       BN_cmp(priv_key, EC_GROUP_get0_order(key->group)) >= 0) {
     OPENSSL_PUT_ERROR(EC, EC_R_WRONG_ORDER);
@@ -334,7 +333,6 @@
   // in case the priv_key is present :
   // check if generator * priv_key == pub_key
   if (eckey->priv_key) {
-    // XXX: |BN_cmp| is not constant time.
     if (BN_is_negative(eckey->priv_key) ||
         BN_cmp(eckey->priv_key, EC_GROUP_get0_order(eckey->group)) >= 0) {
       OPENSSL_PUT_ERROR(EC, EC_R_WRONG_ORDER);
diff --git a/crypto/fipsmodule/rsa/rsa_impl.c b/crypto/fipsmodule/rsa/rsa_impl.c
index b3981db..8612977 100644
--- a/crypto/fipsmodule/rsa/rsa_impl.c
+++ b/crypto/fipsmodule/rsa/rsa_impl.c
@@ -191,7 +191,7 @@
       if (rsa->inv_small_mod_large_mont == NULL) {
         BIGNUM *inv_small_mod_large_mont = BN_new();
         int ok;
-        if (BN_less_than_consttime(rsa->p, rsa->q)) {
+        if (BN_cmp(rsa->p, rsa->q) < 0) {
           ok = inv_small_mod_large_mont != NULL &&
                bn_mod_inverse_secret_prime(inv_small_mod_large_mont, rsa->p,
                                            rsa->q, ctx, rsa->mont_q) &&
@@ -816,7 +816,7 @@
   // larger. Canonicalize fields so that |p| is the larger prime.
   const BIGNUM *p = rsa->p, *q = rsa->q, *dmp1 = rsa->dmp1, *dmq1 = rsa->dmq1;
   const BN_MONT_CTX *mont_p = rsa->mont_p, *mont_q = rsa->mont_q;
-  if (BN_less_than_consttime(rsa->p, rsa->q)) {
+  if (BN_cmp(rsa->p, rsa->q) < 0) {
     p = rsa->q;
     q = rsa->p;
     mont_p = rsa->mont_q;
@@ -964,7 +964,7 @@
     // retrying. That is, we reject a negligible fraction of primes that are
     // within the FIPS bound, but we will never accept a prime outside the
     // bound, ensuring the resulting RSA key is the right size.
-    if (!BN_less_than_consttime(sqrt2, out)) {
+    if (BN_cmp(out, sqrt2) <= 0) {
       continue;
     }
 
diff --git a/include/openssl/bn.h b/include/openssl/bn.h
index 4e9d54f..495e338 100644
--- a/include/openssl/bn.h
+++ b/include/openssl/bn.h
@@ -440,11 +440,6 @@
 // independent of the contents (including the signs) of |a| and |b|.
 OPENSSL_EXPORT int BN_equal_consttime(const BIGNUM *a, const BIGNUM *b);
 
-// BN_less_than_consttime returns one if |a| is less than |b|, and zero
-// otherwise. It takes an amount of time dependent on the sizes and signs of |a|
-// and |b|, but independent of the contents of |a| and |b|.
-OPENSSL_EXPORT int BN_less_than_consttime(const BIGNUM *a, const BIGNUM *b);
-
 // BN_abs_is_word returns one if the absolute value of |bn| equals |w| and zero
 // otherwise.
 OPENSSL_EXPORT int BN_abs_is_word(const BIGNUM *bn, BN_ULONG w);