Make bn_mul_part_recursive constant-time.
This follows similar lines as the previous cleanups and fixes the
documentation of the preconditions.
And with that, RSA private key operations, provided p and q have the
same bit length, should be constant time, as far as I know. (Though I'm
sure I've missed something.)
bn_cmp_part_words and bn_cmp_words are no longer used and deleted.
Bug: 234
Change-Id: Iceefa39f57e466c214794c69b335c4d2c81f5577
Reviewed-on: https://boringssl-review.googlesource.com/25404
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/bn/cmp.c b/crypto/fipsmodule/bn/cmp.c
index f5b4fdc..7790a8d 100644
--- a/crypto/fipsmodule/bn/cmp.c
+++ b/crypto/fipsmodule/bn/cmp.c
@@ -135,48 +135,6 @@
return 0;
}
-int bn_cmp_words(const BN_ULONG *a, const BN_ULONG *b, int n) {
- int i;
- BN_ULONG aa, bb;
-
- aa = a[n - 1];
- bb = b[n - 1];
- if (aa != bb) {
- return (aa > bb) ? 1 : -1;
- }
-
- for (i = n - 2; i >= 0; i--) {
- aa = a[i];
- bb = b[i];
- if (aa != bb) {
- return (aa > bb) ? 1 : -1;
- }
- }
- return 0;
-}
-
-int bn_cmp_part_words(const BN_ULONG *a, const BN_ULONG *b, int cl, int dl) {
- int n, i;
- n = cl - 1;
-
- if (dl < 0) {
- for (i = dl; i < 0; i++) {
- if (b[n - i] != 0) {
- return -1; // a < b
- }
- }
- }
- if (dl > 0) {
- for (i = dl; i > 0; i--) {
- if (a[n + i] != 0) {
- return 1; // a > b
- }
- }
- }
-
- return bn_cmp_words(a, b, cl);
-}
-
static int bn_less_than_words_impl(const BN_ULONG *a, size_t a_len,
const BN_ULONG *b, size_t b_len) {
OPENSSL_COMPILE_ASSERT(sizeof(BN_ULONG) <= sizeof(crypto_word_t),
diff --git a/crypto/fipsmodule/bn/internal.h b/crypto/fipsmodule/bn/internal.h
index 56a0703..20945a9 100644
--- a/crypto/fipsmodule/bn/internal.h
+++ b/crypto/fipsmodule/bn/internal.h
@@ -280,16 +280,6 @@
// bn_sqr_comba4 sets |r| to |a|^2.
void bn_sqr_comba4(BN_ULONG r[8], const BN_ULONG a[4]);
-// bn_cmp_words returns a value less than, equal to or greater than zero if
-// the, length |n|, array |a| is less than, equal to or greater than |b|.
-int bn_cmp_words(const BN_ULONG *a, const BN_ULONG *b, int n);
-
-// bn_cmp_words returns a value less than, equal to or greater than zero if the
-// array |a| is less than, equal to or greater than |b|. The arrays can be of
-// different lengths: |cl| gives the minimum of the two lengths and |dl| gives
-// the length of |a| minus the length of |b|.
-int bn_cmp_part_words(const BN_ULONG *a, const BN_ULONG *b, int cl, int dl);
-
// bn_less_than_words returns one if |a| < |b| and zero otherwise, where |a|
// and |b| both are |len| words long. It runs in constant time.
int bn_less_than_words(const BN_ULONG *a, const BN_ULONG *b, size_t len);
diff --git a/crypto/fipsmodule/bn/mul.c b/crypto/fipsmodule/bn/mul.c
index 38c70ca..92389fe 100644
--- a/crypto/fipsmodule/bn/mul.c
+++ b/crypto/fipsmodule/bn/mul.c
@@ -409,69 +409,63 @@
assert(c == 0);
}
-// n+tn is the word length
-// t needs to be n*4 is size, as does r
-// tnX may not be negative but less than n
+// bn_mul_part_recursive sets |r| to |a| * |b|, using |t| as scratch space. |r|
+// has length 4*|n|, |a| has length |n| + |tna|, |b| has length |n| + |tnb|, and
+// |t| has length 8*|n|. |n| must be a power of two. Additionally, we must have
+// 0 <= tna < n and 0 <= tnb < n, and |tna| and |tnb| must differ by at most
+// one.
+//
+// TODO(davidben): Make this take |size_t| and perhaps the actual lengths of |a|
+// and |b|.
static void bn_mul_part_recursive(BN_ULONG *r, const BN_ULONG *a,
const BN_ULONG *b, int n, int tna, int tnb,
BN_ULONG *t) {
- int i, j, n2 = n * 2;
- int c1, c2, neg;
- BN_ULONG ln, lo, *p;
+ // |n| is a power of two.
+ assert(n != 0 && (n & (n - 1)) == 0);
+ // Check |tna| and |tnb| are in range.
+ assert(0 <= tna && tna < n);
+ assert(0 <= tnb && tnb < n);
+ assert(-1 <= tna - tnb && tna - tnb <= 1);
+ int n2 = n * 2;
if (n < 8) {
bn_mul_normal(r, a, n + tna, b, n + tnb);
+ OPENSSL_memset(r + n2 + tna + tnb, 0, n2 - tna - tnb);
return;
}
- // TODO(davidben): This function is not constant-time, but should be. See
- // https://crbug.com/boringssl/234.
+ // Split |a| and |b| into a0,a1 and b0,b1, where a0 and b0 have size |n|. |a1|
+ // and |b1| have size |tna| and |tnb|, respectively.
+ // Split |t| into t0,t1,t2,t3, each of size |n|, with the remaining 4*|n| used
+ // for recursive calls.
+ // Split |r| into r0,r1,r2,r3. We must contribute a0*b0 to r0,r1, a0*a1+b0*b1
+ // to r1,r2, and a1*b1 to r2,r3. The middle term we will compute as:
+ //
+ // a0*a1 + b0*b1 = (a0 - a1)*(b1 - b0) + a1*b1 + a0*b0
- // r=(a[0]-a[1])*(b[1]-b[0])
- c1 = bn_cmp_part_words(a, &(a[n]), tna, n - tna);
- c2 = bn_cmp_part_words(&(b[n]), b, tnb, tnb - n);
- neg = 0;
- switch (c1 * 3 + c2) {
- case -4:
- bn_sub_part_words(t, &(a[n]), a, tna, tna - n); // -
- bn_sub_part_words(&(t[n]), b, &(b[n]), tnb, n - tnb); // -
- break;
- case -3:
- // break;
- case -2:
- bn_sub_part_words(t, &(a[n]), a, tna, tna - n); // -
- bn_sub_part_words(&(t[n]), &(b[n]), b, tnb, tnb - n); // +
- neg = 1;
- break;
- case -1:
- case 0:
- case 1:
- // break;
- case 2:
- bn_sub_part_words(t, a, &(a[n]), tna, n - tna); // +
- bn_sub_part_words(&(t[n]), b, &(b[n]), tnb, n - tnb); // -
- neg = 1;
- break;
- case 3:
- // break;
- case 4:
- bn_sub_part_words(t, a, &(a[n]), tna, n - tna);
- bn_sub_part_words(&(t[n]), &(b[n]), b, tnb, tnb - n);
- break;
- }
+ // t0 = a0 - a1 and t1 = b1 - b0. The result will be multiplied, so we XOR
+ // their sign masks, giving the sign of (a0 - a1)*(b1 - b0). t0 and t1
+ // themselves store the absolute value.
+ BN_ULONG neg = bn_abs_sub_part_words(t, a, &a[n], tna, n - tna, &t[n2]);
+ neg ^= bn_abs_sub_part_words(&t[n], &b[n], b, tnb, tnb - n, &t[n2]);
+ // Compute:
+ // t2,t3 = t0 * t1 = |(a0 - a1)*(b1 - b0)|
+ // r0,r1 = a0 * b0
+ // r2,r3 = a1 * b1
if (n == 8) {
- bn_mul_comba8(&(t[n2]), t, &(t[n]));
+ bn_mul_comba8(&t[n2], t, &t[n]);
bn_mul_comba8(r, a, b);
- bn_mul_normal(&(r[n2]), &(a[n]), tna, &(b[n]), tnb);
- OPENSSL_memset(&(r[n2 + tna + tnb]), 0, sizeof(BN_ULONG) * (n2 - tna - tnb));
+
+ bn_mul_normal(&r[n2], &a[n], tna, &b[n], tnb);
+ // |bn_mul_normal| only writes |tna| + |tna| words. Zero the rest.
+ OPENSSL_memset(&r[n2 + tna + tnb], 0, sizeof(BN_ULONG) * (n2 - tna - tnb));
} else {
- p = &(t[n2 * 2]);
- bn_mul_recursive(&(t[n2]), t, &(t[n]), n, 0, 0, p);
+ BN_ULONG *p = &t[n2 * 2];
+ bn_mul_recursive(&t[n2], t, &t[n], n, 0, 0, p);
bn_mul_recursive(r, a, b, n, 0, 0, p);
- i = n / 2;
- // If there is only a bottom half to the number,
- // just do it
+
+ int i = n / 2, j;
if (tna > tnb) {
j = tna - i;
} else {
@@ -479,32 +473,37 @@
}
if (j == 0) {
- bn_mul_recursive(&(r[n2]), &(a[n]), &(b[n]), i, tna - i, tnb - i, p);
- OPENSSL_memset(&(r[n2 + i * 2]), 0, sizeof(BN_ULONG) * (n2 - i * 2));
+ // If there is only a bottom half to the number, just do it. We know the
+ // larger of |tna - i| and |tnb - i| is zero. The other is zero or -1
+ // because |tna| and |tnb| differ by at most one.
+ bn_mul_recursive(&r[n2], &a[n], &b[n], i, tna - i, tnb - i, p);
+ // |bn_mul_recursive| only writes the bottom |i|*2 words.
+ OPENSSL_memset(&r[n2 + i * 2], 0, sizeof(BN_ULONG) * (n2 - i * 2));
} else if (j > 0) {
- // eg, n == 16, i == 8 and tn == 11
- bn_mul_part_recursive(&(r[n2]), &(a[n]), &(b[n]), i, tna - i, tnb - i, p);
- OPENSSL_memset(&(r[n2 + tna + tnb]), 0,
- sizeof(BN_ULONG) * (n2 - tna - tnb));
+ // E.g,, n == 16, i == 8 and tna == 11.
+ // |tna| and |tnb| are within one of each other, so if |tna| is larger and
+ // tna > i, then we know tnb >= i, and this call is valid.
+ bn_mul_part_recursive(&r[n2], &a[n], &b[n], i, tna - i, tnb - i, p);
} else {
- // (j < 0) eg, n == 16, i == 8 and tn == 5
- OPENSSL_memset(&(r[n2]), 0, sizeof(BN_ULONG) * n2);
+ // (j < 0) E.g., n == 16, i == 8 and tn == 5
+ OPENSSL_memset(&r[n2], 0, sizeof(BN_ULONG) * n2);
if (tna < BN_MUL_RECURSIVE_SIZE_NORMAL &&
tnb < BN_MUL_RECURSIVE_SIZE_NORMAL) {
- bn_mul_normal(&(r[n2]), &(a[n]), tna, &(b[n]), tnb);
+ bn_mul_normal(&r[n2], &a[n], tna, &b[n], tnb);
} else {
for (;;) {
i /= 2;
- // these simplified conditions work
- // exclusively because difference
- // between tna and tnb is 1 or 0
+ // These simplified conditions work exclusively because difference
+ // between |tna| and |tnb| is 1 or 0.
+ //
+ // TODO(davidben): This loop condition is exactly the same as the
+ // |j > 0| one but more complicated. Merge them.
if (i < tna || i < tnb) {
- bn_mul_part_recursive(&(r[n2]), &(a[n]), &(b[n]), i, tna - i,
- tnb - i, p);
+ bn_mul_part_recursive(&r[n2], &a[n], &b[n], i, tna - i, tnb - i, p);
break;
- } else if (i == tna || i == tnb) {
- bn_mul_recursive(&(r[n2]), &(a[n]), &(b[n]), i, tna - i, tnb - i,
- p);
+ }
+ if (i == tna || i == tnb) {
+ bn_mul_recursive(&r[n2], &a[n], &b[n], i, tna - i, tnb - i, p);
break;
}
}
@@ -512,42 +511,32 @@
}
}
- // t[32] holds (a[0]-a[1])*(b[1]-b[0]), c1 is the sign
- // r[10] holds (a[0]*b[0])
- // r[32] holds (b[1]*b[1])
+ // t0,t1,c = r0,r1 + r2,r3 = a0*b0 + a1*b1
+ BN_ULONG c = bn_add_words(t, r, &r[n2], n2);
- c1 = (int)(bn_add_words(t, r, &(r[n2]), n2));
+ // t2,t3,c = t0,t1,c + neg*t2,t3 = (a0 - a1)*(b1 - b0) + a1*b1 + a0*b0.
+ // The second term is stored as the absolute value, so we do this with a
+ // constant-time select.
+ BN_ULONG c_neg = c - bn_sub_words(&t[n2 * 2], t, &t[n2], n2);
+ BN_ULONG c_pos = c + bn_add_words(&t[n2], t, &t[n2], n2);
+ bn_select_words(&t[n2], neg, &t[n2 * 2], &t[n2], n2);
+ OPENSSL_COMPILE_ASSERT(sizeof(BN_ULONG) <= sizeof(crypto_word_t),
+ crypto_word_t_too_small);
+ c = constant_time_select_w(neg, c_neg, c_pos);
- if (neg) {
- // if t[32] is negative
- c1 -= (int)(bn_sub_words(&(t[n2]), t, &(t[n2]), n2));
- } else {
- // Might have a carry
- c1 += (int)(bn_add_words(&(t[n2]), &(t[n2]), t, n2));
+ // We now have our three components. Add them together.
+ // r1,r2,c = r1,r2 + t2,t3,c
+ c += bn_add_words(&r[n], &r[n], &t[n2], n2);
+
+ // Propagate the carry bit to the end.
+ for (int i = n + n2; i < n2 + n2; i++) {
+ BN_ULONG old = r[i];
+ r[i] = old + c;
+ c = r[i] < old;
}
- // t[32] holds (a[0]-a[1])*(b[1]-b[0])+(a[0]*b[0])+(a[1]*b[1])
- // r[10] holds (a[0]*b[0])
- // r[32] holds (b[1]*b[1])
- // c1 holds the carry bits
- c1 += (int)(bn_add_words(&(r[n]), &(r[n]), &(t[n2]), n2));
- if (c1) {
- p = &(r[n + n2]);
- lo = *p;
- ln = lo + c1;
- *p = ln;
-
- // The overflow will stop before we over write
- // words we should not overwrite
- if (ln < (BN_ULONG)c1) {
- do {
- p++;
- lo = *p;
- ln = lo + 1;
- *p = ln;
- } while (ln == 0);
- }
- }
+ // The product should fit without carries.
+ assert(c == 0);
}
// bn_mul_impl implements |BN_mul| and |bn_mul_fixed|. Note this function breaks
@@ -605,6 +594,9 @@
goto err;
}
if (al > j || bl > j) {
+ // We know |al| and |bl| are at most one from each other, so if al > j,
+ // bl >= j, and vice versa. Thus we can use |bn_mul_part_recursive|.
+ assert(al >= j && bl >= j);
// TODO(davidben): Check that these are correctly-sized, after rewriting
// |bn_mul_part_recursive|.
if (!bn_wexpand(t, j * 8) ||