Simplify HRSS mod3 circuits.

The multiplication and subtraction circuits were found by djb using GNU
Superoptimizer, and the addition circuit is derived from the subtraction
one by hand. They depend on a different representation: -1 is now (1, 1)
rather than (1, 0), and the latter becomes undefined.

The following Python program checks that the circuits work:

values = [0, 1, -1]

def toBits(v):
    if v == 0:
        return 0, 0
    elif v == 1:
        return 0, 1
    elif v == -1:
        return 1, 1
    else:
        raise ValueError(v)

def mul((s1, a1), (s2, a2)):
    return ((s1 ^ s2) & a1 & a2, a1 & a2)

def add((s1, a1), (s2, a2)):
    t = s1 ^ a2
    return (t & (s2 ^ a1), (a1 ^ a2) | (t ^ s2))

def sub((s1, a1), (s2, a2)):
    t = a1 ^ a2
    return ((s1 ^ a2) & (t ^ s2), t | (s1 ^ s2))

def fromBits((s, a)):
    if s == 0 and a == 0:
        return 0
    if s == 0 and a == 1:
        return 1
    if s == 1 and a == 1:
        return -1
    else:
        raise ValueError((s, a))

def wrap(v):
    if v == 2:
        return -1
    elif v == -2:
        return 1
    else:
        return v

for v1 in values:
    for v2 in values:
        print v1, v2

        result = fromBits(mul(toBits(v1), toBits(v2)))
        if result != v1 * v2:
            raise ValueError((v1, v2, result))

        result = fromBits(add(toBits(v1), toBits(v2)))
        if result != wrap(v1 + v2):
            raise ValueError((v1, v2, result))

        result = fromBits(sub(toBits(v1), toBits(v2)))
        if result != wrap(v1 - v2):
            raise ValueError((v1, v2, result))

Change-Id: Ie1a4ca5a82c2651057efc62330eca6fdd9878122
Reviewed-on: https://boringssl-review.googlesource.com/c/34344
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/crypto/hrss/hrss.c b/crypto/hrss/hrss.c
index 71fa5e3..753d50d 100644
--- a/crypto/hrss/hrss.c
+++ b/crypto/hrss/hrss.c
@@ -488,26 +488,30 @@
 //  -----------------
 //   0  |  0  | 0
 //   0  |  1  | 1
-//   1  |  0  | 2 (aka -1)
-//   1  |  1  | <invalid>
+//   1  |  1  | -1 (aka 2)
+//   1  |  0  | <invalid>
 //
-// ('s' is for sign, and 'a' just a letter.)
+// ('s' is for sign, and 'a' is the absolute value.)
 //
 // Once bitsliced as such, the following circuits can be used to implement
 // addition and multiplication mod 3:
 //
 //   (s3, a3) = (s1, a1) × (s2, a2)
-//   s3 = (a1 ∧ s2) ⊕ (s1 ∧ a2)
-//   a3 = (s1 ∧ s2) ⊕ (a1 ∧ a2)
+//   a3 = a1 ∧ a2
+//   s3 = (s1 ⊕ s2) ∧ a3
 //
 //   (s3, a3) = (s1, a1) + (s2, a2)
-//   x = (a1 ⊕ a2)
-//   y = (s1 ⊕ s2) ⊕ (a1 ∧ a2)
-//   z = (s1 ∧ s2)
-//   s3 = y ∧ ¬x
-//   a3 = z ∨ (x ∧ ¬y)
+//   t = s1 ⊕ a2
+//   s3 = t ∧ (s2 ⊕ a1)
+//   a3 = (a1 ⊕ a2) ∨ (t ⊕ s2)
 //
-// Negating a value just involves swapping s and a.
+//   (s3, a3) = (s1, a1) - (s2, a2)
+//   t = a1 ⊕ a2
+//   s3 = (s1 ⊕ a2) ∧ (t ⊕ s2)
+//   a3 = t ∨ (s1 ⊕ s2)
+//
+// Negating a value just involves XORing s by a.
+//
 // struct poly3 {
 //   struct poly2 s, a;
 // };
@@ -540,22 +544,45 @@
   poly2_zero(&p->a);
 }
 
