Reduce bouncing on the cache lock in ssl_update_cache.

ssl_update_cache takes the cache lock to add to the session cache,
releases it, and then immediately takes and releases the lock to
increment handshakes_since_cache_flush. Then, in 1/255 connections, does
the same thing again to flush stale sessions.

Merge the first two into one lock. In doing so, move ssl_update_cache to
ssl_session.cc, so it can access a newly-extracted add_session_lock.
Also remove the mode parameter (the SSL knows if it's a client or
server), and move the established_session != session check to the
caller, which more directly knows whether there was a new session.

Also add some TSan coverage for this path in the tests. In an earlier
iteration of this patch, I managed to introduce a double-locking bug
because we weren't testing it at all. Confirmed this test catches both
double-locking and insufficient locking. (It doesn't seem able to catch
using a read lock instead of a write lock in SSL_CTX_flush_sessions,
however. I suspect the hash table is distributing the cells each thread
touches.)

Update-Note: This reshuffles some locks around the session cache.
(Hopefully for the better.)

Change-Id: I78dca53fda74e036b90110cca7fbcc306a5c8ebe
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/48133
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index 0fba737..d5ccafc 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -1801,10 +1801,8 @@
 
   // Note TLS 1.2 resumptions with ticket renewal have both |ssl->session| (the
   // resumed session) and |hs->new_session| (the session with the new ticket).
-  if (hs->new_session == nullptr) {
-    assert(ssl->session != nullptr);
-    ssl->s3->established_session = UpRef(ssl->session);
-  } else {
+  bool has_new_session = hs->new_session != nullptr;
+  if (has_new_session) {
     // When False Start is enabled, the handshake reports completion early. The
     // caller may then have passed the (then unresuable) |hs->new_session| to
     // another thread via |SSL_get0_session| for resumption. To avoid potential
@@ -1821,11 +1819,16 @@
     }
 
     hs->new_session.reset();
+  } else {
+    assert(ssl->session != nullptr);
+    ssl->s3->established_session = UpRef(ssl->session);
   }
 
   hs->handshake_finalized = true;
   ssl->s3->initial_handshake_complete = true;
-  ssl_update_cache(hs, SSL_SESS_CACHE_CLIENT);
+  if (has_new_session) {
+    ssl_update_cache(ssl);
+  }
 
   hs->state = state_done;
   return ssl_hs_ok;
diff --git a/ssl/handshake_server.cc b/ssl/handshake_server.cc
index ec5c8b9..74ac133 100644
--- a/ssl/handshake_server.cc
+++ b/ssl/handshake_server.cc
@@ -1791,18 +1791,21 @@
     ssl->ctx->x509_method->session_clear(hs->new_session.get());
   }
 
-  if (hs->new_session == nullptr) {
-    assert(ssl->session != nullptr);
-    ssl->s3->established_session = UpRef(ssl->session);
-  } else {
+  bool has_new_session = hs->new_session != nullptr;
+  if (has_new_session) {
     assert(ssl->session == nullptr);
     ssl->s3->established_session = std::move(hs->new_session);
     ssl->s3->established_session->not_resumable = false;
+  } else {
+    assert(ssl->session != nullptr);
+    ssl->s3->established_session = UpRef(ssl->session);
   }
 
   hs->handshake_finalized = true;
   ssl->s3->initial_handshake_complete = true;
-  ssl_update_cache(hs, SSL_SESS_CACHE_SERVER);
+  if (has_new_session) {
+    ssl_update_cache(ssl);
+  }
 
   hs->state = state12_done;
   return ssl_hs_ok;
diff --git a/ssl/internal.h b/ssl/internal.h
index ad199aa..ba1dde3 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -3147,7 +3147,7 @@
 void ssl_session_renew_timeout(SSL *ssl, SSL_SESSION *session,
                                uint32_t timeout);
 
-void ssl_update_cache(SSL_HANDSHAKE *hs, int mode);
+void ssl_update_cache(SSL *ssl);
 
 void ssl_send_alert(SSL *ssl, int level, int desc);
 int ssl_send_alert_impl(SSL *ssl, int level, int desc);
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index ae9883d..3c9fc90 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -272,50 +272,6 @@
   return ret;
 }
 
