Add a test for session ID context logic. We almost forgot to handle this in TLS 1.3, so add a test for it. Change-Id: I28600325d8fb6c09365e909db607cbace12ecac7 Reviewed-on: https://boringssl-review.googlesource.com/9093 Reviewed-by: Adam Langley <agl@google.com> Commit-Queue: Adam Langley <agl@google.com> CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc index 0dc0884..97ade03 100644 --- a/ssl/ssl_test.cc +++ b/ssl/ssl_test.cc
@@ -1096,7 +1096,8 @@ } static bool ConnectClientAndServer(ScopedSSL *out_client, ScopedSSL *out_server, - SSL_CTX *client_ctx, SSL_CTX *server_ctx) { + SSL_CTX *client_ctx, SSL_CTX *server_ctx, + SSL_SESSION *session) { ScopedSSL client(SSL_new(client_ctx)), server(SSL_new(server_ctx)); if (!client || !server) { return false; @@ -1104,6 +1105,8 @@ SSL_set_connect_state(client.get()); SSL_set_accept_state(server.get()); + SSL_set_session(client.get(), session); + BIO *bio1, *bio2; if (!BIO_new_bio_pair(&bio1, 0, &bio2, 0)) { return false; @@ -1159,7 +1162,7 @@ ScopedSSL client, server; if (!ConnectClientAndServer(&client, &server, client_ctx.get(), - server_ctx.get())) { + server_ctx.get(), nullptr /* no session */)) { return false; } @@ -1228,7 +1231,7 @@ ScopedSSL client, server; if (!ConnectClientAndServer(&client, &server, client_ctx.get(), - server_ctx.get())) { + server_ctx.get(), nullptr /* no session */)) { return false; } @@ -1283,7 +1286,7 @@ ScopedSSL client, server; if (!ConnectClientAndServer(&client, &server, client_ctx.get(), - server_ctx.get())) { + server_ctx.get(), nullptr /* no session */)) { return false; } @@ -1494,7 +1497,8 @@ SSL_CTX_set_cert_verify_callback(ctx.get(), VerifySucceed, NULL); ScopedSSL client, server; - if (!ConnectClientAndServer(&client, &server, ctx.get(), ctx.get())) { + if (!ConnectClientAndServer(&client, &server, ctx.get(), ctx.get(), + nullptr /* no session */)) { return false; } @@ -1561,7 +1565,8 @@ SSL_CTX_set_retain_only_sha256_of_client_certs(ctx.get(), 1); ScopedSSL client, server; - if (!ConnectClientAndServer(&client, &server, ctx.get(), ctx.get())) { + if (!ConnectClientAndServer(&client, &server, ctx.get(), ctx.get(), + nullptr /* no session */)) { return false; } @@ -1712,6 +1717,129 @@ return true; } +static ScopedSSL_SESSION g_last_session; + +static int SaveLastSession(SSL *ssl, SSL_SESSION *session) { + // Save the most recent session. + g_last_session.reset(session); + return 1; +} + +static ScopedSSL_SESSION CreateClientSession(SSL_CTX *client_ctx, + SSL_CTX *server_ctx) { + g_last_session = nullptr; + SSL_CTX_sess_set_new_cb(client_ctx, SaveLastSession); + + // Connect client and server to get a session. + ScopedSSL client, server; + if (!ConnectClientAndServer(&client, &server, client_ctx, server_ctx, + nullptr /* no session */)) { + fprintf(stderr, "Failed to connect client and server.\n"); + return nullptr; + } + + // Run the read loop to account for post-handshake tickets in TLS 1.3. + SSL_read(client.get(), nullptr, 0); + + SSL_CTX_sess_set_new_cb(client_ctx, nullptr); + + if (!g_last_session) { + fprintf(stderr, "Client did not receive a session.\n"); + return nullptr; + } + return std::move(g_last_session); +} + +static bool ExpectSessionReused(SSL_CTX *client_ctx, SSL_CTX *server_ctx, + SSL_SESSION *session, + bool reused) { + ScopedSSL client, server; + if (!ConnectClientAndServer(&client, &server, client_ctx, + server_ctx, session)) { + fprintf(stderr, "Failed to connect client and server.\n"); + return false; + } + + if (SSL_session_reused(client.get()) != SSL_session_reused(server.get())) { + fprintf(stderr, "Client and server were inconsistent.\n"); + return false; + } + + bool was_reused = !!SSL_session_reused(client.get()); + if (was_reused != reused) { + fprintf(stderr, "Session was%s reused, but we expected the opposite.\n", + was_reused ? "" : " not"); + return false; + } + + return true; +} + +static bool TestSessionIDContext() { + ScopedX509 cert = GetTestCertificate(); + ScopedEVP_PKEY key = GetTestKey(); + if (!cert || !key) { + return false; + } + + static const uint8_t kContext1[] = {1}; + static const uint8_t kContext2[] = {2}; + + for (uint16_t version : kVersions) { + // TODO(davidben): Enable this when TLS 1.3 resumption is implemented. + if (version == TLS1_3_VERSION) { + continue; + } + + ScopedSSL_CTX server_ctx(SSL_CTX_new(TLS_method())); + ScopedSSL_CTX client_ctx(SSL_CTX_new(TLS_method())); + if (!server_ctx || !client_ctx || + !SSL_CTX_use_certificate(server_ctx.get(), cert.get()) || + !SSL_CTX_use_PrivateKey(server_ctx.get(), key.get()) || + !SSL_CTX_set_session_id_context(server_ctx.get(), kContext1, + sizeof(kContext1))) { + return false; + } + + SSL_CTX_set_min_version(client_ctx.get(), version); + SSL_CTX_set_max_version(client_ctx.get(), version); + SSL_CTX_set_session_cache_mode(client_ctx.get(), SSL_SESS_CACHE_BOTH); + + SSL_CTX_set_min_version(server_ctx.get(), version); + SSL_CTX_set_max_version(server_ctx.get(), version); + SSL_CTX_set_session_cache_mode(server_ctx.get(), SSL_SESS_CACHE_BOTH); + + ScopedSSL_SESSION session = + CreateClientSession(client_ctx.get(), server_ctx.get()); + if (!session) { + fprintf(stderr, "Error getting session (version = %04x).\n", version); + return false; + } + + if (!ExpectSessionReused(client_ctx.get(), server_ctx.get(), session.get(), + true /* expect session reused */)) { + fprintf(stderr, "Error resuming session (version = %04x).\n", version); + return false; + } + + // Change the session ID context. + if (!SSL_CTX_set_session_id_context(server_ctx.get(), kContext2, + sizeof(kContext2))) { + return false; + } + + if (!ExpectSessionReused(client_ctx.get(), server_ctx.get(), session.get(), + false /* expect session not reused */)) { + fprintf(stderr, + "Error connection with different context (version = %04x).\n", + version); + return false; + } + } + + return true; +} + int main() { CRYPTO_library_init(); @@ -1743,7 +1871,8 @@ !TestSetBIO() || !TestGetPeerCertificate() || !TestRetainOnlySHA256OfCerts() || - !TestClientHello()) { + !TestClientHello() || + !TestSessionIDContext()) { ERR_print_errors_fp(stderr); return 1; }