Drop CECPQ2 support. HRSS itself remains in libcrypto because there are some direct users of it. But this will let it be dropped by the linker in many cases. Change-Id: I870eda30c9ed1d08693c770e9e7df45a2711b7df Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/58645 Commit-Queue: Adam Langley <agl@google.com> Reviewed-by: David Benjamin <davidben@google.com> Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/obj/obj_dat.h b/crypto/obj/obj_dat.h index d879233..7cd1153 100644 --- a/crypto/obj/obj_dat.h +++ b/crypto/obj/obj_dat.h
@@ -8777,7 +8777,7 @@ {"AuthPSK", "auth-psk", NID_auth_psk, 0, NULL, 0}, {"KxANY", "kx-any", NID_kx_any, 0, NULL, 0}, {"AuthANY", "auth-any", NID_auth_any, 0, NULL, 0}, - {"CECPQ2", "CECPQ2", NID_CECPQ2, 0, NULL, 0}, + {NULL, NULL, NID_undef, 0, NULL, 0}, {"ED448", "ED448", NID_ED448, 3, &kObjectData[6181], 0}, {"X448", "X448", NID_X448, 3, &kObjectData[6184], 0}, {"SHA512-256", "sha512-256", NID_sha512_256, 9, &kObjectData[6187], 0}, @@ -8846,7 +8846,6 @@ 110 /* CAST5-CFB */, 109 /* CAST5-ECB */, 111 /* CAST5-OFB */, - 959 /* CECPQ2 */, 894 /* CMAC */, 13 /* CN */, 141 /* CRLReason */, @@ -9758,7 +9757,6 @@ 285 /* Biometric Info */, 179 /* CA Issuers */, 785 /* CA Repository */, - 959 /* CECPQ2 */, 131 /* Code Signing */, 783 /* Diffie-Hellman based MAC */, 382 /* Directory */,
diff --git a/crypto/obj/obj_mac.num b/crypto/obj/obj_mac.num index c0473bc..583f6e3 100644 --- a/crypto/obj/obj_mac.num +++ b/crypto/obj/obj_mac.num
@@ -947,7 +947,6 @@ auth_psk 956 kx_any 957 auth_any 958 -CECPQ2 959 ED448 960 X448 961 sha512_256 962
diff --git a/crypto/obj/objects.txt b/crypto/obj/objects.txt index 11151f9..cad6a3b 100644 --- a/crypto/obj/objects.txt +++ b/crypto/obj/objects.txt
@@ -1333,7 +1333,6 @@ : dh-cofactor-kdf # NIDs for post quantum key agreements (no corresponding OIDs). - : CECPQ2 : X25519Kyber768 : P256Kyber768 : P384Kyber768
diff --git a/include/openssl/nid.h b/include/openssl/nid.h index cf5691d..64c9c9c 100644 --- a/include/openssl/nid.h +++ b/include/openssl/nid.h
@@ -4235,9 +4235,6 @@ #define LN_auth_any "auth-any" #define NID_auth_any 958 -#define SN_CECPQ2 "CECPQ2" -#define NID_CECPQ2 959 - #define SN_ED448 "ED448" #define NID_ED448 960 #define OBJ_ED448 1L, 3L, 101L, 113L
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h index 68253bd..da78d1a 100644 --- a/include/openssl/ssl.h +++ b/include/openssl/ssl.h
@@ -2334,7 +2334,6 @@ #define SSL_CURVE_SECP384R1 24 #define SSL_CURVE_SECP521R1 25 #define SSL_CURVE_X25519 29 -#define SSL_CURVE_CECPQ2 16696 #define SSL_CURVE_X25519KYBER768 0x6399 #define SSL_CURVE_P256KYBER768 0xfe32
diff --git a/ssl/extensions.cc b/ssl/extensions.cc index ba92360..4d9651b 100644 --- a/ssl/extensions.cc +++ b/ssl/extensions.cc
@@ -206,7 +206,6 @@ static bool is_post_quantum_group(uint16_t id) { switch (id) { - case SSL_CURVE_CECPQ2: case SSL_CURVE_X25519KYBER768: case SSL_CURVE_P256KYBER768: return true; @@ -414,7 +413,7 @@ bool tls1_check_group_id(const SSL_HANDSHAKE *hs, uint16_t group_id) { if (is_post_quantum_group(group_id) && ssl_protocol_version(hs->ssl) < TLS1_3_VERSION) { - // CECPQ2(b) requires TLS 1.3. + // Post-quantum "groups" require TLS 1.3. return false; }
diff --git a/ssl/ssl_key_share.cc b/ssl/ssl_key_share.cc index 5741c6b..8885246 100644 --- a/ssl/ssl_key_share.cc +++ b/ssl/ssl_key_share.cc
@@ -192,101 +192,6 @@ uint8_t private_key_[32]; }; -class CECPQ2KeyShare : public SSLKeyShare { - public: - CECPQ2KeyShare() {} - - uint16_t GroupID() const override { return SSL_CURVE_CECPQ2; } - - bool Generate(CBB *out) override { - uint8_t x25519_public_key[32]; - X25519_keypair(x25519_public_key, x25519_private_key_); - - uint8_t hrss_entropy[HRSS_GENERATE_KEY_BYTES]; - HRSS_public_key hrss_public_key; - RAND_bytes(hrss_entropy, sizeof(hrss_entropy)); - if (!HRSS_generate_key(&hrss_public_key, &hrss_private_key_, - hrss_entropy)) { - return false; - } - - uint8_t hrss_public_key_bytes[HRSS_PUBLIC_KEY_BYTES]; - HRSS_marshal_public_key(hrss_public_key_bytes, &hrss_public_key); - - if (!CBB_add_bytes(out, x25519_public_key, sizeof(x25519_public_key)) || - !CBB_add_bytes(out, hrss_public_key_bytes, - sizeof(hrss_public_key_bytes))) { - return false; - } - - return true; - } - - bool Encap(CBB *out_ciphertext, Array<uint8_t> *out_secret, - uint8_t *out_alert, Span<const uint8_t> peer_key) override { - Array<uint8_t> secret; - if (!secret.Init(32 + HRSS_KEY_BYTES)) { - return false; - } - - uint8_t x25519_public_key[32]; - X25519_keypair(x25519_public_key, x25519_private_key_); - - HRSS_public_key peer_public_key; - if (peer_key.size() != 32 + HRSS_PUBLIC_KEY_BYTES || - !HRSS_parse_public_key(&peer_public_key, peer_key.data() + 32) || - !X25519(secret.data(), x25519_private_key_, peer_key.data())) { - *out_alert = SSL_AD_DECODE_ERROR; - OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_ECPOINT); - return false; - } - - uint8_t ciphertext[HRSS_CIPHERTEXT_BYTES]; - uint8_t entropy[HRSS_ENCAP_BYTES]; - RAND_bytes(entropy, sizeof(entropy)); - - if (!HRSS_encap(ciphertext, secret.data() + 32, &peer_public_key, - entropy) || - !CBB_add_bytes(out_ciphertext, x25519_public_key, - sizeof(x25519_public_key)) || - !CBB_add_bytes(out_ciphertext, ciphertext, sizeof(ciphertext))) { - return false; - } - - *out_secret = std::move(secret); - return true; - } - - bool Decap(Array<uint8_t> *out_secret, uint8_t *out_alert, - Span<const uint8_t> ciphertext) override { - *out_alert = SSL_AD_INTERNAL_ERROR; - - Array<uint8_t> secret; - if (!secret.Init(32 + HRSS_KEY_BYTES)) { - return false; - } - - if (ciphertext.size() != 32 + HRSS_CIPHERTEXT_BYTES || - !X25519(secret.data(), x25519_private_key_, ciphertext.data())) { - *out_alert = SSL_AD_DECODE_ERROR; - OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_ECPOINT); - return false; - } - - if (!HRSS_decap(secret.data() + 32, &hrss_private_key_, - ciphertext.data() + 32, ciphertext.size() - 32)) { - return false; - } - - *out_secret = std::move(secret); - return true; - } - - private: - uint8_t x25519_private_key_[32]; - HRSS_private_key hrss_private_key_; -}; - class X25519Kyber768KeyShare : public SSLKeyShare { public: X25519Kyber768KeyShare() {} @@ -405,7 +310,6 @@ {NID_secp384r1, SSL_CURVE_SECP384R1, "P-384", "secp384r1"}, {NID_secp521r1, SSL_CURVE_SECP521R1, "P-521", "secp521r1"}, {NID_X25519, SSL_CURVE_X25519, "X25519", "x25519"}, - {NID_CECPQ2, SSL_CURVE_CECPQ2, "CECPQ2", "CECPQ2"}, {NID_X25519Kyber768, SSL_CURVE_X25519KYBER768, "X25519KYBER", "X25519Kyber"}, {NID_P256Kyber768, SSL_CURVE_P256KYBER768, "P256KYBER", "P256Kyber"}, @@ -429,8 +333,6 @@ return MakeUnique<ECKeyShare>(NID_secp521r1, SSL_CURVE_SECP521R1); case SSL_CURVE_X25519: return MakeUnique<X25519KeyShare>(); - case SSL_CURVE_CECPQ2: - return MakeUnique<CECPQ2KeyShare>(); case SSL_CURVE_X25519KYBER768: return MakeUnique<X25519Kyber768KeyShare>(); case SSL_CURVE_P256KYBER768:
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc index 85c06a5..854068f 100644 --- a/ssl/ssl_test.cc +++ b/ssl/ssl_test.cc
@@ -401,8 +401,8 @@ { SSL_CURVE_SECP256R1 }, }, { - "P-256:CECPQ2", - { SSL_CURVE_SECP256R1, SSL_CURVE_CECPQ2 }, + "P-256:X25519KYBER", + { SSL_CURVE_SECP256R1, SSL_CURVE_X25519KYBER768 }, }, {
diff --git a/ssl/test/fuzzer.h b/ssl/test/fuzzer.h index 00b5e84..8f73fc0 100644 --- a/ssl/test/fuzzer.h +++ b/ssl/test/fuzzer.h
@@ -418,8 +418,9 @@ return false; } - static const int kCurves[] = {NID_CECPQ2, NID_X25519, NID_X9_62_prime256v1, - NID_secp384r1, NID_secp521r1}; + static const int kCurves[] = {NID_X25519Kyber768, NID_X25519, + NID_X9_62_prime256v1, NID_secp384r1, + NID_secp521r1}; if (!SSL_CTX_set1_curves(ctx_.get(), kCurves, OPENSSL_ARRAY_SIZE(kCurves))) { return false;
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go index 3854283..ce06779 100644 --- a/ssl/test/runner/common.go +++ b/ssl/test/runner/common.go
@@ -148,12 +148,12 @@ type CurveID uint16 const ( - CurveP224 CurveID = 21 - CurveP256 CurveID = 23 - CurveP384 CurveID = 24 - CurveP521 CurveID = 25 - CurveX25519 CurveID = 29 - CurveCECPQ2 CurveID = 16696 + CurveP224 CurveID = 21 + CurveP256 CurveID = 23 + CurveP384 CurveID = 24 + CurveP521 CurveID = 25 + CurveX25519 CurveID = 29 + CurveX25519Kyber768 CurveID = 0x6399 ) // TLS Elliptic Curve Point Formats @@ -1890,9 +1890,9 @@ // hello retry. FailIfHelloRetryRequested bool - // FailedIfCECPQ2Offered will cause a server to reject a ClientHello if CECPQ2 + // FailedIfKyberOffered will cause a server to reject a ClientHello if Kyber // is supported. - FailIfCECPQ2Offered bool + FailIfKyberOffered bool // ExpectKeyShares, if not nil, lists (in order) the curves that a ClientHello // should have key shares for. @@ -1996,7 +1996,7 @@ return ret } -var defaultCurvePreferences = []CurveID{CurveCECPQ2, CurveX25519, CurveP256, CurveP384, CurveP521} +var defaultCurvePreferences = []CurveID{CurveX25519, CurveP256, CurveP384, CurveP521} func (c *Config) curvePreferences() []CurveID { if c == nil || len(c.CurvePreferences) == 0 {
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go index de297a6..4f3cf75 100644 --- a/ssl/test/runner/handshake_server.go +++ b/ssl/test/runner/handshake_server.go
@@ -280,10 +280,10 @@ } } - if config.Bugs.FailIfCECPQ2Offered { + if config.Bugs.FailIfKyberOffered { for _, offeredCurve := range hs.clientHello.supportedCurves { if isPqGroup(offeredCurve) { - return errors.New("tls: CECPQ2 was offered") + return errors.New("tls: X25519Kyber768 was offered") } } } @@ -1467,7 +1467,7 @@ Curves: for _, curve := range hs.clientHello.supportedCurves { if isPqGroup(curve) && c.vers < VersionTLS13 { - // CECPQ2 is TLS 1.3-only. + // Post-quantum is TLS 1.3 only. continue }
diff --git a/ssl/test/runner/hrss/hrss.go b/ssl/test/runner/hrss/hrss.go deleted file mode 100644 index 9f4fdd7..0000000 --- a/ssl/test/runner/hrss/hrss.go +++ /dev/null
@@ -1,1212 +0,0 @@ -package hrss - -import ( - "crypto/hmac" - "crypto/sha256" - "crypto/subtle" - "encoding/binary" - "io" - "math/bits" -) - -const ( - PublicKeySize = modQBytes - CiphertextSize = modQBytes -) - -const ( - N = 701 - Q = 8192 - mod3Bytes = 140 - modQBytes = 1138 -) - -const ( - bitsPerWord = bits.UintSize - wordsPerPoly = (N + bitsPerWord - 1) / bitsPerWord - fullWordsPerPoly = N / bitsPerWord - bitsInLastWord = N % bitsPerWord -) - -// poly3 represents a degree-N polynomial over GF(3). Each coefficient is -// bitsliced across the |s| and |a| arrays, like this: -// -// s | a | value -// ----------------- -// 0 | 0 | 0 -// 0 | 1 | 1 -// 1 | 0 | 2 (aka -1) -// 1 | 1 | <invalid> -// -// ('s' is for sign, and 'a' is just a letter.) -// -// Once bitsliced as such, the following circuits can be used to implement -// addition and multiplication mod 3: -// -// (s3, a3) = (s1, a1) × (s2, a2) -// s3 = (s2 ∧ a1) ⊕ (s1 ∧ a2) -// a3 = (s1 ∧ s2) ⊕ (a1 ∧ a2) -// -// (s3, a3) = (s1, a1) + (s2, a2) -// t1 = ~(s1 ∨ a1) -// t2 = ~(s2 ∨ a2) -// s3 = (a1 ∧ a2) ⊕ (t1 ∧ s2) ⊕ (t2 ∧ s1) -// a3 = (s1 ∧ s2) ⊕ (t1 ∧ a2) ⊕ (t2 ∧ a1) -// -// Negating a value just involves swapping s and a. -type poly3 struct { - s [wordsPerPoly]uint - a [wordsPerPoly]uint -} - -func (p *poly3) trim() { - p.s[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1 - p.a[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1 -} - -func (p *poly3) zero() { - for i := range p.a { - p.s[i] = 0 - p.a[i] = 0 - } -} - -func (p *poly3) fromDiscrete(in *poly) { - var shift uint - s := p.s[:] - a := p.a[:] - s[0] = 0 - a[0] = 0 - - for _, v := range in { - s[0] >>= 1 - s[0] |= uint((v>>1)&1) << (bitsPerWord - 1) - a[0] >>= 1 - a[0] |= uint(v&1) << (bitsPerWord - 1) - shift++ - if shift == bitsPerWord { - s = s[1:] - a = a[1:] - s[0] = 0 - a[0] = 0 - shift = 0 - } - } - - a[0] >>= bitsPerWord - shift - s[0] >>= bitsPerWord - shift -} - -func (p *poly3) fromModQ(in *poly) int { - var shift uint - s := p.s[:] - a := p.a[:] - s[0] = 0 - a[0] = 0 - ok := 1 - - for _, v := range in { - vMod3, vOk := modQToMod3(v) - ok &= vOk - - s[0] >>= 1 - s[0] |= uint((vMod3>>1)&1) << (bitsPerWord - 1) - a[0] >>= 1 - a[0] |= uint(vMod3&1) << (bitsPerWord - 1) - shift++ - if shift == bitsPerWord { - s = s[1:] - a = a[1:] - s[0] = 0 - a[0] = 0 - shift = 0 - } - } - - a[0] >>= bitsPerWord - shift - s[0] >>= bitsPerWord - shift - - return ok -} - -func (p *poly3) fromDiscreteMod3(in *poly) { - var shift uint - s := p.s[:] - a := p.a[:] - s[0] = 0 - a[0] = 0 - - for _, v := range in { - // This duplicates the 13th bit upwards to the top of the - // uint16, essentially treating it as a sign bit and converting - // into a signed int16. The signed value is reduced mod 3, - // yeilding {-2, -1, 0, 1, 2}. - v = uint16((int16(v<<3)>>3)%3) & 7 - - // We want to map v thus: - // {-2, -1, 0, 1, 2} -> {1, 2, 0, 1, 2}. We take the bottom - // three bits and then the constants below, when shifted by - // those three bits, perform the required mapping. - s[0] >>= 1 - s[0] |= (0xbc >> v) << (bitsPerWord - 1) - a[0] >>= 1 - a[0] |= (0x7a >> v) << (bitsPerWord - 1) - shift++ - if shift == bitsPerWord { - s = s[1:] - a = a[1:] - s[0] = 0 - a[0] = 0 - shift = 0 - } - } - - a[0] >>= bitsPerWord - shift - s[0] >>= bitsPerWord - shift -} - -func (p *poly3) marshal(out []byte) { - s := p.s[:] - a := p.a[:] - sw := s[0] - aw := a[0] - var shift int - - for i := 0; i < 700; i += 5 { - acc, scale := 0, 1 - for j := 0; j < 5; j++ { - v := int(aw&1) | int(sw&1)<<1 - acc += scale * v - scale *= 3 - - shift++ - if shift == bitsPerWord { - s = s[1:] - a = a[1:] - sw = s[0] - aw = a[0] - shift = 0 - } else { - sw >>= 1 - aw >>= 1 - } - } - - out[0] = byte(acc) - out = out[1:] - } -} - -func (p *poly) fromMod2(in *poly2) { - var shift uint - words := in[:] - word := words[0] - - for i := range p { - p[i] = uint16(word & 1) - word >>= 1 - shift++ - if shift == bitsPerWord { - words = words[1:] - word = words[0] - shift = 0 - } - } -} - -func (p *poly) fromMod3(in *poly3) { - var shift uint - s := in.s[:] - a := in.a[:] - sw := s[0] - aw := a[0] - - for i := range p { - p[i] = uint16(aw&1 | (sw&1)<<1) - aw >>= 1 - sw >>= 1 - shift++ - if shift == bitsPerWord { - a = a[1:] - s = s[1:] - aw = a[0] - sw = s[0] - shift = 0 - } - } -} - -func (p *poly) fromMod3ToModQ(in *poly3) { - var shift uint - s := in.s[:] - a := in.a[:] - sw := s[0] - aw := a[0] - - for i := range p { - p[i] = mod3ToModQ(uint16(aw&1 | (sw&1)<<1)) - aw >>= 1 - sw >>= 1 - shift++ - if shift == bitsPerWord { - a = a[1:] - s = s[1:] - aw = a[0] - sw = s[0] - shift = 0 - } - } -} - -func lsbToAll(v uint) uint { - return uint(int(v<<(bitsPerWord-1)) >> (bitsPerWord - 1)) -} - -func (p *poly3) mulConst(ms, ma uint) { - ms = lsbToAll(ms) - ma = lsbToAll(ma) - - for i := range p.a { - p.s[i], p.a[i] = (ma&p.s[i])^(ms&p.a[i]), (ma&p.a[i])^(ms&p.s[i]) - } -} - -func cmovWords(out, in *[wordsPerPoly]uint, mov uint) { - for i := range out { - out[i] = (out[i] & ^mov) | (in[i] & mov) - } -} - -func rotWords(out, in *[wordsPerPoly]uint, bits uint) { - start := bits / bitsPerWord - n := (N - bits) / bitsPerWord - - for i := uint(0); i < n; i++ { - out[i] = in[start+i] - } - - carry := in[wordsPerPoly-1] - - for i := uint(0); i < start; i++ { - out[n+i] = carry | in[i]<<bitsInLastWord - carry = in[i] >> (bitsPerWord - bitsInLastWord) - } - - out[wordsPerPoly-1] = carry -} - -// rotBits right-rotates the bits in |in|. bits must be a non-zero power of two -// and less than bitsPerWord. -func rotBits(out, in *[wordsPerPoly]uint, bits uint) { - if (bits == 0 || (bits & (bits - 1)) != 0 || bits > bitsPerWord/2 || bitsInLastWord < bitsPerWord/2) { - panic("internal error"); - } - - carry := in[wordsPerPoly-1] << (bitsPerWord - bits) - - for i := wordsPerPoly - 2; i >= 0; i-- { - out[i] = carry | in[i]>>bits - carry = in[i] << (bitsPerWord - bits) - } - - out[wordsPerPoly-1] = carry>>(bitsPerWord-bitsInLastWord) | in[wordsPerPoly-1]>>bits -} - -func (p *poly3) rotWords(bits uint, in *poly3) { - rotWords(&p.s, &in.s, bits) - rotWords(&p.a, &in.a, bits) -} - -func (p *poly3) rotBits(bits uint, in *poly3) { - rotBits(&p.s, &in.s, bits) - rotBits(&p.a, &in.a, bits) -} - -func (p *poly3) cmov(in *poly3, mov uint) { - cmovWords(&p.s, &in.s, mov) - cmovWords(&p.a, &in.a, mov) -} - -func (p *poly3) rot(bits uint) { - if bits > N { - panic("invalid") - } - var shifted poly3 - - shift := uint(9) - for ; (1 << shift) >= bitsPerWord; shift-- { - shifted.rotWords(1<<shift, p) - p.cmov(&shifted, lsbToAll(bits>>shift)) - } - for ; shift < 9; shift-- { - shifted.rotBits(1<<shift, p) - p.cmov(&shifted, lsbToAll(bits>>shift)) - } -} - -func (p *poly3) fmadd(ms, ma uint, in *poly3) { - ms = lsbToAll(ms) - ma = lsbToAll(ma) - - for i := range p.a { - products := (ma & in.s[i]) ^ (ms & in.a[i]) - producta := (ma & in.a[i]) ^ (ms & in.s[i]) - - ns1Ana1 := ^p.s[i] & ^p.a[i] - ns2Ana2 := ^products & ^producta - - p.s[i], p.a[i] = (p.a[i]&producta)^(ns1Ana1&products)^(p.s[i]&ns2Ana2), (p.s[i]&products)^(ns1Ana1&producta)^(p.a[i]&ns2Ana2) - } -} - -func (p *poly3) modPhiN() { - factora := uint(int(p.s[wordsPerPoly-1]<<(bitsPerWord-bitsInLastWord)) >> (bitsPerWord - 1)) - factors := uint(int(p.a[wordsPerPoly-1]<<(bitsPerWord-bitsInLastWord)) >> (bitsPerWord - 1)) - ns2Ana2 := ^factors & ^factora - - for i := range p.s { - ns1Ana1 := ^p.s[i] & ^p.a[i] - p.s[i], p.a[i] = (p.a[i]&factora)^(ns1Ana1&factors)^(p.s[i]&ns2Ana2), (p.s[i]&factors)^(ns1Ana1&factora)^(p.a[i]&ns2Ana2) - } -} - -func (p *poly3) cswap(other *poly3, swap uint) { - for i := range p.s { - sums := swap & (p.s[i] ^ other.s[i]) - p.s[i] ^= sums - other.s[i] ^= sums - - suma := swap & (p.a[i] ^ other.a[i]) - p.a[i] ^= suma - other.a[i] ^= suma - } -} - -func (p *poly3) mulx() { - carrys := (p.s[wordsPerPoly-1] >> (bitsInLastWord - 1)) & 1 - carrya := (p.a[wordsPerPoly-1] >> (bitsInLastWord - 1)) & 1 - - for i := range p.s { - outCarrys := p.s[i] >> (bitsPerWord - 1) - outCarrya := p.a[i] >> (bitsPerWord - 1) - p.s[i] <<= 1 - p.a[i] <<= 1 - p.s[i] |= carrys - p.a[i] |= carrya - carrys = outCarrys - carrya = outCarrya - } -} - -func (p *poly3) divx() { - var carrys, carrya uint - - for i := len(p.s) - 1; i >= 0; i-- { - outCarrys := p.s[i] & 1 - outCarrya := p.a[i] & 1 - p.s[i] >>= 1 - p.a[i] >>= 1 - p.s[i] |= carrys << (bitsPerWord - 1) - p.a[i] |= carrya << (bitsPerWord - 1) - carrys = outCarrys - carrya = outCarrya - } -} - -type poly2 [wordsPerPoly]uint - -func (p *poly2) fromDiscrete(in *poly) { - var shift uint - words := p[:] - words[0] = 0 - - for _, v := range in { - words[0] >>= 1 - words[0] |= uint(v&1) << (bitsPerWord - 1) - shift++ - if shift == bitsPerWord { - words = words[1:] - words[0] = 0 - shift = 0 - } - } - - words[0] >>= bitsPerWord - shift -} - -func (p *poly2) setPhiN() { - for i := range p { - p[i] = ^uint(0) - } - p[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1 -} - -func (p *poly2) cswap(other *poly2, swap uint) { - for i := range p { - sum := swap & (p[i] ^ other[i]) - p[i] ^= sum - other[i] ^= sum - } -} - -func (p *poly2) fmadd(m uint, in *poly2) { - m = ^(m - 1) - - for i := range p { - p[i] ^= in[i] & m - } -} - -func (p *poly2) lshift1() { - var carry uint - for i := range p { - nextCarry := p[i] >> (bitsPerWord - 1) - p[i] <<= 1 - p[i] |= carry - carry = nextCarry - } -} - -func (p *poly2) rshift1() { - var carry uint - for i := len(p) - 1; i >= 0; i-- { - nextCarry := p[i] & 1 - p[i] >>= 1 - p[i] |= carry << (bitsPerWord - 1) - carry = nextCarry - } -} - -func (p *poly2) rot(bits uint) { - if bits > N { - panic("invalid") - } - var shifted [wordsPerPoly]uint - out := (*[wordsPerPoly]uint)(p) - - shift := uint(9) - for ; (1 << shift) >= bitsPerWord; shift-- { - rotWords(&shifted, out, 1<<shift) - cmovWords(out, &shifted, lsbToAll(bits>>shift)) - } - for ; shift < 9; shift-- { - rotBits(&shifted, out, 1<<shift) - cmovWords(out, &shifted, lsbToAll(bits>>shift)) - } -} - -type poly [N]uint16 - -func (in *poly) marshal(out []byte) { - p := in[:] - - for len(p) >= 8 { - out[0] = byte(p[0]) - out[1] = byte(p[0]>>8) | byte((p[1]&0x07)<<5) - out[2] = byte(p[1] >> 3) - out[3] = byte(p[1]>>11) | byte((p[2]&0x3f)<<2) - out[4] = byte(p[2]>>6) | byte((p[3]&0x01)<<7) - out[5] = byte(p[3] >> 1) - out[6] = byte(p[3]>>9) | byte((p[4]&0x0f)<<4) - out[7] = byte(p[4] >> 4) - out[8] = byte(p[4]>>12) | byte((p[5]&0x7f)<<1) - out[9] = byte(p[5]>>7) | byte((p[6]&0x03)<<6) - out[10] = byte(p[6] >> 2) - out[11] = byte(p[6]>>10) | byte((p[7]&0x1f)<<3) - out[12] = byte(p[7] >> 5) - - p = p[8:] - out = out[13:] - } - - // There are four remaining values. - out[0] = byte(p[0]) - out[1] = byte(p[0]>>8) | byte((p[1]&0x07)<<5) - out[2] = byte(p[1] >> 3) - out[3] = byte(p[1]>>11) | byte((p[2]&0x3f)<<2) - out[4] = byte(p[2]>>6) | byte((p[3]&0x01)<<7) - out[5] = byte(p[3] >> 1) - out[6] = byte(p[3] >> 9) -} - -func (out *poly) unmarshal(in []byte) bool { - p := out[:] - for i := 0; i < 87; i++ { - p[0] = uint16(in[0]) | uint16(in[1]&0x1f)<<8 - p[1] = uint16(in[1]>>5) | uint16(in[2])<<3 | uint16(in[3]&3)<<11 - p[2] = uint16(in[3]>>2) | uint16(in[4]&0x7f)<<6 - p[3] = uint16(in[4]>>7) | uint16(in[5])<<1 | uint16(in[6]&0xf)<<9 - p[4] = uint16(in[6]>>4) | uint16(in[7])<<4 | uint16(in[8]&1)<<12 - p[5] = uint16(in[8]>>1) | uint16(in[9]&0x3f)<<7 - p[6] = uint16(in[9]>>6) | uint16(in[10])<<2 | uint16(in[11]&7)<<10 - p[7] = uint16(in[11]>>3) | uint16(in[12])<<5 - - p = p[8:] - in = in[13:] - } - - // There are four coefficients left over - p[0] = uint16(in[0]) | uint16(in[1]&0x1f)<<8 - p[1] = uint16(in[1]>>5) | uint16(in[2])<<3 | uint16(in[3]&3)<<11 - p[2] = uint16(in[3]>>2) | uint16(in[4]&0x7f)<<6 - p[3] = uint16(in[4]>>7) | uint16(in[5])<<1 | uint16(in[6]&0xf)<<9 - - if in[6]&0xf0 != 0 { - return false - } - - out[N-1] = 0 - var top int - for _, v := range out { - top += int(v) - } - - out[N-1] = uint16(-top) % Q - return true -} - -func (in *poly) marshalS3(out []byte) { - p := in[:] - for len(p) >= 5 { - out[0] = byte(p[0] + p[1]*3 + p[2]*9 + p[3]*27 + p[4]*81) - out = out[1:] - p = p[5:] - } -} - -func (out *poly) unmarshalS3(in []byte) bool { - p := out[:] - for i := 0; i < 140; i++ { - c := in[0] - if c >= 243 { - return false - } - p[0] = uint16(c % 3) - p[1] = uint16((c / 3) % 3) - p[2] = uint16((c / 9) % 3) - p[3] = uint16((c / 27) % 3) - p[4] = uint16((c / 81) % 3) - - p = p[5:] - in = in[1:] - } - - out[N-1] = 0 - return true -} - -func (p *poly) modPhiN() { - for i := range p { - p[i] = (p[i] + Q - p[N-1]) % Q - } -} - -func (out *poly) shortSample(in []byte) { - // b a result - // 00 00 00 - // 00 01 01 - // 00 10 10 - // 00 11 11 - // 01 00 10 - // 01 01 00 - // 01 10 01 - // 01 11 11 - // 10 00 01 - // 10 01 10 - // 10 10 00 - // 10 11 11 - // 11 00 11 - // 11 01 11 - // 11 10 11 - // 11 11 11 - - // 1111 1111 1100 1001 1101 0010 1110 0100 - // f f c 9 d 2 e 4 - const lookup = uint32(0xffc9d2e4) - - p := out[:] - for i := 0; i < 87; i++ { - v := binary.LittleEndian.Uint32(in) - v2 := (v & 0x55555555) + ((v >> 1) & 0x55555555) - for j := 0; j < 8; j++ { - p[j] = uint16(lookup >> ((v2 & 15) << 1) & 3) - v2 >>= 4 - } - p = p[8:] - in = in[4:] - } - - // There are four values remaining. - v := binary.LittleEndian.Uint32(in) - v2 := (v & 0x55555555) + ((v >> 1) & 0x55555555) - for j := 0; j < 4; j++ { - p[j] = uint16(lookup >> ((v2 & 15) << 1) & 3) - v2 >>= 4 - } - - out[N-1] = 0 -} - -func (out *poly) shortSamplePlus(in []byte) { - out.shortSample(in) - - var sum uint16 - for i := 0; i < N-1; i++ { - sum += mod3ResultToModQ(out[i] * out[i+1]) - } - - scale := 1 + (1 & (sum >> 12)) - for i := 0; i < len(out); i += 2 { - out[i] = (out[i] * scale) % 3 - } -} - -func mul(out, scratch, a, b []uint16) { - const schoolbookLimit = 32 - if len(a) < schoolbookLimit { - for i := 0; i < len(a)*2; i++ { - out[i] = 0 - } - for i := range a { - for j := range b { - out[i+j] += a[i] * b[j] - } - } - return - } - - lowLen := len(a) / 2 - highLen := len(a) - lowLen - aLow, aHigh := a[:lowLen], a[lowLen:] - bLow, bHigh := b[:lowLen], b[lowLen:] - - for i := 0; i < lowLen; i++ { - out[i] = aHigh[i] + aLow[i] - } - if highLen != lowLen { - out[lowLen] = aHigh[lowLen] - } - - for i := 0; i < lowLen; i++ { - out[highLen+i] = bHigh[i] + bLow[i] - } - if highLen != lowLen { - out[highLen+lowLen] = bHigh[lowLen] - } - - mul(scratch, scratch[2*highLen:], out[:highLen], out[highLen:highLen*2]) - mul(out[lowLen*2:], scratch[2*highLen:], aHigh, bHigh) - mul(out, scratch[2*highLen:], aLow, bLow) - - for i := 0; i < lowLen*2; i++ { - scratch[i] -= out[i] + out[lowLen*2+i] - } - if lowLen != highLen { - scratch[lowLen*2] -= out[lowLen*4] - } - - for i := 0; i < 2*highLen; i++ { - out[lowLen+i] += scratch[i] - } -} - -func (out *poly) mul(a, b *poly) { - var prod, scratch [2 * N]uint16 - mul(prod[:], scratch[:], a[:], b[:]) - for i := range out { - out[i] = (prod[i] + prod[i+N]) % Q - } -} - -func (p3 *poly3) mulMod3(x, y *poly3) { - // (𝑥^n - 1) is a multiple of Φ(N) so we can work mod (𝑥^n - 1) here and - // (reduce mod Φ(N) afterwards. - x3 := *x - y3 := *y - s := x3.s[:] - a := x3.a[:] - sw := s[0] - aw := a[0] - p3.zero() - var shift uint - for i := 0; i < N; i++ { - p3.fmadd(sw, aw, &y3) - sw >>= 1 - aw >>= 1 - shift++ - if shift == bitsPerWord { - s = s[1:] - a = a[1:] - sw = s[0] - aw = a[0] - shift = 0 - } - y3.mulx() - } - p3.modPhiN() -} - -// mod3ToModQ maps {0, 1, 2, 3} to {0, 1, Q-1, 0xffff} -// The case of n == 3 should never happen but is included so that modQToMod3 -// can easily catch invalid inputs. -func mod3ToModQ(n uint16) uint16 { - return uint16(uint64(0xffff1fff00010000) >> (16 * n)) -} - -// modQToMod3 maps {0, 1, Q-1} to {(0, 0), (0, 1), (1, 0)} and also returns an int -// which is one if the input is in range and zero otherwise. -func modQToMod3(n uint16) (uint16, int) { - result := (n&3 - (n>>1)&1) - return result, subtle.ConstantTimeEq(int32(mod3ToModQ(result)), int32(n)) -} - -// mod3ResultToModQ maps {0, 1, 2, 4} to {0, 1, Q-1, 1} -func mod3ResultToModQ(n uint16) uint16 { - return ((((uint16(0x13) >> n) & 1) - 1) & 0x1fff) | ((uint16(0x12) >> n) & 1) - //shift := (uint(0x324) >> (2 * n)) & 3 - //return uint16(uint64(0x00011fff00010000) >> (16 * shift)) -} - -// mulXMinus1 sets out to a×(𝑥 - 1) mod (𝑥^n - 1) -func (out *poly) mulXMinus1() { - // Multiplying by (𝑥 - 1) means negating each coefficient and adding in - // the value of the previous one. - origOut700 := out[700] - - for i := N - 1; i > 0; i-- { - out[i] = (Q - out[i] + out[i-1]) % Q - } - out[0] = (Q - out[0] + origOut700) % Q -} - -func (out *poly) lift(a *poly) { - // We wish to calculate a/(𝑥-1) mod Φ(N) over GF(3), where Φ(N) is the - // Nth cyclotomic polynomial, i.e. 1 + 𝑥 + … + 𝑥^700 (since N is prime). - - // 1/(𝑥-1) has a fairly basic structure that we can exploit to speed this up: - // - // R.<x> = PolynomialRing(GF(3)…) - // inv = R.cyclotomic_polynomial(1).inverse_mod(R.cyclotomic_polynomial(n)) - // list(inv)[:15] - // [1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2] - // - // This three-element pattern of coefficients repeats for the whole - // polynomial. - // - // Next define the overbar operator such that z̅ = z[0] + - // reverse(z[1:]). (Index zero of a polynomial here is the coefficient - // of the constant term. So index one is the coefficient of 𝑥 and so - // on.) - // - // A less odd way to define this is to see that z̅ negates the indexes, - // so z̅[0] = z[-0], z̅[1] = z[-1] and so on. - // - // The use of z̅ is that, when working mod (𝑥^701 - 1), vz[0] = <v, - // z̅>, vz[1] = <v, 𝑥z̅>, …. (Where <a, b> is the inner product: the sum - // of the point-wise products.) Although we calculated the inverse mod - // Φ(N), we can work mod (𝑥^N - 1) and reduce mod Φ(N) at the end. - // (That's because (𝑥^N - 1) is a multiple of Φ(N).) - // - // When working mod (𝑥^N - 1), multiplication by 𝑥 is a right-rotation - // of the list of coefficients. - // - // Thus we can consider what the pattern of z̅, 𝑥z̅, 𝑥^2z̅, … looks like: - // - // def reverse(xs): - // suffix = list(xs[1:]) - // suffix.reverse() - // return [xs[0]] + suffix - // - // def rotate(xs): - // return [xs[-1]] + xs[:-1] - // - // zoverbar = reverse(list(inv) + [0]) - // xzoverbar = rotate(reverse(list(inv) + [0])) - // x2zoverbar = rotate(rotate(reverse(list(inv) + [0]))) - // - // zoverbar[:15] - // [1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1] - // xzoverbar[:15] - // [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0] - // x2zoverbar[:15] - // [2, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] - // - // (For a formula for z̅, see lemma two of appendix B.) - // - // After the first three elements have been taken care of, all then have - // a repeating three-element cycle. The next value (𝑥^3z̅) involves - // three rotations of the first pattern, thus the three-element cycle - // lines up. However, the discontinuity in the first three elements - // obviously moves to a different position. Consider the difference - // between 𝑥^3z̅ and z̅: - // - // [x-y for (x,y) in zip(zoverbar, x3zoverbar)][:15] - // [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - // - // This pattern of differences is the same for all elements, although it - // obviously moves right with the rotations. - // - // From this, we reach algorithm eight of appendix B. - - // Handle the first three elements of the inner products. - out[0] = a[0] + a[2] - out[1] = a[1] - out[2] = 2*a[0] + a[2] - - // Use the repeating pattern to complete the first three inner products. - for i := 3; i < 699; i += 3 { - out[0] += 2*a[i] + a[i+2] - out[1] += a[i] + 2*a[i+1] - out[2] += a[i+1] + 2*a[i+2] - } - - // Handle the fact that the three-element pattern doesn't fill the - // polynomial exactly (since 701 isn't a multiple of three). - out[2] += a[700] - out[0] += 2 * a[699] - out[1] += a[699] + 2*a[700] - - out[0] = out[0] % 3 - out[1] = out[1] % 3 - out[2] = out[2] % 3 - - // Calculate the remaining inner products by taking advantage of the - // fact that the pattern repeats every three cycles and the pattern of - // differences is moves with the rotation. - for i := 3; i < N; i++ { - // Add twice something is the same as subtracting when working - // mod 3. Doing it this way avoids underflow. Underflow is bad - // because "% 3" doesn't work correctly for negative numbers - // here since underflow will wrap to 2^16-1 and 2^16 isn't a - // multiple of three. - out[i] = (out[i-3] + 2*(a[i-2]+a[i-1]+a[i])) % 3 - } - - // Reduce mod Φ(N) by subtracting a multiple of out[700] from every - // element and convert to mod Q. (See above about adding twice as - // subtraction.) - v := out[700] * 2 - for i := range out { - out[i] = mod3ToModQ((out[i] + v) % 3) - } - - out.mulXMinus1() -} - -func (a *poly) cswap(b *poly, swap uint16) { - for i := range a { - sum := swap & (a[i] ^ b[i]) - a[i] ^= sum - b[i] ^= sum - } -} - -func lt(a, b uint) uint { - if a < b { - return ^uint(0) - } - return 0 -} - -func bsMul(s1, a1, s2, a2 uint) (s3, a3 uint) { - s3 = (a1 & s2) ^ (s1 & a2) - a3 = (a1 & a2) ^ (s1 & s2) - return -} - -func (out *poly3) invertMod3(in *poly3) { - // This algorithm follows algorithm 10 in the paper. (Although note that - // the paper appears to have a bug: k should start at zero, not one.) - // The best explanation for why it works is in the "Why it works" - // section of - // https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf. - var k uint - degF, degG := uint(N-1), uint(N-1) - - var b, c, g poly3 - f := *in - - for i := range g.a { - g.a[i] = ^uint(0) - } - - b.a[0] = 1 - - var f0s, f0a uint - stillGoing := ^uint(0) - for i := 0; i < 2*(N-1)-1; i++ { - ss, sa := bsMul(f.s[0], f.a[0], g.s[0], g.a[0]) - ss, sa = sa&stillGoing&1, ss&stillGoing&1 - shouldSwap := ^uint(int((ss|sa)-1)>>(bitsPerWord-1)) & lt(degF, degG) - f.cswap(&g, shouldSwap) - b.cswap(&c, shouldSwap) - degF, degG = (degG&shouldSwap)|(degF & ^shouldSwap), (degF&shouldSwap)|(degG&^shouldSwap) - f.fmadd(ss, sa, &g) - b.fmadd(ss, sa, &c) - - f.divx() - f.s[wordsPerPoly-1] &= ((1 << bitsInLastWord) - 1) >> 1 - f.a[wordsPerPoly-1] &= ((1 << bitsInLastWord) - 1) >> 1 - c.mulx() - c.s[0] &= ^uint(1) - c.a[0] &= ^uint(1) - - degF-- - k += 1 & stillGoing - f0s = (stillGoing & f.s[0]) | (^stillGoing & f0s) - f0a = (stillGoing & f.a[0]) | (^stillGoing & f0a) - stillGoing = ^uint(int(degF-1) >> (bitsPerWord - 1)) - } - - k -= N & lt(N, k) - *out = b - out.rot(k) - out.mulConst(f0s, f0a) - out.modPhiN() -} - -func (out *poly) invertMod2(a *poly) { - // This algorithm follows mix of algorithm 10 in the paper and the first - // page of the PDF linked below. (Although note that the paper appears - // to have a bug: k should start at zero, not one.) The best explanation - // for why it works is in the "Why it works" section of - // https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf. - var k uint - degF, degG := uint(N-1), uint(N-1) - - var f poly2 - f.fromDiscrete(a) - var b, c, g poly2 - g.setPhiN() - b[0] = 1 - - stillGoing := ^uint(0) - for i := 0; i < 2*(N-1)-1; i++ { - s := uint(f[0]&1) & stillGoing - shouldSwap := ^(s - 1) & lt(degF, degG) - f.cswap(&g, shouldSwap) - b.cswap(&c, shouldSwap) - degF, degG = (degG&shouldSwap)|(degF & ^shouldSwap), (degF&shouldSwap)|(degG&^shouldSwap) - f.fmadd(s, &g) - b.fmadd(s, &c) - - f.rshift1() - c.lshift1() - - degF-- - k += 1 & stillGoing - stillGoing = ^uint(int(degF-1) >> (bitsPerWord - 1)) - } - - k -= N & lt(N, k) - b.rot(k) - out.fromMod2(&b) -} - -func (out *poly) invert(origA *poly) { - // Inversion mod Q, which is done based on the result of inverting mod - // 2. See the NTRU paper, page three. - var a, tmp, tmp2, b poly - b.invertMod2(origA) - - // Negate a. - for i := range a { - a[i] = Q - origA[i] - } - - // We are working mod Q=2**13 and we need to iterate ceil(log_2(13)) - // times, which is four. - for i := 0; i < 4; i++ { - tmp.mul(&a, &b) - tmp[0] += 2 - tmp2.mul(&b, &tmp) - b = tmp2 - } - - *out = b -} - -type PublicKey struct { - h poly -} - -func ParsePublicKey(in []byte) (*PublicKey, bool) { - ret := new(PublicKey) - if !ret.h.unmarshal(in) { - return nil, false - } - return ret, true -} - -func (pub *PublicKey) Marshal() []byte { - ret := make([]byte, modQBytes) - pub.h.marshal(ret) - return ret -} - -func (pub *PublicKey) Encap(rand io.Reader) (ciphertext []byte, sharedKey []byte) { - var randBytes [352 + 352]byte - if _, err := io.ReadFull(rand, randBytes[:]); err != nil { - panic("rand failed") - } - - var m, r poly - m.shortSample(randBytes[:352]) - r.shortSample(randBytes[352:]) - - var mBytes, rBytes [mod3Bytes]byte - m.marshalS3(mBytes[:]) - r.marshalS3(rBytes[:]) - - ciphertext = pub.owf(&m, &r) - - h := sha256.New() - h.Write([]byte("shared key\x00")) - h.Write(mBytes[:]) - h.Write(rBytes[:]) - h.Write(ciphertext) - sharedKey = h.Sum(nil) - - return ciphertext, sharedKey -} - -func (pub *PublicKey) owf(m, r *poly) []byte { - for i := range r { - r[i] = mod3ToModQ(r[i]) - } - - var mq poly - mq.lift(m) - - var e poly - e.mul(r, &pub.h) - for i := range e { - e[i] = (e[i] + mq[i]) % Q - } - - ret := make([]byte, modQBytes) - e.marshal(ret[:]) - return ret -} - -type PrivateKey struct { - PublicKey - f, fp poly3 - hInv poly - hmacKey [32]byte -} - -func (priv *PrivateKey) Marshal() []byte { - var ret [2*mod3Bytes + modQBytes]byte - priv.f.marshal(ret[:]) - priv.fp.marshal(ret[mod3Bytes:]) - priv.h.marshal(ret[2*mod3Bytes:]) - return ret[:] -} - -func (priv *PrivateKey) Decap(ciphertext []byte) (sharedKey []byte, ok bool) { - if len(ciphertext) != modQBytes { - return nil, false - } - - var e poly - if !e.unmarshal(ciphertext) { - return nil, false - } - - var f poly - f.fromMod3ToModQ(&priv.f) - - var v1, m poly - v1.mul(&e, &f) - - var v13 poly3 - v13.fromDiscreteMod3(&v1) - // Note: v13 is not reduced mod phi(n). - - var m3 poly3 - m3.mulMod3(&v13, &priv.fp) - m3.modPhiN() - m.fromMod3(&m3) - - var mLift, delta poly - mLift.lift(&m) - for i := range delta { - delta[i] = (e[i] - mLift[i] + Q) % Q - } - delta.mul(&delta, &priv.hInv) - delta.modPhiN() - - var r poly3 - allOk := r.fromModQ(&delta) - - var mBytes, rBytes [mod3Bytes]byte - m.marshalS3(mBytes[:]) - r.marshal(rBytes[:]) - - var rPoly poly - rPoly.fromMod3(&r) - expectedCiphertext := priv.PublicKey.owf(&m, &rPoly) - - allOk &= subtle.ConstantTimeCompare(ciphertext, expectedCiphertext) - - hmacHash := hmac.New(sha256.New, priv.hmacKey[:]) - hmacHash.Write(ciphertext) - hmacDigest := hmacHash.Sum(nil) - - h := sha256.New() - h.Write([]byte("shared key\x00")) - h.Write(mBytes[:]) - h.Write(rBytes[:]) - h.Write(ciphertext) - sharedKey = h.Sum(nil) - - mask := uint8(allOk - 1) - for i := range sharedKey { - sharedKey[i] = (sharedKey[i] & ^mask) | (hmacDigest[i] & mask) - } - - return sharedKey, true -} - -func GenerateKey(rand io.Reader) PrivateKey { - var randBytes [352 + 352]byte - if _, err := io.ReadFull(rand, randBytes[:]); err != nil { - panic("rand failed") - } - - var f poly - f.shortSamplePlus(randBytes[:352]) - var priv PrivateKey - priv.f.fromDiscrete(&f) - priv.fp.invertMod3(&priv.f) - - var g poly - g.shortSamplePlus(randBytes[352:]) - - var pgPhi1 poly - for i := range g { - pgPhi1[i] = mod3ToModQ(g[i]) - } - for i := range pgPhi1 { - pgPhi1[i] = (pgPhi1[i] * 3) % Q - } - pgPhi1.mulXMinus1() - - var fModQ poly - fModQ.fromMod3ToModQ(&priv.f) - - var pfgPhi1 poly - pfgPhi1.mul(&fModQ, &pgPhi1) - - var i poly - i.invert(&pfgPhi1) - - priv.h.mul(&i, &pgPhi1) - priv.h.mul(&priv.h, &pgPhi1) - - priv.hInv.mul(&i, &fModQ) - priv.hInv.mul(&priv.hInv, &fModQ) - - return priv -}
diff --git a/ssl/test/runner/key_agreement.go b/ssl/test/runner/key_agreement.go index 47cdbb8..5739888 100644 --- a/ssl/test/runner/key_agreement.go +++ b/ssl/test/runner/key_agreement.go
@@ -17,7 +17,6 @@ "io" "math/big" - "boringssl.googlesource.com/boringssl/ssl/test/runner/hrss" "golang.org/x/crypto/curve25519" ) @@ -341,90 +340,6 @@ return out[:], nil } -// cecpq2KEM implements CECPQ2, which is HRSS+SXY combined with X25519. -type cecpq2KEM struct { - x25519PrivateKey [32]byte - hrssPrivateKey hrss.PrivateKey -} - -func (e *cecpq2KEM) generate(rand io.Reader) (publicKey []byte, err error) { - if _, err := io.ReadFull(rand, e.x25519PrivateKey[:]); err != nil { - return nil, err - } - - var x25519Public [32]byte - curve25519.ScalarBaseMult(&x25519Public, &e.x25519PrivateKey) - - e.hrssPrivateKey = hrss.GenerateKey(rand) - hrssPublic := e.hrssPrivateKey.PublicKey.Marshal() - - var ret []byte - ret = append(ret, x25519Public[:]...) - ret = append(ret, hrssPublic...) - return ret, nil -} - -func (e *cecpq2KEM) encap(rand io.Reader, peerKey []byte) (ciphertext []byte, secret []byte, err error) { - if len(peerKey) != 32+hrss.PublicKeySize { - return nil, nil, errors.New("tls: bad length CECPQ2 offer") - } - - if _, err := io.ReadFull(rand, e.x25519PrivateKey[:]); err != nil { - return nil, nil, err - } - - var x25519Shared, x25519PeerKey, x25519Public [32]byte - copy(x25519PeerKey[:], peerKey) - curve25519.ScalarBaseMult(&x25519Public, &e.x25519PrivateKey) - curve25519.ScalarMult(&x25519Shared, &e.x25519PrivateKey, &x25519PeerKey) - - // Per RFC 7748, reject the all-zero value in constant time. - var zeros [32]byte - if subtle.ConstantTimeCompare(zeros[:], x25519Shared[:]) == 1 { - return nil, nil, errors.New("tls: X25519 value with wrong order") - } - - hrssPublicKey, ok := hrss.ParsePublicKey(peerKey[32:]) - if !ok { - return nil, nil, errors.New("tls: bad CECPQ2 offer") - } - - hrssCiphertext, hrssShared := hrssPublicKey.Encap(rand) - - ciphertext = append(ciphertext, x25519Public[:]...) - ciphertext = append(ciphertext, hrssCiphertext...) - secret = append(secret, x25519Shared[:]...) - secret = append(secret, hrssShared...) - - return ciphertext, secret, nil -} - -func (e *cecpq2KEM) decap(ciphertext []byte) (secret []byte, err error) { - if len(ciphertext) != 32+hrss.CiphertextSize { - return nil, errors.New("tls: bad length CECPQ2 reply") - } - - var x25519Shared, x25519PeerKey [32]byte - copy(x25519PeerKey[:], ciphertext) - curve25519.ScalarMult(&x25519Shared, &e.x25519PrivateKey, &x25519PeerKey) - - // Per RFC 7748, reject the all-zero value in constant time. - var zeros [32]byte - if subtle.ConstantTimeCompare(zeros[:], x25519Shared[:]) == 1 { - return nil, errors.New("tls: X25519 value with wrong order") - } - - hrssShared, ok := e.hrssPrivateKey.Decap(ciphertext[32:]) - if !ok { - return nil, errors.New("tls: invalid HRSS ciphertext") - } - - secret = append(secret, x25519Shared[:]...) - secret = append(secret, hrssShared...) - - return secret, nil -} - func kemForCurveID(id CurveID, config *Config) (kemImplementation, bool) { switch id { case CurveP224: @@ -437,8 +352,6 @@ return &ecdhKEM{curve: elliptic.P521(), sendCompressed: config.Bugs.SendCompressedCoordinates}, true case CurveX25519: return &x25519KEM{setHighBit: config.Bugs.SetX25519HighBit}, true - case CurveCECPQ2: - return &cecpq2KEM{}, true default: return nil, false } @@ -587,7 +500,7 @@ NextCandidate: for _, candidate := range preferredCurves { if isPqGroup(candidate) && version < VersionTLS13 { - // CECPQ2 is TLS 1.3-only. + // Post-quantum "groups" require TLS 1.3. continue }
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go index 3d660da..54bb7b4 100644 --- a/ssl/test/runner/runner.go +++ b/ssl/test/runner/runner.go
@@ -11371,13 +11371,12 @@ {"P-384", CurveP384}, {"P-521", CurveP521}, {"X25519", CurveX25519}, - {"CECPQ2", CurveCECPQ2}, } const bogusCurve = 0x1234 func isPqGroup(r CurveID) bool { - return r == CurveCECPQ2 + return r == CurveX25519Kyber768 } func addCurveTests() { @@ -11841,78 +11840,79 @@ }, }) - // CECPQ2 should not be offered by a TLS < 1.3 client. + // Kyber should not be offered by a TLS < 1.3 client. testCases = append(testCases, testCase{ - name: "CECPQ2NotInTLS12", + name: "KyberNotInTLS12", config: Config{ Bugs: ProtocolBugs{ - FailIfCECPQ2Offered: true, + FailIfKyberOffered: true, }, }, flags: []string{ "-max-version", strconv.Itoa(VersionTLS12), - "-curves", strconv.Itoa(int(CurveCECPQ2)), + "-curves", strconv.Itoa(int(CurveX25519Kyber768)), "-curves", strconv.Itoa(int(CurveX25519)), }, }) - // CECPQ2 should not crash a TLS < 1.3 client if the server mistakenly + // Kyber should not crash a TLS < 1.3 client if the server mistakenly // selects it. testCases = append(testCases, testCase{ - name: "CECPQ2NotAcceptedByTLS12Client", + name: "KyberNotAcceptedByTLS12Client", config: Config{ Bugs: ProtocolBugs{ - SendCurve: CurveCECPQ2, + SendCurve: CurveX25519Kyber768, }, }, flags: []string{ "-max-version", strconv.Itoa(VersionTLS12), - "-curves", strconv.Itoa(int(CurveCECPQ2)), + "-curves", strconv.Itoa(int(CurveX25519Kyber768)), "-curves", strconv.Itoa(int(CurveX25519)), }, shouldFail: true, expectedError: ":WRONG_CURVE:", }) - // CECPQ2 should not be offered by default as a client. + // Kyber should not be offered by default as a client. testCases = append(testCases, testCase{ - name: "CECPQ2NotEnabledByDefaultInClients", + name: "KyberNotEnabledByDefaultInClients", config: Config{ MinVersion: VersionTLS13, Bugs: ProtocolBugs{ - FailIfCECPQ2Offered: true, + FailIfKyberOffered: true, }, }, }) - // If CECPQ2 is offered, both X25519 and CECPQ2 should have a key-share. + // If Kyber is offered, both X25519 and Kyber should have a key-share. testCases = append(testCases, testCase{ - name: "NotJustCECPQ2KeyShare", + name: "NotJustKyberKeyShare", config: Config{ MinVersion: VersionTLS13, Bugs: ProtocolBugs{ - ExpectedKeyShares: []CurveID{CurveCECPQ2, CurveX25519}, + ExpectedKeyShares: []CurveID{CurveX25519Kyber768, CurveX25519}, }, }, flags: []string{ - "-curves", strconv.Itoa(int(CurveCECPQ2)), + "-curves", strconv.Itoa(int(CurveX25519Kyber768)), "-curves", strconv.Itoa(int(CurveX25519)), - "-expect-curve-id", strconv.Itoa(int(CurveCECPQ2)), + // Cannot expect Kyber until we have a Go implementation of it. + // "-expect-curve-id", strconv.Itoa(int(CurveX25519Kyber768)), }, }) // ... and the other way around testCases = append(testCases, testCase{ - name: "CECPQ2KeyShareIncludedSecond", + name: "KyberKeyShareIncludedSecond", config: Config{ MinVersion: VersionTLS13, Bugs: ProtocolBugs{ - ExpectedKeyShares: []CurveID{CurveX25519, CurveCECPQ2}, + ExpectedKeyShares: []CurveID{CurveX25519, CurveX25519Kyber768}, }, }, flags: []string{ "-curves", strconv.Itoa(int(CurveX25519)), - "-curves", strconv.Itoa(int(CurveCECPQ2)), + "-curves", strconv.Itoa(int(CurveX25519Kyber768)), "-expect-curve-id", strconv.Itoa(int(CurveX25519)), }, }) @@ -11921,44 +11921,46 @@ // first classical and first post-quantum "curves" that get key shares // included. testCases = append(testCases, testCase{ - name: "CECPQ2KeyShareIncludedThird", + name: "KyberKeyShareIncludedThird", config: Config{ MinVersion: VersionTLS13, Bugs: ProtocolBugs{ - ExpectedKeyShares: []CurveID{CurveX25519, CurveCECPQ2}, + ExpectedKeyShares: []CurveID{CurveX25519, CurveX25519Kyber768}, }, }, flags: []string{ "-curves", strconv.Itoa(int(CurveX25519)), "-curves", strconv.Itoa(int(CurveP256)), - "-curves", strconv.Itoa(int(CurveCECPQ2)), + "-curves", strconv.Itoa(int(CurveX25519Kyber768)), "-expect-curve-id", strconv.Itoa(int(CurveX25519)), }, }) - // If CECPQ2 is the only configured curve, the key share is sent. + // If Kyber is the only configured curve, the key share is sent. testCases = append(testCases, testCase{ - name: "JustConfiguringCECPQ2Works", + name: "JustConfiguringKyberWorks", config: Config{ MinVersion: VersionTLS13, Bugs: ProtocolBugs{ - ExpectedKeyShares: []CurveID{CurveCECPQ2}, + ExpectedKeyShares: []CurveID{CurveX25519Kyber768}, }, }, flags: []string{ - "-curves", strconv.Itoa(int(CurveCECPQ2)), - "-expect-curve-id", strconv.Itoa(int(CurveCECPQ2)), + "-curves", strconv.Itoa(int(CurveX25519Kyber768)), + "-expect-curve-id", strconv.Itoa(int(CurveX25519Kyber768)), }, + shouldFail: true, + expectedLocalError: "no curve supported by both client and server", }) - // As a server, CECPQ2 is not yet supported by default. + // As a server, Kyber is not yet supported by default. testCases = append(testCases, testCase{ testType: serverTest, - name: "CECPQ2NotEnabledByDefaultForAServer", + name: "KyberNotEnabledByDefaultForAServer", config: Config{ MinVersion: VersionTLS13, - CurvePreferences: []CurveID{CurveCECPQ2, CurveX25519}, - DefaultCurves: []CurveID{CurveCECPQ2}, + CurvePreferences: []CurveID{CurveX25519Kyber768, CurveX25519}, + DefaultCurves: []CurveID{CurveX25519Kyber768}, }, flags: []string{ "-server-preference",
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc index 109c69e..09faf04 100644 --- a/ssl/test/test_config.cc +++ b/ssl/test/test_config.cc
@@ -1909,8 +1909,8 @@ nids.push_back(NID_X25519); break; - case SSL_CURVE_CECPQ2: - nids.push_back(NID_CECPQ2); + case SSL_CURVE_X25519KYBER768: + nids.push_back(NID_X25519Kyber768); break; } if (!SSL_set1_curves(ssl.get(), &nids[0], nids.size())) {