Fix double-free under load.

The BN_BLINDING cache, when 1024 threads are performing concurrent
private operations on one RSA key, can race to append a BN_BLINDING to a
cache which is just short of the maximum length. The cache ends up one
(or more) elements longer than the maximum length. That causes the index
of one of the cache elements to _be_ the supposed maximum length, but
that index is treated as a magic number that indicates that a
BN_BLINDING isn't from the cache and thus needs to be freed after use.
That BN_BLINDING is then double-freed when the cache itself is freed.

See internal bug b/147126942.

Since the fact that someone hit this means that 1024 threads working on
a single RSA key is a thing that's happening, take the opportunity to
grow the cache by doubling rather than by single elements at a time.
Once the number of extensions is so reduced, the trick of unlocking to
keep a few allocations outside of the lock (which caused the problem)
can be discarded.

Change-Id: I32dd16d825b702b31ee9b776414c4e6afe883724
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/39324
Reviewed-by: Adam Langley <agl@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/rsa/rsa_impl.c b/crypto/fipsmodule/rsa/rsa_impl.c
index 74bf098..6fb8ce7 100644
--- a/crypto/fipsmodule/rsa/rsa_impl.c
+++ b/crypto/fipsmodule/rsa/rsa_impl.c
@@ -360,80 +360,84 @@
   assert(rsa->mont_n != NULL);
 
   BN_BLINDING *ret = NULL;
-  BN_BLINDING **new_blindings;
-  uint8_t *new_blindings_inuse;
-  char overflow = 0;
-
   CRYPTO_MUTEX_lock_write(&rsa->lock);
 
