Use thread-local storage for PRNG states if fork-unsafe buffering is enabled.

We switched from thread-local storage to a mutex-pool in 82639e6f53
because, for highly-threaded processes, the memory used by all the
states could be quite large. I had judged that a mutex-pool should be
fine, but had underestimated the PRNG requirements of some of our jobs.

This change makes rand.c support using either thread-locals or a
mutex-pool. Thread-locals are used if fork-unsafe buffering is enabled.
While not strictly related to fork-safety, we already have the
fork-unsafe control, and it's already set by jobs that care a lot about
PRNG performance, so fits quite nicely here.

Change-Id: Iaf1e0171c70d4c8dbe1e42283ea13df5b613cb2d
Reviewed-on: https://boringssl-review.googlesource.com/c/31564
Commit-Queue: Adam Langley <agl@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/crypto/fipsmodule/rand/rand.c b/crypto/fipsmodule/rand/rand.c
index e6b4bb4..bff2471 100644
--- a/crypto/fipsmodule/rand/rand.c
+++ b/crypto/fipsmodule/rand/rand.c
@@ -103,10 +103,14 @@
 
 #endif
 
-// rand_state contains an RNG state.
+// rand_state contains an RNG state. State object are managed in one of two
+// ways, depending on whether |RAND_enable_fork_unsafe_buffering| has been
+// called: if it has been called then thread-local storage is used to keep a
+// per-thread state. Otherwise a mutex-protected pool of state objects is used.
 struct rand_state {
   CTR_DRBG_STATE drbg;
-  // next forms a NULL-terminated linked-list of all free |rand_state| objects.
+  // next forms a NULL-terminated linked-list of all free |rand_state| objects
+  // in a pool. This is unused if using thread-local states.
   struct rand_state *next;
   // calls is the number of generate calls made on |drbg| since it was last
   // (re)seeded. This is bound by
@@ -114,10 +118,10 @@
   size_t calls;
 
 #if defined(BORINGSSL_FIPS)
-  // next_all forms another NULL-terminated linked-list, this time of all
-  // |rand_state| objects that have been allocated including those that might
-  // currently be in use.
-  struct rand_state *next_all;
+  // prev_all and next_all form another NULL-terminated linked-list, this time
+  // of all |rand_state| objects that have been allocated including those that
+  // might currently be in use.
+  struct rand_state *prev_all, *next_all;
   // last_block contains the previous block from |CRYPTO_sysrand|.
   uint8_t last_block[CRNGT_BLOCK_SIZE];
   // last_block_valid is non-zero iff |last_block| contains data from
@@ -187,14 +191,14 @@
 #endif
 
 // rand_state_free_list is a list of currently free, |rand_state| structures.
-// When a thread needs a |rand_state| it picks the head element of this list and
-// allocs a new one if the list is empty. Once it's finished, it pushes the
-// state back onto the front of the list.
+// (It is only used if a mutex-pool is being used to manage |rand_state|
+// objects.) When a thread needs a |rand_state| it picks the head element of
+// this list and allocs a new one if the list is empty. Once it's finished, it
+// pushes the state back onto the front of the list.
 //
-// Previously we used a thread-local state but for processes with large numbers
-// of threads this can result in excessive memory usage. Since we don't free
-// |rand_state| objects, the number of objects in memory will eventually equal
-// the maximum concurrency of |RAND_bytes|.
+// Since we don't free |rand_state| objects, the number of objects in memory
+// will eventually equal the maximum concurrency of |RAND_bytes| in the
+// mutex-pool model.
 DEFINE_BSS_GET(struct rand_state *, rand_state_free_list);
 
 // rand_state_lock protects |rand_state_free_list| (and |rand_state_all_list|,
@@ -226,6 +230,33 @@
 }
 #endif
 
