Use EC_SCALAR for compute_wNAF.
Note this switches from walking BN_num_bits to the full bit length of
the scalar. But that can only cause it to add a few extra zeros to the
front of the schedule, which r_is_at_infinity will skip over.
Change-Id: I91e087c9c03505566b68f75fb37dfb53db467652
Reviewed-on: https://boringssl-review.googlesource.com/25147
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/ec/wnaf.c b/crypto/fipsmodule/ec/wnaf.c
index 0a2bcba..e2e5871 100644
--- a/crypto/fipsmodule/ec/wnaf.c
+++ b/crypto/fipsmodule/ec/wnaf.c
@@ -75,6 +75,7 @@
#include <openssl/thread.h>
#include "internal.h"
+#include "../bn/internal.h"
#include "../../internal.h"
@@ -90,27 +91,16 @@
// where at most one of any w+1 consecutive digits is non-zero
// with the exception that the most significant digit may be only
// w-1 zeros away from that next non-zero digit.
-static int8_t *compute_wNAF(const BIGNUM *scalar, int w, size_t *ret_len) {
+static int8_t *compute_wNAF(const EC_GROUP *group, const EC_SCALAR *scalar,
+ size_t bits, int w, size_t *ret_len) {
int window_val;
int ok = 0;
int8_t *r = NULL;
- int sign = 1;
int bit, next_bit, mask;
- size_t len = 0, j;
-
- if (BN_is_zero(scalar)) {
- r = OPENSSL_malloc(1);
- if (!r) {
- OPENSSL_PUT_ERROR(EC, ERR_R_MALLOC_FAILURE);
- goto err;
- }
- r[0] = 0;
- *ret_len = 1;
- return r;
- }
+ size_t j;
// 'int8_t' can represent integers with absolute values less than 2^7.
- if (w <= 0 || w > 7) {
+ if (w <= 0 || w > 7 || bits == 0) {
OPENSSL_PUT_ERROR(EC, ERR_R_INTERNAL_ERROR);
goto err;
}
@@ -118,23 +108,17 @@
next_bit = bit << 1; // at most 256
mask = next_bit - 1; // at most 255
- if (BN_is_negative(scalar)) {
- sign = -1;
- }
-
- len = BN_num_bits(scalar);
// The modified wNAF may be one digit longer than binary representation
- // (*ret_len will be set to the actual length, i.e. at most
- // BN_num_bits(scalar) + 1).
- r = OPENSSL_malloc(len + 1);
+ // (*ret_len will be set to the actual length, i.e. at most |bits| + 1.
+ r = OPENSSL_malloc(bits + 1);
if (r == NULL) {
OPENSSL_PUT_ERROR(EC, ERR_R_MALLOC_FAILURE);
goto err;
}
- window_val = scalar->d[0] & mask;
+ window_val = scalar->words[0] & mask;
j = 0;
- // If j+w+1 >= len, window_val will not increase.
- while (window_val != 0 || j + w + 1 < len) {
+ // If j+w+1 >= bits, window_val will not increase.
+ while (window_val != 0 || j + w + 1 < bits) {
int digit = 0;
// 0 <= window_val <= 2^(w+1)
@@ -146,7 +130,7 @@
digit = window_val - next_bit; // -2^w < digit < 0
#if 1 // modified wNAF
- if (j + w + 1 >= len) {
+ if (j + w + 1 >= bits) {
// special case for generating modified wNAFs:
// no new bits will be added into window_val,
// so using a positive digit here will decrease
@@ -174,10 +158,11 @@
}
}
- r[j++] = sign * digit;
+ r[j++] = digit;
window_val >>= 1;
- window_val += bit * BN_is_bit_set(scalar, j + w);
+ window_val +=
+ bit * bn_is_bit_set_words(scalar->words, group->order.top, j + w);
if (window_val > next_bit) {
OPENSSL_PUT_ERROR(EC, ERR_R_INTERNAL_ERROR);
@@ -185,11 +170,11 @@
}
}
- if (j > len + 1) {
+ if (j > bits + 1) {
OPENSSL_PUT_ERROR(EC, ERR_R_INTERNAL_ERROR);
goto err;
}
- len = j;
+ bits = j;
ok = 1;
err:
@@ -198,7 +183,7 @@
r = NULL;
}
if (ok) {
- *ret_len = len;
+ *ret_len = bits;
}
return r;
}
@@ -223,9 +208,8 @@
return 1;
}
-int ec_wNAF_mul(const EC_GROUP *group, EC_POINT *r,
- const EC_SCALAR *g_scalar_raw, const EC_POINT *p,
- const EC_SCALAR *p_scalar_raw, BN_CTX *ctx) {
+int ec_wNAF_mul(const EC_GROUP *group, EC_POINT *r, const EC_SCALAR *g_scalar,
+ const EC_POINT *p, const EC_SCALAR *p_scalar, BN_CTX *ctx) {
BN_CTX *new_ctx = NULL;
const EC_POINT *generator = NULL;
EC_POINT *tmp = NULL;
@@ -247,32 +231,13 @@
goto err;
}
}
- BN_CTX_start(ctx);
-
- // Convert from |EC_SCALAR| to |BIGNUM|. |BIGNUM| is not constant-time, but
- // neither is the rest of this function.
- BIGNUM *g_scalar = NULL, *p_scalar = NULL;
- if (g_scalar_raw != NULL) {
- g_scalar = BN_CTX_get(ctx);
- if (g_scalar == NULL ||
- !bn_set_words(g_scalar, g_scalar_raw->words, group->order.top)) {
- goto err;
- }
- }
- if (p_scalar_raw != NULL) {
- p_scalar = BN_CTX_get(ctx);
- if (p_scalar == NULL ||
- !bn_set_words(p_scalar, p_scalar_raw->words, group->order.top)) {
- goto err;
- }
- }
// TODO: This function used to take |points| and |scalars| as arrays of
// |num| elements. The code below should be simplified to work in terms of |p|
// and |p_scalar|.
size_t num = p != NULL ? 1 : 0;
const EC_POINT **points = p != NULL ? &p : NULL;
- BIGNUM **scalars = p != NULL ? &p_scalar : NULL;
+ const EC_SCALAR **scalars = p != NULL ? &p_scalar : NULL;
total_num = num;
@@ -301,10 +266,11 @@
goto err;
}
- size_t wsize = window_bits_for_scalar_size(BN_num_bits(&group->order));
+ size_t bits = BN_num_bits(&group->order);
+ size_t wsize = window_bits_for_scalar_size(bits);
for (i = 0; i < total_num; i++) {
- wNAF[i] =
- compute_wNAF((i < num ? scalars[i] : g_scalar), wsize, &wNAF_len[i]);
+ wNAF[i] = compute_wNAF(group, (i < num ? scalars[i] : g_scalar), bits,
+ wsize, &wNAF_len[i]);
if (wNAF[i] == NULL) {
goto err;
}
@@ -421,9 +387,6 @@
ret = 1;
err:
- if (ctx != NULL) {
- BN_CTX_end(ctx);
- }
BN_CTX_free(new_ctx);
EC_POINT_free(tmp);
OPENSSL_free(wNAF_len);