+// poly3_word_mul sets (|out_s|, |out_a) to (|s1|, |a1|) × (|s2|, |a2|).
+static void poly3_word_mul(crypto_word_t *out_s, crypto_word_t *out_a,
+                           const crypto_word_t s1, const crypto_word_t a1,
+                           const crypto_word_t s2, const crypto_word_t a2) {
+  *out_a = a1 & a2;
+  *out_s = (s1 ^ s2) & *out_a;
+}
+
+// poly3_word_add sets (|out_s|, |out_a|) to (|s1|, |a1|) + (|s2|, |a2|).
+static void poly3_word_add(crypto_word_t *out_s, crypto_word_t *out_a,
+                           const crypto_word_t s1, const crypto_word_t a1,
+                           const crypto_word_t s2, const crypto_word_t a2) {
+  const crypto_word_t t = s1 ^ a2;
+  *out_s = t & (s2 ^ a1);
+  *out_a = (a1 ^ a2) | (t ^ s2);
+}
+
+// poly3_word_sub sets (|out_s|, |out_a|) to (|s1|, |a1|) - (|s2|, |a2|).
+static void poly3_word_sub(crypto_word_t *out_s, crypto_word_t *out_a,
+                           const crypto_word_t s1, const crypto_word_t a1,
+                           const crypto_word_t s2, const crypto_word_t a2) {
+  const crypto_word_t t = a1 ^ a2;
+  *out_s = (s1 ^ a2) & (t ^ s2);
+  *out_a = t | (s1 ^ s2);
+}
+
 // lsb_to_all replicates the least-significant bit of |v| to all bits of the
 // word. This is used in bit-slicing operations to make a vector from a fixed
 // value.
 static crypto_word_t lsb_to_all(crypto_word_t v) { return 0u - (v & 1); }
 
-// poly3_mul_const sets |p| to |p|×m, where m  = (ms, ma).
+// poly3_mul_const sets |p| to |p|×m, where m = (ms, ma).
 static void poly3_mul_const(struct poly3 *p, crypto_word_t ms,
                             crypto_word_t ma) {
   ms = lsb_to_all(ms);
   ma = lsb_to_all(ma);
 
   for (size_t i = 0; i < WORDS_PER_POLY; i++) {
-    const crypto_word_t s = p->s.v[i];
-    const crypto_word_t a = p->a.v[i];
-    p->s.v[i] = (s & ma) ^ (ms & a);
-    p->a.v[i] = (ms & s) ^ (ma & a);
+    poly3_word_mul(&p->s.v[i], &p->a.v[i], p->s.v[i], p->a.v[i], ms, ma);
   }
 }
 
@@ -566,23 +593,15 @@
   HRSS_poly2_rotr_consttime(&p->a, bits);
 }
 
