Remove the need for scratch space when squaring
BN_sqr only has scratch space because it needs to compute a temporary
buffer with the a[i]^2 terms to add into the final result. But those
terms can be computed and added in a single pass.
This isn't expected to have any performance impact on assembly-enabled
builds. All those builds have bn_mul_mont optimizations, which means the
plain squaring operation is more-or-less unused. (This begs the question
why we have assembly optimizations for it, when it's only used in
conjunction with builds that barely use it, but ah well.) On NO_ASM
builds, the plain square operation is used more, but this impacts
linearly many terms out of an overall quadratic operation.
I was unable to measure a consistent difference with or without this
change. Really the benefit is that, by removing the dependency on
scratch space, we can remove the dependency on BN_CTX and can unify our
various Montgomery multiplication codepaths.
Bug: 42290433
Change-Id: I1527bd212529bbd4a1abedec22bb1dc3d7e12cbb
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/79307
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
Auto-Submit: David Benjamin <davidben@google.com>
diff --git a/crypto/fipsmodule/bn/asm/bn-586.pl b/crypto/fipsmodule/bn/asm/bn-586.pl
index fa7eb12..ad6c7ad 100644
--- a/crypto/fipsmodule/bn/asm/bn-586.pl
+++ b/crypto/fipsmodule/bn/asm/bn-586.pl
@@ -27,7 +27,7 @@
&bn_mul_add_words("bn_mul_add_words");
&bn_mul_words("bn_mul_words");
-&bn_sqr_words("bn_sqr_words");
+&bn_sqr_add_words("bn_sqr_add_words");
&bn_add_words("bn_add_words");
&bn_sub_words("bn_sub_words");
@@ -174,7 +174,7 @@
&function_end($name);
}
-sub bn_sqr_words
+sub bn_sqr_add_words
{
local($name)=@_;
@@ -188,12 +188,21 @@
&mov($r,&wparam(0));
&mov($a,&wparam(1));
&mov($c,&wparam(2));
+ &pxor("mm1","mm1"); # mm1 = carry_in
&set_label("sqr_sse2_loop",16);
&movd("mm0",&DWP(0,$a)); # mm0 = a[i]
+ &movd("mm2",&DWP(0,$r,"",0)); # mm2 = r[i]
+ &movd("mm3",&DWP(4,$r,"",0)); # mm3 = r[i+1]
&pmuludq("mm0","mm0"); # a[i] *= a[i]
&lea($a,&DWP(4,$a)); # a++
- &movq(&QWP(0,$r),"mm0"); # r[i] = a[i]*a[i]
+ &paddq("mm1","mm0"); # carry += a[i] * a[i]
+ &paddq("mm1","mm2"); # carry += r[i]
+ &movd(&DWP(0,$r), "mm1");
+ &psrlq("mm1",32); # carry >>= 32
+ &paddq("mm1","mm3"); # carry += r[i+1]
+ &movd(&DWP(4,$r), "mm1");
+ &psrlq("mm1",32); # carry >>= 32
&sub($c,1);
&lea($r,&DWP(8,$r)); # r += 2
&jnz(&label("sqr_sse2_loop"));
diff --git a/crypto/fipsmodule/bn/asm/x86_64-gcc.cc.inc b/crypto/fipsmodule/bn/asm/x86_64-gcc.cc.inc
index cbf7107..adf32d3 100644
--- a/crypto/fipsmodule/bn/asm/x86_64-gcc.cc.inc
+++ b/crypto/fipsmodule/bn/asm/x86_64-gcc.cc.inc
@@ -102,8 +102,24 @@
(r) = (carry); \
(carry) = high; \
} while (0)
-#undef sqr
-#define sqr(r0, r1, a) __asm__("mulq %2" : "=a"(r0), "=d"(r1) : "a"(a) : "cc");
+
+// r0:r1:carry = r0:r1 + a^2 + carry:0
+#define sqr_add(r0, r1, a, carry) \
+ do { \
+ BN_ULONG high, low; \
+ /* lo:hi = a^2 */ \
+ __asm__("mulq %2" : "=a"(low), "=d"(high) : "a"(a) : "cc"); \
+ /* carry:hi = lo:hi + carry:0 = a^2 + carry */ \
+ __asm__("addq %2,%0; adcq $0,%1" \
+ : "+r"(carry), "+d"(high) \
+ : "a"(low) \
+ : "cc"); \
+ /* r0:r1:carry = carry:hi + r0:r1 */ \
+ __asm__("addq %2,%0; adcq %3,%1; movq $0, %2; adcq $0, %2" \
+ : "+m"(r0), "+m"(r1), "+r"(carry) \
+ : "d"(high) \
+ : "cc"); \
+ } while (0)
BN_ULONG bn_mul_add_words(BN_ULONG *rp, const BN_ULONG *ap, size_t num,
BN_ULONG w) {
@@ -169,30 +185,31 @@
return c1;
}
-void bn_sqr_words(BN_ULONG *r, const BN_ULONG *a, size_t n) {
+void bn_sqr_add_words(BN_ULONG *r, const BN_ULONG *a, size_t n) {
if (n == 0) {
return;
}
+ BN_ULONG carry = 0;
while (n & ~3) {
- sqr(r[0], r[1], a[0]);
- sqr(r[2], r[3], a[1]);
- sqr(r[4], r[5], a[2]);
- sqr(r[6], r[7], a[3]);
+ sqr_add(r[0], r[1], a[0], carry);
+ sqr_add(r[2], r[3], a[1], carry);
+ sqr_add(r[4], r[5], a[2], carry);
+ sqr_add(r[6], r[7], a[3], carry);
a += 4;
r += 8;
n -= 4;
}
if (n) {
- sqr(r[0], r[1], a[0]);
+ sqr_add(r[0], r[1], a[0], carry);
if (--n == 0) {
return;
}
- sqr(r[2], r[3], a[1]);
+ sqr_add(r[2], r[3], a[1], carry);
if (--n == 0) {
return;
}
- sqr(r[4], r[5], a[2]);
+ sqr_add(r[4], r[5], a[2], carry);
}
}
diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc
index eb7e82d..10fd3b4 100644
--- a/crypto/fipsmodule/bn/bn_test.cc
+++ b/crypto/fipsmodule/bn/bn_test.cc
@@ -2793,7 +2793,7 @@
CHECK_ABI(bn_mul_add_words, r.data(), a.data(), num, 42);
r.resize(2 * num);
- CHECK_ABI(bn_sqr_words, r.data(), a.data(), num);
+ CHECK_ABI(bn_sqr_add_words, r.data(), a.data(), num);
if (num == 4) {
CHECK_ABI(bn_mul_comba4, r.data(), a.data(), b.data());
diff --git a/crypto/fipsmodule/bn/generic.cc.inc b/crypto/fipsmodule/bn/generic.cc.inc
index 9cbbd85..9a42fa0 100644
--- a/crypto/fipsmodule/bn/generic.cc.inc
+++ b/crypto/fipsmodule/bn/generic.cc.inc
@@ -151,22 +151,33 @@
return c1;
}
-void bn_sqr_words(BN_ULONG *r, const BN_ULONG *a, size_t n) {
+void bn_sqr_add_words(BN_ULONG *r, const BN_ULONG *a, size_t n) {
if (n == 0) {
return;
}
+ BN_ULONG carry = 0, lo, hi;
while (n & ~3) {
- sqr(r[0], r[1], a[0]);
- sqr(r[2], r[3], a[1]);
- sqr(r[4], r[5], a[2]);
- sqr(r[6], r[7], a[3]);
+ sqr(lo, hi, a[0]);
+ r[0] = CRYPTO_addc_w(r[0], lo, carry, &carry);
+ r[1] = CRYPTO_addc_w(r[1], hi, carry, &carry);
+ sqr(lo, hi, a[1]);
+ r[2] = CRYPTO_addc_w(r[2], lo, carry, &carry);
+ r[3] = CRYPTO_addc_w(r[3], hi, carry, &carry);
+ sqr(lo, hi, a[2]);
+ r[4] = CRYPTO_addc_w(r[4], lo, carry, &carry);
+ r[5] = CRYPTO_addc_w(r[5], hi, carry, &carry);
+ sqr(lo, hi, a[3]);
+ r[6] = CRYPTO_addc_w(r[6], lo, carry, &carry);
+ r[7] = CRYPTO_addc_w(r[7], hi, carry, &carry);
a += 4;
r += 8;
n -= 4;
}
while (n) {
- sqr(r[0], r[1], a[0]);
+ sqr(lo, hi, a[0]);
+ r[0] = CRYPTO_addc_w(r[0], lo, carry, &carry);
+ r[1] = CRYPTO_addc_w(r[1], hi, carry, &carry);
a++;
r += 2;
n--;
diff --git a/crypto/fipsmodule/bn/internal.h b/crypto/fipsmodule/bn/internal.h
index 9dfc114..1ec2562 100644
--- a/crypto/fipsmodule/bn/internal.h
+++ b/crypto/fipsmodule/bn/internal.h
@@ -186,12 +186,14 @@
// operation. |ap| and |rp| may be equal but otherwise may not alias.
BN_ULONG bn_mul_words(BN_ULONG *rp, const BN_ULONG *ap, size_t num, BN_ULONG w);
-// bn_sqr_words sets |rp[2*i]| and |rp[2*i+1]| to |ap[i]|'s square, for all |i|
-// up to |num|. |ap| is an array of |num| words and |rp| an array of |2*num|
-// words. |ap| and |rp| may not alias.
+// bn_sqr_add_words computes |tmp| where |tmp[2*i]| and |tmp[2*i+1]| are
+// |ap[i]|'s square, for all |i| up to |num|, and adds the result to |rp|. If
+// the result does not fit in |2*num| words, the final carry bit is truncated.
+// |ap| is an array of |num| words and |rp| an array of |2*num| words. |ap| and
+// |rp| may not alias.
//
// This gives the contribution of the |ap[i]*ap[i]| terms when squaring |ap|.
-void bn_sqr_words(BN_ULONG *rp, const BN_ULONG *ap, size_t num);
+void bn_sqr_add_words(BN_ULONG *rp, const BN_ULONG *ap, size_t num);
// bn_add_words adds |ap| to |bp| and places the result in |rp|, each of which
// are |num| words long. It returns the carry bit, which is one if the operation
@@ -645,8 +647,8 @@
void bn_mul_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a, size_t num_a,
const BN_ULONG *b, size_t num_b);
-// bn_sqr_small sets |r| to |a|^2. |num_a| must be at most |BN_SMALL_MAX_WORDS|.
-// |num_r| must be |num_a|*2. |r| and |a| may not alias.
+// bn_sqr_small sets |r| to |a|^2. |num_r| must be |num_a|*2. |r| and |a| may
+// not alias.
void bn_sqr_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a, size_t num_a);
// In the following functions, the modulus must be at most |BN_SMALL_MAX_WORDS|
diff --git a/crypto/fipsmodule/bn/mul.cc.inc b/crypto/fipsmodule/bn/mul.cc.inc
index 4a39ee3..e33fb48 100644
--- a/crypto/fipsmodule/bn/mul.cc.inc
+++ b/crypto/fipsmodule/bn/mul.cc.inc
@@ -25,8 +25,6 @@
#include "internal.h"
-#define BN_SQR_STACK_WORDS 16
-
static void bn_mul_normal(BN_ULONG *r, const BN_ULONG *a, size_t na,
const BN_ULONG *b, size_t nb) {
if (na < nb) {
@@ -228,9 +226,7 @@
}
}
-// tmp must have 2*n words
-static void bn_sqr_normal(BN_ULONG *r, const BN_ULONG *a, size_t n,
- BN_ULONG *tmp) {
+static void bn_sqr_normal(BN_ULONG *r, const BN_ULONG *a, size_t n) {
if (n == 0) {
return;
}
@@ -262,8 +258,7 @@
bn_add_words(r, r, r, max);
// Add in the contribution of a[i] * a[i] for all i.
- bn_sqr_words(tmp, a, n);
- bn_add_words(r, r, tmp, max);
+ bn_sqr_add_words(r, a, n);
}
int BN_mul_word(BIGNUM *bn, BN_ULONG w) {
@@ -297,8 +292,7 @@
bssl::BN_CTXScope scope(ctx);
BIGNUM *rr = (a != r) ? r : BN_CTX_get(ctx);
- BIGNUM *tmp = BN_CTX_get(ctx);
- if (!rr || !tmp) {
+ if (!rr) {
return 0;
}
@@ -312,15 +306,7 @@
} else if (al == 8) {
bn_sqr_comba8(rr->d, a->d);
} else {
- if (al < BN_SQR_STACK_WORDS) {
- BN_ULONG t[BN_SQR_STACK_WORDS * 2];
- bn_sqr_normal(rr->d, a->d, al, t);
- } else {
- if (!bn_wexpand(tmp, max)) {
- return 0;
- }
- bn_sqr_normal(rr->d, a->d, al, tmp->d);
- }
+ bn_sqr_normal(rr->d, a->d, al);
}
rr->neg = 0;
@@ -342,7 +328,8 @@
}
void bn_sqr_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a, size_t num_a) {
- if (num_r != 2 * num_a || num_a > BN_SMALL_MAX_WORDS) {
+ assert(r != a);
+ if (num_r != 2 * num_a) {
abort();
}
if (num_a == 4) {
@@ -350,8 +337,6 @@
} else if (num_a == 8) {
bn_sqr_comba8(r, a);
} else {
- BN_ULONG tmp[2 * BN_SMALL_MAX_WORDS];
- bn_sqr_normal(r, a, num_a, tmp);
- OPENSSL_cleanse(tmp, 2 * num_a * sizeof(BN_ULONG));
+ bn_sqr_normal(r, a, num_a);
}
}
diff --git a/gen/bcm/bn-586-apple.S b/gen/bcm/bn-586-apple.S
index 3e6f791..e96d0a4 100644
--- a/gen/bcm/bn-586-apple.S
+++ b/gen/bcm/bn-586-apple.S
@@ -132,20 +132,29 @@
popl %ebx
popl %ebp
ret
-.globl _bn_sqr_words
-.private_extern _bn_sqr_words
+.globl _bn_sqr_add_words
+.private_extern _bn_sqr_add_words
.align 4
-_bn_sqr_words:
-L_bn_sqr_words_begin:
+_bn_sqr_add_words:
+L_bn_sqr_add_words_begin:
movl 4(%esp),%eax
movl 8(%esp),%edx
movl 12(%esp),%ecx
+ pxor %mm1,%mm1
.align 4,0x90
L005sqr_sse2_loop:
movd (%edx),%mm0
+ movd (%eax),%mm2
+ movd 4(%eax),%mm3
pmuludq %mm0,%mm0
leal 4(%edx),%edx
- movq %mm0,(%eax)
+ paddq %mm0,%mm1
+ paddq %mm2,%mm1
+ movd %mm1,(%eax)
+ psrlq $32,%mm1
+ paddq %mm3,%mm1
+ movd %mm1,4(%eax)
+ psrlq $32,%mm1
subl $1,%ecx
leal 8(%eax),%eax
jnz L005sqr_sse2_loop
diff --git a/gen/bcm/bn-586-linux.S b/gen/bcm/bn-586-linux.S
index 808f63e..8e2bdb0 100644
--- a/gen/bcm/bn-586-linux.S
+++ b/gen/bcm/bn-586-linux.S
@@ -136,21 +136,30 @@
popl %ebp
ret
.size bn_mul_words,.-.L_bn_mul_words_begin
-.globl bn_sqr_words
-.hidden bn_sqr_words
-.type bn_sqr_words,@function
+.globl bn_sqr_add_words
+.hidden bn_sqr_add_words
+.type bn_sqr_add_words,@function
.align 16
-bn_sqr_words:
-.L_bn_sqr_words_begin:
+bn_sqr_add_words:
+.L_bn_sqr_add_words_begin:
movl 4(%esp),%eax
movl 8(%esp),%edx
movl 12(%esp),%ecx
+ pxor %mm1,%mm1
.align 16
.L005sqr_sse2_loop:
movd (%edx),%mm0
+ movd (%eax),%mm2
+ movd 4(%eax),%mm3
pmuludq %mm0,%mm0
leal 4(%edx),%edx
- movq %mm0,(%eax)
+ paddq %mm0,%mm1
+ paddq %mm2,%mm1
+ movd %mm1,(%eax)
+ psrlq $32,%mm1
+ paddq %mm3,%mm1
+ movd %mm1,4(%eax)
+ psrlq $32,%mm1
subl $1,%ecx
leal 8(%eax),%eax
jnz .L005sqr_sse2_loop
@@ -161,7 +170,7 @@
popl %ebx
popl %ebp
ret
-.size bn_sqr_words,.-.L_bn_sqr_words_begin
+.size bn_sqr_add_words,.-.L_bn_sqr_add_words_begin
.globl bn_add_words
.hidden bn_add_words
.type bn_add_words,@function
diff --git a/gen/bcm/bn-586-win.asm b/gen/bcm/bn-586-win.asm
index 1250eb6..fd3fe01 100644
--- a/gen/bcm/bn-586-win.asm
+++ b/gen/bcm/bn-586-win.asm
@@ -138,19 +138,28 @@
pop ebx
pop ebp
ret
-global _bn_sqr_words
+global _bn_sqr_add_words
align 16
-_bn_sqr_words:
-L$_bn_sqr_words_begin:
+_bn_sqr_add_words:
+L$_bn_sqr_add_words_begin:
mov eax,DWORD [4+esp]
mov edx,DWORD [8+esp]
mov ecx,DWORD [12+esp]
+ pxor mm1,mm1
align 16
L$005sqr_sse2_loop:
movd mm0,DWORD [edx]
+ movd mm2,DWORD [eax]
+ movd mm3,DWORD [4+eax]
pmuludq mm0,mm0
lea edx,[4+edx]
- movq [eax],mm0
+ paddq mm1,mm0
+ paddq mm1,mm2
+ movd DWORD [eax],mm1
+ psrlq mm1,32
+ paddq mm1,mm3
+ movd DWORD [4+eax],mm1
+ psrlq mm1,32
sub ecx,1
lea eax,[8+eax]
jnz NEAR L$005sqr_sse2_loop