Make an internal RefCounted base class for libssl This is still a bit more tedious than I'd like, but we've got three of these and I'm about to add a fourth. Add something like Chromium's base class. But where Chromium integrates the base class directly with scoped_refptr (giving a place for a static_assert that you did the subclassing right), we don't quite have that since we need to integrate with the external C API. Instead, use the "passkey" pattern and have RefCounted<T>'s protected constructor take a struct that only T can construct. The passkey ensures that only T can construct RefCounted<T>, and the protectedness ensures that T subclassed RefCounted<T>. (I think the latter already comes from the static_cast in DecRef, but may as well.) Change-Id: Icf4cbc7d4168010ee46dfa3a7b0a2e7c20aaf383 Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/66369 Reviewed-by: Bob Beck <bbe@google.com> Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/encrypted_client_hello.cc b/ssl/encrypted_client_hello.cc index a5492e9..8c4a42c 100644 --- a/ssl/encrypted_client_hello.cc +++ b/ssl/encrypted_client_hello.cc
@@ -1012,18 +1012,12 @@ SSL_ECH_KEYS *SSL_ECH_KEYS_new() { return New<SSL_ECH_KEYS>(); } -void SSL_ECH_KEYS_up_ref(SSL_ECH_KEYS *keys) { - CRYPTO_refcount_inc(&keys->references); -} +void SSL_ECH_KEYS_up_ref(SSL_ECH_KEYS *keys) { keys->UpRefInternal(); } void SSL_ECH_KEYS_free(SSL_ECH_KEYS *keys) { - if (keys == nullptr || - !CRYPTO_refcount_dec_and_test_zero(&keys->references)) { - return; + if (keys != nullptr) { + keys->DecRefInternal(); } - - keys->~ssl_ech_keys_st(); - OPENSSL_free(keys); } int SSL_ECH_KEYS_add(SSL_ECH_KEYS *configs, int is_retry_config,
diff --git a/ssl/internal.h b/ssl/internal.h index f1d02a0..dcc546b 100644 --- a/ssl/internal.h +++ b/ssl/internal.h
@@ -460,6 +460,48 @@ return fixed_names.size() + objects.size(); } +// RefCounted is a common base for ref-counted types. This is an instance of the +// C++ curiously-recurring template pattern, so a type Foo must subclass +// RefCounted<Foo>. It additionally must friend RefCounted<Foo> to allow calling +// the destructor. +template <typename Derived> +class RefCounted { + public: + RefCounted(const RefCounted &) = delete; + RefCounted &operator=(const RefCounted &) = delete; + + // These methods are intentionally named differently from `bssl::UpRef` to + // avoid a collision. Only the implementations of `FOO_up_ref` and `FOO_free` + // should call these. + void UpRefInternal() { CRYPTO_refcount_inc(&references_); } + void DecRefInternal() { + if (CRYPTO_refcount_dec_and_test_zero(&references_)) { + Derived *d = static_cast<Derived *>(this); + d->~Derived(); + OPENSSL_free(d); + } + } + + protected: + // Ensure that only `Derived`, which must inherit from `RefCounted<Derived>`, + // can call the constructor. This catches bugs where someone inherited from + // the wrong base. + class CheckSubClass { + private: + friend Derived; + CheckSubClass() = default; + }; + RefCounted(CheckSubClass) { + static_assert(std::is_base_of<RefCounted, Derived>::value, + "Derived must subclass RefCounted<Derived>"); + } + + ~RefCounted() = default; + + private: + CRYPTO_refcount_t references_ = 1; +}; + // Protocol versions. // @@ -3446,7 +3488,7 @@ const bssl::SSL_X509_METHOD *x509_method; }; -struct ssl_ctx_st { +struct ssl_ctx_st : public bssl::RefCounted<ssl_ctx_st> { explicit ssl_ctx_st(const SSL_METHOD *ssl_method); ssl_ctx_st(const ssl_ctx_st &) = delete; ssl_ctx_st &operator=(const ssl_ctx_st &) = delete; @@ -3516,8 +3558,6 @@ SSL_SESSION *(*get_session_cb)(SSL *ssl, const uint8_t *data, int len, int *copy) = nullptr; - CRYPTO_refcount_t references = 1; - // if defined, these override the X509_verify_cert() calls int (*app_verify_callback)(X509_STORE_CTX *store_ctx, void *arg) = nullptr; void *app_verify_arg = nullptr; @@ -3754,8 +3794,8 @@ bool aes_hw_override_value : 1; private: + friend RefCounted; ~ssl_ctx_st(); - friend OPENSSL_EXPORT void SSL_CTX_free(SSL_CTX *); }; struct ssl_st { @@ -3847,13 +3887,11 @@ bool enable_early_data : 1; }; -struct ssl_session_st { +struct ssl_session_st : public bssl::RefCounted<ssl_session_st> { explicit ssl_session_st(const bssl::SSL_X509_METHOD *method); ssl_session_st(const ssl_session_st &) = delete; ssl_session_st &operator=(const ssl_session_st &) = delete; - CRYPTO_refcount_t references = 1; - // ssl_version is the (D)TLS version that established the session. uint16_t ssl_version = 0; @@ -3996,21 +4034,18 @@ bssl::Array<uint8_t> quic_early_data_context; private: + friend RefCounted; ~ssl_session_st(); - friend OPENSSL_EXPORT void SSL_SESSION_free(SSL_SESSION *); }; -struct ssl_ech_keys_st { - ssl_ech_keys_st() = default; - ssl_ech_keys_st(const ssl_ech_keys_st &) = delete; - ssl_ech_keys_st &operator=(const ssl_ech_keys_st &) = delete; +struct ssl_ech_keys_st : public bssl::RefCounted<ssl_ech_keys_st> { + ssl_ech_keys_st() : RefCounted(CheckSubClass()) {} bssl::GrowableArray<bssl::UniquePtr<bssl::ECHServerConfig>> configs; - CRYPTO_refcount_t references = 1; private: + friend RefCounted; ~ssl_ech_keys_st() = default; - friend OPENSSL_EXPORT void SSL_ECH_KEYS_free(SSL_ECH_KEYS *); }; #endif // OPENSSL_HEADER_SSL_INTERNAL_H
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc index 58b68e6..91741fd 100644 --- a/ssl/ssl_lib.cc +++ b/ssl/ssl_lib.cc
@@ -523,7 +523,8 @@ } ssl_ctx_st::ssl_ctx_st(const SSL_METHOD *ssl_method) - : method(ssl_method->method), + : RefCounted(CheckSubClass()), + method(ssl_method->method), x509_method(ssl_method->x509_method), retain_only_sha256_of_client_certs(false), quiet_shutdown(false), @@ -589,18 +590,14 @@ } int SSL_CTX_up_ref(SSL_CTX *ctx) { - CRYPTO_refcount_inc(&ctx->references); + ctx->UpRefInternal(); return 1; } void SSL_CTX_free(SSL_CTX *ctx) { - if (ctx == NULL || - !CRYPTO_refcount_dec_and_test_zero(&ctx->references)) { - return; + if (ctx != nullptr) { + ctx->DecRefInternal(); } - - ctx->~ssl_ctx_st(); - OPENSSL_free(ctx); } ssl_st::ssl_st(SSL_CTX *ctx_arg)
diff --git a/ssl/ssl_session.cc b/ssl/ssl_session.cc index 979ac59..5275b69 100644 --- a/ssl/ssl_session.cc +++ b/ssl/ssl_session.cc
@@ -935,7 +935,8 @@ using namespace bssl; ssl_session_st::ssl_session_st(const SSL_X509_METHOD *method) - : x509_method(method), + : RefCounted(CheckSubClass()), + x509_method(method), extended_master_secret(false), peer_sha256_valid(false), not_resumable(false), @@ -957,18 +958,14 @@ } int SSL_SESSION_up_ref(SSL_SESSION *session) { - CRYPTO_refcount_inc(&session->references); + session->UpRefInternal(); return 1; } void SSL_SESSION_free(SSL_SESSION *session) { - if (session == NULL || - !CRYPTO_refcount_dec_and_test_zero(&session->references)) { - return; + if (session != nullptr) { + session->DecRefInternal(); } - - session->~ssl_session_st(); - OPENSSL_free(session); } const uint8_t *SSL_SESSION_get_id(const SSL_SESSION *session,