Add lh_FOO_retrieve_key to avoid stack-allocating SSL_SESSION.
lh_FOO_retrieve is often called with a dummy instance of FOO that has
only a few fields filled in. This works fine for C, but a C++
SSL_SESSION with destructors is a bit more of a nuisance here.
Instead, teach LHASH to allow queries by some external key type. This
avoids stack-allocating SSL_SESSION. Along the way, fix the
make_macros.sh script.
Change-Id: Ie0b482d4ffe1027049d49db63274c7c17f9398fa
Reviewed-on: https://boringssl-review.googlesource.com/29586
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/lhash/lhash.c b/crypto/lhash/lhash.c
index 7bfd289..e4fc3fd 100644
--- a/crypto/lhash/lhash.c
+++ b/crypto/lhash/lhash.c
@@ -141,14 +141,12 @@
static LHASH_ITEM **get_next_ptr_and_hash(const _LHASH *lh, uint32_t *out_hash,
const void *data) {
const uint32_t hash = lh->hash(data);
- LHASH_ITEM *cur, **ret;
-
if (out_hash != NULL) {
*out_hash = hash;
}
- ret = &lh->buckets[hash % lh->num_buckets];
- for (cur = *ret; cur != NULL; cur = *ret) {
+ LHASH_ITEM **ret = &lh->buckets[hash % lh->num_buckets];
+ for (LHASH_ITEM *cur = *ret; cur != NULL; cur = *ret) {
if (lh->comp(cur->data, data) == 0) {
break;
}
@@ -158,16 +156,32 @@
return ret;
}
-void *lh_retrieve(const _LHASH *lh, const void *data) {
- LHASH_ITEM **next_ptr;
-
- next_ptr = get_next_ptr_and_hash(lh, NULL, data);
-
- if (*next_ptr == NULL) {
- return NULL;
+// get_next_ptr_by_key behaves like |get_next_ptr_and_hash| but takes a key
+// which may be a different type from the values stored in |lh|.
+static LHASH_ITEM **get_next_ptr_by_key(const _LHASH *lh, const void *key,
+ uint32_t key_hash,
+ int (*cmp_key)(const void *key,
+ const void *value)) {
+ LHASH_ITEM **ret = &lh->buckets[key_hash % lh->num_buckets];
+ for (LHASH_ITEM *cur = *ret; cur != NULL; cur = *ret) {
+ if (cmp_key(key, cur->data) == 0) {
+ break;
+ }
+ ret = &cur->next;
}
- return (*next_ptr)->data;
+ return ret;
+}
+
+void *lh_retrieve(const _LHASH *lh, const void *data) {
+ LHASH_ITEM **next_ptr = get_next_ptr_and_hash(lh, NULL, data);
+ return *next_ptr == NULL ? NULL : (*next_ptr)->data;
+}
+
+void *lh_retrieve_key(const _LHASH *lh, const void *key, uint32_t key_hash,
+ int (*cmp_key)(const void *key, const void *value)) {
+ LHASH_ITEM **next_ptr = get_next_ptr_by_key(lh, key, key_hash, cmp_key);
+ return *next_ptr == NULL ? NULL : (*next_ptr)->data;
}
// lh_rebucket allocates a new array of |new_num_buckets| pointers and
diff --git a/crypto/lhash/lhash_test.cc b/crypto/lhash/lhash_test.cc
index 0859eeb..a2f61f6 100644
--- a/crypto/lhash/lhash_test.cc
+++ b/crypto/lhash/lhash_test.cc
@@ -103,6 +103,17 @@
std::unique_ptr<char[]> key = RandString();
void *value = lh_retrieve(lh.get(), key.get());
EXPECT_EQ(Lookup(&dummy_lh, key.get()), value);
+
+ // Do the same lookup with |lh_retrieve_key|.
+ value = lh_retrieve_key(
+ lh.get(), &key, lh_strhash(key.get()),
+ [](const void *key_ptr, const void *data) -> int {
+ const char *key_data =
+ reinterpret_cast<const std::unique_ptr<char[]> *>(key_ptr)
+ ->get();
+ return strcmp(key_data, reinterpret_cast<const char *>(data));
+ });
+ EXPECT_EQ(Lookup(&dummy_lh, key.get()), value);
break;
}
diff --git a/crypto/lhash/make_macros.sh b/crypto/lhash/make_macros.sh
index 8a876af..1418539 100644
--- a/crypto/lhash/make_macros.sh
+++ b/crypto/lhash/make_macros.sh
@@ -28,7 +28,7 @@
type=$1
cat >> $out << EOF
-/* ${type} */
+// ${type}
#define lh_${type}_new(hash, comp)\\
((LHASH_OF(${type})*) lh_new(CHECKED_CAST(lhash_hash_func, uint32_t (*) (const ${type} *), hash), CHECKED_CAST(lhash_cmp_func, int (*) (const ${type} *a, const ${type} *b), comp)))
@@ -41,6 +41,9 @@
#define lh_${type}_retrieve(lh, data)\\
((${type}*) lh_retrieve(CHECKED_CAST(_LHASH*, LHASH_OF(${type})*, lh), CHECKED_CAST(void*, ${type}*, data)))
+#define lh_${type}_retrieve_key(lh, key, key_hash, cmp_key)\\
+ ((${type}*) lh_retrieve_key(CHECKED_CAST(_LHASH*, LHASH_OF(${type})*, lh), key, key_hash, CHECKED_CAST(int (*)(const void *, const void *), int (*)(const void *, const ${type} *), cmp_key)))
+
#define lh_${type}_insert(lh, old_data, data)\\
lh_insert(CHECKED_CAST(_LHASH*, LHASH_OF(${type})*, lh), CHECKED_CAST(void**, ${type}**, old_data), CHECKED_CAST(void*, ${type}*, data))
@@ -57,7 +60,7 @@
EOF
}
-lhash_types=$(cat ${include_dir}/lhash.h | grep '^ \* LHASH_OF:' | sed -e 's/.*LHASH_OF://' -e 's/ .*//')
+lhash_types=$(cat ${include_dir}/lhash.h | grep '^// LHASH_OF:' | sed -e 's/.*LHASH_OF://' -e 's/ .*//')
for type in $lhash_types; do
echo Hash of ${type}
diff --git a/include/openssl/lhash.h b/include/openssl/lhash.h
index 1ceeb69..287ad63 100644
--- a/include/openssl/lhash.h
+++ b/include/openssl/lhash.h
@@ -141,6 +141,15 @@
// it. If no such element exists, it returns NULL.
OPENSSL_EXPORT void *lh_retrieve(const _LHASH *lh, const void *data);
+// lh_retrieve_key finds an element matching |key|, given the specified hash and
+// comparison function. This differs from |lh_retrieve| in that the key may be a
+// different type than the values stored in |lh|. |key_hash| and |cmp_key| must
+// be compatible with the functions passed into |lh_new|.
+OPENSSL_EXPORT void *lh_retrieve_key(const _LHASH *lh, const void *key,
+ uint32_t key_hash,
+ int (*cmp_key)(const void *key,
+ const void *value));
+
// lh_insert inserts |data| into the hash table. If an existing element is
// equal to |data| (with respect to the comparison function) then |*old_data|
// will be set to that value and it will be replaced. Otherwise, or in the
diff --git a/include/openssl/lhash_macros.h b/include/openssl/lhash_macros.h
index 378c839..dd3e4dc 100644
--- a/include/openssl/lhash_macros.h
+++ b/include/openssl/lhash_macros.h
@@ -35,6 +35,12 @@
CHECKED_CAST(_LHASH *, LHASH_OF(ASN1_OBJECT) *, lh), \
CHECKED_CAST(void *, ASN1_OBJECT *, data)))
+#define lh_ASN1_OBJECT_retrieve_key(lh, key, key_hash, cmp_key) \
+ ((ASN1_OBJECT *)lh_retrieve_key( \
+ CHECKED_CAST(_LHASH *, LHASH_OF(ASN1_OBJECT) *, lh), key, key_hash, \
+ CHECKED_CAST(int (*)(const void *, const void *), \
+ int (*)(const void *, const ASN1_OBJECT *), cmp_key)))
+
#define lh_ASN1_OBJECT_insert(lh, old_data, data) \
lh_insert(CHECKED_CAST(_LHASH *, LHASH_OF(ASN1_OBJECT) *, lh), \
CHECKED_CAST(void **, ASN1_OBJECT **, old_data), \
@@ -74,6 +80,12 @@
CHECKED_CAST(_LHASH *, LHASH_OF(CONF_VALUE) *, lh), \
CHECKED_CAST(void *, CONF_VALUE *, data)))
+#define lh_CONF_VALUE_retrieve_key(lh, key, key_hash, cmp_key) \
+ ((CONF_VALUE *)lh_retrieve_key( \
+ CHECKED_CAST(_LHASH *, LHASH_OF(CONF_VALUE) *, lh), key, key_hash, \
+ CHECKED_CAST(int (*)(const void *, const void *), \
+ int (*)(const void *, const CONF_VALUE *), cmp_key)))
+
#define lh_CONF_VALUE_insert(lh, old_data, data) \
lh_insert(CHECKED_CAST(_LHASH *, LHASH_OF(CONF_VALUE) *, lh), \
CHECKED_CAST(void **, CONF_VALUE **, old_data), \
@@ -113,6 +125,12 @@
CHECKED_CAST(_LHASH *, LHASH_OF(CRYPTO_BUFFER) *, lh), \
CHECKED_CAST(void *, CRYPTO_BUFFER *, data)))
+#define lh_CRYPTO_BUFFER_retrieve_key(lh, key, key_hash, cmp_key) \
+ ((CRYPTO_BUFFER *)lh_retrieve_key( \
+ CHECKED_CAST(_LHASH *, LHASH_OF(CRYPTO_BUFFER) *, lh), key, key_hash, \
+ CHECKED_CAST(int (*)(const void *, const void *), \
+ int (*)(const void *, const CRYPTO_BUFFER *), cmp_key)))
+
#define lh_CRYPTO_BUFFER_insert(lh, old_data, data) \
lh_insert(CHECKED_CAST(_LHASH *, LHASH_OF(CRYPTO_BUFFER) *, lh), \
CHECKED_CAST(void **, CRYPTO_BUFFER **, old_data), \
@@ -153,6 +171,12 @@
CHECKED_CAST(_LHASH *, LHASH_OF(SSL_SESSION) *, lh), \
CHECKED_CAST(void *, SSL_SESSION *, data)))
+#define lh_SSL_SESSION_retrieve_key(lh, key, key_hash, cmp_key) \
+ ((SSL_SESSION *)lh_retrieve_key( \
+ CHECKED_CAST(_LHASH *, LHASH_OF(SSL_SESSION) *, lh), key, key_hash, \
+ CHECKED_CAST(int (*)(const void *, const void *), \
+ int (*)(const void *, const SSL_SESSION *), cmp_key)))
+
#define lh_SSL_SESSION_insert(lh, old_data, data) \
lh_insert(CHECKED_CAST(_LHASH *, LHASH_OF(SSL_SESSION) *, lh), \
CHECKED_CAST(void **, SSL_SESSION **, old_data), \
diff --git a/ssl/internal.h b/ssl/internal.h
index 05b967d..d66b418 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2909,6 +2909,10 @@
// error.
UniquePtr<SSL_SESSION> ssl_session_new(const SSL_X509_METHOD *x509_method);
+// ssl_hash_session_id returns a hash of |session_id|, suitable for a hash table
+// keyed on session IDs.
+uint32_t ssl_hash_session_id(Span<const uint8_t> session_id);
+
// SSL_SESSION_parse parses an |SSL_SESSION| from |cbs| and advances |cbs| over
// the parsed data.
UniquePtr<SSL_SESSION> SSL_SESSION_parse(CBS *cbs,
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index 222a316..a153c60 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -567,22 +567,8 @@
}
static uint32_t ssl_session_hash(const SSL_SESSION *sess) {
- const uint8_t *session_id = sess->session_id;
-
- uint8_t tmp_storage[sizeof(uint32_t)];
- if (sess->session_id_length < sizeof(tmp_storage)) {
- OPENSSL_memset(tmp_storage, 0, sizeof(tmp_storage));
- OPENSSL_memcpy(tmp_storage, sess->session_id, sess->session_id_length);
- session_id = tmp_storage;
- }
-
- uint32_t hash =
- ((uint32_t)session_id[0]) |
- ((uint32_t)session_id[1] << 8) |
- ((uint32_t)session_id[2] << 16) |
- ((uint32_t)session_id[3] << 24);
-
- return hash;
+ return ssl_hash_session_id(
+ MakeConstSpan(sess->session_id, sess->session_id_length));
}
static int ssl_session_cmp(const SSL_SESSION *a, const SSL_SESSION *b) {
diff --git a/ssl/ssl_session.cc b/ssl/ssl_session.cc
index adec8dc..a2a1482 100644
--- a/ssl/ssl_session.cc
+++ b/ssl/ssl_session.cc
@@ -184,6 +184,26 @@
return session;
}
+uint32_t ssl_hash_session_id(Span<const uint8_t> session_id) {
+ // Take the first four bytes of |session_id|. Session IDs are generated by the
+ // server randomly, so we can assume even using the first four bytes results
+ // in a good distribution.
+ uint8_t tmp_storage[sizeof(uint32_t)];
+ if (session_id.size() < sizeof(tmp_storage)) {
+ OPENSSL_memset(tmp_storage, 0, sizeof(tmp_storage));
+ OPENSSL_memcpy(tmp_storage, session_id.data(), session_id.size());
+ session_id = tmp_storage;
+ }
+
+ uint32_t hash =
+ ((uint32_t)session_id[0]) |
+ ((uint32_t)session_id[1] << 8) |
+ ((uint32_t)session_id[2] << 16) |
+ ((uint32_t)session_id[3] << 24);
+
+ return hash;
+}
+
UniquePtr<SSL_SESSION> SSL_SESSION_dup(SSL_SESSION *session, int dup_flags) {
UniquePtr<SSL_SESSION> new_session = ssl_session_new(session->x509_method);
if (!new_session) {
@@ -657,11 +677,11 @@
// |*out_session| to an |SSL_SESSION| object if found.
static enum ssl_hs_wait_t ssl_lookup_session(
SSL_HANDSHAKE *hs, UniquePtr<SSL_SESSION> *out_session,
- const uint8_t *session_id, size_t session_id_len) {
+ Span<const uint8_t> session_id) {
SSL *const ssl = hs->ssl;
out_session->reset();
- if (session_id_len == 0 || session_id_len > SSL_MAX_SSL_SESSION_ID_LENGTH) {
+ if (session_id.empty() || session_id.size() > SSL_MAX_SSL_SESSION_ID_LENGTH) {
return ssl_hs_ok;
}
@@ -669,21 +689,26 @@
// Try the internal cache, if it exists.
if (!(ssl->session_ctx->session_cache_mode &
SSL_SESS_CACHE_NO_INTERNAL_LOOKUP)) {
- SSL_SESSION data;
- data.session_id_length = session_id_len;
- OPENSSL_memcpy(data.session_id, session_id, session_id_len);
-
+ uint32_t hash = ssl_hash_session_id(session_id);
+ auto cmp = [](const void *key, const SSL_SESSION *sess) -> int {
+ Span<const uint8_t> key_id =
+ *reinterpret_cast<const Span<const uint8_t> *>(key);
+ Span<const uint8_t> sess_id =
+ MakeConstSpan(sess->session_id, sess->session_id_length);
+ return key_id == sess_id ? 0 : 1;
+ };
MutexReadLock lock(&ssl->session_ctx->lock);
- // |lh_SSL_SESSION_retrieve| returns a non-owning pointer.
- session = UpRef(lh_SSL_SESSION_retrieve(ssl->session_ctx->sessions, &data));
+ // |lh_SSL_SESSION_retrieve_key| returns a non-owning pointer.
+ session = UpRef(lh_SSL_SESSION_retrieve_key(ssl->session_ctx->sessions,
+ &session_id, hash, cmp));
// TODO(davidben): This should probably move it to the front of the list.
}
// Fall back to the external cache, if it exists.
if (!session && ssl->session_ctx->get_session_cb != nullptr) {
int copy = 1;
- session.reset(ssl->session_ctx->get_session_cb(ssl, session_id,
- session_id_len, ©));
+ session.reset(ssl->session_ctx->get_session_cb(ssl, session_id.data(),
+ session_id.size(), ©));
if (!session) {
return ssl_hs_ok;
}
@@ -752,7 +777,8 @@
} else {
// The client didn't send a ticket, so the session ID is a real ID.
enum ssl_hs_wait_t lookup_ret = ssl_lookup_session(
- hs, &session, client_hello->session_id, client_hello->session_id_len);
+ hs, &session,
+ MakeConstSpan(client_hello->session_id, client_hello->session_id_len));
if (lookup_ret != ssl_hs_ok) {
return lookup_ret;
}