+// rand_state_free frees a |rand_state|. This is called when a thread exits if
+// we're using thread-local states.
+static void rand_state_free(void *state_in) {
+  struct rand_state *state = state_in;
+  if (state_in == NULL) {
+    return;
+  }
+
+#if defined(BORINGSSL_FIPS)
+  CRYPTO_STATIC_MUTEX_lock_write(rand_state_lock_bss_get());
+  if (state->prev_all != NULL) {
+    state->prev_all->next_all = state->next_all;
+  } else {
+    *rand_state_all_list_bss_get() = state->next_all;
+  }
+
+  if (state->next_all != NULL) {
+    state->next_all->prev_all = state->prev_all;
+  }
+  CRYPTO_STATIC_MUTEX_unlock_write(rand_state_lock_bss_get());
+
+  CTR_DRBG_clear(&state->drbg);
+#endif
+
+  OPENSSL_free(state);
+}
+
 // rand_state_init seeds a |rand_state|.
 static void rand_state_init(struct rand_state *state) {
   OPENSSL_memset(state, 0, sizeof(struct rand_state));
@@ -236,17 +267,30 @@
   }
 }
 
-// rand_state_get pops a |rand_state| from the head of
+// rand_state_get returns a usable |rand_state|, or NULL if memory is exhausted.
+//
+// If a pool is being used, it pops a |rand_state| from the head of
 // |rand_state_free_list| and returns it. If the list is empty, it
 // creates a fresh |rand_state| and returns that instead.