-  unsigned i;
-  for (i = 0; i < rsa->num_blindings; i++) {
-    if (rsa->blindings_inuse[i] == 0) {
-      rsa->blindings_inuse[i] = 1;
-      ret = rsa->blindings[i];
-      *index_used = i;
-      break;
+  uint8_t *const free_inuse_flag =
+      memchr(rsa->blindings_inuse, 0, rsa->num_blindings);
+  if (free_inuse_flag != NULL) {
+    *free_inuse_flag = 1;
+    *index_used = free_inuse_flag - rsa->blindings_inuse;
+    ret = rsa->blindings[*index_used];
+    goto out;
+  }
+
+  if (rsa->num_blindings >= MAX_BLINDINGS_PER_RSA) {
+    // No |BN_BLINDING| is free and nor can the cache be extended. This index
+    // value is magic and indicates to |rsa_blinding_release| that a
+    // |BN_BLINDING| was not inserted into the array.
+    *index_used = MAX_BLINDINGS_PER_RSA;
+    ret = BN_BLINDING_new();
+    goto out;
+  }
+
+  // Double the length of the cache.
+  OPENSSL_STATIC_ASSERT(MAX_BLINDINGS_PER_RSA < UINT_MAX / 2,
+                        "MAX_BLINDINGS_PER_RSA too large");
+  unsigned new_num_blindings = rsa->num_blindings * 2;
+  if (new_num_blindings == 0) {
+    new_num_blindings = 1;
+  }
+  if (new_num_blindings > MAX_BLINDINGS_PER_RSA) {
+    new_num_blindings = MAX_BLINDINGS_PER_RSA;
+  }
+  assert(new_num_blindings > rsa->num_blindings);
+
+  OPENSSL_STATIC_ASSERT(
+      MAX_BLINDINGS_PER_RSA < UINT_MAX / sizeof(BN_BLINDING *),
+      "MAX_BLINDINGS_PER_RSA too large");
+  BN_BLINDING **new_blindings =
+      OPENSSL_malloc(sizeof(BN_BLINDING *) * new_num_blindings);
+  uint8_t *new_blindings_inuse = OPENSSL_malloc(new_num_blindings);
+  if (new_blindings == NULL || new_blindings_inuse == NULL) {
+    goto err;
+  }
+
+  OPENSSL_memcpy(new_blindings, rsa->blindings,
+                 sizeof(BN_BLINDING *) * rsa->num_blindings);
+  OPENSSL_memcpy(new_blindings_inuse, rsa->blindings_inuse, rsa->num_blindings);
+
+  for (unsigned i = rsa->num_blindings; i < new_num_blindings; i++) {
+    new_blindings[i] = BN_BLINDING_new();
+    if (new_blindings[i] == NULL) {
+      for (unsigned j = rsa->num_blindings; j < i; j++) {
+        BN_BLINDING_free(new_blindings[j]);
+      }
+      goto err;
     }
   }
+  memset(&new_blindings_inuse[rsa->num_blindings], 0,
+         new_num_blindings - rsa->num_blindings);
 
-  if (ret != NULL) {
-    CRYPTO_MUTEX_unlock_write(&rsa->lock);
-    return ret;
-  }
-
-  overflow = rsa->num_blindings >= MAX_BLINDINGS_PER_RSA;
-
-  // We didn't find a free BN_BLINDING to use so increase the length of
-  // the arrays by one and use the newly created element.
-
-  CRYPTO_MUTEX_unlock_write(&rsa->lock);
-  ret = BN_BLINDING_new();
-  if (ret == NULL) {
-    return NULL;
-  }
-
-  if (overflow) {
-    // We cannot add any more cached BN_BLINDINGs so we use |ret|
-    // and mark it for destruction in |rsa_blinding_release|.
-    *index_used = MAX_BLINDINGS_PER_RSA;
-    return ret;
-  }
-
-  CRYPTO_MUTEX_lock_write(&rsa->lock);
-
-  new_blindings =
-      OPENSSL_malloc(sizeof(BN_BLINDING *) * (rsa->num_blindings + 1));
-  if (new_blindings == NULL) {
-    goto err1;
-  }
-  OPENSSL_memcpy(new_blindings, rsa->blindings,
-         sizeof(BN_BLINDING *) * rsa->num_blindings);
-  new_blindings[rsa->num_blindings] = ret;
-
-  new_blindings_inuse = OPENSSL_malloc(rsa->num_blindings + 1);
-  if (new_blindings_inuse == NULL) {
-    goto err2;
-  }
-  OPENSSL_memcpy(new_blindings_inuse, rsa->blindings_inuse, rsa->num_blindings);
   new_blindings_inuse[rsa->num_blindings] = 1;
   *index_used = rsa->num_blindings;
+  assert(*index_used != MAX_BLINDINGS_PER_RSA);
+  ret = new_blindings[rsa->num_blindings];
 
   OPENSSL_free(rsa->blindings);
   rsa->blindings = new_blindings;
   OPENSSL_free(rsa->blindings_inuse);
   rsa->blindings_inuse = new_blindings_inuse;
-  rsa->num_blindings++;
+  rsa->num_blindings = new_num_blindings;
 
-  CRYPTO_MUTEX_unlock_write(&rsa->lock);
-  return ret;
+  goto out;
 
-err2:
+err:
+  OPENSSL_free(new_blindings_inuse);
   OPENSSL_free(new_blindings);
 
-err1:
+out:
   CRYPTO_MUTEX_unlock_write(&rsa->lock);
-  BN_BLINDING_free(ret);
-  return NULL;
+  return ret;
 }
 
 // rsa_blinding_release marks the cached BN_BLINDING at the given index as free
diff --git a/crypto/rsa_extra/rsa_test.cc b/crypto/rsa_extra/rsa_test.cc
index 0fe0351..4218cdb 100644
--- a/crypto/rsa_extra/rsa_test.cc
+++ b/crypto/rsa_extra/rsa_test.cc
@@ -1117,4 +1117,38 @@
     thread.join();
   }
 }
-#endif
+
+#if defined(OPENSSL_X86_64)
+// This test might be excessively slow on slower CPUs.
+TEST(RSATest, BlindingCacheConcurrency) {
+  bssl::UniquePtr<RSA> rsa(
+      RSA_private_key_from_bytes(kKey1, sizeof(kKey1) - 1));
+  ASSERT_TRUE(rsa);
+
+  constexpr size_t kSignaturesPerThread = 100;
+  constexpr size_t kNumThreads = 2048;
+
+  const uint8_t kDummyHash[32] = {0};
+  auto worker = [&] {
+    uint8_t sig[256];
+    ASSERT_LE(RSA_size(rsa.get()), sizeof(sig));
+
+    for (size_t i = 0; i < kSignaturesPerThread; i++) {
+      unsigned sig_len = sizeof(sig);
+      EXPECT_TRUE(RSA_sign(NID_sha256, kDummyHash, sizeof(kDummyHash), sig,
+                           &sig_len, rsa.get()));
+    }
+  };
+
+  std::vector<std::thread> threads;
+  threads.reserve(kNumThreads);
+  for (size_t i = 0; i < kNumThreads; i++) {
+    threads.emplace_back(worker);
+  }
+  for (auto &thread : threads) {
+    thread.join();
+  }
+}
+#endif  // X86_64
+
+#endif  // THREADS