-void ssl_update_cache(SSL_HANDSHAKE *hs, int mode) {
-  SSL *const ssl = hs->ssl;
-  SSL_CTX *ctx = ssl->session_ctx.get();
-  if (!SSL_SESSION_is_resumable(ssl->s3->established_session.get()) ||
-      (ctx->session_cache_mode & mode) != mode) {
-    return;
-  }
-
-  // Clients never use the internal session cache.
-  int use_internal_cache = ssl->server && !(ctx->session_cache_mode &
-                                            SSL_SESS_CACHE_NO_INTERNAL_STORE);
-  if (ssl->s3->established_session.get() != ssl->session.get()) {
-    if (use_internal_cache) {
-      SSL_CTX_add_session(ctx, ssl->s3->established_session.get());
-    }
-    if (ctx->new_session_cb != NULL) {
-      UniquePtr<SSL_SESSION> ref = UpRef(ssl->s3->established_session);
-      if (ctx->new_session_cb(ssl, ref.get())) {
-        // |new_session_cb|'s return value signals whether it took ownership.
-        ref.release();
-      }
-    }
-  }
-
-  if (use_internal_cache &&
-      !(ctx->session_cache_mode & SSL_SESS_CACHE_NO_AUTO_CLEAR)) {
-    // Automatically flush the internal session cache every 255 connections.
-    int flush_cache = 0;
-    CRYPTO_MUTEX_lock_write(&ctx->lock);
-    ctx->handshakes_since_cache_flush++;
-    if (ctx->handshakes_since_cache_flush >= 255) {
-      flush_cache = 1;
-      ctx->handshakes_since_cache_flush = 0;
-    }
-    CRYPTO_MUTEX_unlock_write(&ctx->lock);
-
-    if (flush_cache) {
-      struct OPENSSL_timeval now;
-      ssl_get_current_time(ssl, &now);
-      SSL_CTX_flush_sessions(ctx, now.tv_sec);
-    }
-  }
-}
-
 static bool cbb_add_hex(CBB *cbb, Span<const uint8_t> in) {
   static const char hextable[] = "0123456789abcdef";
   uint8_t *out;
diff --git a/ssl/ssl_session.cc b/ssl/ssl_session.cc
index b0f0892..76e6dc4 100644
--- a/ssl/ssl_session.cc
+++ b/ssl/ssl_session.cc
@@ -163,7 +163,6 @@
 
 static void SSL_SESSION_list_remove(SSL_CTX *ctx, SSL_SESSION *session);
 static void SSL_SESSION_list_add(SSL_CTX *ctx, SSL_SESSION *session);
-static int remove_session_lock(SSL_CTX *ctx, SSL_SESSION *session, int lock);
 
 UniquePtr<SSL_SESSION> ssl_session_new(const SSL_X509_METHOD *x509_method) {
   return MakeUnique<SSL_SESSION>(x509_method);
@@ -754,34 +753,36 @@
   return ssl_hs_ok;
 }
 
-static int remove_session_lock(SSL_CTX *ctx, SSL_SESSION *session, int lock) {
-  int ret = 0;
-
-  if (session != NULL && session->session_id_length != 0) {
-    if (lock) {
-      CRYPTO_MUTEX_lock_write(&ctx->lock);
-    }
-    SSL_SESSION *found_session = lh_SSL_SESSION_retrieve(ctx->sessions,
-                                                         session);
-    if (found_session == session) {
-      ret = 1;
-      found_session = lh_SSL_SESSION_delete(ctx->sessions, session);
-      SSL_SESSION_list_remove(ctx, session);
-    }
-
-    if (lock) {
-      CRYPTO_MUTEX_unlock_write(&ctx->lock);
-    }
-
-    if (ret) {
-      if (ctx->remove_session_cb != NULL) {
-        ctx->remove_session_cb(ctx, found_session);
-      }
-      SSL_SESSION_free(found_session);
-    }
+static bool remove_session(SSL_CTX *ctx, SSL_SESSION *session, bool lock) {
+  if (session == nullptr || session->session_id_length == 0) {
+    return false;
   }
 
-  return ret;
+  if (lock) {
+    CRYPTO_MUTEX_lock_write(&ctx->lock);
+  }
+
+  SSL_SESSION *found_session = lh_SSL_SESSION_retrieve(ctx->sessions, session);
+  bool found = found_session == session;
+  if (found) {
+    found_session = lh_SSL_SESSION_delete(ctx->sessions, session);
+    SSL_SESSION_list_remove(ctx, session);
+  }
+
+  if (lock) {
+    CRYPTO_MUTEX_unlock_write(&ctx->lock);
+  }
+
+  if (found) {
+    // TODO(https://crbug.com/boringssl/251): Callbacks should not be called
+    // under a lock.
+    if (ctx->remove_session_cb != nullptr) {
+      ctx->remove_session_cb(ctx, found_session);
+    }
+    SSL_SESSION_free(found_session);
+  }
+
+  return found;
 }
 
 void ssl_set_session(SSL *ssl, SSL_SESSION *session) {
@@ -839,6 +840,98 @@
   }
 }
 
+static bool add_session_locked(SSL_CTX *ctx, UniquePtr<SSL_SESSION> session) {
+  SSL_SESSION *new_session = session.get();
+  SSL_SESSION *old_session;
+  if (!lh_SSL_SESSION_insert(ctx->sessions, &old_session, new_session)) {
+    return false;
+  }
+  // |ctx->sessions| took ownership of |new_session| and gave us back a
+  // reference to |old_session|. (|old_session| may be the same as
+  // |new_session|, in which case we traded identical references with
+  // |ctx->sessions|.)
+  session.release();
+  session.reset(old_session);
+
+  if (old_session != nullptr) {
+    if (old_session == new_session) {
+      // |session| was already in the cache. There are no linked list pointers
+      // to update.
+      return false;
+    }
+
+    // There was a session ID collision. |old_session| was replaced with
+    // |session| in the hash table, so |old_session| must be removed from the
+    // linked list to match.
+    SSL_SESSION_list_remove(ctx, old_session);
+  }
+
+  // This does not increment the reference count. Although |session| is inserted
+  // into two structures (a doubly-linked list and the hash table), |ctx| only
+  // takes one reference.
+  SSL_SESSION_list_add(ctx, new_session);
+
+  // Enforce any cache size limits.
+  if (SSL_CTX_sess_get_cache_size(ctx) > 0) {
+    while (lh_SSL_SESSION_num_items(ctx->sessions) >
+           SSL_CTX_sess_get_cache_size(ctx)) {
+      if (!remove_session(ctx, ctx->session_cache_tail,
+                          /*lock=*/false)) {
+        break;
+      }
+    }
+  }
+
+  return true;
+}
+
+void ssl_update_cache(SSL *ssl) {
+  SSL_CTX *ctx = ssl->session_ctx.get();
+  SSL_SESSION *session = ssl->s3->established_session.get();
+  int mode = SSL_is_server(ssl) ? SSL_SESS_CACHE_SERVER : SSL_SESS_CACHE_CLIENT;
+  if (!SSL_SESSION_is_resumable(session) ||
+      (ctx->session_cache_mode & mode) != mode) {
+    return;
+  }
+
+  // Clients never use the internal session cache.
+  if (ssl->server &&
+      !(ctx->session_cache_mode & SSL_SESS_CACHE_NO_INTERNAL_STORE)) {
+    UniquePtr<SSL_SESSION> ref = UpRef(session);
+    bool remove_expired_sessions = false;
+    {
+      MutexWriteLock lock(&ctx->lock);
+      add_session_locked(ctx, std::move(ref));
+
+      if (!(ctx->session_cache_mode & SSL_SESS_CACHE_NO_AUTO_CLEAR)) {
+        // Automatically flush the internal session cache every 255 connections.
+        ctx->handshakes_since_cache_flush++;
+        if (ctx->handshakes_since_cache_flush >= 255) {
+          remove_expired_sessions = true;
+          ctx->handshakes_since_cache_flush = 0;
+        }
+      }
+    }
+
+    if (remove_expired_sessions) {
+      // |SSL_CTX_flush_sessions| takes the lock we just released. We could
+      // merge the critical sections, but we'd then call user code under a
+      // lock, or compute |now| earlier, even when not flushing.
+      OPENSSL_timeval now;
+      ssl_get_current_time(ssl, &now);
+      SSL_CTX_flush_sessions(ctx, now.tv_sec);
+    }
+  }
+
+  if (ctx->new_session_cb != nullptr) {
+    UniquePtr<SSL_SESSION> ref = UpRef(session);
+    if (ctx->new_session_cb(ssl, ref.get())) {
+      // |new_session_cb|'s return value signals whether it took ownership.
+      ref.release();
+    }
+  }
+}
+
 BSSL_NAMESPACE_END
 
 using namespace bssl;
@@ -1121,51 +1214,13 @@
 }
 
 int SSL_CTX_add_session(SSL_CTX *ctx, SSL_SESSION *session) {
-  // Although |session| is inserted into two structures (a doubly-linked list
-  // and the hash table), |ctx| only takes one reference.
   UniquePtr<SSL_SESSION> owned_session = UpRef(session);
-
-  SSL_SESSION *old_session;
   MutexWriteLock lock(&ctx->lock);
-  if (!lh_SSL_SESSION_insert(ctx->sessions, &old_session, session)) {
-    return 0;
-  }
-  // |ctx->sessions| took ownership of |session| and gave us back a reference to
-  // |old_session|. (|old_session| may be the same as |session|, in which case
-  // we traded identical references with |ctx->sessions|.)
-  owned_session.release();
-  owned_session.reset(old_session);
-
-  if (old_session != NULL) {
-    if (old_session == session) {
-      // |session| was already in the cache. There are no linked list pointers
-      // to update.
-      return 0;
-    }
-
-    // There was a session ID collision. |old_session| was replaced with
-    // |session| in the hash table, so |old_session| must be removed from the
-    // linked list to match.
-    SSL_SESSION_list_remove(ctx, old_session);
-  }
-
-  SSL_SESSION_list_add(ctx, session);
-
-  // Enforce any cache size limits.
-  if (SSL_CTX_sess_get_cache_size(ctx) > 0) {
-    while (lh_SSL_SESSION_num_items(ctx->sessions) >
-           SSL_CTX_sess_get_cache_size(ctx)) {
-      if (!remove_session_lock(ctx, ctx->session_cache_tail, 0)) {
-        break;
-      }
-    }
-  }
-
-  return 1;
+  return add_session_locked(ctx, std::move(owned_session));
 }
 
 int SSL_CTX_remove_session(SSL_CTX *ctx, SSL_SESSION *session) {
-  return remove_session_lock(ctx, session, 1);
+  return remove_session(ctx, session, /*lock=*/true);
 }
 
 int SSL_set_session(SSL *ssl, SSL_SESSION *session) {
@@ -1219,10 +1274,11 @@
   if (param->time == 0 ||
       session->time + session->timeout < session->time ||
       param->time > (session->time + session->timeout)) {
-    // The reason we don't call SSL_CTX_remove_session() is to
-    // save on locking overhead
+    // TODO(davidben): This can probably just call |remove_session|.
     (void) lh_SSL_SESSION_delete(param->cache, session);
     SSL_SESSION_list_remove(param->ctx, session);
+    // TODO(https://crbug.com/boringssl/251): Callbacks should not be called
+    // under a lock.
     if (param->ctx->remove_session_cb != NULL) {
       param->ctx->remove_session_cb(param->ctx, session);
     }
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index d6edcfe..4bb9c32 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -5427,7 +5427,8 @@
     }
   }
 
-  // Hit the maximum session cache size across multiple threads
+  // Hit the maximum session cache size across multiple threads, to test the
+  // size enforcement logic.
   size_t limit = SSL_CTX_sess_number(server_ctx_.get()) + 2;
   SSL_CTX_sess_set_cache_size(server_ctx_.get(), limit);
   {
@@ -5443,6 +5444,59 @@
     }
     EXPECT_EQ(SSL_CTX_sess_number(server_ctx_.get()), limit);
   }