-// poly3_fmadd sets |out| to |out| + |in|×m, where m is (ms, ma).
-static void poly3_fmadd(struct poly3 *RESTRICT out,
+// poly3_fmadd sets |out| to |out| - |in|×m, where m is (ms, ma).
+static void poly3_fmsub(struct poly3 *RESTRICT out,
                         const struct poly3 *RESTRICT in, crypto_word_t ms,
                         crypto_word_t ma) {
-  // (See the multiplication and addition circuits given above.)
+  crypto_word_t product_s, product_a;
   for (size_t i = 0; i < WORDS_PER_POLY; i++) {
-    const crypto_word_t s = in->s.v[i];
-    const crypto_word_t a = in->a.v[i];
-    const crypto_word_t product_s = (s & ma) ^ (ms & a);
-    const crypto_word_t product_a = (ms & s) ^ (ma & a);
-
-    const crypto_word_t x = out->a.v[i] ^ product_a;
-    const crypto_word_t y =
-        (out->s.v[i] ^ product_s) ^ (out->a.v[i] & product_a);
-    const crypto_word_t z = (out->s.v[i] & product_s);
-    out->s.v[i] = y & ~x;
-    out->a.v[i] = z | (x & ~y);
+    poly3_word_mul(&product_s, &product_a, in->s.v[i], in->a.v[i], ms, ma);
+    poly3_word_sub(&out->s.v[i], &out->a.v[i], out->s.v[i], out->a.v[i],
+                   product_s, product_a);
   }
 }
 
@@ -601,20 +620,13 @@
 // poly3_mod_phiN reduces |p| by Φ(N).
 static void poly3_mod_phiN(struct poly3 *p) {
   // In order to reduce by Φ(N) we subtract by the value of the greatest
-  // coefficient. That's the same as adding the negative of its value. The
-  // negative of (s, a) is (a, s), so the arguments are swapped in the following
-  // two lines.
-  const crypto_word_t factor_s = final_bit_to_all(p->a.v[WORDS_PER_POLY - 1]);
-  const crypto_word_t factor_a = final_bit_to_all(p->s.v[WORDS_PER_POLY - 1]);
+  // coefficient.
+  const crypto_word_t factor_s = final_bit_to_all(p->s.v[WORDS_PER_POLY - 1]);
+  const crypto_word_t factor_a = final_bit_to_all(p->a.v[WORDS_PER_POLY - 1]);
 
   for (size_t i = 0; i < WORDS_PER_POLY; i++) {
-    const crypto_word_t s = p->s.v[i];
-    const crypto_word_t a = p->a.v[i];
-    const crypto_word_t x = a ^ factor_a;
-    const crypto_word_t y = (s ^ factor_s) ^ (a & factor_a);
-    const crypto_word_t z = (s & factor_s);
-    p->s.v[i] = y & ~x;
-    p->a.v[i] = z | (x & ~y);
+    poly3_word_sub(&p->s.v[i], &p->a.v[i], p->s.v[i], p->a.v[i], factor_s,
+                   factor_a);
   }
 
   poly2_clear_top_bits(&p->s);
@@ -642,17 +654,6 @@
   crypto_word_t *a;
 };
 
-// poly3_word_add sets (|out_s|, |out_a|) to (|s1|, |a1|) + (|s2|, |a2|).
-static void poly3_word_add(crypto_word_t *out_s, crypto_word_t *out_a,
-                           const crypto_word_t s1, const crypto_word_t a1,
-                           const crypto_word_t s2, const crypto_word_t a2) {
-  const crypto_word_t x = a1 ^ a2;
-  const crypto_word_t y = (s1 ^ s2) ^ (a1 & a2);
-  const crypto_word_t z = s1 & s2;
-  *out_s = y & ~x;
-  *out_a = z | (x & ~y);
-}
-
 // poly3_span_add adds |n| words of values from |a| and |b| and writes the
 // result to |out|.
 static void poly3_span_add(const struct poly3_span *out,
@@ -667,8 +668,7 @@
 static void poly3_span_sub(const struct poly3_span *a,
                            const struct poly3_span *b, size_t n) {
   for (size_t i = 0; i < n; i++) {
-    // Swapping |b->s| and |b->a| negates the value being added.
-    poly3_word_add(&a->s[i], &a->a[i], a->s[i], a->a[i], b->a[i], b->s[i]);
+    poly3_word_sub(&a->s[i], &a->a[i], a->s[i], a->a[i], b->s[i], b->a[i]);
   }
 }
 
@@ -688,14 +688,11 @@
 
     for (size_t i = 0; i < BITS_PER_WORD; i++) {
       // Multiply (s, a) by the next value from (b_s, b_a).
-      const crypto_word_t v_s = lsb_to_all(b_s);
-      const crypto_word_t v_a = lsb_to_all(b_a);
+      crypto_word_t m_s, m_a;
+      poly3_word_mul(&m_s, &m_a, a_s, a_a, lsb_to_all(b_s), lsb_to_all(b_a));
       b_s >>= 1;
       b_a >>= 1;
 
-      const crypto_word_t m_s = (v_s & a_a) ^ (a_s & v_a);
-      const crypto_word_t m_a = (a_s & v_s) ^ (a_a & v_a);
-
       if (i == 0) {
         // Special case otherwise the code tries to shift by BITS_PER_WORD
         // below, which is undefined.
@@ -816,21 +813,22 @@
   }
 }
 
-// poly3_vec_fmadd adds (|ms|, |ma|) × (|b_s|, |b_a|) to (|a_s|, |a_a|).
-static inline void poly3_vec_fmadd(vec_t a_s[6], vec_t a_a[6], vec_t b_s[6],
+// poly3_vec_fmsub subtracts (|ms|, |ma|) × (|b_s|, |b_a|) from (|a_s|, |a_a|).
+static inline void poly3_vec_fmsub(vec_t a_s[6], vec_t a_a[6], vec_t b_s[6],
                                    vec_t b_a[6], const vec_t ms,
                                    const vec_t ma) {
   for (int i = 0; i < 6; i++) {
+    // See the bitslice formula, above.
     const vec_t s = b_s[i];
     const vec_t a = b_a[i];
-    const vec_t product_s = (s & ma) ^ (ms & a);
-    const vec_t product_a = (ms & s) ^ (ma & a);
+    const vec_t product_a = a & ma;
+    const vec_t product_s = (s ^ ms) & product_a;
 
-    const vec_t x = a_a[i] ^ product_a;
-    const vec_t y = (a_s[i] ^ product_s) ^ (a_a[i] & product_a);
-    const vec_t z = (a_s[i] & product_s);
-    a_s[i] = y & ~x;
-    a_a[i] = z | (x & ~y);
+    const vec_t out_s = a_s[i];
+    const vec_t out_a = a_a[i];
+    const vec_t t = out_a ^ product_a;
+    a_s[i] = (out_s ^ product_a) & (t ^ product_s);
+    a_a[i] = t | (out_s ^ product_s);
   }
 }
 
@@ -874,19 +872,18 @@
   memset(&still_going, 0xff, sizeof(still_going));
 
   for (unsigned i = 0; i < 2 * (N - 1) - 1; i++) {
-    const vec_t s_a = vec_broadcast_bit(
-        still_going & ((f_a[0] & g_s[0]) ^ (f_s[0] & g_a[0])));
-    const vec_t s_s = vec_broadcast_bit(
-        still_going & ((f_a[0] & g_a[0]) ^ (f_s[0] & g_s[0])));
+    const vec_t s_a = vec_broadcast_bit(still_going & (f_a[0] & g_a[0]));
+    const vec_t s_s =
+        vec_broadcast_bit(still_going & ((f_s[0] ^ g_s[0]) & s_a));
     const vec_t should_swap =
         (s_s | s_a) & vec_broadcast_bit15(deg_f - deg_g);
 
     poly3_vec_cswap(f_s, f_a, g_s, g_a, should_swap);
-    poly3_vec_fmadd(f_s, f_a, g_s, g_a, s_s, s_a);
+    poly3_vec_fmsub(f_s, f_a, g_s, g_a, s_s, s_a);
     poly3_vec_rshift1(f_s, f_a);
 
     poly3_vec_cswap(b_s, b_a, c_s, c_a, should_swap);
-    poly3_vec_fmadd(b_s, b_a, c_s, c_a, s_s, s_a);
+    poly3_vec_fmsub(b_s, b_a, c_s, c_a, s_s, s_a);
     poly3_vec_lshift1(c_s, c_a);
 
     const vec_t deg_sum = should_swap & (deg_f ^ deg_g);
@@ -959,9 +956,9 @@
 
   for (unsigned i = 0; i < 2 * (N - 1) - 1; i++) {
     const crypto_word_t s_a = lsb_to_all(
-        still_going & ((f.a.v[0] & g.s.v[0]) ^ (f.s.v[0] & g.a.v[0])));
+        still_going & (f.a.v[0] & g.a.v[0]));
     const crypto_word_t s_s = lsb_to_all(
-        still_going & ((f.a.v[0] & g.a.v[0]) ^ (f.s.v[0] & g.s.v[0])));
+        still_going & ((f.s.v[0] ^ g.s.v[0]) & s_a));
     const crypto_word_t should_swap =
         (s_s | s_a) & constant_time_lt_w(deg_f, deg_g);
 
@@ -973,8 +970,8 @@
     deg_g ^= deg_sum;
     assert(deg_g >= 1);
 
-    poly3_fmadd(&f, &g, s_s, s_a);
-    poly3_fmadd(b, &c, s_s, s_a);
+    poly3_fmsub(&f, &g, s_s, s_a);
+    poly3_fmsub(b, &c, s_s, s_a);
     poly3_rshift1(&f);
     poly3_lshift1(&c);
 
@@ -1486,9 +1483,10 @@
     // The signed value is reduced mod 3, yielding {0, 1, 2}.
     const uint16_t v = mod3((int16_t)(in->v[i] << 3) >> 3);
     s >>= 1;
-    s |= (crypto_word_t)(v & 2) << (BITS_PER_WORD - 2);
+    const crypto_word_t s_bit = (crypto_word_t)(v & 2) << (BITS_PER_WORD - 2);
+    s |= s_bit;
     a >>= 1;
-    a |= (crypto_word_t)(v & 1) << (BITS_PER_WORD - 1);
+    a |= s_bit | (crypto_word_t)(v & 1) << (BITS_PER_WORD - 1);
     shift++;
 
     if (shift == BITS_PER_WORD) {
@@ -1528,9 +1526,11 @@
     ok &= constant_time_eq_w(v, expected);
 
     s >>= 1;
-    s |= (crypto_word_t)(mod3 & 2) << (BITS_PER_WORD - 2);
+    const crypto_word_t s_bit = (crypto_word_t)(mod3 & 2)
+                                << (BITS_PER_WORD - 2);
+    s |= s_bit;
     a >>= 1;
-    a |= (crypto_word_t)(mod3 & 1) << (BITS_PER_WORD - 1);
+    a |= s_bit | (crypto_word_t)(mod3 & 1) << (BITS_PER_WORD - 1);
     shift++;
 
     if (shift == BITS_PER_WORD) {
diff --git a/crypto/hrss/hrss_test.cc b/crypto/hrss/hrss_test.cc
index 596db07..97c1bf0 100644
--- a/crypto/hrss/hrss_test.cc
+++ b/crypto/hrss/hrss_test.cc
@@ -81,20 +81,18 @@
   RAND_bytes(reinterpret_cast<uint8_t *>(p), sizeof(poly3));
   p->s.v[WORDS_PER_POLY - 1] &= (UINT64_C(1) << BITS_IN_LAST_WORD) - 1;
   p->a.v[WORDS_PER_POLY - 1] &= (UINT64_C(1) << BITS_IN_LAST_WORD) - 1;
-  // (s, a) = (1, 1) is invalid. Map those to one.
+  // (s, a) = (1, 0) is invalid. Map those to -1.
   for (size_t j = 0; j < WORDS_PER_POLY; j++) {
-    p->s.v[j] ^= p->s.v[j] & p->a.v[j];
+    p->a.v[j] |= p->s.v[j];
   }
 }
 
 // poly3_word_add sets (|s1|, |a1|) += (|s2|, |a2|).
 static void poly3_word_add(crypto_word_t *s1, crypto_word_t *a1,
                            const crypto_word_t s2, const crypto_word_t a2) {
-  const crypto_word_t x = *a1 ^ a2;
-  const crypto_word_t y = (*s1 ^ s2) ^ (*a1 & a2);
-  const crypto_word_t z = *s1 & s2;
-  *s1 = y & ~x;
-  *a1 = z | (x & ~y);
+  const crypto_word_t t = *s1 ^ a2;
+  *s1 = t & (s2 ^ *a1);
+  *a1 = (*a1 ^ a2) | (t ^ s2);
 }
 
 TEST(HRSS, Poly3Invert) {
@@ -105,6 +103,7 @@
 
   // The inverse of -1 is -1.
   p.s.v[0] = 1;
+  p.a.v[0] = 1;
   HRSS_poly3_invert(&inverse, &p);
   EXPECT_EQ(Bytes(reinterpret_cast<const uint8_t*>(&p), sizeof(p)),
             Bytes(reinterpret_cast<const uint8_t*>(&inverse), sizeof(inverse)));
@@ -144,7 +143,7 @@
   // |r| is probably already not reduced mod Φ(N), but add x^701 - 1 and
   // recompute to ensure that we get the same answer. (Since (x^701 - 1) ≡ 0 mod
   // Φ(N).)
-  poly3_word_add(&r.s.v[0], &r.a.v[0], 1, 0);
+  poly3_word_add(&r.s.v[0], &r.a.v[0], 1, 1);
   poly3_word_add(&r.s.v[WORDS_PER_POLY - 1], &r.a.v[WORDS_PER_POLY - 1], 0,
                  UINT64_C(1) << BITS_IN_LAST_WORD);
 
@@ -160,11 +159,11 @@
 
   for (size_t i = 0; i < WORDS_PER_POLY-1; i++) {
     EXPECT_EQ(CONSTTIME_TRUE_W, result.s.v[i]);
-    EXPECT_EQ(0u, result.a.v[i]);
+    EXPECT_EQ(CONSTTIME_TRUE_W, result.a.v[i]);
   }
   EXPECT_EQ((UINT64_C(1) << (BITS_IN_LAST_WORD - 1)) - 1,
             result.s.v[WORDS_PER_POLY - 1]);
-  EXPECT_EQ(0u, result.a.v[WORDS_PER_POLY - 1]);
+  EXPECT_EQ(result.s.v[WORDS_PER_POLY - 1], result.a.v[WORDS_PER_POLY - 1]);
 }
 
 TEST(HRSS, Basic) {