Pull the malloc out of compute_wNAF.

This is to simplify clearing unnecessary mallocs out of ec_wNAF_mul, and
perhaps to use it in tuned variable-time multiplication functions.

Change-Id: Ic390d2e8e20d0ee50f3643830a582e94baebba95
Reviewed-on: https://boringssl-review.googlesource.com/25149
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/ec/wnaf.c b/crypto/fipsmodule/ec/wnaf.c
index 8c71038..9e7b034 100644
--- a/crypto/fipsmodule/ec/wnaf.c
+++ b/crypto/fipsmodule/ec/wnaf.c
@@ -84,30 +84,25 @@
 //   http://link.springer.com/chapter/10.1007%2F3-540-45537-X_13
 //   http://www.bmoeller.de/pdf/TI-01-08.multiexp.pdf
 
-// Determine the modified width-(w+1) Non-Adjacent Form (wNAF) of 'scalar'.
-// This is an array  r[]  of |bits| + 1 values that are either zero or odd with
-// an absolute value less than  2^w  satisfying
-//     scalar = \sum_j r[j]*2^j
+// compute_wNAF writes the modified width-(w+1) Non-Adjacent Form (wNAF) of
+// |scalar| to |out| and returns one on success or zero on internal error. |out|
+// must have room for |bits| + 1 elements, each of which will be either zero or
+// odd with an absolute value less than  2^w  satisfying
+//     scalar = \sum_j out[j]*2^j
 // where at most one of any  w+1  consecutive digits is non-zero
 // with the exception that the most significant digit may be only
 // w-1 zeros away from that next non-zero digit.
-static int8_t *compute_wNAF(const EC_GROUP *group, const EC_SCALAR *scalar,
-                            size_t bits, int w) {
+static int compute_wNAF(const EC_GROUP *group, int8_t *out,
+                        const EC_SCALAR *scalar, size_t bits, int w) {
   // 'int8_t' can represent integers with absolute values less than 2^7.
   if (w <= 0 || w > 7 || bits == 0) {
     OPENSSL_PUT_ERROR(EC, ERR_R_INTERNAL_ERROR);
-    return NULL;
+    return 0;
   }
   int bit = 1 << w;         // at most 128
   int next_bit = bit << 1;  // at most 256
   int mask = next_bit - 1;  // at most 255
 
-  // The modified wNAF will be one digit longer than binary representation.
-  int8_t *r = OPENSSL_malloc(bits + 1);
-  if (r == NULL) {
-    OPENSSL_PUT_ERROR(EC, ERR_R_MALLOC_FAILURE);
-    goto err;
-  }
   int window_val = scalar->words[0] & mask;
   size_t j = 0;
   // If j+w+1 >= bits, window_val will not increase.
@@ -138,7 +133,7 @@
 
       if (digit <= -bit || digit >= bit || !(digit & 1)) {
         OPENSSL_PUT_ERROR(EC, ERR_R_INTERNAL_ERROR);
-        goto err;
+        return 0;
       }
 
       window_val -= digit;
@@ -147,11 +142,11 @@
       // for modified window NAFs, it may also be 2^w.
       if (window_val != 0 && window_val != next_bit && window_val != bit) {
         OPENSSL_PUT_ERROR(EC, ERR_R_INTERNAL_ERROR);
-        goto err;
+        return 0;
       }
     }
 
-    r[j++] = digit;
+    out[j++] = digit;
 
     window_val >>= 1;
     window_val +=
@@ -159,24 +154,20 @@
 
     if (window_val > next_bit) {
       OPENSSL_PUT_ERROR(EC, ERR_R_INTERNAL_ERROR);
-      goto err;
+      return 0;
     }
   }
 
   // Fill the rest of the wNAF with zeros.
   if (j > bits + 1) {
     OPENSSL_PUT_ERROR(EC, ERR_R_INTERNAL_ERROR);
-    goto err;
+    return 0;
   }
   for (size_t i = j; i < bits + 1; i++) {
-    r[i] = 0;
+    out[i] = 0;
   }
 
-  return r;
-
-err:
-  OPENSSL_free(r);
-  return NULL;
+  return 1;
 }
 
 
@@ -257,9 +248,10 @@
   size_t wsize = window_bits_for_scalar_size(bits);
   size_t wNAF_len = bits + 1;
   for (i = 0; i < total_num; i++) {
-    wNAF[i] =
-        compute_wNAF(group, (i < num ? scalars[i] : g_scalar), bits, wsize);
-    if (wNAF[i] == NULL) {
+    wNAF[i] = OPENSSL_malloc(wNAF_len);
+    if (wNAF[i] == NULL ||
+        !compute_wNAF(group, wNAF[i], (i < num ? scalars[i] : g_scalar), bits,
+                      wsize)) {
       goto err;
     }
   }