Switch HRSS inversion algorithm.

This algorithm is much simplier and more obvious than the one from the
HRSS paper. Unfortunately it's not immediately any faster (roughly a
no-op on most platforms, +5% on ARM) but it does allow a bunch of
constant-time rotation code to be deleted.

Since it's simplier, however, it's easier to speed-up a little with
future changes.

Change-Id: Ic0e92c77c44ea9aeb6fe35940af9767084fe5f58
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/39084
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/crypto/hrss/hrss.c b/crypto/hrss/hrss.c
index 0f66e97..67ff4c1 100644
--- a/crypto/hrss/hrss.c
+++ b/crypto/hrss/hrss.c
@@ -51,8 +51,8 @@
 // SXY: https://eprint.iacr.org/2017/1005.pdf
 // NTRUTN14:
 // https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf
-// NTRUCOMP:
-// https://eprint.iacr.org/2018/1174
+// NTRUCOMP: https://eprint.iacr.org/2018/1174
+// SAFEGCD: https://gcd.cr.yp.to/papers.html#safegcd
 
 
 // Vector operations.
@@ -184,13 +184,6 @@
                            0b01010101);
 }
 
-// vec_broadcast_bit15 duplicates the most-significant bit of the first word in
-// |a| to all bits in a vector and returns the result.
-static inline vec_t vec_broadcast_bit15(vec_t a) {
-  return _mm_shuffle_epi32(_mm_srai_epi32(_mm_slli_epi64(a, 63 - 15), 31),
-                           0b01010101);
-}
-
 // vec_get_word returns the |i|th uint16_t in |v|. (This is a macro because the
 // compiler requires that |i| be a compile-time constant.)
 #define vec_get_word(v, i) _mm_extract_epi16(v, i)
@@ -246,11 +239,6 @@
   return vdupq_lane_u16(vget_low_u16(a), 0);
 }
 
-static inline vec_t vec_broadcast_bit15(vec_t a) {
-  a = (vec_t)vshrq_n_s16((int16x8_t)a, 15);
-  return vdupq_lane_u16(vget_low_u16(a), 0);
-}
-
 static inline void poly3_vec_lshift1(vec_t a_s[6], vec_t a_a[6]) {
   vec_t carry_s = {0};
   vec_t carry_a = {0};
@@ -324,99 +312,64 @@
   OPENSSL_memset(&p->v[0], 0, sizeof(crypto_word_t) * WORDS_PER_POLY);
 }
 
