Clear bottom three bits of password scalar in SPAKE2.

Due to a copy-paste error, the call to |left_shift_3| is missing after
reducing the password scalar in SPAKE2. This means that three bits of
the password leak in Alice's message. (Two in Bob's message as the point
N happens to have order 4l, not 8l.)

The “correct” fix is to put in the missing call to |left_shift_3|, but
that would be a breaking change. In order to fix this in a unilateral
way, we add points of small order to the masking point to bring it into
prime-order subgroup.

BUG=chromium:778101

Change-Id: I440931a3df7f009b324d2a3e3af2d893a101804f
Reviewed-on: https://boringssl-review.googlesource.com/22445
Reviewed-by: Adam Langley <agl@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: Adam Langley <agl@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/crypto/curve25519/internal.h b/crypto/curve25519/internal.h
index 9487a6c..4d79cb9 100644
--- a/crypto/curve25519/internal.h
+++ b/crypto/curve25519/internal.h
@@ -101,6 +101,26 @@
 void x25519_ge_scalarmult(ge_p2 *r, const uint8_t *scalar, const ge_p3 *A);
 void x25519_sc_reduce(uint8_t *s);
 
+enum spake2_state_t {
+  spake2_state_init = 0,
+  spake2_state_msg_generated,
+  spake2_state_key_generated,
+};
+
+struct spake2_ctx_st {
+  uint8_t private_key[32];
+  uint8_t my_msg[32];
+  uint8_t password_scalar[32];
+  uint8_t password_hash[64];
+  uint8_t *my_name;
+  size_t my_name_len;
+  uint8_t *their_name;
+  size_t their_name_len;
+  enum spake2_role_t my_role;
+  enum spake2_state_t state;
+  char disable_password_scalar_hack;
+};
+
 
 #if defined(__cplusplus)
 }  // extern C
diff --git a/crypto/curve25519/spake25519.c b/crypto/curve25519/spake25519.c
index 8ebedf9..e17d510 100644
--- a/crypto/curve25519/spake25519.c
+++ b/crypto/curve25519/spake25519.c
@@ -14,6 +14,7 @@
 
 #include <openssl/curve25519.h>
 
+#include <assert.h>
 #include <string.h>
 
 #include <openssl/bytestring.h>
@@ -267,25 +268,6 @@
     0xa6, 0x76, 0x81, 0x28, 0xb2, 0x65, 0xe8, 0x47, 0x14, 0xc6, 0x39, 0x06,
 };
 
-enum spake2_state_t {
-  spake2_state_init = 0,
-  spake2_state_msg_generated,
-  spake2_state_key_generated,
-};
-
-struct spake2_ctx_st {
-  uint8_t private_key[32];
-  uint8_t my_msg[32];
-  uint8_t password_scalar[32];
-  uint8_t password_hash[SHA512_DIGEST_LENGTH];
-  uint8_t *my_name;
-  size_t my_name_len;
-  uint8_t *their_name;
-  size_t their_name_len;
-  enum spake2_role_t my_role;
-  enum spake2_state_t state;
-};
-
 SPAKE2_CTX *SPAKE2_CTX_new(enum spake2_role_t my_role,
                            const uint8_t *my_name, size_t my_name_len,
                            const uint8_t *their_name, size_t their_name_len) {
@@ -332,6 +314,48 @@
   }
 }
 
