Add an interface for QUIC integration. 0-RTT support and APIs to consume NewSessionTicket will be added in a follow-up. Change-Id: Ib2b2c6b618b3e33a74355fb53fdbd2ffafcc5c56 Reviewed-on: https://boringssl-review.googlesource.com/c/31744 Commit-Queue: Steven Valdez <svaldez@google.com> CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org> Reviewed-by: Steven Valdez <svaldez@google.com> Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc index 61a47d3..c237809 100644 --- a/ssl/ssl_test.cc +++ b/ssl/ssl_test.cc
@@ -17,6 +17,7 @@ #include <time.h> #include <algorithm> +#include <limits> #include <string> #include <utility> #include <vector> @@ -104,6 +105,26 @@ std::vector<uint16_t> expected; }; +template <typename T> +class UnownedSSLExData { + public: + UnownedSSLExData() { + index_ = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + } + + T *Get(const SSL *ssl) { + return index_ < 0 ? nullptr + : static_cast<T *>(SSL_get_ex_data(ssl, index_)); + } + + bool Set(SSL *ssl, T *t) { + return index_ >= 0 && SSL_set_ex_data(ssl, index_, t); + } + + private: + int index_; +}; + static const CipherTest kCipherTests[] = { // Selecting individual ciphers should work. { @@ -4422,6 +4443,546 @@ } #endif +constexpr size_t kNumQUICLevels = 4; +static_assert(ssl_encryption_initial < kNumQUICLevels, + "kNumQUICLevels is wrong"); +static_assert(ssl_encryption_early_data < kNumQUICLevels, + "kNumQUICLevels is wrong"); +static_assert(ssl_encryption_handshake < kNumQUICLevels, + "kNumQUICLevels is wrong"); +static_assert(ssl_encryption_application < kNumQUICLevels, + "kNumQUICLevels is wrong"); + +class MockQUICTransport { + public: + MockQUICTransport() { + // The caller is expected to configure initial secrets. + levels_[ssl_encryption_initial].write_secret = {1}; + levels_[ssl_encryption_initial].read_secret = {1}; + } + + void set_peer(MockQUICTransport *peer) { peer_ = peer; } + + bool has_alert() const { return has_alert_; } + ssl_encryption_level_t alert_level() const { return alert_level_; } + uint8_t alert() const { return alert_; } + + bool PeerSecretsMatch(ssl_encryption_level_t level) const { + return levels_[level].write_secret == peer_->levels_[level].read_secret && + levels_[level].read_secret == peer_->levels_[level].write_secret; + } + + bool HasSecrets(ssl_encryption_level_t level) const { + return !levels_[level].write_secret.empty() || + !levels_[level].read_secret.empty(); + } + + bool SetEncryptionSecrets(ssl_encryption_level_t level, + const uint8_t *read_secret, + const uint8_t *write_secret, size_t secret_len) { + if (HasSecrets(level)) { + ADD_FAILURE() << "duplicate keys configured"; + return false; + } + if (level != ssl_encryption_early_data && + (read_secret == nullptr || write_secret == nullptr)) { + ADD_FAILURE() << "key was unexpectedly null"; + return false; + } + if (read_secret != nullptr) { + levels_[level].read_secret.assign(read_secret, read_secret + secret_len); + } + if (write_secret != nullptr) { + levels_[level].write_secret.assign(write_secret, + write_secret + secret_len); + } + return true; + } + + bool WriteHandshakeData(ssl_encryption_level_t level, + Span<const uint8_t> data) { + if (levels_[level].write_secret.empty()) { + ADD_FAILURE() << "data written before keys configured"; + return false; + } + levels_[level].write_data.insert(levels_[level].write_data.end(), + data.begin(), data.end()); + return true; + } + + bool SendAlert(ssl_encryption_level_t level, uint8_t alert_value) { + if (has_alert_) { + ADD_FAILURE() << "duplicate alert sent"; + return false; + } + + if (levels_[level].write_secret.empty()) { + ADD_FAILURE() << "alert sent before keys configured"; + return false; + } + + has_alert_ = true; + alert_level_ = level; + alert_ = alert_value; + return true; + } + + bool ReadHandshakeData(std::vector<uint8_t> *out, + ssl_encryption_level_t level, + size_t num = std::numeric_limits<size_t>::max()) { + if (levels_[level].read_secret.empty()) { + ADD_FAILURE() << "data read before keys configured"; + return false; + } + // The peer may not have configured any keys yet. + if (peer_->levels_[level].write_secret.empty()) { + return true; + } + // Check the peer computed the same key. + if (peer_->levels_[level].write_secret != levels_[level].read_secret) { + ADD_FAILURE() << "peer write key does not match read key"; + return false; + } + std::vector<uint8_t> *peer_data = &peer_->levels_[level].write_data; + num = std::min(num, peer_data->size()); + out->assign(peer_data->begin(), peer_data->begin() + num); + peer_data->erase(peer_data->begin(), peer_data->begin() + num); + return true; + } + + private: + MockQUICTransport *peer_ = nullptr; + + bool has_alert_ = false; + ssl_encryption_level_t alert_level_ = ssl_encryption_initial; + uint8_t alert_ = 0; + + struct Level { + std::vector<uint8_t> write_data; + std::vector<uint8_t> write_secret; + std::vector<uint8_t> read_secret; + }; + Level levels_[kNumQUICLevels]; +}; + +class MockQUICTransportPair { + public: + MockQUICTransportPair() { + server_.set_peer(&client_); + client_.set_peer(&server_); + } + + ~MockQUICTransportPair() { + server_.set_peer(nullptr); + client_.set_peer(nullptr); + } + + MockQUICTransport *client() { return &client_; } + MockQUICTransport *server() { return &server_; } + + bool SecretsMatch(ssl_encryption_level_t level) const { + return client_.PeerSecretsMatch(level); + } + + private: + MockQUICTransport client_; + MockQUICTransport server_; +}; + +class QUICMethodTest : public testing::Test { + protected: + void SetUp() override { + client_ctx_.reset(SSL_CTX_new(TLS_method())); + server_ctx_.reset(SSL_CTX_new(TLS_method())); + ASSERT_TRUE(client_ctx_); + ASSERT_TRUE(server_ctx_); + + bssl::UniquePtr<X509> cert = GetTestCertificate(); + bssl::UniquePtr<EVP_PKEY> key = GetTestKey(); + ASSERT_TRUE(cert); + ASSERT_TRUE(key); + ASSERT_TRUE(SSL_CTX_use_certificate(server_ctx_.get(), cert.get())); + ASSERT_TRUE(SSL_CTX_use_PrivateKey(server_ctx_.get(), key.get())); + + SSL_CTX_set_min_proto_version(server_ctx_.get(), TLS1_3_VERSION); + SSL_CTX_set_max_proto_version(server_ctx_.get(), TLS1_3_VERSION); + SSL_CTX_set_min_proto_version(client_ctx_.get(), TLS1_3_VERSION); + SSL_CTX_set_max_proto_version(client_ctx_.get(), TLS1_3_VERSION); + } + + static MockQUICTransport *TransportFromSSL(const SSL *ssl) { + return ex_data_.Get(ssl); + } + + static bool ProvideHandshakeData( + SSL *ssl, size_t num = std::numeric_limits<size_t>::max()) { + MockQUICTransport *transport = TransportFromSSL(ssl); + ssl_encryption_level_t level = SSL_quic_read_level(ssl); + std::vector<uint8_t> data; + return transport->ReadHandshakeData(&data, level, num) && + SSL_provide_quic_data(ssl, level, data.data(), data.size()); + } + + bool CreateClientAndServer() { + client_.reset(SSL_new(client_ctx_.get())); + server_.reset(SSL_new(server_ctx_.get())); + if (!client_ || !server_) { + return false; + } + + SSL_set_connect_state(client_.get()); + SSL_set_accept_state(server_.get()); + + ex_data_.Set(client_.get(), transport_.client()); + ex_data_.Set(server_.get(), transport_.server()); + return true; + } + + // The following functions may be configured on an |SSL_QUIC_METHOD| as + // default implementations. + + static int SetEncryptionSecretsCallback(SSL *ssl, + ssl_encryption_level_t level, + const uint8_t *read_key, + const uint8_t *write_key, + size_t key_len) { + return TransportFromSSL(ssl)->SetEncryptionSecrets(level, read_key, + write_key, key_len); + } + + static int AddMessageCallback(SSL *ssl, enum ssl_encryption_level_t level, + const uint8_t *data, size_t len) { + EXPECT_EQ(level, SSL_quic_write_level(ssl)); + return TransportFromSSL(ssl)->WriteHandshakeData(level, + MakeConstSpan(data, len)); + } + + static int FlushFlightCallback(SSL *ssl) { return 1; } + + static int SendAlertCallback(SSL *ssl, ssl_encryption_level_t level, + uint8_t alert) { + EXPECT_EQ(level, SSL_quic_write_level(ssl)); + return TransportFromSSL(ssl)->SendAlert(level, alert); + } + + bssl::UniquePtr<SSL_CTX> client_ctx_; + bssl::UniquePtr<SSL_CTX> server_ctx_; + + static UnownedSSLExData<MockQUICTransport> ex_data_; + MockQUICTransportPair transport_; + + bssl::UniquePtr<SSL> client_; + bssl::UniquePtr<SSL> server_; +}; + +UnownedSSLExData<MockQUICTransport> QUICMethodTest::ex_data_; + +// Test a full handshake works. +TEST_F(QUICMethodTest, Basic) { + const SSL_QUIC_METHOD quic_method = { + SetEncryptionSecretsCallback, + AddMessageCallback, + FlushFlightCallback, + SendAlertCallback, + }; + + ASSERT_TRUE(SSL_CTX_set_quic_method(client_ctx_.get(), &quic_method)); + ASSERT_TRUE(SSL_CTX_set_quic_method(server_ctx_.get(), &quic_method)); + ASSERT_TRUE(CreateClientAndServer()); + + for (;;) { + ASSERT_TRUE(ProvideHandshakeData(client_.get())); + int client_ret = SSL_do_handshake(client_.get()); + if (client_ret != 1) { + ASSERT_EQ(client_ret, -1); + ASSERT_EQ(SSL_get_error(client_.get(), client_ret), SSL_ERROR_WANT_READ); + } + + ASSERT_TRUE(ProvideHandshakeData(server_.get())); + int server_ret = SSL_do_handshake(server_.get()); + if (server_ret != 1) { + ASSERT_EQ(server_ret, -1); + ASSERT_EQ(SSL_get_error(server_.get(), server_ret), SSL_ERROR_WANT_READ); + } + + if (client_ret == 1 && server_ret == 1) { + break; + } + } + + EXPECT_EQ(SSL_do_handshake(client_.get()), 1); + EXPECT_EQ(SSL_do_handshake(server_.get()), 1); + EXPECT_TRUE(transport_.SecretsMatch(ssl_encryption_application)); + EXPECT_FALSE(transport_.client()->has_alert()); + EXPECT_FALSE(transport_.server()->has_alert()); + + // The server sent NewSessionTicket messages in the handshake. + // + // TODO(davidben,svaldez): Add an API for the client to consume post-handshake + // messages and update these tests. + std::vector<uint8_t> new_session_ticket; + ASSERT_TRUE(transport_.client()->ReadHandshakeData( + &new_session_ticket, ssl_encryption_application)); + EXPECT_FALSE(new_session_ticket.empty()); +} + +// Test only releasing data to QUIC one byte at a time on request, to maximize +// state machine pauses. Additionally, test that existing asynchronous callbacks +// still work. +TEST_F(QUICMethodTest, Async) { + const SSL_QUIC_METHOD quic_method = { + SetEncryptionSecretsCallback, + AddMessageCallback, + FlushFlightCallback, + SendAlertCallback, + }; + + ASSERT_TRUE(SSL_CTX_set_quic_method(client_ctx_.get(), &quic_method)); + ASSERT_TRUE(SSL_CTX_set_quic_method(server_ctx_.get(), &quic_method)); + ASSERT_TRUE(CreateClientAndServer()); + + // Install an asynchronous certificate callback. + bool cert_cb_ok = false; + SSL_set_cert_cb(server_.get(), + [](SSL *, void *arg) -> int { + return *static_cast<bool *>(arg) ? 1 : -1; + }, + &cert_cb_ok); + + for (;;) { + int client_ret = SSL_do_handshake(client_.get()); + if (client_ret != 1) { + ASSERT_EQ(client_ret, -1); + ASSERT_EQ(SSL_get_error(client_.get(), client_ret), SSL_ERROR_WANT_READ); + ASSERT_TRUE(ProvideHandshakeData(client_.get(), 1)); + } + + int server_ret = SSL_do_handshake(server_.get()); + if (server_ret != 1) { + ASSERT_EQ(server_ret, -1); + int ssl_err = SSL_get_error(server_.get(), server_ret); + switch (ssl_err) { + case SSL_ERROR_WANT_READ: + ASSERT_TRUE(ProvideHandshakeData(server_.get(), 1)); + break; + case SSL_ERROR_WANT_X509_LOOKUP: + ASSERT_FALSE(cert_cb_ok); + cert_cb_ok = true; + break; + default: + FAIL() << "Unexpected SSL_get_error result: " << ssl_err; + } + } + + if (client_ret == 1 && server_ret == 1) { + break; + } + } + + EXPECT_EQ(SSL_do_handshake(client_.get()), 1); + EXPECT_EQ(SSL_do_handshake(server_.get()), 1); + EXPECT_TRUE(transport_.SecretsMatch(ssl_encryption_application)); + EXPECT_FALSE(transport_.client()->has_alert()); + EXPECT_FALSE(transport_.server()->has_alert()); +} + +// Test buffering write data until explicit flushes. +TEST_F(QUICMethodTest, Buffered) { + struct BufferedFlight { + std::vector<uint8_t> data[kNumQUICLevels]; + }; + static UnownedSSLExData<BufferedFlight> buffered_flights; + + auto add_message = [](SSL *ssl, enum ssl_encryption_level_t level, + const uint8_t *data, size_t len) -> int { + BufferedFlight *flight = buffered_flights.Get(ssl); + flight->data[level].insert(flight->data[level].end(), data, data + len); + return 1; + }; + + auto flush_flight = [](SSL *ssl) -> int { + BufferedFlight *flight = buffered_flights.Get(ssl); + for (size_t level = 0; level < kNumQUICLevels; level++) { + if (!flight->data[level].empty()) { + if (!TransportFromSSL(ssl)->WriteHandshakeData( + static_cast<ssl_encryption_level_t>(level), + flight->data[level])) { + return 0; + } + flight->data[level].clear(); + } + } + return 1; + }; + + const SSL_QUIC_METHOD quic_method = { + SetEncryptionSecretsCallback, + add_message, + flush_flight, + SendAlertCallback, + }; + + ASSERT_TRUE(SSL_CTX_set_quic_method(client_ctx_.get(), &quic_method)); + ASSERT_TRUE(SSL_CTX_set_quic_method(server_ctx_.get(), &quic_method)); + ASSERT_TRUE(CreateClientAndServer()); + + BufferedFlight client_flight, server_flight; + buffered_flights.Set(client_.get(), &client_flight); + buffered_flights.Set(server_.get(), &server_flight); + + for (;;) { + ASSERT_TRUE(ProvideHandshakeData(client_.get())); + int client_ret = SSL_do_handshake(client_.get()); + if (client_ret != 1) { + ASSERT_EQ(client_ret, -1); + ASSERT_EQ(SSL_get_error(client_.get(), client_ret), SSL_ERROR_WANT_READ); + } + + ASSERT_TRUE(ProvideHandshakeData(server_.get())); + int server_ret = SSL_do_handshake(server_.get()); + if (server_ret != 1) { + ASSERT_EQ(server_ret, -1); + ASSERT_EQ(SSL_get_error(server_.get(), server_ret), SSL_ERROR_WANT_READ); + } + + if (client_ret == 1 && server_ret == 1) { + break; + } + } + + EXPECT_EQ(SSL_do_handshake(client_.get()), 1); + EXPECT_EQ(SSL_do_handshake(server_.get()), 1); + EXPECT_TRUE(transport_.SecretsMatch(ssl_encryption_application)); + EXPECT_FALSE(transport_.client()->has_alert()); + EXPECT_FALSE(transport_.server()->has_alert()); +} + +// Test that excess data at one level is rejected. That is, if a single +// |SSL_provide_quic_data| call included both ServerHello and +// EncryptedExtensions in a single chunk, BoringSSL notices and rejects this on +// key change. +TEST_F(QUICMethodTest, ExcessProvidedData) { + auto add_message = [](SSL *ssl, enum ssl_encryption_level_t level, + const uint8_t *data, size_t len) -> int { + // Switch everything to the initial level. + return TransportFromSSL(ssl)->WriteHandshakeData(ssl_encryption_initial, + MakeConstSpan(data, len)); + }; + + const SSL_QUIC_METHOD quic_method = { + SetEncryptionSecretsCallback, + add_message, + FlushFlightCallback, + SendAlertCallback, + }; + + ASSERT_TRUE(SSL_CTX_set_quic_method(client_ctx_.get(), &quic_method)); + ASSERT_TRUE(SSL_CTX_set_quic_method(server_ctx_.get(), &quic_method)); + ASSERT_TRUE(CreateClientAndServer()); + + // Send the ClientHello and ServerHello through Finished. + ASSERT_EQ(SSL_do_handshake(client_.get()), -1); + ASSERT_EQ(SSL_get_error(client_.get(), -1), SSL_ERROR_WANT_READ); + ASSERT_TRUE(ProvideHandshakeData(server_.get())); + ASSERT_EQ(SSL_do_handshake(server_.get()), -1); + ASSERT_EQ(SSL_get_error(server_.get(), -1), SSL_ERROR_WANT_READ); + + // The client is still waiting for the ServerHello at initial + // encryption. + ASSERT_EQ(ssl_encryption_initial, SSL_quic_read_level(client_.get())); + + // |add_message| incorrectly wrote everything at the initial level, so this + // queues up ServerHello through Finished in one chunk. + ASSERT_TRUE(ProvideHandshakeData(client_.get())); + + // The client reads ServerHello successfully, but then rejects the buffered + // EncryptedExtensions on key change. + ASSERT_EQ(SSL_do_handshake(client_.get()), -1); + ASSERT_EQ(SSL_get_error(client_.get(), -1), SSL_ERROR_SSL); + uint32_t err = ERR_get_error(); + EXPECT_EQ(ERR_GET_LIB(err), ERR_LIB_SSL); + EXPECT_EQ(ERR_GET_REASON(err), SSL_R_BUFFERED_MESSAGES_ON_CIPHER_CHANGE); + + // The client sends an alert in response to this. + ASSERT_TRUE(transport_.client()->has_alert()); + EXPECT_EQ(transport_.client()->alert_level(), ssl_encryption_initial); + EXPECT_EQ(transport_.client()->alert(), SSL_AD_UNEXPECTED_MESSAGE); + + // Sanity-check client did get far enough to process the ServerHello and + // install keys. + EXPECT_TRUE(transport_.client()->HasSecrets(ssl_encryption_handshake)); +} + +// Test that |SSL_provide_quic_data| will reject data at the wrong level. +TEST_F(QUICMethodTest, ProvideWrongLevel) { + const SSL_QUIC_METHOD quic_method = { + SetEncryptionSecretsCallback, + AddMessageCallback, + FlushFlightCallback, + SendAlertCallback, + }; + + ASSERT_TRUE(SSL_CTX_set_quic_method(client_ctx_.get(), &quic_method)); + ASSERT_TRUE(SSL_CTX_set_quic_method(server_ctx_.get(), &quic_method)); + ASSERT_TRUE(CreateClientAndServer()); + + // Send the ClientHello and ServerHello through Finished. + ASSERT_EQ(SSL_do_handshake(client_.get()), -1); + ASSERT_EQ(SSL_get_error(client_.get(), -1), SSL_ERROR_WANT_READ); + ASSERT_TRUE(ProvideHandshakeData(server_.get())); + ASSERT_EQ(SSL_do_handshake(server_.get()), -1); + ASSERT_EQ(SSL_get_error(server_.get(), -1), SSL_ERROR_WANT_READ); + + // The client is still waiting for the ServerHello at initial + // encryption. + ASSERT_EQ(ssl_encryption_initial, SSL_quic_read_level(client_.get())); + + // Data cannot be provided at the next level. + std::vector<uint8_t> data; + ASSERT_TRUE( + transport_.client()->ReadHandshakeData(&data, ssl_encryption_initial)); + ASSERT_FALSE(SSL_provide_quic_data(client_.get(), ssl_encryption_handshake, + data.data(), data.size())); + ERR_clear_error(); + + // Progress to EncryptedExtensions. + ASSERT_TRUE(SSL_provide_quic_data(client_.get(), ssl_encryption_initial, + data.data(), data.size())); + ASSERT_EQ(SSL_do_handshake(client_.get()), -1); + ASSERT_EQ(SSL_get_error(client_.get(), -1), SSL_ERROR_WANT_READ); + ASSERT_EQ(ssl_encryption_handshake, SSL_quic_read_level(client_.get())); + + // Data cannot be provided at the previous level. + ASSERT_TRUE( + transport_.client()->ReadHandshakeData(&data, ssl_encryption_handshake)); + ASSERT_FALSE(SSL_provide_quic_data(client_.get(), ssl_encryption_initial, + data.data(), data.size())); +} + +TEST_F(QUICMethodTest, TooMuchData) { + const SSL_QUIC_METHOD quic_method = { + SetEncryptionSecretsCallback, + AddMessageCallback, + FlushFlightCallback, + SendAlertCallback, + }; + + ASSERT_TRUE(SSL_CTX_set_quic_method(client_ctx_.get(), &quic_method)); + ASSERT_TRUE(SSL_CTX_set_quic_method(server_ctx_.get(), &quic_method)); + ASSERT_TRUE(CreateClientAndServer()); + + size_t limit = + SSL_quic_max_handshake_flight_len(client_.get(), ssl_encryption_initial); + uint8_t b = 0; + for (size_t i = 0; i < limit; i++) { + ASSERT_TRUE( + SSL_provide_quic_data(client_.get(), ssl_encryption_initial, &b, 1)); + } + + EXPECT_FALSE( + SSL_provide_quic_data(client_.get(), ssl_encryption_initial, &b, 1)); +} + // TODO(davidben): Convert this file to GTest properly. TEST(SSLTest, AllTests) { if (!TestSSL_SESSIONEncoding(kOpenSSLSession) ||