+
+  // Reset the session cache, this time with a mock clock.
+  ASSERT_NO_FATAL_FAILURE(ResetContexts());
+  SSL_CTX_set_options(server_ctx_.get(), SSL_OP_NO_TICKET);
+  SSL_CTX_set_session_cache_mode(client_ctx_.get(), SSL_SESS_CACHE_BOTH);
+  SSL_CTX_set_session_cache_mode(server_ctx_.get(), SSL_SESS_CACHE_BOTH);
+  SSL_CTX_set_current_time_cb(server_ctx_.get(), CurrentTimeCallback);
+
+  // Make some sessions at an arbitrary start time. Then expire them.
+  g_current_time.tv_sec = 1000;
+  bssl::UniquePtr<SSL_SESSION> expired_session1 =
+      CreateClientSession(client_ctx_.get(), server_ctx_.get());
+  ASSERT_TRUE(expired_session1);
+  bssl::UniquePtr<SSL_SESSION> expired_session2 =
+      CreateClientSession(client_ctx_.get(), server_ctx_.get());
+  ASSERT_TRUE(expired_session2);
+  g_current_time.tv_sec += 100 * SSL_DEFAULT_SESSION_TIMEOUT;
+
+  session1 = CreateClientSession(client_ctx_.get(), server_ctx_.get());
+  ASSERT_TRUE(session1);
+
+  // Every 256 connections, we flush stale sessions from the session cache. Test
+  // this logic is correctly synchronized with other connection attempts.
+  static const int kNumConnections = 256;
+  {
+    std::vector<std::thread> threads;
+    threads.emplace_back([&] {
+      for (int i = 0; i < kNumConnections; i++) {
+        connect_with_session(nullptr);
+      }
+    });
+    threads.emplace_back([&] {
+      for (int i = 0; i < kNumConnections; i++) {
+        connect_with_session(nullptr);
+      }
+    });
+    threads.emplace_back([&] {
+      // Never connect with |expired_session2|. The session cache eagerly
+      // removes expired sessions when it sees them. Leaving |expired_session2|
+      // untouched ensures it is instead cleared by periodic flushing.
+      for (int i = 0; i < kNumConnections; i++) {
+        connect_with_session(expired_session1.get());
+      }
+    });
+    threads.emplace_back([&] {
+      for (int i = 0; i < kNumConnections; i++) {
+        connect_with_session(session1.get());
+      }
+    });
+    for (auto &thread : threads) {
+      thread.join();
+    }
+  }
 }
 
 TEST_P(SSLVersionTest, SessionTicketThreads) {