Remove the need for scratch space when squaring

BN_sqr only has scratch space because it needs to compute a temporary
buffer with the a[i]^2 terms to add into the final result. But those
terms can be computed and added in a single pass.

This isn't expected to have any performance impact on assembly-enabled
builds. All those builds have bn_mul_mont optimizations, which means the
plain squaring operation is more-or-less unused. (This begs the question
why we have assembly optimizations for it, when it's only used in
conjunction with builds that barely use it, but ah well.) On NO_ASM
builds, the plain square operation is used more, but this impacts
linearly many terms out of an overall quadratic operation.

I was unable to measure a consistent difference with or without this
change. Really the benefit is that, by removing the dependency on
scratch space, we can remove the dependency on BN_CTX and can unify our
various Montgomery multiplication codepaths.

Bug: 42290433
Change-Id: I1527bd212529bbd4a1abedec22bb1dc3d7e12cbb
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/79307
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
Auto-Submit: David Benjamin <davidben@google.com>
diff --git a/crypto/fipsmodule/bn/asm/bn-586.pl b/crypto/fipsmodule/bn/asm/bn-586.pl
index fa7eb12..ad6c7ad 100644
--- a/crypto/fipsmodule/bn/asm/bn-586.pl
+++ b/crypto/fipsmodule/bn/asm/bn-586.pl
@@ -27,7 +27,7 @@
 
 &bn_mul_add_words("bn_mul_add_words");
 &bn_mul_words("bn_mul_words");
-&bn_sqr_words("bn_sqr_words");
+&bn_sqr_add_words("bn_sqr_add_words");
 &bn_add_words("bn_add_words");
 &bn_sub_words("bn_sub_words");
 
@@ -174,7 +174,7 @@
 	&function_end($name);
 	}
 
-sub bn_sqr_words
+sub bn_sqr_add_words
 	{
 	local($name)=@_;
 
@@ -188,12 +188,21 @@
 		&mov($r,&wparam(0));
 		&mov($a,&wparam(1));
 		&mov($c,&wparam(2));
+		&pxor("mm1","mm1");		# mm1 = carry_in
 
 	&set_label("sqr_sse2_loop",16);
 		&movd("mm0",&DWP(0,$a));	# mm0 = a[i]
+		&movd("mm2",&DWP(0,$r,"",0));	# mm2 = r[i]
+		&movd("mm3",&DWP(4,$r,"",0));	# mm3 = r[i+1]
 		&pmuludq("mm0","mm0");		# a[i] *= a[i]
 		&lea($a,&DWP(4,$a));		# a++
-		&movq(&QWP(0,$r),"mm0");	# r[i] = a[i]*a[i]
+		&paddq("mm1","mm0");		# carry += a[i] * a[i]
+		&paddq("mm1","mm2");		# carry += r[i]
+		&movd(&DWP(0,$r), "mm1");
+		&psrlq("mm1",32);		# carry >>= 32
+		&paddq("mm1","mm3");		# carry += r[i+1]
+		&movd(&DWP(4,$r), "mm1");
+		&psrlq("mm1",32);		# carry >>= 32
 		&sub($c,1);
 		&lea($r,&DWP(8,$r));		# r += 2
 		&jnz(&label("sqr_sse2_loop"));
diff --git a/crypto/fipsmodule/bn/asm/x86_64-gcc.cc.inc b/crypto/fipsmodule/bn/asm/x86_64-gcc.cc.inc
index cbf7107..adf32d3 100644
--- a/crypto/fipsmodule/bn/asm/x86_64-gcc.cc.inc
+++ b/crypto/fipsmodule/bn/asm/x86_64-gcc.cc.inc
@@ -102,8 +102,24 @@
     (r) = (carry);                                                         \
     (carry) = high;                                                        \
   } while (0)
