Add lh_FOO_retrieve_key to avoid stack-allocating SSL_SESSION.

lh_FOO_retrieve is often called with a dummy instance of FOO that has
only a few fields filled in. This works fine for C, but a C++
SSL_SESSION with destructors is a bit more of a nuisance here.

Instead, teach LHASH to allow queries by some external key type. This
avoids stack-allocating SSL_SESSION. Along the way, fix the
make_macros.sh script.

Change-Id: Ie0b482d4ffe1027049d49db63274c7c17f9398fa
Reviewed-on: https://boringssl-review.googlesource.com/29586
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/lhash/lhash.c b/crypto/lhash/lhash.c
index 7bfd289..e4fc3fd 100644
--- a/crypto/lhash/lhash.c
+++ b/crypto/lhash/lhash.c
@@ -141,14 +141,12 @@
 static LHASH_ITEM **get_next_ptr_and_hash(const _LHASH *lh, uint32_t *out_hash,
                                           const void *data) {
   const uint32_t hash = lh->hash(data);
-  LHASH_ITEM *cur, **ret;
-
   if (out_hash != NULL) {
     *out_hash = hash;
   }
 
-  ret = &lh->buckets[hash % lh->num_buckets];
-  for (cur = *ret; cur != NULL; cur = *ret) {
+  LHASH_ITEM **ret = &lh->buckets[hash % lh->num_buckets];
+  for (LHASH_ITEM *cur = *ret; cur != NULL; cur = *ret) {
     if (lh->comp(cur->data, data) == 0) {
       break;
     }
@@ -158,16 +156,32 @@
   return ret;
 }
 
-void *lh_retrieve(const _LHASH *lh, const void *data) {
-  LHASH_ITEM **next_ptr;
-
-  next_ptr = get_next_ptr_and_hash(lh, NULL, data);
-
-  if (*next_ptr == NULL) {
-    return NULL;
+// get_next_ptr_by_key behaves like |get_next_ptr_and_hash| but takes a key
+// which may be a different type from the values stored in |lh|.
+static LHASH_ITEM **get_next_ptr_by_key(const _LHASH *lh, const void *key,
+                                        uint32_t key_hash,
+                                        int (*cmp_key)(const void *key,
+                                                       const void *value)) {
+  LHASH_ITEM **ret = &lh->buckets[key_hash % lh->num_buckets];
+  for (LHASH_ITEM *cur = *ret; cur != NULL; cur = *ret) {
+    if (cmp_key(key, cur->data) == 0) {
+      break;
+    }
+    ret = &cur->next;
   }
 
-  return (*next_ptr)->data;
+  return ret;
+}
+
+void *lh_retrieve(const _LHASH *lh, const void *data) {
+  LHASH_ITEM **next_ptr = get_next_ptr_and_hash(lh, NULL, data);
+  return *next_ptr == NULL ? NULL : (*next_ptr)->data;
+}
+
+void *lh_retrieve_key(const _LHASH *lh, const void *key, uint32_t key_hash,
+                      int (*cmp_key)(const void *key, const void *value)) {
+  LHASH_ITEM **next_ptr = get_next_ptr_by_key(lh, key, key_hash, cmp_key);
+  return *next_ptr == NULL ? NULL : (*next_ptr)->data;
 }
 
 // lh_rebucket allocates a new array of |new_num_buckets| pointers and
diff --git a/crypto/lhash/lhash_test.cc b/crypto/lhash/lhash_test.cc
index 0859eeb..a2f61f6 100644
--- a/crypto/lhash/lhash_test.cc
+++ b/crypto/lhash/lhash_test.cc
@@ -103,6 +103,17 @@
         std::unique_ptr<char[]> key = RandString();
         void *value = lh_retrieve(lh.get(), key.get());
         EXPECT_EQ(Lookup(&dummy_lh, key.get()), value);
+
+        // Do the same lookup with |lh_retrieve_key|.
+        value = lh_retrieve_key(
+            lh.get(), &key, lh_strhash(key.get()),
+            [](const void *key_ptr, const void *data) -> int {
+              const char *key_data =
+                  reinterpret_cast<const std::unique_ptr<char[]> *>(key_ptr)
+                      ->get();
+              return strcmp(key_data, reinterpret_cast<const char *>(data));
+            });
+        EXPECT_EQ(Lookup(&dummy_lh, key.get()), value);
         break;
       }
 
diff --git a/crypto/lhash/make_macros.sh b/crypto/lhash/make_macros.sh
index 8a876af..1418539 100644
--- a/crypto/lhash/make_macros.sh
+++ b/crypto/lhash/make_macros.sh
@@ -28,7 +28,7 @@
   type=$1
 
   cat >> $out << EOF
-/* ${type} */
+// ${type}
 #define lh_${type}_new(hash, comp)\\
 ((LHASH_OF(${type})*) lh_new(CHECKED_CAST(lhash_hash_func, uint32_t (*) (const ${type} *), hash), CHECKED_CAST(lhash_cmp_func, int (*) (const ${type} *a, const ${type} *b), comp)))
 
@@ -41,6 +41,9 @@
 #define lh_${type}_retrieve(lh, data)\\
   ((${type}*) lh_retrieve(CHECKED_CAST(_LHASH*, LHASH_OF(${type})*, lh), CHECKED_CAST(void*, ${type}*, data)))
 
+#define lh_${type}_retrieve_key(lh, key, key_hash, cmp_key)\\
+  ((${type}*) lh_retrieve_key(CHECKED_CAST(_LHASH*, LHASH_OF(${type})*, lh), key, key_hash, CHECKED_CAST(int (*)(const void *, const void *), int (*)(const void *, const ${type} *), cmp_key)))
+
 #define lh_${type}_insert(lh, old_data, data)\\
   lh_insert(CHECKED_CAST(_LHASH*, LHASH_OF(${type})*, lh), CHECKED_CAST(void**, ${type}**, old_data), CHECKED_CAST(void*, ${type}*, data))
 
@@ -57,7 +60,7 @@
 EOF
 }
 
-lhash_types=$(cat ${include_dir}/lhash.h | grep '^ \* LHASH_OF:' | sed -e 's/.*LHASH_OF://' -e 's/ .*//')
+lhash_types=$(cat ${include_dir}/lhash.h | grep '^// LHASH_OF:' | sed -e 's/.*LHASH_OF://' -e 's/ .*//')
 
 for type in $lhash_types; do
   echo Hash of ${type}
diff --git a/include/openssl/lhash.h b/include/openssl/lhash.h
index 1ceeb69..287ad63 100644
--- a/include/openssl/lhash.h
+++ b/include/openssl/lhash.h
@@ -141,6 +141,15 @@
 // it. If no such element exists, it returns NULL.
 OPENSSL_EXPORT void *lh_retrieve(const _LHASH *lh, const void *data);
 
+// lh_retrieve_key finds an element matching |key|, given the specified hash and
+// comparison function. This differs from |lh_retrieve| in that the key may be a
+// different type than the values stored in |lh|. |key_hash| and |cmp_key| must
+// be compatible with the functions passed into |lh_new|.
+OPENSSL_EXPORT void *lh_retrieve_key(const _LHASH *lh, const void *key,
+                                     uint32_t key_hash,
+                                     int (*cmp_key)(const void *key,
+                                                    const void *value));
+
 // lh_insert inserts |data| into the hash table. If an existing element is
 // equal to |data| (with respect to the comparison function) then |*old_data|
 // will be set to that value and it will be replaced. Otherwise, or in the
diff --git a/include/openssl/lhash_macros.h b/include/openssl/lhash_macros.h
index 378c839..dd3e4dc 100644
--- a/include/openssl/lhash_macros.h
+++ b/include/openssl/lhash_macros.h
@@ -35,6 +35,12 @@
       CHECKED_CAST(_LHASH *, LHASH_OF(ASN1_OBJECT) *, lh), \
       CHECKED_CAST(void *, ASN1_OBJECT *, data)))
 
