Add a test for session ID context logic.

We almost forgot to handle this in TLS 1.3, so add a test for it.

Change-Id: I28600325d8fb6c09365e909db607cbace12ecac7
Reviewed-on: https://boringssl-review.googlesource.com/9093
Reviewed-by: Adam Langley <agl@google.com>
Commit-Queue: Adam Langley <agl@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 0dc0884..97ade03 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -1096,7 +1096,8 @@
 }
 
 static bool ConnectClientAndServer(ScopedSSL *out_client, ScopedSSL *out_server,
-                                   SSL_CTX *client_ctx, SSL_CTX *server_ctx) {
+                                   SSL_CTX *client_ctx, SSL_CTX *server_ctx,
+                                   SSL_SESSION *session) {
   ScopedSSL client(SSL_new(client_ctx)), server(SSL_new(server_ctx));
   if (!client || !server) {
     return false;
@@ -1104,6 +1105,8 @@
   SSL_set_connect_state(client.get());
   SSL_set_accept_state(server.get());
 
+  SSL_set_session(client.get(), session);
+
   BIO *bio1, *bio2;
   if (!BIO_new_bio_pair(&bio1, 0, &bio2, 0)) {
     return false;
@@ -1159,7 +1162,7 @@
 
   ScopedSSL client, server;
   if (!ConnectClientAndServer(&client, &server, client_ctx.get(),
-                              server_ctx.get())) {
+                              server_ctx.get(), nullptr /* no session */)) {
     return false;
   }
 
@@ -1228,7 +1231,7 @@
 
   ScopedSSL client, server;
   if (!ConnectClientAndServer(&client, &server, client_ctx.get(),
-                              server_ctx.get())) {
+                              server_ctx.get(), nullptr /* no session */)) {
     return false;
   }
 
@@ -1283,7 +1286,7 @@
 
   ScopedSSL client, server;
   if (!ConnectClientAndServer(&client, &server, client_ctx.get(),
-                              server_ctx.get())) {
+                              server_ctx.get(), nullptr /* no session */)) {
     return false;
   }
 
@@ -1494,7 +1497,8 @@
     SSL_CTX_set_cert_verify_callback(ctx.get(), VerifySucceed, NULL);
 
     ScopedSSL client, server;