-static struct rand_state *rand_state_get(void) {
+//
+// Alternatively, if thread-local states are being used, it returns the current
+// thread's state object and creates it if needed.
+static struct rand_state *rand_state_get(const int fork_unsafe_buffering) {
   struct rand_state *state = NULL;
-  CRYPTO_STATIC_MUTEX_lock_write(rand_state_lock_bss_get());
-  state = *rand_state_free_list_bss_get();
-  if (state != NULL) {
-    *rand_state_free_list_bss_get() = state->next;
+  if (fork_unsafe_buffering) {
+    // Thread-local storage is used in this case. This is unrelated to fork-
+    // safety and we are overloading this global control to also identify
+    // processes that really care about PRNG speed.
+    state = CRYPTO_get_thread_local(OPENSSL_THREAD_LOCAL_RAND);
+  } else {
+    // Otherwise a mutex-protected pool of states is used.
+    CRYPTO_STATIC_MUTEX_lock_write(rand_state_lock_bss_get());
+    state = *rand_state_free_list_bss_get();
+    if (state != NULL) {
+      *rand_state_free_list_bss_get() = state->next;
+    }
+    CRYPTO_STATIC_MUTEX_unlock_write(rand_state_lock_bss_get());
   }
-  CRYPTO_STATIC_MUTEX_unlock_write(rand_state_lock_bss_get());
 
   if (state != NULL) {
     return state;
@@ -262,14 +306,25 @@
 #if defined(BORINGSSL_FIPS)
   CRYPTO_STATIC_MUTEX_lock_write(rand_state_lock_bss_get());
   state->next_all = *rand_state_all_list_bss_get();
+  if (state->next_all) {
+    state->next_all->prev_all = state;
+  }
   *rand_state_all_list_bss_get() = state;
   CRYPTO_STATIC_MUTEX_unlock_write(rand_state_lock_bss_get());
 #endif
 
+  if (fork_unsafe_buffering &&
+      !CRYPTO_set_thread_local(OPENSSL_THREAD_LOCAL_RAND, state,
+                               rand_state_free)) {
+    rand_state_free(state);
+    return NULL;
+  }
+
   return state;
 }
 
-// rand_state_put pushes |state| onto |rand_state_free_list|.
+// rand_state_put pushes |state| onto |rand_state_free_list| if the pool is
+// being used. May only be called if the pool is being used.
 static void rand_state_put(struct rand_state *state) {
   CRYPTO_STATIC_MUTEX_lock_write(rand_state_lock_bss_get());
   state->next = *rand_state_free_list_bss_get();
@@ -283,6 +338,8 @@
     return;
   }
 
+  const int fork_unsafe_buffering = rand_fork_unsafe_buffering_enabled();
+
   // Additional data is mixed into every CTR-DRBG call to protect, as best we
   // can, against forks & VM clones. We do not over-read this information and
   // don't reseed with it so, from the point of view of FIPS, this doesn't
@@ -293,7 +350,7 @@
     // entropy is used. This can be expensive (one read per |RAND_bytes| call)
     // and so can be disabled by applications that we have ensured don't fork
     // and aren't at risk of VM cloning.
-    if (!rand_fork_unsafe_buffering_enabled()) {
+    if (!fork_unsafe_buffering) {
       CRYPTO_sysrand(additional_data, sizeof(additional_data));
     } else {
       OPENSSL_memset(additional_data, 0, sizeof(additional_data));
@@ -305,7 +362,7 @@
   }
 
   struct rand_state stack_state;
-  struct rand_state *state = rand_state_get();
+  struct rand_state *state = rand_state_get(fork_unsafe_buffering);
 
   if (state == NULL) {
     // If the system is out of memory, use an ephemeral state on the
@@ -366,7 +423,7 @@
   CRYPTO_STATIC_MUTEX_unlock_read(rand_drbg_lock_bss_get());
 #endif
 
-  if (state != &stack_state) {
+  if (!fork_unsafe_buffering && state != &stack_state) {
     rand_state_put(state);
   }
 }
diff --git a/crypto/internal.h b/crypto/internal.h
index 52799e8..3d7b5c1 100644
--- a/crypto/internal.h
+++ b/crypto/internal.h
@@ -553,6 +553,7 @@
 // stored.
 typedef enum {
   OPENSSL_THREAD_LOCAL_ERR = 0,
+  OPENSSL_THREAD_LOCAL_RAND,
   OPENSSL_THREAD_LOCAL_TEST,
   NUM_OPENSSL_THREAD_LOCALS,
 } thread_local_data_t;
diff --git a/crypto/rand_extra/forkunsafe.c b/crypto/rand_extra/forkunsafe.c
index 0f1ecec..27921f0 100644
--- a/crypto/rand_extra/forkunsafe.c
+++ b/crypto/rand_extra/forkunsafe.c
@@ -21,9 +21,12 @@
 
 // g_buffering_enabled is true if fork-unsafe buffering has been enabled.
 static int g_buffering_enabled = 0;
+static CRYPTO_once_t g_buffering_enabled_once = CRYPTO_ONCE_INIT;
+static int g_buffering_enabled_pending = 0;
 
-// g_lock protects |g_buffering_enabled|.
-static struct CRYPTO_STATIC_MUTEX g_lock = CRYPTO_STATIC_MUTEX_INIT;
+static void g_buffer_enabled_init(void) {
+  g_buffering_enabled = g_buffering_enabled_pending;
+}
 
 #if !defined(OPENSSL_WINDOWS)
 void RAND_enable_fork_unsafe_buffering(int fd) {
@@ -32,15 +35,16 @@
     abort();
   }
 
-  CRYPTO_STATIC_MUTEX_lock_write(&g_lock);
-  g_buffering_enabled = 1;
-  CRYPTO_STATIC_MUTEX_unlock_write(&g_lock);
+  g_buffering_enabled_pending = 1;
+  CRYPTO_once(&g_buffering_enabled_once, g_buffer_enabled_init);
+  if (g_buffering_enabled != 1) {
+    // RAND_bytes has been called before this function.
+    abort();
+  }
 }
 #endif
 
 int rand_fork_unsafe_buffering_enabled(void) {
-  CRYPTO_STATIC_MUTEX_lock_read(&g_lock);
-  const int ret = g_buffering_enabled;
-  CRYPTO_STATIC_MUTEX_unlock_read(&g_lock);
-  return ret;
+  CRYPTO_once(&g_buffering_enabled_once, g_buffer_enabled_init);
+  return g_buffering_enabled;
 }
diff --git a/include/openssl/rand.h b/include/openssl/rand.h
index 5d02e12..c6527b3 100644
--- a/include/openssl/rand.h
+++ b/include/openssl/rand.h
@@ -57,6 +57,10 @@
 // ownership of |fd|. If |fd| is negative then /dev/urandom will be opened and
 // any error from open(2) crashes the address space.
 //
+// Setting this also enables thread-local PRNG state, which can reduce lock
+// contention in highly-threaded applications although at the cost of yet more
+// memory.
+//
 // It has an unusual name because the buffer is unsafe across calls to |fork|.
 // Hence, this function should never be called by libraries.
 OPENSSL_EXPORT void RAND_enable_fork_unsafe_buffering(int fd);