+#define lh_ASN1_OBJECT_retrieve_key(lh, key, key_hash, cmp_key)           \
+  ((ASN1_OBJECT *)lh_retrieve_key(                                        \
+      CHECKED_CAST(_LHASH *, LHASH_OF(ASN1_OBJECT) *, lh), key, key_hash, \
+      CHECKED_CAST(int (*)(const void *, const void *),                   \
+                   int (*)(const void *, const ASN1_OBJECT *), cmp_key)))
+
 #define lh_ASN1_OBJECT_insert(lh, old_data, data)                \
   lh_insert(CHECKED_CAST(_LHASH *, LHASH_OF(ASN1_OBJECT) *, lh), \
             CHECKED_CAST(void **, ASN1_OBJECT **, old_data),     \
@@ -74,6 +80,12 @@
       CHECKED_CAST(_LHASH *, LHASH_OF(CONF_VALUE) *, lh), \
       CHECKED_CAST(void *, CONF_VALUE *, data)))
 
+#define lh_CONF_VALUE_retrieve_key(lh, key, key_hash, cmp_key)           \
+  ((CONF_VALUE *)lh_retrieve_key(                                        \
+      CHECKED_CAST(_LHASH *, LHASH_OF(CONF_VALUE) *, lh), key, key_hash, \
+      CHECKED_CAST(int (*)(const void *, const void *),                  \
+                   int (*)(const void *, const CONF_VALUE *), cmp_key)))
+
 #define lh_CONF_VALUE_insert(lh, old_data, data)                \
   lh_insert(CHECKED_CAST(_LHASH *, LHASH_OF(CONF_VALUE) *, lh), \
             CHECKED_CAST(void **, CONF_VALUE **, old_data),     \
@@ -113,6 +125,12 @@
       CHECKED_CAST(_LHASH *, LHASH_OF(CRYPTO_BUFFER) *, lh), \
       CHECKED_CAST(void *, CRYPTO_BUFFER *, data)))
 