-#undef sqr
-#define sqr(r0, r1, a) __asm__("mulq %2" : "=a"(r0), "=d"(r1) : "a"(a) : "cc");
+
+// r0:r1:carry = r0:r1 + a^2 + carry:0
+#define sqr_add(r0, r1, a, carry)                               \
+  do {                                                          \
+    BN_ULONG high, low;                                         \
+    /* lo:hi = a^2 */                                           \
+    __asm__("mulq %2" : "=a"(low), "=d"(high) : "a"(a) : "cc"); \
+    /* carry:hi = lo:hi + carry:0 = a^2 + carry */              \
+    __asm__("addq %2,%0; adcq $0,%1"                            \
+            : "+r"(carry), "+d"(high)                           \
+            : "a"(low)                                          \
+            : "cc");                                            \
+    /* r0:r1:carry = carry:hi + r0:r1 */                        \
+    __asm__("addq %2,%0; adcq %3,%1; movq $0, %2; adcq $0, %2"  \
+            : "+m"(r0), "+m"(r1), "+r"(carry)                   \
+            : "d"(high)                                         \
+            : "cc");                                            \
+  } while (0)
 
 BN_ULONG bn_mul_add_words(BN_ULONG *rp, const BN_ULONG *ap, size_t num,
                           BN_ULONG w) {
@@ -169,30 +185,31 @@
   return c1;
 }
 