-// poly2_cmov sets |out| to |in| iff |mov| is all ones.
-static void poly2_cmov(struct poly2 *out, const struct poly2 *in,
-                       crypto_word_t mov) {
+// word_reverse returns |in| with the bits in reverse order.
+static crypto_word_t word_reverse(crypto_word_t in) {
+#if defined(OPENSSL_64_BIT)
+  static const crypto_word_t kMasks[6] = {
+    UINT64_C(0x5555555555555555),
+    UINT64_C(0x3333333333333333),
+    UINT64_C(0x0f0f0f0f0f0f0f0f),
+    UINT64_C(0x00ff00ff00ff00ff),
+    UINT64_C(0x0000ffff0000ffff),
+    UINT64_C(0x00000000ffffffff),
+  };
+#else
+  static const crypto_word_t kMasks[5] = {
+    0x55555555,
+    0x33333333,
+    0x0f0f0f0f,
+    0x00ff00ff,
+    0x0000ffff,
+  };
+#endif
+
+  for (size_t i = 0; i < OPENSSL_ARRAY_SIZE(kMasks); i++) {
+    in = ((in >> (1 << i)) & kMasks[i]) | ((in & kMasks[i]) << (1 << i));
+  }
+
+  return in;
+}
+
+// lsb_to_all replicates the least-significant bit of |v| to all bits of the
+// word. This is used in bit-slicing operations to make a vector from a fixed
+// value.
+static crypto_word_t lsb_to_all(crypto_word_t v) { return 0u - (v & 1); }
+
+// poly2_mod_phiN reduces |p| by Φ(N).
+static void poly2_mod_phiN(struct poly2 *p) {
+  // m is the term at x^700, replicated to every bit.
+  const crypto_word_t m =
+      lsb_to_all(p->v[WORDS_PER_POLY - 1] >> (BITS_IN_LAST_WORD - 1));
   for (size_t i = 0; i < WORDS_PER_POLY; i++) {
-    out->v[i] = (out->v[i] & ~mov) | (in->v[i] & mov);
+    p->v[i] ^= m;
   }
+  p->v[WORDS_PER_POLY - 1] &= (UINT64_C(1) << (BITS_IN_LAST_WORD - 1)) - 1;
 }
 
-// poly2_rotr_words performs a right-rotate on |in|, writing the result to
-// |out|. The shift count, |bits|, must be a non-zero multiple of the word size.
-static void poly2_rotr_words(struct poly2 *out, const struct poly2 *in,
-                             size_t bits) {
-  assert(bits >= BITS_PER_WORD && bits % BITS_PER_WORD == 0);
-  assert(out != in);
-
-  const size_t start = bits / BITS_PER_WORD;
-  const size_t n = (N - bits) / BITS_PER_WORD;
-
-  // The rotate is by a whole number of words so the first few words are easy:
-  // just move them down.
-  for (size_t i = 0; i < n; i++) {
-    out->v[i] = in->v[start + i];
+// poly2_reverse_700 reverses the order of the first 700 bits of |in| and writes
+// the result to |out|.
+static void poly2_reverse_700(struct poly2 *out, const struct poly2 *in) {
+  struct poly2 t;
+  for (size_t i = 0; i < WORDS_PER_POLY; i++) {
+    t.v[i] = word_reverse(in->v[i]);
   }
 
-  // Since the last word is only partially filled, however, the remainder needs
-  // shifting and merging of words to take care of that.
-  crypto_word_t carry = in->v[WORDS_PER_POLY - 1];
-
-  for (size_t i = 0; i < start; i++) {
-    out->v[n + i] = carry | in->v[i] << BITS_IN_LAST_WORD;
-    carry = in->v[i] >> (BITS_PER_WORD - BITS_IN_LAST_WORD);
+  static const size_t shift = BITS_PER_WORD - ((N-1) % BITS_PER_WORD);
+  for (size_t i = 0; i < WORDS_PER_POLY-1; i++) {
+    out->v[i] = t.v[WORDS_PER_POLY-1-i] >> shift;
+    out->v[i] |= t.v[WORDS_PER_POLY-2-i] << (BITS_PER_WORD - shift);
   }
-
-  out->v[WORDS_PER_POLY - 1] = carry;
-}
-
-// poly2_rotr_bits performs a right-rotate on |in|, writing the result to |out|.
-// The shift count, |bits|, must be a power of two that is less than
-// |BITS_PER_WORD|.
-static void poly2_rotr_bits(struct poly2 *out, const struct poly2 *in,
-                            size_t bits) {
-  assert(bits <= BITS_PER_WORD / 2);
-  assert(bits != 0);
-  assert((bits & (bits - 1)) == 0);
-  assert(out != in);
-
-  // BITS_PER_WORD/2 is the greatest legal value of |bits|. If
-  // |BITS_IN_LAST_WORD| is smaller than this then the code below doesn't work
-  // because more than the last word needs to carry down in the previous one and
-  // so on.
-  OPENSSL_STATIC_ASSERT(
-      BITS_IN_LAST_WORD >= BITS_PER_WORD / 2,
-      "there are more carry bits than fit in BITS_IN_LAST_WORD");
-
-  crypto_word_t carry = in->v[WORDS_PER_POLY - 1] << (BITS_PER_WORD - bits);
-
-  for (size_t i = WORDS_PER_POLY - 2; i < WORDS_PER_POLY; i--) {
-    out->v[i] = carry | in->v[i] >> bits;
-    carry = in->v[i] << (BITS_PER_WORD - bits);
-  }
-
-  crypto_word_t last_word = carry >> (BITS_PER_WORD - BITS_IN_LAST_WORD) |
-                            in->v[WORDS_PER_POLY - 1] >> bits;
-  last_word &= (UINT64_C(1) << BITS_IN_LAST_WORD) - 1;
-  out->v[WORDS_PER_POLY - 1] = last_word;
-}
-
-// HRSS_poly2_rotr_consttime right-rotates |p| by |bits| in constant-time.
-void HRSS_poly2_rotr_consttime(struct poly2 *p, size_t bits) {
-  assert(bits <= N);
-  assert(p->v[WORDS_PER_POLY-1] >> BITS_IN_LAST_WORD == 0);
-
-  // Constant-time rotation is implemented by calculating the rotations of
-  // powers-of-two bits and throwing away the unneeded values. 2^9 (i.e. 512) is
-  // the largest power-of-two shift that we need to consider because 2^10 > N.
-#define HRSS_POLY2_MAX_SHIFT 9
-  size_t shift = HRSS_POLY2_MAX_SHIFT;
-  OPENSSL_STATIC_ASSERT((1 << (HRSS_POLY2_MAX_SHIFT + 1)) > N,
-                        "maximum shift is too small");
-  OPENSSL_STATIC_ASSERT((1 << HRSS_POLY2_MAX_SHIFT) <= N,
-                        "maximum shift is too large");
-  struct poly2 shifted;
-
-  for (; (UINT64_C(1) << shift) >= BITS_PER_WORD; shift--) {
-    poly2_rotr_words(&shifted, p, UINT64_C(1) << shift);
-    poly2_cmov(p, &shifted, ~((1 & (bits >> shift)) - 1));
-  }
-
-  for (; shift < HRSS_POLY2_MAX_SHIFT; shift--) {
-    poly2_rotr_bits(&shifted, p, UINT64_C(1) << shift);
-    poly2_cmov(p, &shifted, ~((1 & (bits >> shift)) - 1));
-  }
-#undef HRSS_POLY2_MAX_SHIFT
+  out->v[WORDS_PER_POLY-1] = t.v[0] >> shift;
 }
 
 // poly2_cswap exchanges the values of |a| and |b| if |swap| is all ones.
@@ -537,7 +490,14 @@
   poly2_zero(&p->a);
 }
 
-// poly3_word_mul sets (|out_s|, |out_a) to (|s1|, |a1|) × (|s2|, |a2|).
+// poly3_reverse_700 reverses the order of the first 700 terms of |in| and
+// writes them to |out|.
+static void poly3_reverse_700(struct poly3 *out, const struct poly3 *in) {
+  poly2_reverse_700(&out->a, &in->a);
+  poly2_reverse_700(&out->s, &in->s);
+}
+
+// poly3_word_mul sets (|out_s|, |out_a|) to (|s1|, |a1|) × (|s2|, |a2|).
 static void poly3_word_mul(crypto_word_t *out_s, crypto_word_t *out_a,
                            const crypto_word_t s1, const crypto_word_t a1,
                            const crypto_word_t s2, const crypto_word_t a2) {
@@ -563,11 +523,6 @@
   *out_a = t | (s1 ^ s2);
 }
 
-// lsb_to_all replicates the least-significant bit of |v| to all bits of the
-// word. This is used in bit-slicing operations to make a vector from a fixed
-// value.
-static crypto_word_t lsb_to_all(crypto_word_t v) { return 0u - (v & 1); }
-
 // poly3_mul_const sets |p| to |p|×m, where m = (ms, ma).
 static void poly3_mul_const(struct poly3 *p, crypto_word_t ms,
                             crypto_word_t ma) {
@@ -579,13 +534,6 @@
   }
 }
 
-// poly3_rotr_consttime right-rotates |p| by |bits| in constant-time.
-static void poly3_rotr_consttime(struct poly3 *p, size_t bits) {
-  assert(bits <= N);
-  HRSS_poly2_rotr_consttime(&p->s, bits);
-  HRSS_poly2_rotr_consttime(&p->a, bits);
-}
-
 // poly3_fmadd sets |out| to |out| - |in|×m, where m is (ms, ma).
 static void poly3_fmsub(struct poly3 *RESTRICT out,
                         const struct poly3 *RESTRICT in, crypto_word_t ms,
@@ -828,83 +776,64 @@
 // poly3_invert_vec sets |*out| to |in|^-1, i.e. such that |out|×|in| == 1 mod
 // Φ(N).
 static void poly3_invert_vec(struct poly3 *out, const struct poly3 *in) {
-  // See the comment in |HRSS_poly3_invert| about this algorithm. In addition to
-  // the changes described there, this implementation attempts to use vector
-  // registers to speed up the computation. Even non-poly3 variables are held in
-  // vectors where possible to minimise the amount of data movement between
-  // the vector and general-purpose registers.
-
-  vec_t b_s[6], b_a[6], c_s[6], c_a[6], f_s[6], f_a[6], g_s[6], g_a[6];
+  // This algorithm is taken from section 7.1 of [SAFEGCD].
   const vec_t kZero = {0};
   const vec_t kOne = {1};
-  static const uint8_t kOneBytes[sizeof(vec_t)] = {1};
   static const uint8_t kBottomSixtyOne[sizeof(vec_t)] = {
       0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x1f};
 
-  memset(b_s, 0, sizeof(b_s));
-  memcpy(b_a, kOneBytes, sizeof(kOneBytes));
-  memset(&b_a[1], 0, 5 * sizeof(vec_t));
+  vec_t v_s[6], v_a[6], r_s[6], r_a[6], f_s[6], f_a[6], g_s[6], g_a[6];
+  // v = 0
+  memset(&v_s, 0, sizeof(v_s));
+  memset(&v_a, 0, sizeof(v_a));
+  // r = 1
+  memset(&r_s, 0, sizeof(r_s));
+  memset(&r_a, 0, sizeof(r_a));
+  r_a[0] = kOne;
+  // f = all ones.
+  memset(f_s, 0, sizeof(f_s));
+  memset(f_a, 0xff, 5 * sizeof(vec_t));
+  memcpy(&f_a[5], kBottomSixtyOne, sizeof(kBottomSixtyOne));
+  // g is the reversal of |in|.
+  struct poly3 in_reversed;
+  poly3_reverse_700(&in_reversed, in);
+  g_s[5] = kZero;
+  memcpy(&g_s, &in_reversed.s.v, WORDS_PER_POLY * sizeof(crypto_word_t));
+  g_a[5] = kZero;
+  memcpy(&g_a, &in_reversed.a.v, WORDS_PER_POLY * sizeof(crypto_word_t));
 
-  memset(c_s, 0, sizeof(c_s));
-  memset(c_a, 0, sizeof(c_a));
+  int delta = 1;
 
-  f_s[5] = kZero;
-  memcpy(f_s, in->s.v, WORDS_PER_POLY * sizeof(crypto_word_t));
-  f_a[5] = kZero;
-  memcpy(f_a, in->a.v, WORDS_PER_POLY * sizeof(crypto_word_t));
+  for (size_t i = 0; i < (2*(N-1)) - 1; i++) {
+    poly3_vec_lshift1(v_s, v_a);
 
-  // Set g to all ones.
-  memset(g_s, 0, sizeof(g_s));
-  memset(g_a, 0xff, 5 * sizeof(vec_t));
-  memcpy(&g_a[5], kBottomSixtyOne, sizeof(kBottomSixtyOne));
+    const crypto_word_t delta_sign_bit = (delta >> (sizeof(delta) * 8 - 1)) & 1;
+    const crypto_word_t delta_is_non_negative = delta_sign_bit - 1;
+    const crypto_word_t delta_is_non_zero = ~constant_time_is_zero_w(delta);
+    const vec_t g_has_constant_term = vec_broadcast_bit(g_a[0]);
+    const vec_t mask_w =
+        {delta_is_non_negative & delta_is_non_zero};
+    const vec_t mask = vec_broadcast_bit(mask_w) & g_has_constant_term;
 
-  vec_t deg_f = {N - 1}, deg_g = {N - 1}, rotation = kZero;
-  vec_t k = kOne;
-  vec_t f0s = {0}, f0a = {0};
-  vec_t still_going;
-  memset(&still_going, 0xff, sizeof(still_going));
+    const vec_t c_a = vec_broadcast_bit(f_a[0] & g_a[0]);
+    const vec_t c_s = vec_broadcast_bit((f_s[0] ^ g_s[0]) & c_a);
 
-  for (unsigned i = 0; i < 2 * (N - 1) - 1; i++) {
-    const vec_t s_a = vec_broadcast_bit(still_going & (f_a[0] & g_a[0]));
-    const vec_t s_s =
-        vec_broadcast_bit(still_going & ((f_s[0] ^ g_s[0]) & s_a));
-    const vec_t should_swap =
-        (s_s | s_a) & vec_broadcast_bit15(deg_f - deg_g);
+    delta = constant_time_select_int(lsb_to_all(mask[0]), -delta, delta);
+    delta++;
 
-    poly3_vec_cswap(f_s, f_a, g_s, g_a, should_swap);
-    poly3_vec_fmsub(f_s, f_a, g_s, g_a, s_s, s_a);
-    poly3_vec_rshift1(f_s, f_a);
+    poly3_vec_cswap(f_s, f_a, g_s, g_a, mask);
+    poly3_vec_fmsub(g_s, g_a, f_s, f_a, c_s, c_a);
+    poly3_vec_rshift1(g_s, g_a);
 
-    poly3_vec_cswap(b_s, b_a, c_s, c_a, should_swap);
-    poly3_vec_fmsub(b_s, b_a, c_s, c_a, s_s, s_a);
-    poly3_vec_lshift1(c_s, c_a);
-
-    const vec_t deg_sum = should_swap & (deg_f ^ deg_g);
-    deg_f ^= deg_sum;
-    deg_g ^= deg_sum;
-
-    deg_f -= kOne;
-    still_going &= ~vec_broadcast_bit15(deg_f - kOne);
-
-    const vec_t f0_is_nonzero = vec_broadcast_bit(f_s[0] | f_a[0]);
-    // |f0_is_nonzero| implies |still_going|.
-    rotation ^= f0_is_nonzero & (k ^ rotation);
-    k += kOne;
-
-    const vec_t f0s_sum = f0_is_nonzero & (f_s[0] ^ f0s);
-    f0s ^= f0s_sum;
-    const vec_t f0a_sum = f0_is_nonzero & (f_a[0] ^ f0a);
-    f0a ^= f0a_sum;
+    poly3_vec_cswap(v_s, v_a, r_s, r_a, mask);
+    poly3_vec_fmsub(r_s, r_a, v_s, v_a, c_s, c_a);
   }
 
-  crypto_word_t rotation_word = vec_get_word(rotation, 0);
-  rotation_word -= N & constant_time_lt_w(N, rotation_word);
-  memcpy(out->s.v, b_s, WORDS_PER_POLY * sizeof(crypto_word_t));
-  memcpy(out->a.v, b_a, WORDS_PER_POLY * sizeof(crypto_word_t));
-  assert(poly3_top_bits_are_clear(out));
-  poly3_rotr_consttime(out, rotation_word);
-  poly3_mul_const(out, vec_get_word(f0s, 0), vec_get_word(f0a, 0));
-  poly3_mod_phiN(out);
+  assert(delta == 0);
+  memcpy(out->s.v, v_s, WORDS_PER_POLY * sizeof(crypto_word_t));
+  memcpy(out->a.v, v_a, WORDS_PER_POLY * sizeof(crypto_word_t));
+  poly3_mul_const(out, vec_get_word(f_s[0], 0), vec_get_word(f_a[0], 0));
+  poly3_reverse_700(out, out);
 }
 
 #endif  // HRSS_HAVE_VECTOR_UNIT
@@ -921,71 +850,50 @@
   }
 #endif
 
-  // This algorithm mostly follows algorithm 10 in the paper. Some changes:
-  //   1) k should start at zero, not one. In the code below k is omitted and
-  //      the loop counter, |i|, is used instead.
-  //   2) The rotation count is conditionally updated to handle trailing zero
-  //      coefficients.
-  // The best explanation for why it works is in the "Why it works" section of
-  // [NTRUTN14].
+  // This algorithm is taken from section 7.1 of [SAFEGCD].
+  struct poly3 v, r, f, g;
+  // v = 0
+  poly3_zero(&v);
+  // r = 1
+  poly3_zero(&r);
+  r.a.v[0] = 1;
+  // f = all ones.
+  OPENSSL_memset(&f.s, 0, sizeof(struct poly2));
+  OPENSSL_memset(&f.a, 0xff, sizeof(struct poly2));
+  f.a.v[WORDS_PER_POLY - 1] >>= BITS_PER_WORD - BITS_IN_LAST_WORD;
+  // g is the reversal of |in|.
+  poly3_reverse_700(&g, in);
+  int delta = 1;
 
-  struct poly3 c, f, g;
-  OPENSSL_memcpy(&f, in, sizeof(f));
+  for (size_t i = 0; i < (2*(N-1)) - 1; i++) {
+    poly3_lshift1(&v);
 
-  // Set g to all ones.
-  OPENSSL_memset(&g.s, 0, sizeof(struct poly2));
-  OPENSSL_memset(&g.a, 0xff, sizeof(struct poly2));
-  g.a.v[WORDS_PER_POLY - 1] >>= BITS_PER_WORD - BITS_IN_LAST_WORD;
+    const crypto_word_t delta_sign_bit = (delta >> (sizeof(delta) * 8 - 1)) & 1;
+    const crypto_word_t delta_is_non_negative = delta_sign_bit - 1;
+    const crypto_word_t delta_is_non_zero = ~constant_time_is_zero_w(delta);
+    const crypto_word_t g_has_constant_term = lsb_to_all(g.a.v[0]);
+    const crypto_word_t mask =
+        g_has_constant_term & delta_is_non_negative & delta_is_non_zero;
 
-  struct poly3 *b = out;
-  poly3_zero(b);
-  poly3_zero(&c);
-  // Set b to one.
-  b->a.v[0] = 1;
+    crypto_word_t c_s, c_a;
+    poly3_word_mul(&c_s, &c_a, f.s.v[0], f.a.v[0], g.s.v[0], g.a.v[0]);
+    c_s = lsb_to_all(c_s);
+    c_a = lsb_to_all(c_a);
 
-  crypto_word_t deg_f = N - 1, deg_g = N - 1, rotation = 0;
-  crypto_word_t f0s = 0, f0a = 0;
-  crypto_word_t still_going = CONSTTIME_TRUE_W;
+    delta = constant_time_select_int(mask, -delta, delta);
+    delta++;
 
-  for (unsigned i = 0; i < 2 * (N - 1) - 1; i++) {
-    const crypto_word_t s_a = lsb_to_all(
-        still_going & (f.a.v[0] & g.a.v[0]));
-    const crypto_word_t s_s = lsb_to_all(
-        still_going & ((f.s.v[0] ^ g.s.v[0]) & s_a));
-    const crypto_word_t should_swap =
-        (s_s | s_a) & constant_time_lt_w(deg_f, deg_g);
+    poly3_cswap(&f, &g, mask);
+    poly3_fmsub(&g, &f, c_s, c_a);
+    poly3_rshift1(&g);
 
-    poly3_cswap(&f, &g, should_swap);
-    poly3_cswap(b, &c, should_swap);
-
-    const crypto_word_t deg_sum = should_swap & (deg_f ^ deg_g);
-    deg_f ^= deg_sum;
-    deg_g ^= deg_sum;
-    assert(deg_g >= 1);
-
-    poly3_fmsub(&f, &g, s_s, s_a);
-    poly3_fmsub(b, &c, s_s, s_a);
-    poly3_rshift1(&f);
-    poly3_lshift1(&c);
-
-    deg_f--;
-    const crypto_word_t f0_is_nonzero =
-        lsb_to_all(f.s.v[0]) | lsb_to_all(f.a.v[0]);
-    // |f0_is_nonzero| implies |still_going|.
-    assert(!(f0_is_nonzero && !still_going));
-    still_going &= ~constant_time_is_zero_w(deg_f);
-
-    rotation = constant_time_select_w(f0_is_nonzero, i, rotation);
-    f0s = constant_time_select_w(f0_is_nonzero, f.s.v[0], f0s);
-    f0a = constant_time_select_w(f0_is_nonzero, f.a.v[0], f0a);
+    poly3_cswap(&v, &r, mask);
+    poly3_fmsub(&r, &v, c_s, c_a);
   }
 
-  rotation++;
-  rotation -= N & constant_time_lt_w(N, rotation);
-  assert(poly3_top_bits_are_clear(out));
-  poly3_rotr_consttime(out, rotation);
-  poly3_mul_const(out, f0s, f0a);
-  poly3_mod_phiN(out);
+  assert(delta == 0);
+  poly3_mul_const(&v, f.s.v[0], f.a.v[0]);
+  poly3_reverse_700(out, &v);
 }
 
 // Polynomials in Q.
@@ -1593,52 +1501,50 @@
 // Φ(N)), all mod 2. This isn't useful in itself, but is part of doing inversion
 // mod Q.
 static void poly_invert_mod2(struct poly *out, const struct poly *in) {
-  // This algorithm follows algorithm 10 in the paper. (Although, in contrast to
-  // the paper, k should start at zero, not one, and the rotation count is needs
-  // to handle trailing zero coefficients.) The best explanation for why it
-  // works is in the "Why it works" section of [NTRUTN14].
+  // This algorithm is taken from section 7.1 of [SAFEGCD].
+  struct poly2 v, r, f, g;
 
-  struct poly2 b, c, f, g;
-  poly2_from_poly(&f, in);
-  OPENSSL_memset(&b, 0, sizeof(b));
-  b.v[0] = 1;
-  OPENSSL_memset(&c, 0, sizeof(c));
+  // v = 0
+  poly2_zero(&v);
+  // r = 1
+  poly2_zero(&r);
+  r.v[0] = 1;
+  // f = all ones.
+  OPENSSL_memset(&f, 0xff, sizeof(struct poly2));
+  f.v[WORDS_PER_POLY - 1] >>= BITS_PER_WORD - BITS_IN_LAST_WORD;
+  // g is the reversal of |in|.
+  poly2_from_poly(&g, in);
+  poly2_mod_phiN(&g);
+  poly2_reverse_700(&g, &g);
+  int delta = 1;
 
-  // Set g to all ones.
-  OPENSSL_memset(&g, 0xff, sizeof(struct poly2));
-  g.v[WORDS_PER_POLY - 1] >>= BITS_PER_WORD - BITS_IN_LAST_WORD;
+  for (size_t i = 0; i < (2*(N-1)) - 1; i++) {
+    poly2_lshift1(&v);
 
-  crypto_word_t deg_f = N - 1, deg_g = N - 1, rotation = 0;
-  crypto_word_t still_going = CONSTTIME_TRUE_W;
+    const crypto_word_t delta_sign_bit = (delta >> (sizeof(delta) * 8 - 1)) & 1;
+    const crypto_word_t delta_is_non_negative = delta_sign_bit - 1;
+    const crypto_word_t delta_is_non_zero = ~constant_time_is_zero_w(delta);
+    const crypto_word_t g_has_constant_term = lsb_to_all(g.v[0]);
+    const crypto_word_t mask =
+        g_has_constant_term & delta_is_non_negative & delta_is_non_zero;
 
-  for (unsigned i = 0; i < 2 * (N - 1) - 1; i++) {
-    const crypto_word_t s = still_going & lsb_to_all(f.v[0]);
-    const crypto_word_t should_swap = s & constant_time_lt_w(deg_f, deg_g);
-    poly2_cswap(&f, &g, should_swap);
-    poly2_cswap(&b, &c, should_swap);
-    const crypto_word_t deg_sum = should_swap & (deg_f ^ deg_g);
-    deg_f ^= deg_sum;
-    deg_g ^= deg_sum;
-    assert(deg_g >= 1);
-    poly2_fmadd(&f, &g, s);
-    poly2_fmadd(&b, &c, s);
+    const crypto_word_t c = lsb_to_all(f.v[0] & g.v[0]);
 
-    poly2_rshift1(&f);
-    poly2_lshift1(&c);
+    delta = constant_time_select_int(mask, -delta, delta);
+    delta++;
 
-    deg_f--;
-    const crypto_word_t f0_is_nonzero = lsb_to_all(f.v[0]);
-    // |f0_is_nonzero| implies |still_going|.
-    assert(!(f0_is_nonzero && !still_going));
-    rotation = constant_time_select_w(f0_is_nonzero, i, rotation);
-    still_going &= ~constant_time_is_zero_w(deg_f);
+    poly2_cswap(&f, &g, mask);
+    poly2_fmadd(&g, &f, c);
+    poly2_rshift1(&g);
+
+    poly2_cswap(&v, &r, mask);
+    poly2_fmadd(&r, &v, c);
   }
 
-  rotation++;
-  rotation -= N & constant_time_lt_w(N, rotation);
-  assert(poly2_top_bits_are_clear(&b));
-  HRSS_poly2_rotr_consttime(&b, rotation);
-  poly_from_poly2(out, &b);
+  assert(delta == 0);
+  assert(f.v[0] & 1);
+  poly2_reverse_700(&v, &v);
+  poly_from_poly2(out, &v);
 }
 
 // poly_invert sets |*out| to |in^-1| (i.e. such that |*out|×|in| = 1 mod Φ(N)).
diff --git a/crypto/hrss/hrss_test.cc b/crypto/hrss/hrss_test.cc
index 493255d..66b9047 100644
--- a/crypto/hrss/hrss_test.cc
+++ b/crypto/hrss/hrss_test.cc
@@ -22,59 +22,6 @@
 #include "../test/test_util.h"
 #include "internal.h"
 
-// poly2_from_bits takes the least-significant bit from each byte of |in| and
-// sets the bits of |*out| to match.
-static void poly2_from_bits(struct poly2 *out, const uint8_t in[N]) {
-  crypto_word_t *words = out->v;
-  unsigned shift = 0;
-  crypto_word_t word = 0;
-
-  for (unsigned i = 0; i < N; i++) {
-    word >>= 1;
-    word |= (crypto_word_t)(in[i] & 1) << (BITS_PER_WORD - 1);
-    shift++;
-
-    if (shift == BITS_PER_WORD) {
-      *words = word;
-      words++;
-      word = 0;
-      shift = 0;
-    }
-  }
-
-  word >>= BITS_PER_WORD - shift;
-  *words = word;
-}
-
-TEST(HRSS, Poly2RotateRight) {
-  uint8_t bits[N];
-  RAND_bytes(bits, sizeof(bits));
-  for (size_t i = 0; i < N; i++) {
-    bits[i] &= 1;
-  };
-
-  struct poly2 p, orig, shifted;
-  poly2_from_bits(&p, bits);
-  OPENSSL_memcpy(&orig, &p, sizeof(orig));
-
-  // Test |HRSS_poly2_rotr_consttime| by manually rotating |bits| step-by-step
-  // and testing every possible shift to ensure that it produces the correct
-  // answer.
-  for (size_t shift = 0; shift <= N; shift++) {
-    SCOPED_TRACE(shift);
-
-    OPENSSL_memcpy(&p, &orig, sizeof(orig));
-    HRSS_poly2_rotr_consttime(&p, shift);
-    poly2_from_bits(&shifted, bits);
-    ASSERT_EQ(
-        Bytes(reinterpret_cast<const uint8_t *>(&shifted), sizeof(shifted)),
-        Bytes(reinterpret_cast<const uint8_t *>(&p), sizeof(p)));
-
-    const uint8_t least_significant_bit = bits[0];
-    OPENSSL_memmove(bits, &bits[1], N-1);
-    bits[N-1] = least_significant_bit;
-  }
-}
 
 // poly3_rand sets |r| to a random value (albeit with bias).
 static void poly3_rand(poly3 *p) {
@@ -101,6 +48,21 @@
   memset(&inverse, 0, sizeof(inverse));
   memset(&result, 0, sizeof(result));
 
+  p.s.v[0] = 0;
+  p.a.v[0] = 1;
+  for (size_t i = 0; i < N - 1; i++) {
+    SCOPED_TRACE(i);
+    poly3 r;
+    OPENSSL_memset(&r, 0, sizeof(r));
+    r.a.v[i / BITS_PER_WORD] = (UINT64_C(1) << (i % BITS_PER_WORD));
+    HRSS_poly3_invert(&inverse, &r);
+    HRSS_poly3_mul(&result, &inverse, &r);
+    // r×r⁻¹ = 1, and |p| contains 1.
+    EXPECT_EQ(
+        Bytes(reinterpret_cast<const uint8_t *>(&p), sizeof(p)),
+        Bytes(reinterpret_cast<const uint8_t *>(&result), sizeof(result)));
+  }
+
   // The inverse of -1 is -1.
   p.s.v[0] = 1;
   p.a.v[0] = 1;
@@ -118,6 +80,10 @@
   for (size_t i = 0; i < 500; i++) {
     poly3 r;
     poly3_rand(&r);
+    // Drop the term at x^700 because |HRSS_poly3_invert| only handles reduced
+    // inputs.
+    r.s.v[WORDS_PER_POLY - 1] &= (UINT64_C(1) << (BITS_IN_LAST_WORD - 1)) - 1;
+    r.a.v[WORDS_PER_POLY - 1] &= (UINT64_C(1) << (BITS_IN_LAST_WORD - 1)) - 1;
     HRSS_poly3_invert(&inverse, &r);
     HRSS_poly3_mul(&result, &inverse, &r);
     // r×r⁻¹ = 1, and |p| contains 1.
@@ -132,6 +98,10 @@
   // Φ(N).
   poly3 r, inverse, result, one;
   poly3_rand(&r);
+  // Drop the term at x^700 because |HRSS_poly3_invert| only handles reduced
+  // inputs.
+  r.s.v[WORDS_PER_POLY - 1] &= (UINT64_C(1) << (BITS_IN_LAST_WORD - 1)) - 1;
+  r.a.v[WORDS_PER_POLY - 1] &= (UINT64_C(1) << (BITS_IN_LAST_WORD - 1)) - 1;
   HRSS_poly3_invert(&inverse, &r);
   HRSS_poly3_mul(&result, &inverse, &r);
 
@@ -140,9 +110,8 @@
   EXPECT_EQ(Bytes(reinterpret_cast<const uint8_t *>(&one), sizeof(one)),
             Bytes(reinterpret_cast<const uint8_t *>(&result), sizeof(result)));
 
-  // |r| is probably already not reduced mod Φ(N), but add x^701 - 1 and
-  // recompute to ensure that we get the same answer. (Since (x^701 - 1) ≡ 0 mod
-  // Φ(N).)
+  // |r| is reduced mod Φ(N), so add x^701 - 1 and recompute to ensure that we
+  // get the same answer. (Since (x^701 - 1) ≡ 0 mod Φ(N).)
   poly3_word_add(&r.s.v[0], &r.a.v[0], 1, 1);
   poly3_word_add(&r.s.v[WORDS_PER_POLY - 1], &r.a.v[WORDS_PER_POLY - 1], 0,
                  UINT64_C(1) << BITS_IN_LAST_WORD);
diff --git a/crypto/hrss/internal.h b/crypto/hrss/internal.h
index 7cfe010..c0d9bd2 100644
--- a/crypto/hrss/internal.h
+++ b/crypto/hrss/internal.h
@@ -36,7 +36,6 @@
   struct poly2 s, a;
 };
 
-OPENSSL_EXPORT void HRSS_poly2_rotr_consttime(struct poly2 *p, size_t bits);
 OPENSSL_EXPORT void HRSS_poly3_mul(struct poly3 *out, const struct poly3 *x,
                                    const struct poly3 *y);
 OPENSSL_EXPORT void HRSS_poly3_invert(struct poly3 *out,