+typedef union {
+  uint8_t bytes[32];
+  uint32_t words[8];
+} scalar;
+
+// kOrder is the order of the prime-order subgroup of curve25519 in
+// little-endian order.
+static const scalar kOrder = {{0xed, 0xd3, 0xf5, 0x5c, 0x1a, 0x63, 0x12, 0x58,
+                               0xd6, 0x9c, 0xf7, 0xa2, 0xde, 0xf9, 0xde, 0x14,
+                               0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+                               0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10}};
+
+// scalar_cmov copies |src| to |dest| if |mask| is all ones.
+static void scalar_cmov(scalar *dest, const scalar *src, crypto_word_t mask) {
+  for (size_t i = 0; i < 8; i++) {
+    dest->words[i] =
+        constant_time_select_w(mask, src->words[i], dest->words[i]);
+  }
+}
+
+// scalar_double sets |s| to |2×s|.
+static void scalar_double(scalar *s) {
+  uint32_t carry = 0;
+
+  for (size_t i = 0; i < 8; i++) {
+    const uint32_t carry_out = s->words[i] >> 31;
+    s->words[i] = (s->words[i] << 1) | carry;
+    carry = carry_out;
+  }
+}
+
+// scalar_add sets |dest| to |dest| plus |src|.
+static void scalar_add(scalar *dest, const scalar *src) {
+  uint32_t carry = 0;
+
+  for (size_t i = 0; i < 8; i++) {
+    uint64_t tmp = ((uint64_t)dest->words[i] + src->words[i]) + carry;
+    dest->words[i] = (uint32_t)tmp;
+    carry = (uint32_t)(tmp >> 32);
+  }
+}
+
 int SPAKE2_generate_msg(SPAKE2_CTX *ctx, uint8_t *out, size_t *out_len,
                          size_t max_out_len, const uint8_t *password,
                          size_t password_len) {
@@ -359,13 +383,61 @@
   SHA512(password, password_len, password_tmp);
   OPENSSL_memcpy(ctx->password_hash, password_tmp, sizeof(ctx->password_hash));
   x25519_sc_reduce(password_tmp);
-  OPENSSL_memcpy(ctx->password_scalar, password_tmp, sizeof(ctx->password_scalar));
+
+  // Due to a copy-paste error, the call to |left_shift_3| was omitted after
+  // the |x25519_sc_reduce|, just above. This meant that |ctx->password_scalar|
+  // was not a multiple of eight to clear the cofactor and thus three bits of
+  // the password hash would leak. In order to fix this in a unilateral way,
+  // points of small order are added to the mask point such that it is in the
+  // prime-order subgroup. Since the ephemeral scalar is a multiple of eight,
+  // these points will cancel out when calculating the shared secret.
+  //
+  // Adding points of small order is the same as adding multiples of the prime
+  // order to the password scalar. Since that's faster, that is what is done
+  // below. The prime order (kOrder) is a large prime, thus odd, thus the LSB
+  // is one. So adding it will flip the LSB. Adding twice it will flip the next
+  // bit and so one for all the bottom three bits.
+
+  scalar password_scalar;
+  OPENSSL_memcpy(&password_scalar, password_tmp, sizeof(password_scalar));
+
+  // |password_scalar| is the result of |x25519_sc_reduce| and thus is, at
+  // most, $l-1$ (where $l$ is |kOrder|, the order of the prime-order subgroup
+  // of Ed25519). In the following, we may add $l + 2×l + 4×l$ for a max value
+  // of $8×l-1$. That is < 2**256, as required.
+
+  if (!ctx->disable_password_scalar_hack) {
+    scalar order = kOrder;
+    scalar tmp;
+
+    OPENSSL_memset(&tmp, 0, sizeof(tmp));
+    scalar_cmov(&tmp, &order,
+                constant_time_eq_w(password_scalar.bytes[0] & 1, 1));
+    scalar_add(&password_scalar, &tmp);
+
+    scalar_double(&order);
+    OPENSSL_memset(&tmp, 0, sizeof(tmp));
+    scalar_cmov(&tmp, &order,
+                constant_time_eq_w(password_scalar.bytes[0] & 2, 2));
+    scalar_add(&password_scalar, &tmp);
+
+    scalar_double(&order);
+    OPENSSL_memset(&tmp, 0, sizeof(tmp));
+    scalar_cmov(&tmp, &order,
+                constant_time_eq_w(password_scalar.bytes[0] & 4, 4));
+    scalar_add(&password_scalar, &tmp);
+
+    assert((password_scalar.bytes[0] & 7) == 0);
+  }
+
+  OPENSSL_memcpy(ctx->password_scalar, password_scalar.bytes,
+                 sizeof(ctx->password_scalar));
 
   ge_p3 mask;
   x25519_ge_scalarmult_small_precomp(&mask, ctx->password_scalar,
-                              ctx->my_role == spake2_role_alice
-                                  ? kSpakeMSmallPrecomp
-                                  : kSpakeNSmallPrecomp);
+                                     ctx->my_role == spake2_role_alice
+                                         ? kSpakeMSmallPrecomp
+                                         : kSpakeNSmallPrecomp);
 
   // P* = P + mask.
   ge_cached mask_cached;
diff --git a/crypto/curve25519/spake25519_test.cc b/crypto/curve25519/spake25519_test.cc
index cdf4ff5..3ebd0a9 100644
--- a/crypto/curve25519/spake25519_test.cc
+++ b/crypto/curve25519/spake25519_test.cc
@@ -23,6 +23,7 @@
 #include <gtest/gtest.h>
 
 #include "../internal.h"
+#include "internal.h"
 
 
 // TODO(agl): add tests with fixed vectors once SPAKE2 is nailed down.
@@ -46,6 +47,13 @@
       return false;
     }
 
+    if (alice_disable_password_scalar_hack) {
+      alice->disable_password_scalar_hack = 1;
+    }
+    if (bob_disable_password_scalar_hack) {
+      bob->disable_password_scalar_hack = 1;
+    }
+
     uint8_t alice_msg[SPAKE2_MAX_MSG_SIZE];
     uint8_t bob_msg[SPAKE2_MAX_MSG_SIZE];
     size_t alice_msg_len, bob_msg_len;
@@ -90,6 +98,8 @@
   std::string bob_password = "password";
   std::pair<std::string, std::string> alice_names = {"alice", "bob"};
   std::pair<std::string, std::string> bob_names = {"bob", "alice"};
+  bool alice_disable_password_scalar_hack = false;
+  bool bob_disable_password_scalar_hack = false;
   int alice_corrupt_msg_bit = -1;
 
  private:
@@ -104,6 +114,24 @@
   }
 }
 
+TEST(SPAKE25519Test, OldAlice) {
+  for (unsigned i = 0; i < 20; i++) {
+    SPAKE2Run spake2;
+    spake2.alice_disable_password_scalar_hack = true;
+    ASSERT_TRUE(spake2.Run());
+    EXPECT_TRUE(spake2.key_matches());
+  }
+}
+
+TEST(SPAKE25519Test, OldBob) {
+  for (unsigned i = 0; i < 20; i++) {
+    SPAKE2Run spake2;
+    spake2.bob_disable_password_scalar_hack = true;
+    ASSERT_TRUE(spake2.Run());
+    EXPECT_TRUE(spake2.key_matches());
+  }
+}
+
 TEST(SPAKE25519Test, WrongPassword) {
   SPAKE2Run spake2;
   spake2.bob_password = "wrong password";