-void bn_sqr_words(BN_ULONG *r, const BN_ULONG *a, size_t n) {
+void bn_sqr_add_words(BN_ULONG *r, const BN_ULONG *a, size_t n) {
   if (n == 0) {
     return;
   }
 
+  BN_ULONG carry = 0;
   while (n & ~3) {
-    sqr(r[0], r[1], a[0]);
-    sqr(r[2], r[3], a[1]);
-    sqr(r[4], r[5], a[2]);
-    sqr(r[6], r[7], a[3]);
+    sqr_add(r[0], r[1], a[0], carry);
+    sqr_add(r[2], r[3], a[1], carry);
+    sqr_add(r[4], r[5], a[2], carry);
+    sqr_add(r[6], r[7], a[3], carry);
     a += 4;
     r += 8;
     n -= 4;
   }
   if (n) {
-    sqr(r[0], r[1], a[0]);
+    sqr_add(r[0], r[1], a[0], carry);
     if (--n == 0) {
       return;
     }
-    sqr(r[2], r[3], a[1]);
+    sqr_add(r[2], r[3], a[1], carry);
     if (--n == 0) {
       return;
     }
-    sqr(r[4], r[5], a[2]);
+    sqr_add(r[4], r[5], a[2], carry);
   }
 }
 
diff --git a/crypto/fipsmodule/bn/bn_test.cc b/crypto/fipsmodule/bn/bn_test.cc
index eb7e82d..10fd3b4 100644
--- a/crypto/fipsmodule/bn/bn_test.cc
+++ b/crypto/fipsmodule/bn/bn_test.cc
@@ -2793,7 +2793,7 @@
     CHECK_ABI(bn_mul_add_words, r.data(), a.data(), num, 42);
 
     r.resize(2 * num);
-    CHECK_ABI(bn_sqr_words, r.data(), a.data(), num);
+    CHECK_ABI(bn_sqr_add_words, r.data(), a.data(), num);
 
     if (num == 4) {
       CHECK_ABI(bn_mul_comba4, r.data(), a.data(), b.data());
diff --git a/crypto/fipsmodule/bn/generic.cc.inc b/crypto/fipsmodule/bn/generic.cc.inc
index 9cbbd85..9a42fa0 100644
--- a/crypto/fipsmodule/bn/generic.cc.inc
+++ b/crypto/fipsmodule/bn/generic.cc.inc
@@ -151,22 +151,33 @@
   return c1;
 }
 
-void bn_sqr_words(BN_ULONG *r, const BN_ULONG *a, size_t n) {
+void bn_sqr_add_words(BN_ULONG *r, const BN_ULONG *a, size_t n) {
   if (n == 0) {
     return;
   }
 
+  BN_ULONG carry = 0, lo, hi;
   while (n & ~3) {
-    sqr(r[0], r[1], a[0]);
-    sqr(r[2], r[3], a[1]);
-    sqr(r[4], r[5], a[2]);
-    sqr(r[6], r[7], a[3]);
+    sqr(lo, hi, a[0]);
+    r[0] = CRYPTO_addc_w(r[0], lo, carry, &carry);
+    r[1] = CRYPTO_addc_w(r[1], hi, carry, &carry);
+    sqr(lo, hi, a[1]);
+    r[2] = CRYPTO_addc_w(r[2], lo, carry, &carry);
+    r[3] = CRYPTO_addc_w(r[3], hi, carry, &carry);
+    sqr(lo, hi, a[2]);
+    r[4] = CRYPTO_addc_w(r[4], lo, carry, &carry);
+    r[5] = CRYPTO_addc_w(r[5], hi, carry, &carry);
+    sqr(lo, hi, a[3]);
+    r[6] = CRYPTO_addc_w(r[6], lo, carry, &carry);
+    r[7] = CRYPTO_addc_w(r[7], hi, carry, &carry);
     a += 4;
     r += 8;
     n -= 4;
   }
   while (n) {
-    sqr(r[0], r[1], a[0]);
+    sqr(lo, hi, a[0]);
+    r[0] = CRYPTO_addc_w(r[0], lo, carry, &carry);
+    r[1] = CRYPTO_addc_w(r[1], hi, carry, &carry);
     a++;
     r += 2;
     n--;
diff --git a/crypto/fipsmodule/bn/internal.h b/crypto/fipsmodule/bn/internal.h
index 9dfc114..1ec2562 100644
--- a/crypto/fipsmodule/bn/internal.h
+++ b/crypto/fipsmodule/bn/internal.h
@@ -186,12 +186,14 @@
 // operation. |ap| and |rp| may be equal but otherwise may not alias.
 BN_ULONG bn_mul_words(BN_ULONG *rp, const BN_ULONG *ap, size_t num, BN_ULONG w);
 
-// bn_sqr_words sets |rp[2*i]| and |rp[2*i+1]| to |ap[i]|'s square, for all |i|
-// up to |num|. |ap| is an array of |num| words and |rp| an array of |2*num|
-// words. |ap| and |rp| may not alias.
+// bn_sqr_add_words computes |tmp| where |tmp[2*i]| and |tmp[2*i+1]| are
+// |ap[i]|'s square, for all |i| up to |num|, and adds the result to |rp|. If
+// the result does not fit in |2*num| words, the final carry bit is truncated.
+// |ap| is an array of |num| words and |rp| an array of |2*num| words. |ap| and
+// |rp| may not alias.
 //
 // This gives the contribution of the |ap[i]*ap[i]| terms when squaring |ap|.
-void bn_sqr_words(BN_ULONG *rp, const BN_ULONG *ap, size_t num);
+void bn_sqr_add_words(BN_ULONG *rp, const BN_ULONG *ap, size_t num);
 
 // bn_add_words adds |ap| to |bp| and places the result in |rp|, each of which
 // are |num| words long. It returns the carry bit, which is one if the operation
@@ -645,8 +647,8 @@
 void bn_mul_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a, size_t num_a,
                  const BN_ULONG *b, size_t num_b);
 
-// bn_sqr_small sets |r| to |a|^2. |num_a| must be at most |BN_SMALL_MAX_WORDS|.
-// |num_r| must be |num_a|*2. |r| and |a| may not alias.
+// bn_sqr_small sets |r| to |a|^2. |num_r| must be |num_a|*2. |r| and |a| may
+// not alias.
 void bn_sqr_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a, size_t num_a);
 
 // In the following functions, the modulus must be at most |BN_SMALL_MAX_WORDS|
diff --git a/crypto/fipsmodule/bn/mul.cc.inc b/crypto/fipsmodule/bn/mul.cc.inc
index 4a39ee3..e33fb48 100644
--- a/crypto/fipsmodule/bn/mul.cc.inc
+++ b/crypto/fipsmodule/bn/mul.cc.inc
@@ -25,8 +25,6 @@
 #include "internal.h"
 
 
-#define BN_SQR_STACK_WORDS 16
-
 static void bn_mul_normal(BN_ULONG *r, const BN_ULONG *a, size_t na,
                           const BN_ULONG *b, size_t nb) {
   if (na < nb) {
@@ -228,9 +226,7 @@
   }
 }
 
-// tmp must have 2*n words
-static void bn_sqr_normal(BN_ULONG *r, const BN_ULONG *a, size_t n,
-                          BN_ULONG *tmp) {
+static void bn_sqr_normal(BN_ULONG *r, const BN_ULONG *a, size_t n) {
   if (n == 0) {
     return;
   }
@@ -262,8 +258,7 @@
   bn_add_words(r, r, r, max);
 
   // Add in the contribution of a[i] * a[i] for all i.
-  bn_sqr_words(tmp, a, n);
-  bn_add_words(r, r, tmp, max);
+  bn_sqr_add_words(r, a, n);
 }
 
 int BN_mul_word(BIGNUM *bn, BN_ULONG w) {
@@ -297,8 +292,7 @@
 
   bssl::BN_CTXScope scope(ctx);
   BIGNUM *rr = (a != r) ? r : BN_CTX_get(ctx);
-  BIGNUM *tmp = BN_CTX_get(ctx);
-  if (!rr || !tmp) {
+  if (!rr) {
     return 0;
   }
 
@@ -312,15 +306,7 @@
   } else if (al == 8) {
     bn_sqr_comba8(rr->d, a->d);
   } else {
-    if (al < BN_SQR_STACK_WORDS) {
-      BN_ULONG t[BN_SQR_STACK_WORDS * 2];
-      bn_sqr_normal(rr->d, a->d, al, t);
-    } else {
-      if (!bn_wexpand(tmp, max)) {
-        return 0;
-      }
-      bn_sqr_normal(rr->d, a->d, al, tmp->d);
-    }
+    bn_sqr_normal(rr->d, a->d, al);
   }
 
   rr->neg = 0;
@@ -342,7 +328,8 @@
 }
 
 void bn_sqr_small(BN_ULONG *r, size_t num_r, const BN_ULONG *a, size_t num_a) {
-  if (num_r != 2 * num_a || num_a > BN_SMALL_MAX_WORDS) {
+  assert(r != a);
+  if (num_r != 2 * num_a) {
     abort();
   }
   if (num_a == 4) {
@@ -350,8 +337,6 @@
   } else if (num_a == 8) {
     bn_sqr_comba8(r, a);
   } else {
-    BN_ULONG tmp[2 * BN_SMALL_MAX_WORDS];
-    bn_sqr_normal(r, a, num_a, tmp);
-    OPENSSL_cleanse(tmp, 2 * num_a * sizeof(BN_ULONG));
+    bn_sqr_normal(r, a, num_a);
   }
 }
diff --git a/gen/bcm/bn-586-apple.S b/gen/bcm/bn-586-apple.S
index 3e6f791..e96d0a4 100644
--- a/gen/bcm/bn-586-apple.S
+++ b/gen/bcm/bn-586-apple.S
@@ -132,20 +132,29 @@
 	popl	%ebx
 	popl	%ebp
 	ret
-.globl	_bn_sqr_words
-.private_extern	_bn_sqr_words
+.globl	_bn_sqr_add_words
+.private_extern	_bn_sqr_add_words
 .align	4
-_bn_sqr_words:
-L_bn_sqr_words_begin:
+_bn_sqr_add_words:
+L_bn_sqr_add_words_begin:
 	movl	4(%esp),%eax
 	movl	8(%esp),%edx
 	movl	12(%esp),%ecx
+	pxor	%mm1,%mm1
 .align	4,0x90
 L005sqr_sse2_loop:
 	movd	(%edx),%mm0
+	movd	(%eax),%mm2
+	movd	4(%eax),%mm3
 	pmuludq	%mm0,%mm0
 	leal	4(%edx),%edx
-	movq	%mm0,(%eax)
+	paddq	%mm0,%mm1
+	paddq	%mm2,%mm1
+	movd	%mm1,(%eax)
+	psrlq	$32,%mm1
+	paddq	%mm3,%mm1
+	movd	%mm1,4(%eax)
+	psrlq	$32,%mm1
 	subl	$1,%ecx
 	leal	8(%eax),%eax
 	jnz	L005sqr_sse2_loop
diff --git a/gen/bcm/bn-586-linux.S b/gen/bcm/bn-586-linux.S
index 808f63e..8e2bdb0 100644
--- a/gen/bcm/bn-586-linux.S
+++ b/gen/bcm/bn-586-linux.S
@@ -136,21 +136,30 @@
 	popl	%ebp
 	ret
 .size	bn_mul_words,.-.L_bn_mul_words_begin
-.globl	bn_sqr_words
-.hidden	bn_sqr_words
-.type	bn_sqr_words,@function
+.globl	bn_sqr_add_words
+.hidden	bn_sqr_add_words
+.type	bn_sqr_add_words,@function
 .align	16
-bn_sqr_words:
-.L_bn_sqr_words_begin:
+bn_sqr_add_words:
+.L_bn_sqr_add_words_begin:
 	movl	4(%esp),%eax
 	movl	8(%esp),%edx
 	movl	12(%esp),%ecx
+	pxor	%mm1,%mm1
 .align	16
 .L005sqr_sse2_loop:
 	movd	(%edx),%mm0
+	movd	(%eax),%mm2
+	movd	4(%eax),%mm3
 	pmuludq	%mm0,%mm0
 	leal	4(%edx),%edx
-	movq	%mm0,(%eax)
+	paddq	%mm0,%mm1
+	paddq	%mm2,%mm1
+	movd	%mm1,(%eax)
+	psrlq	$32,%mm1
+	paddq	%mm3,%mm1
+	movd	%mm1,4(%eax)
+	psrlq	$32,%mm1
 	subl	$1,%ecx
 	leal	8(%eax),%eax
 	jnz	.L005sqr_sse2_loop
@@ -161,7 +170,7 @@
 	popl	%ebx
 	popl	%ebp
 	ret
-.size	bn_sqr_words,.-.L_bn_sqr_words_begin
+.size	bn_sqr_add_words,.-.L_bn_sqr_add_words_begin
 .globl	bn_add_words
 .hidden	bn_add_words
 .type	bn_add_words,@function
diff --git a/gen/bcm/bn-586-win.asm b/gen/bcm/bn-586-win.asm
index 1250eb6..fd3fe01 100644
--- a/gen/bcm/bn-586-win.asm
+++ b/gen/bcm/bn-586-win.asm
@@ -138,19 +138,28 @@
 	pop	ebx
 	pop	ebp
 	ret
-global	_bn_sqr_words
+global	_bn_sqr_add_words
 align	16
-_bn_sqr_words:
-L$_bn_sqr_words_begin:
+_bn_sqr_add_words:
+L$_bn_sqr_add_words_begin:
 	mov	eax,DWORD [4+esp]
 	mov	edx,DWORD [8+esp]
 	mov	ecx,DWORD [12+esp]
+	pxor	mm1,mm1
 align	16
 L$005sqr_sse2_loop:
 	movd	mm0,DWORD [edx]
+	movd	mm2,DWORD [eax]
+	movd	mm3,DWORD [4+eax]
 	pmuludq	mm0,mm0
 	lea	edx,[4+edx]
-	movq	[eax],mm0
+	paddq	mm1,mm0
+	paddq	mm1,mm2
+	movd	DWORD [eax],mm1
+	psrlq	mm1,32
+	paddq	mm1,mm3
+	movd	DWORD [4+eax],mm1
+	psrlq	mm1,32
 	sub	ecx,1
 	lea	eax,[8+eax]
 	jnz	NEAR L$005sqr_sse2_loop