-    if (!ConnectClientAndServer(&client, &server, ctx.get(), ctx.get())) {
+    if (!ConnectClientAndServer(&client, &server, ctx.get(), ctx.get(),
+                                nullptr /* no session */)) {
       return false;
     }
 
@@ -1561,7 +1565,8 @@
     SSL_CTX_set_retain_only_sha256_of_client_certs(ctx.get(), 1);
 
     ScopedSSL client, server;
-    if (!ConnectClientAndServer(&client, &server, ctx.get(), ctx.get())) {
+    if (!ConnectClientAndServer(&client, &server, ctx.get(), ctx.get(),
+                                nullptr /* no session */)) {
       return false;
     }
 
@@ -1712,6 +1717,129 @@
   return true;
 }
 
+static ScopedSSL_SESSION g_last_session;
+
+static int SaveLastSession(SSL *ssl, SSL_SESSION *session) {
+  // Save the most recent session.
+  g_last_session.reset(session);
+  return 1;
+}
+
+static ScopedSSL_SESSION CreateClientSession(SSL_CTX *client_ctx,
+                                             SSL_CTX *server_ctx) {
+  g_last_session = nullptr;
+  SSL_CTX_sess_set_new_cb(client_ctx, SaveLastSession);
+
+  // Connect client and server to get a session.
+  ScopedSSL client, server;
+  if (!ConnectClientAndServer(&client, &server, client_ctx, server_ctx,
+                              nullptr /* no session */)) {
+    fprintf(stderr, "Failed to connect client and server.\n");
+    return nullptr;
+  }
+
+  // Run the read loop to account for post-handshake tickets in TLS 1.3.
+  SSL_read(client.get(), nullptr, 0);
+
+  SSL_CTX_sess_set_new_cb(client_ctx, nullptr);
+
+  if (!g_last_session) {
+    fprintf(stderr, "Client did not receive a session.\n");
+    return nullptr;
+  }
+  return std::move(g_last_session);
+}
+
+static bool ExpectSessionReused(SSL_CTX *client_ctx, SSL_CTX *server_ctx,
+                                SSL_SESSION *session,
+                                bool reused) {
+  ScopedSSL client, server;
+  if (!ConnectClientAndServer(&client, &server, client_ctx,
+                              server_ctx, session)) {
+    fprintf(stderr, "Failed to connect client and server.\n");
+    return false;
+  }
+
+  if (SSL_session_reused(client.get()) != SSL_session_reused(server.get())) {
+    fprintf(stderr, "Client and server were inconsistent.\n");
+    return false;
+  }
+
+  bool was_reused = !!SSL_session_reused(client.get());
+  if (was_reused != reused) {
+    fprintf(stderr, "Session was%s reused, but we expected the opposite.\n",
+            was_reused ? "" : " not");
+    return false;
+  }
+
+  return true;
+}
+
+static bool TestSessionIDContext() {
+  ScopedX509 cert = GetTestCertificate();
+  ScopedEVP_PKEY key = GetTestKey();
+  if (!cert || !key) {
+    return false;
+  }
+
+  static const uint8_t kContext1[] = {1};
+  static const uint8_t kContext2[] = {2};
+
+  for (uint16_t version : kVersions) {
+    // TODO(davidben): Enable this when TLS 1.3 resumption is implemented.
+    if (version == TLS1_3_VERSION) {
+      continue;
+    }
+
+    ScopedSSL_CTX server_ctx(SSL_CTX_new(TLS_method()));
+    ScopedSSL_CTX client_ctx(SSL_CTX_new(TLS_method()));
+    if (!server_ctx || !client_ctx ||
+        !SSL_CTX_use_certificate(server_ctx.get(), cert.get()) ||
+        !SSL_CTX_use_PrivateKey(server_ctx.get(), key.get()) ||
+        !SSL_CTX_set_session_id_context(server_ctx.get(), kContext1,
+                                        sizeof(kContext1))) {
+      return false;
+    }
+
+    SSL_CTX_set_min_version(client_ctx.get(), version);
+    SSL_CTX_set_max_version(client_ctx.get(), version);
+    SSL_CTX_set_session_cache_mode(client_ctx.get(), SSL_SESS_CACHE_BOTH);
+
+    SSL_CTX_set_min_version(server_ctx.get(), version);
+    SSL_CTX_set_max_version(server_ctx.get(), version);
+    SSL_CTX_set_session_cache_mode(server_ctx.get(), SSL_SESS_CACHE_BOTH);
+
+    ScopedSSL_SESSION session =
+        CreateClientSession(client_ctx.get(), server_ctx.get());
+    if (!session) {
+      fprintf(stderr, "Error getting session (version = %04x).\n", version);
+      return false;
+    }
+
+    if (!ExpectSessionReused(client_ctx.get(), server_ctx.get(), session.get(),
+                             true /* expect session reused */)) {
+      fprintf(stderr, "Error resuming session (version = %04x).\n", version);
+      return false;
+    }
+
+    // Change the session ID context.
+    if (!SSL_CTX_set_session_id_context(server_ctx.get(), kContext2,
+                                        sizeof(kContext2))) {
+      return false;
+    }
+
+    if (!ExpectSessionReused(client_ctx.get(), server_ctx.get(), session.get(),
+                             false /* expect session not reused */)) {
+      fprintf(stderr,
+              "Error connection with different context (version = %04x).\n",
+              version);
+      return false;
+    }
+  }
+
+  return true;
+}
+
 int main() {
   CRYPTO_library_init();
 
@@ -1743,7 +1871,8 @@
       !TestSetBIO() ||
       !TestGetPeerCertificate() ||
       !TestRetainOnlySHA256OfCerts() ||
-      !TestClientHello()) {
+      !TestClientHello() ||
+      !TestSessionIDContext()) {
     ERR_print_errors_fp(stderr);
     return 1;
   }