Add API for configuring client key shares This change introduces a new API to allow callers to configure the exact set of key shares to be sent in a client's key_share extension. If the supported groups for a connection are modified, any previously-selected key shares are cleared if they are no longer compatible. Clients are allowed to configure an empty list of key shares, which results in always taking a round-trip for HelloRetryRequest. Bug: 437414371 Change-Id: Ibd2a9b217fa3c746dec194a027d3ec45a81bb578 Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/81567 Commit-Queue: Lily Chen <chlily@google.com> Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/crypto/mem_internal.h b/crypto/mem_internal.h index 20b45fd..5ba44dc 100644 --- a/crypto/mem_internal.h +++ b/crypto/mem_internal.h
@@ -379,6 +379,8 @@ template <typename T, size_t N> class InplaceVector { public: + using value_type = std::remove_cv_t<T>; + InplaceVector() = default; InplaceVector(const InplaceVector &other) { *this = other; } InplaceVector(InplaceVector &&other) { *this = std::move(other); }
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h index 79eb0c7..81de2fd 100644 --- a/include/openssl/ssl.h +++ b/include/openssl/ssl.h
@@ -2605,6 +2605,47 @@ OPENSSL_EXPORT int SSL_get_negotiated_group(const SSL *ssl); +// Client key shares. +// +// The key_share extension in TLS 1.3 (RFC 8446 section 4.2.8) may be sent in +// the initial ClientHello to provide key exchange parameters for a subset of +// the groups offered in the client's supported_groups extension, in hopes of +// saving a round-trip by having proactively started a key exchange for the +// ultimately-negotiated group. +// +// If not otherwise configured, the default client key share selection logic +// outputs key shares for up to two supported groups, at most one of which is +// post-quantum. + +// SSL_set1_client_key_shares, when called by a client before the handshake, +// configures |ssl| to send a key_share extension in the initial ClientHello +// containing exactly the groups given by |group_ids|, in the order given. Each +// member of |group_ids| should be one of the |SSL_GROUP_*| constants, and they +// must be unique. This function returns one on success and zero on failure. +// +// If non-empty, the sequence of |group_ids| must be a (not necessarily +// contiguous) subsequence of the groups supported by |ssl|, which may have been +// configured explicitly on |ssl| or its context, or populated by default. +// Caller should finish configuring the group list before calling this function. +// Changing the supported groups for |ssl| after having set client key shares +// will result in the key share selections being reset if this constraint no +// longer holds. +// +// Setting an empty sequence of |group_ids| results in an empty client +// key_share, which will cause the handshake to always take an extra round-trip +// for HelloRetryRequest. +// +// An extra round-trip will be needed if the server's choice of group is not +// among the key shares sent; conversely, sending any key shares other than the +// server's choice wastes CPU and bandwidth (the latter is particularly costly +// for post-quantum key exchanges). To avoid these sub-optimal outcomes, +// key shares should be chosen such that they are likely to be supported by the +// peer server. +OPENSSL_EXPORT int SSL_set1_client_key_shares(SSL *ssl, + const uint16_t *group_ids, + size_t num_group_ids); + + // Certificate verification. // // SSL may authenticate either endpoint with an X.509 certificate. Typically
diff --git a/ssl/extensions.cc b/ssl/extensions.cc index a09594e..2169966 100644 --- a/ssl/extensions.cc +++ b/ssl/extensions.cc
@@ -2197,6 +2197,45 @@ return true; } + const Span<const uint16_t> supported_group_list = + hs->config->supported_group_list; + if (supported_group_list.empty()) { + OPENSSL_PUT_ERROR(SSL, SSL_R_NO_GROUPS_SPECIFIED); + return false; + } + + InplaceVector<uint16_t, 2> default_key_shares; + + Span<const uint16_t> selected_key_shares; + + // Determine the key shares to send. + if (override_group_id != 0) { + assert(std::find(supported_group_list.begin(), supported_group_list.end(), + override_group_id) != supported_group_list.end()); + selected_key_shares = Span(&override_group_id, 1u); + } else if (ssl->config->client_key_share_selections.has_value()) { + selected_key_shares = *(ssl->config->client_key_share_selections); + } else { + // By default, predict the most preferred group. + if (!default_key_shares.TryPushBack(supported_group_list[0])) { + return false; + } + // We'll try to include one post-quantum and one classical initial key + // share. + for (size_t i = 1; i < supported_group_list.size(); i++) { + if (is_post_quantum_group(default_key_shares[0]) == + is_post_quantum_group(supported_group_list[i])) { + continue; + } + if (!default_key_shares.TryPushBack(supported_group_list[i])) { + return false; + } + assert(default_key_shares[1] != default_key_shares[0]); + break; + } + selected_key_shares = default_key_shares; + } + bssl::ScopedCBB cbb; if (!CBB_init(cbb.get(), 64)) { return false; @@ -2211,45 +2250,13 @@ } } - uint16_t group_id = override_group_id; - uint16_t second_group_id = 0; - if (override_group_id == 0) { - // Predict the most preferred group. - Span<const uint16_t> groups = hs->config->supported_group_list; - if (groups.empty()) { - OPENSSL_PUT_ERROR(SSL, SSL_R_NO_GROUPS_SPECIFIED); - return false; - } - - group_id = groups[0]; - - // We'll try to include one post-quantum and one classical initial key - // share. - for (size_t i = 1; i < groups.size() && second_group_id == 0; i++) { - if (is_post_quantum_group(group_id) != is_post_quantum_group(groups[i])) { - second_group_id = groups[i]; - assert(second_group_id != group_id); - } - } - } - CBB key_exchange; - { + for (const uint16_t group_id : selected_key_shares) { UniquePtr<SSLKeyShare> key_share = SSLKeyShare::Create(group_id); - if (key_share == nullptr || !CBB_add_u16(cbb.get(), group_id) || - !CBB_add_u16_length_prefixed(cbb.get(), &key_exchange) || - !key_share->Generate(&key_exchange) || - !hs->key_shares.TryPushBack(std::move(key_share))) { - return false; - } - } - - if (second_group_id != 0) { - // TODO(chlily): Fix temporary code duplication. - UniquePtr<SSLKeyShare> key_share = SSLKeyShare::Create(second_group_id); - if (key_share == nullptr || !CBB_add_u16(cbb.get(), second_group_id) || - !CBB_add_u16_length_prefixed(cbb.get(), &key_exchange) || - !key_share->Generate(&key_exchange) || + if (key_share == nullptr || // + !CBB_add_u16(cbb.get(), group_id) || // + !CBB_add_u16_length_prefixed(cbb.get(), &key_exchange) || // + !key_share->Generate(&key_exchange) || // !hs->key_shares.TryPushBack(std::move(key_share))) { return false; } @@ -2267,7 +2274,12 @@ return true; } - assert(!hs->key_share_bytes.empty()); + // The caller may explicitly configure empty key shares to request a + // HelloRetryRequest. + assert(!hs->key_share_bytes.empty() || + (hs->config->client_key_share_selections.has_value() && + hs->config->client_key_share_selections->empty())); + CBB contents, kse_bytes; if (!CBB_add_u16(out_compressible, TLSEXT_TYPE_key_share) || !CBB_add_u16_length_prefixed(out_compressible, &contents) ||
diff --git a/ssl/internal.h b/ssl/internal.h index 37eb54e..d72f7a5 100644 --- a/ssl/internal.h +++ b/ssl/internal.h
@@ -1770,9 +1770,7 @@ UniquePtr<ERR_SAVE_STATE> error; // key_shares are the current key exchange instances, in preference order. Any - // members of this vector must be non-null. By default, no more than two are - // used, and the second is only used as a client if we believe that we should - // offer two key shares in a ClientHello. + // members of this vector must be non-null. InplaceVector<UniquePtr<SSLKeyShare>, kNumNamedGroups> key_shares; // transcript is the current handshake transcript. @@ -2130,9 +2128,19 @@ bool ssl_setup_extension_permutation(SSL_HANDSHAKE *hs); // ssl_setup_key_shares computes client key shares and saves them in |hs|. It -// returns true on success and false on failure. If |override_group_id| is zero, -// it offers the default groups, including GREASE. If it is non-zero, it offers -// a single key share of the specified group. +// returns true on success and false on failure. In order of precedence: +// +// - If |override_group_id| is non-zero, it offers a single key share of the +// specified group. +// +// - If key shares were previously specified by the caller via +// |SSL_set_client_key_shares|, those are used. +// +// - If |override_group_id| is zero and no selections were made by the caller +// via |SSL_set_client_key_shares|, it selects the first supported group and +// may select a second if at most one of the two is a post-quantum group. +// +// GREASE will be included if enabled, when |override_group_id| is zero. bool ssl_setup_key_shares(SSL_HANDSHAKE *hs, uint16_t override_group_id); // ssl_setup_pake_shares computes the client PAKE shares and saves them in |hs|. @@ -3275,7 +3283,19 @@ // Trust anchor IDs to be requested in the trust_anchors extension. std::optional<Array<uint8_t>> requested_trust_anchors; - Array<uint16_t> supported_group_list; // our list + // Our list of supported groups. If this list is modified, for a client, + // |client_key_share_selections| must be reset if the key shares are no longer + // a valid subsequence of the supported group list. + Array<uint16_t> supported_group_list; + + // For a client, this may contain a subsequence of the group IDs in + // |suppported_group_list|, which gives the groups for which key shares should + // be sent in the client's key_share extension. This is non-nullopt iff + // |SSL_set_client_key_shares| was successfully called to configure key + // shares. If non-nullopt, these groups are in the same order as they appear + // in |supported_group_list|, and may not contain duplicates. + std::optional<InplaceVector<uint16_t, kNumNamedGroups>> + client_key_share_selections; // channel_id_private is the client's Channel ID private key, or null if // Channel ID should not be offered on this connection.
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc index 2e0db35..532744f 100644 --- a/ssl/ssl_lib.cc +++ b/ssl/ssl_lib.cc
@@ -1844,6 +1844,37 @@ return check_no_duplicates(group_ids); } +// validate_key_shares returns whether the `requested_key_shares` are free of +// duplicates and are a (correctly ordered) subsequence of the supported +// `groups`. +static bool validate_key_shares(Span<const uint16_t> requested_key_shares, + Span<const uint16_t> groups) { + if (!check_no_duplicates(requested_key_shares)) { + return false; + } + if (requested_key_shares.size() > groups.size()) { + return false; + } + size_t key_shares_idx = 0u, groups_idx = 0u; + while (key_shares_idx < requested_key_shares.size() && + groups_idx < groups.size()) { + if (requested_key_shares[key_shares_idx] == groups[groups_idx++]) { + ++key_shares_idx; + } + } + return key_shares_idx == requested_key_shares.size(); +} + +static void clear_key_shares_if_invalid(SSL_CONFIG *config) { + if (!config->client_key_share_selections) { + return; + } + if (!validate_key_shares(*(config->client_key_share_selections), + config->supported_group_list)) { + config->client_key_share_selections.reset(); + } +} + int SSL_CTX_set1_group_ids(SSL_CTX *ctx, const uint16_t *group_ids, size_t num_group_ids) { auto span = Span(group_ids, num_group_ids); @@ -1862,8 +1893,12 @@ if (span.empty()) { span = DefaultSupportedGroupIds(); } - return check_group_ids(span) && - ssl->config->supported_group_list.CopyFrom(span); + if (check_group_ids(span) && + ssl->config->supported_group_list.CopyFrom(span)) { + clear_key_shares_if_invalid(ssl->config.get()); + return 1; + } + return 0; } static bool ssl_nids_to_group_ids(Array<uint16_t> *out_group_ids, @@ -1899,8 +1934,12 @@ if (!ssl->config) { return 0; } - return ssl_nids_to_group_ids(&ssl->config->supported_group_list, - Span(groups, num_groups)); + if (ssl_nids_to_group_ids(&ssl->config->supported_group_list, + Span(groups, num_groups))) { + clear_key_shares_if_invalid(ssl->config.get()); + return 1; + } + return 0; } static bool ssl_str_to_group_ids(Array<uint16_t> *out_group_ids, @@ -1951,7 +1990,11 @@ if (!ssl->config) { return 0; } - return ssl_str_to_group_ids(&ssl->config->supported_group_list, groups); + if (ssl_str_to_group_ids(&ssl->config->supported_group_list, groups)) { + clear_key_shares_if_invalid(ssl->config.get()); + return 1; + } + return 0; } uint16_t SSL_get_group_id(const SSL *ssl) { @@ -1971,6 +2014,23 @@ return ssl_group_id_to_nid(group_id); } +int SSL_set1_client_key_shares(SSL *ssl, const uint16_t *group_ids, + size_t num_group_ids) { + if (!ssl->config) { + return 0; + } + auto requested_key_shares = Span(group_ids, num_group_ids); + if (!validate_key_shares(requested_key_shares, + ssl->config->supported_group_list)) { + return 0; + } + + assert(requested_key_shares.size() <= kNumNamedGroups); + ssl->config->client_key_share_selections.emplace(); + ssl->config->client_key_share_selections->CopyFrom(requested_key_shares); + return 1; +} + int SSL_CTX_set_tmp_dh(SSL_CTX *ctx, const DH *dh) { return 1; } int SSL_set_tmp_dh(SSL *ssl, const DH *dh) { return 1; }
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc index 96edb61..1b34c14 100644 --- a/ssl/ssl_test.cc +++ b/ssl/ssl_test.cc
@@ -722,6 +722,146 @@ } } +TEST(SSLTest, SetClientKeyShares) { + const struct { + const char *description; + std::vector<uint16_t> supported_groups; + std::vector<uint16_t> key_shares; + bool expected_success; + } kTests[] = { + { + "Empty key shares with default supported groups", + {}, + {}, + true, + }, + { + "Empty key shares with custom supported groups", + {SSL_GROUP_X25519, SSL_GROUP_X25519_MLKEM768}, + {}, + true, + }, + { + "One key share matching default supported groups", + {}, + {SSL_GROUP_X25519}, + true, + }, + { + "One key share matching custom supported groups", + {SSL_GROUP_X25519, SSL_GROUP_X25519_MLKEM768}, + {SSL_GROUP_X25519}, + true, + }, + { + "Key share not in supported default groups", + {}, + {SSL_GROUP_MLKEM1024}, + false, + }, + { + "Key share not in supported custom groups", + {SSL_GROUP_X25519, SSL_GROUP_SECP256R1}, + {SSL_GROUP_X25519_MLKEM768}, + false, + }, + { + "Multiple key shares, in correct order", + {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768}, + {SSL_GROUP_X25519, SSL_GROUP_X25519_MLKEM768}, + true, + }, + { + "Multiple key shares, out of order", + {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768}, + {SSL_GROUP_X25519_MLKEM768, SSL_GROUP_X25519}, + false, + }, + { + "More than two key shares", + {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768, + SSL_GROUP_MLKEM1024}, + {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768}, + true, + }, + { + "Key shares cover all supported groups", + {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768}, + {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768}, + true, + }, + { + "Multiple key shares, not all valid", + {SSL_GROUP_X25519, SSL_GROUP_X25519_MLKEM768}, + {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768}, + false, + }, + { + "Key shares contain duplicates", + {}, + {SSL_GROUP_X25519, SSL_GROUP_X25519}, + false, + }, + }; + + for (const auto &t : kTests) { + SCOPED_TRACE(t.description); + bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method())); + ASSERT_TRUE(ctx); + bssl::UniquePtr<SSL> ssl(SSL_new(ctx.get())); + ASSERT_TRUE(ssl); + ASSERT_FALSE(ssl->config->client_key_share_selections.has_value()); + + ASSERT_TRUE(SSL_set1_group_ids(ssl.get(), t.supported_groups.data(), + t.supported_groups.size())); + EXPECT_EQ(SSL_set1_client_key_shares(ssl.get(), t.key_shares.data(), + t.key_shares.size()), + t.expected_success); + if (t.expected_success) { + ASSERT_TRUE(ssl->config->client_key_share_selections.has_value()); + EXPECT_THAT(ssl->config->client_key_share_selections.value(), + ElementsAreArray(t.key_shares)); + } + } +} + +// Test the behavior that modifying the SSL's supported groups results in +// clearing the previously set client key shares, iff the supported groups +// become incompatible with the key shares. +TEST(SSLTest, ClientKeySharesResetAfterChangingGroups) { + bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method())); + ASSERT_TRUE(ctx); + bssl::UniquePtr<SSL> ssl(SSL_new(ctx.get())); + ASSERT_TRUE(ssl); + ASSERT_FALSE(ssl->config->client_key_share_selections.has_value()); + + // An initial groups list and key shares that are compatible. + const uint16_t kGroups1[] = {SSL_GROUP_X25519_MLKEM768, SSL_GROUP_X25519}; + const uint16_t kKeyShares[] = {SSL_GROUP_X25519_MLKEM768, SSL_GROUP_X25519}; + ASSERT_TRUE( + SSL_set1_group_ids(ssl.get(), kGroups1, std::size(kGroups1))); + ASSERT_TRUE(SSL_set1_client_key_shares(ssl.get(), kKeyShares, + std::size(kKeyShares))); + ASSERT_TRUE(ssl->config->client_key_share_selections.has_value()); + EXPECT_EQ(ssl->config->client_key_share_selections->size(), 2u); + + // A new groups list that is still compatible with the previously set key + // shares. + const uint16_t kGroups2[] = {SSL_GROUP_MLKEM1024, SSL_GROUP_X25519_MLKEM768, + SSL_GROUP_X25519}; + ASSERT_TRUE( + SSL_set1_group_ids(ssl.get(), kGroups2, std::size(kGroups2))); + ASSERT_TRUE(ssl->config->client_key_share_selections.has_value()); + EXPECT_EQ(ssl->config->client_key_share_selections->size(), 2u); + + // A new groups list that is no longer compatible with the previously set key + // shares. + const uint16_t kGroups3[] = {SSL_GROUP_MLKEM1024, SSL_GROUP_X25519}; + ASSERT_TRUE( + SSL_set1_group_ids(ssl.get(), kGroups3, std::size(kGroups3))); + EXPECT_FALSE(ssl->config->client_key_share_selections.has_value()); +} + // kOpenSSLSession is a serialized SSL_SESSION. static const char kOpenSSLSession[] = "MIIFqgIBAQICAwMEAsAvBCAG5Q1ndq4Yfmbeo1zwLkNRKmCXGdNgWvGT3cskV0yQ"
diff --git a/ssl/test/runner/pake_tests.go b/ssl/test/runner/pake_tests.go index ddd2eb3..677a3a4 100644 --- a/ssl/test/runner/pake_tests.go +++ b/ssl/test/runner/pake_tests.go
@@ -14,7 +14,10 @@ package runner -import "errors" +import ( + "errors" + "strconv" +) func addPAKETests() { spakeCredential := Credential{ @@ -231,17 +234,22 @@ shimCredentials: []*Credential{&spakeCredential}, }) testCases = append(testCases, testCase{ - // A PAKE client will not offer key shares, so the client should - // reject a HelloRetryRequest requesting a different key share. + // A PAKE client will not offer key shares, even if explicitly configured, + // and the client should reject a HelloRetryRequest requesting a different + // key share. name: "PAKE-Client-HRRKeyShare", testType: clientTest, config: Config{ MinVersion: VersionTLS13, Credential: &spakeCredential, Bugs: ProtocolBugs{ + ExpectedKeyShares: []CurveID{}, SendHelloRetryRequestCurve: CurveX25519, }, }, + flags: []string{ + "-key-shares", strconv.Itoa(int(CurveP256)), + }, shimCredentials: []*Credential{&spakeCredential}, shouldFail: true, expectedError: ":UNEXPECTED_EXTENSION:",
diff --git a/ssl/test/runner/tls13_tests.go b/ssl/test/runner/tls13_tests.go index 2b08ec8..b9826eb 100644 --- a/ssl/test/runner/tls13_tests.go +++ b/ssl/test/runner/tls13_tests.go
@@ -132,6 +132,175 @@ }) testCases = append(testCases, testCase{ + testType: clientTest, + name: "CustomKeyShares-TLS13", + config: Config{ + MinVersion: VersionTLS13, + Bugs: ProtocolBugs{ + ExpectedKeyShares: []CurveID{CurveP256}, + }, + }, + flags: []string{ + "-curves", strconv.Itoa(int(CurveX25519)), + "-curves", strconv.Itoa(int(CurveP256)), + "-key-shares", strconv.Itoa(int(CurveP256)), + }, + }) + + testCases = append(testCases, testCase{ + testType: clientTest, + name: "CustomKeyShares-DefaultSupportedGroups-TLS13", + config: Config{ + MinVersion: VersionTLS13, + Bugs: ProtocolBugs{ + ExpectedKeyShares: []CurveID{CurveX25519}, + }, + }, + flags: []string{ + // Configure key shares without explicitly configuring supported groups. + "-key-shares", strconv.Itoa(int(CurveX25519)), + }, + }) + + testCases = append(testCases, testCase{ + testType: clientTest, + name: "CustomKeyShares-Multiple-TLS13", + config: Config{ + MinVersion: VersionTLS13, + Bugs: ProtocolBugs{ + ExpectedKeyShares: []CurveID{CurveX25519, CurveX25519MLKEM768}, + }, + }, + flags: []string{ + "-curves", strconv.Itoa(int(CurveX25519)), + "-curves", strconv.Itoa(int(CurveX25519MLKEM768)), + "-curves", strconv.Itoa(int(CurveP256)), + // Predict the top 2 out of 3 supported curves. + "-key-shares", strconv.Itoa(int(CurveX25519)), + "-key-shares", strconv.Itoa(int(CurveX25519MLKEM768)), + "-expect-curve-id", strconv.Itoa(int(CurveX25519)), + "-expect-no-hrr", + }, + }) + + testCases = append(testCases, testCase{ + testType: clientTest, + name: "CustomKeyShares-MultipleNonContiguous-TLS13", + config: Config{ + MinVersion: VersionTLS13, + Bugs: ProtocolBugs{ + ExpectedKeyShares: []CurveID{CurveX25519, CurveP256}, + }, + }, + flags: []string{ + "-curves", strconv.Itoa(int(CurveX25519)), + "-curves", strconv.Itoa(int(CurveX25519MLKEM768)), + "-curves", strconv.Itoa(int(CurveP256)), + // Predict the first and last of 3 supported curves. + "-key-shares", strconv.Itoa(int(CurveX25519)), + "-key-shares", strconv.Itoa(int(CurveP256)), + "-expect-curve-id", strconv.Itoa(int(CurveX25519)), + "-expect-no-hrr", + }, + }) + + testCases = append(testCases, testCase{ + testType: clientTest, + name: "CustomKeyShares-NotMostPreferred-HRR-TLS13", + config: Config{ + MinVersion: VersionTLS13, + Bugs: ProtocolBugs{ + ExpectedKeyShares: []CurveID{CurveX25519MLKEM768, CurveP256}, + }, + }, + flags: []string{ + "-curves", strconv.Itoa(int(CurveX25519)), + "-curves", strconv.Itoa(int(CurveX25519MLKEM768)), + "-curves", strconv.Itoa(int(CurveP256)), + // Predict some curves that we support, not including the top one. + "-key-shares", strconv.Itoa(int(CurveX25519MLKEM768)), + "-key-shares", strconv.Itoa(int(CurveP256)), + "-expect-curve-id", strconv.Itoa(int(CurveX25519)), + // Check that we triggered a HelloRetryRequest. + "-expect-hrr", + }, + }) + + testCases = append(testCases, testCase{ + testType: clientTest, + name: "CustomKeyShares-PeerSelectedLaterKeyShare-TLS13", + config: Config{ + MinVersion: VersionTLS13, + CurvePreferences: []CurveID{CurveP256}, + Bugs: ProtocolBugs{ + ExpectedKeyShares: []CurveID{CurveX25519MLKEM768, CurveP256}, + // Make the server select one of the groups with a key_share, but not + // the most preferred one. + SendCurve: CurveP256, + }, + }, + flags: []string{ + "-curves", strconv.Itoa(int(CurveX25519)), + "-curves", strconv.Itoa(int(CurveX25519MLKEM768)), + "-curves", strconv.Itoa(int(CurveP256)), + "-key-shares", strconv.Itoa(int(CurveX25519MLKEM768)), + "-key-shares", strconv.Itoa(int(CurveP256)), + "-expect-curve-id", strconv.Itoa(int(CurveP256)), + "-expect-no-hrr", + }, + }) + + testCases = append(testCases, testCase{ + testType: clientTest, + name: "CustomKeyShares-All-TLS13", + config: Config{ + MinVersion: VersionTLS13, + Bugs: ProtocolBugs{ + ExpectedKeyShares: []CurveID{ + CurveP256, + CurveP384, + CurveP521, + CurveX25519, + CurveX25519Kyber768, + CurveX25519MLKEM768, + CurveMLKEM1024, + }, + }, + }, + flags: []string{ + "-curves", strconv.Itoa(int(CurveP256)), + "-curves", strconv.Itoa(int(CurveP384)), + "-curves", strconv.Itoa(int(CurveP521)), + "-curves", strconv.Itoa(int(CurveX25519)), + "-curves", strconv.Itoa(int(CurveX25519Kyber768)), + "-curves", strconv.Itoa(int(CurveX25519MLKEM768)), + "-curves", strconv.Itoa(int(CurveMLKEM1024)), + "-key-shares", strconv.Itoa(int(CurveP256)), + "-key-shares", strconv.Itoa(int(CurveP384)), + "-key-shares", strconv.Itoa(int(CurveP521)), + "-key-shares", strconv.Itoa(int(CurveX25519)), + "-key-shares", strconv.Itoa(int(CurveX25519Kyber768)), + "-key-shares", strconv.Itoa(int(CurveX25519MLKEM768)), + "-key-shares", strconv.Itoa(int(CurveMLKEM1024)), + }, + }) + + testCases = append(testCases, testCase{ + testType: clientTest, + name: "CustomKeyShares-Empty-HRR-TLS13", + config: Config{ + MinVersion: VersionTLS13, + Bugs: ProtocolBugs{ + ExpectedKeyShares: []CurveID{}, + }, + }, + flags: []string{ + "-no-key-shares", + "-expect-hrr", + }, + }) + + testCases = append(testCases, testCase{ testType: serverTest, name: "SkipEarlyData-TLS13", config: Config{
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc index f973d9e..0cfdc37 100644 --- a/ssl/test/test_config.cc +++ b/ssl/test/test_config.cc
@@ -53,12 +53,17 @@ template <typename Config> struct Flag { const char *name; + // has_param, if true, causes the parser to look for a param value following + // this flag's name. bool has_param; // skip_handshaker, if true, causes this flag to be skipped when // forwarding flags to the handshaker. This should be used with flags // that only impact connecting to the runner. bool skip_handshaker; - // If |has_param| is false, |param| will be nullptr. + // set_param is called after parsing to interpret and set the result on + // `config`. If `has_param` is false for this flag, `param` will be nullptr. + // This function should return whether the param value (or lack thereof) was + // valid for this flag. std::function<bool(Config *config, const char *param)> set_param; }; @@ -168,6 +173,39 @@ }}; } +// Defines a flag which adds an integer param value to an optional vector of +// integers. +template <typename Config, typename T> +Flag<Config> OptionalIntVectorFlag(const char *name, + std::optional<std::vector<T>> Config::*field, + bool skip_handshaker = false) { + return Flag<Config>{name, true, skip_handshaker, + [=](Config *config, const char *param) -> bool { + if (!(config->*field)) { + (config->*field).emplace(); + } + T value; + if (!StringToInt(&value, param)) { + return false; + } + (config->*field)->push_back(value); + return true; + }}; +} + +// Defines a flag which resets a std::optional field to its default constructed +// value. +template <typename Config, typename T> +Flag<Config> OptionalDefaultInitFlag(const char *name, + std::optional<T> Config::*field, + bool skip_handshaker = false) { + return Flag<Config>{name, false, skip_handshaker, + [=](Config *config, const char *) -> bool { + (config->*field).emplace(); + return true; + }}; +} + template <typename Config> Flag<Config> StringFlag(const char *name, std::string Config::*field, bool skip_handshaker = false) { @@ -327,6 +365,8 @@ IntVectorFlag("-expect-peer-verify-pref", &TestConfig::expect_peer_verify_prefs), IntVectorFlag("-curves", &TestConfig::curves), + OptionalIntVectorFlag("-key-shares", &TestConfig::key_shares), + OptionalDefaultInitFlag("-no-key-shares", &TestConfig::key_shares), StringFlag("-trust-cert", &TestConfig::trust_cert), StringFlag("-expect-server-name", &TestConfig::expect_server_name), BoolFlag("-enable-ech-grease", &TestConfig::enable_ech_grease), @@ -679,6 +719,9 @@ if (!skip) { if (out != nullptr) { if (!flag->set_param(out, param)) { + if (!param) { + param = "(no parameter)"; + } fprintf(stderr, "Invalid parameter for %s: %s\n", name, param); return false; } @@ -687,6 +730,9 @@ if (!flag->set_param(out_initial, param) || !flag->set_param(out_resume, param) || !flag->set_param(out_retry, param)) { + if (!param) { + param = "(no parameter)"; + } fprintf(stderr, "Invalid parameter for %s: %s\n", name, param); return false; } @@ -2481,6 +2527,11 @@ !SSL_set1_group_ids(ssl.get(), curves.data(), curves.size())) { return nullptr; } + if (key_shares.has_value() && + !SSL_set1_client_key_shares(ssl.get(), key_shares->data(), + key_shares->size())) { + return nullptr; + } if (initial_timeout_duration_ms > 0) { DTLSv1_set_initial_timeout_duration(ssl.get(), initial_timeout_duration_ms); }
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h index de2f5c5..3b2f4e1 100644 --- a/ssl/test/test_config.h +++ b/ssl/test/test_config.h
@@ -65,6 +65,7 @@ std::vector<uint16_t> verify_prefs; std::vector<uint16_t> expect_peer_verify_prefs; std::vector<uint16_t> curves; + std::optional<std::vector<uint16_t>> key_shares; std::string key_file; std::string cert_file; std::string trust_cert;
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc index dc417b9..a664321 100644 --- a/ssl/tls13_client.cc +++ b/ssl/tls13_client.cc
@@ -253,7 +253,11 @@ // The ECH extension, if present, was already parsed by // |check_ech_confirmation|. SSLExtension cookie(TLSEXT_TYPE_cookie), - key_share(TLSEXT_TYPE_key_share, !hs->key_share_bytes.empty()), + // If offering PAKE, we won't send key_share extensions and we should + // reject key_share from the peer. Otherwise, it is valid to have sent an + // empty key_share extension, and expect the HelloRetryRequest to contain + // a key_share. + key_share(TLSEXT_TYPE_key_share, !hs->pake_prover), supported_versions(TLSEXT_TYPE_supported_versions), ech_unused(TLSEXT_TYPE_encrypted_client_hello, hs->selected_ech_config || hs->config->ech_grease_enabled); @@ -286,8 +290,6 @@ } if (key_share.present) { - // If offering PAKE, we won't send key_share extensions, in which case we - // would have rejected key_share from the peer. assert(!hs->pake_prover); uint16_t group_id;