+#define lh_CRYPTO_BUFFER_retrieve_key(lh, key, key_hash, cmp_key)           \
+  ((CRYPTO_BUFFER *)lh_retrieve_key(                                        \
+      CHECKED_CAST(_LHASH *, LHASH_OF(CRYPTO_BUFFER) *, lh), key, key_hash, \
+      CHECKED_CAST(int (*)(const void *, const void *),                     \
+                   int (*)(const void *, const CRYPTO_BUFFER *), cmp_key)))
+
 #define lh_CRYPTO_BUFFER_insert(lh, old_data, data)                \
   lh_insert(CHECKED_CAST(_LHASH *, LHASH_OF(CRYPTO_BUFFER) *, lh), \
             CHECKED_CAST(void **, CRYPTO_BUFFER **, old_data),     \
@@ -153,6 +171,12 @@
       CHECKED_CAST(_LHASH *, LHASH_OF(SSL_SESSION) *, lh), \
       CHECKED_CAST(void *, SSL_SESSION *, data)))
 
+#define lh_SSL_SESSION_retrieve_key(lh, key, key_hash, cmp_key)           \
+  ((SSL_SESSION *)lh_retrieve_key(                                        \
+      CHECKED_CAST(_LHASH *, LHASH_OF(SSL_SESSION) *, lh), key, key_hash, \
+      CHECKED_CAST(int (*)(const void *, const void *),                   \
+                   int (*)(const void *, const SSL_SESSION *), cmp_key)))
+
 #define lh_SSL_SESSION_insert(lh, old_data, data)                \
   lh_insert(CHECKED_CAST(_LHASH *, LHASH_OF(SSL_SESSION) *, lh), \
             CHECKED_CAST(void **, SSL_SESSION **, old_data),     \
diff --git a/ssl/internal.h b/ssl/internal.h
index 05b967d..d66b418 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2909,6 +2909,10 @@
 // error.
 UniquePtr<SSL_SESSION> ssl_session_new(const SSL_X509_METHOD *x509_method);
 
+// ssl_hash_session_id returns a hash of |session_id|, suitable for a hash table
+// keyed on session IDs.
+uint32_t ssl_hash_session_id(Span<const uint8_t> session_id);
+
 // SSL_SESSION_parse parses an |SSL_SESSION| from |cbs| and advances |cbs| over
 // the parsed data.
 UniquePtr<SSL_SESSION> SSL_SESSION_parse(CBS *cbs,
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index 222a316..a153c60 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -567,22 +567,8 @@
 }
 
 static uint32_t ssl_session_hash(const SSL_SESSION *sess) {
-  const uint8_t *session_id = sess->session_id;
-
-  uint8_t tmp_storage[sizeof(uint32_t)];
-  if (sess->session_id_length < sizeof(tmp_storage)) {
-    OPENSSL_memset(tmp_storage, 0, sizeof(tmp_storage));
-    OPENSSL_memcpy(tmp_storage, sess->session_id, sess->session_id_length);
-    session_id = tmp_storage;
-  }
-
-  uint32_t hash =
-      ((uint32_t)session_id[0]) |
-      ((uint32_t)session_id[1] << 8) |
-      ((uint32_t)session_id[2] << 16) |
-      ((uint32_t)session_id[3] << 24);
-
-  return hash;
+  return ssl_hash_session_id(
+      MakeConstSpan(sess->session_id, sess->session_id_length));
 }
 
 static int ssl_session_cmp(const SSL_SESSION *a, const SSL_SESSION *b) {
diff --git a/ssl/ssl_session.cc b/ssl/ssl_session.cc
index adec8dc..a2a1482 100644
--- a/ssl/ssl_session.cc
+++ b/ssl/ssl_session.cc
@@ -184,6 +184,26 @@
   return session;
 }
 
+uint32_t ssl_hash_session_id(Span<const uint8_t> session_id) {
+  // Take the first four bytes of |session_id|. Session IDs are generated by the
+  // server randomly, so we can assume even using the first four bytes results
+  // in a good distribution.
+  uint8_t tmp_storage[sizeof(uint32_t)];
+  if (session_id.size() < sizeof(tmp_storage)) {
+    OPENSSL_memset(tmp_storage, 0, sizeof(tmp_storage));
+    OPENSSL_memcpy(tmp_storage, session_id.data(), session_id.size());
+    session_id = tmp_storage;
+  }
+
+  uint32_t hash =
+      ((uint32_t)session_id[0]) |
+      ((uint32_t)session_id[1] << 8) |
+      ((uint32_t)session_id[2] << 16) |
+      ((uint32_t)session_id[3] << 24);
+
+  return hash;
+}
+
 UniquePtr<SSL_SESSION> SSL_SESSION_dup(SSL_SESSION *session, int dup_flags) {
   UniquePtr<SSL_SESSION> new_session = ssl_session_new(session->x509_method);
   if (!new_session) {
@@ -657,11 +677,11 @@
 // |*out_session| to an |SSL_SESSION| object if found.
 static enum ssl_hs_wait_t ssl_lookup_session(
     SSL_HANDSHAKE *hs, UniquePtr<SSL_SESSION> *out_session,
-    const uint8_t *session_id, size_t session_id_len) {
+    Span<const uint8_t> session_id) {
   SSL *const ssl = hs->ssl;
   out_session->reset();
 
-  if (session_id_len == 0 || session_id_len > SSL_MAX_SSL_SESSION_ID_LENGTH) {
+  if (session_id.empty() || session_id.size() > SSL_MAX_SSL_SESSION_ID_LENGTH) {
     return ssl_hs_ok;
   }
 
@@ -669,21 +689,26 @@
   // Try the internal cache, if it exists.
   if (!(ssl->session_ctx->session_cache_mode &
         SSL_SESS_CACHE_NO_INTERNAL_LOOKUP)) {
-    SSL_SESSION data;
-    data.session_id_length = session_id_len;
-    OPENSSL_memcpy(data.session_id, session_id, session_id_len);
-
+    uint32_t hash = ssl_hash_session_id(session_id);
+    auto cmp = [](const void *key, const SSL_SESSION *sess) -> int {
+      Span<const uint8_t> key_id =
+          *reinterpret_cast<const Span<const uint8_t> *>(key);
+      Span<const uint8_t> sess_id =
+          MakeConstSpan(sess->session_id, sess->session_id_length);
+      return key_id == sess_id ? 0 : 1;
+    };
     MutexReadLock lock(&ssl->session_ctx->lock);
-    // |lh_SSL_SESSION_retrieve| returns a non-owning pointer.
-    session = UpRef(lh_SSL_SESSION_retrieve(ssl->session_ctx->sessions, &data));
+    // |lh_SSL_SESSION_retrieve_key| returns a non-owning pointer.
+    session = UpRef(lh_SSL_SESSION_retrieve_key(ssl->session_ctx->sessions,
+                                                &session_id, hash, cmp));
     // TODO(davidben): This should probably move it to the front of the list.
   }
 
   // Fall back to the external cache, if it exists.
   if (!session && ssl->session_ctx->get_session_cb != nullptr) {
     int copy = 1;
-    session.reset(ssl->session_ctx->get_session_cb(ssl, session_id,
-                                                   session_id_len, &copy));
+    session.reset(ssl->session_ctx->get_session_cb(ssl, session_id.data(),
+                                                   session_id.size(), &copy));
     if (!session) {
       return ssl_hs_ok;
     }
@@ -752,7 +777,8 @@
   } else {
     // The client didn't send a ticket, so the session ID is a real ID.
     enum ssl_hs_wait_t lookup_ret = ssl_lookup_session(
-        hs, &session, client_hello->session_id, client_hello->session_id_len);
+        hs, &session,
+        MakeConstSpan(client_hello->session_id, client_hello->session_id_len));
     if (lookup_ret != ssl_hs_ok) {
       return lookup_ret;
     }