Make an internal RefCounted base class for libssl
This is still a bit more tedious than I'd like, but we've got three of
these and I'm about to add a fourth. Add something like Chromium's base
class. But where Chromium integrates the base class directly with
scoped_refptr (giving a place for a static_assert that you did the
subclassing right), we don't quite have that since we need to integrate
with the external C API.
Instead, use the "passkey" pattern and have RefCounted<T>'s protected
constructor take a struct that only T can construct. The passkey ensures
that only T can construct RefCounted<T>, and the protectedness ensures
that T subclassed RefCounted<T>. (I think the latter already comes from
the static_cast in DecRef, but may as well.)
Change-Id: Icf4cbc7d4168010ee46dfa3a7b0a2e7c20aaf383
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/66369
Reviewed-by: Bob Beck <bbe@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/encrypted_client_hello.cc b/ssl/encrypted_client_hello.cc
index a5492e9..8c4a42c 100644
--- a/ssl/encrypted_client_hello.cc
+++ b/ssl/encrypted_client_hello.cc
@@ -1012,18 +1012,12 @@
SSL_ECH_KEYS *SSL_ECH_KEYS_new() { return New<SSL_ECH_KEYS>(); }
-void SSL_ECH_KEYS_up_ref(SSL_ECH_KEYS *keys) {
- CRYPTO_refcount_inc(&keys->references);
-}
+void SSL_ECH_KEYS_up_ref(SSL_ECH_KEYS *keys) { keys->UpRefInternal(); }
void SSL_ECH_KEYS_free(SSL_ECH_KEYS *keys) {
- if (keys == nullptr ||
- !CRYPTO_refcount_dec_and_test_zero(&keys->references)) {
- return;
+ if (keys != nullptr) {
+ keys->DecRefInternal();
}
-
- keys->~ssl_ech_keys_st();
- OPENSSL_free(keys);
}
int SSL_ECH_KEYS_add(SSL_ECH_KEYS *configs, int is_retry_config,
diff --git a/ssl/internal.h b/ssl/internal.h
index f1d02a0..dcc546b 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -460,6 +460,48 @@
return fixed_names.size() + objects.size();
}
+// RefCounted is a common base for ref-counted types. This is an instance of the
+// C++ curiously-recurring template pattern, so a type Foo must subclass
+// RefCounted<Foo>. It additionally must friend RefCounted<Foo> to allow calling
+// the destructor.
+template <typename Derived>
+class RefCounted {
+ public:
+ RefCounted(const RefCounted &) = delete;
+ RefCounted &operator=(const RefCounted &) = delete;
+
+ // These methods are intentionally named differently from `bssl::UpRef` to
+ // avoid a collision. Only the implementations of `FOO_up_ref` and `FOO_free`
+ // should call these.
+ void UpRefInternal() { CRYPTO_refcount_inc(&references_); }
+ void DecRefInternal() {
+ if (CRYPTO_refcount_dec_and_test_zero(&references_)) {
+ Derived *d = static_cast<Derived *>(this);
+ d->~Derived();
+ OPENSSL_free(d);
+ }
+ }
+
+ protected:
+ // Ensure that only `Derived`, which must inherit from `RefCounted<Derived>`,
+ // can call the constructor. This catches bugs where someone inherited from
+ // the wrong base.
+ class CheckSubClass {
+ private:
+ friend Derived;
+ CheckSubClass() = default;
+ };
+ RefCounted(CheckSubClass) {
+ static_assert(std::is_base_of<RefCounted, Derived>::value,
+ "Derived must subclass RefCounted<Derived>");
+ }
+
+ ~RefCounted() = default;
+
+ private:
+ CRYPTO_refcount_t references_ = 1;
+};
+
// Protocol versions.
//
@@ -3446,7 +3488,7 @@
const bssl::SSL_X509_METHOD *x509_method;
};
-struct ssl_ctx_st {
+struct ssl_ctx_st : public bssl::RefCounted<ssl_ctx_st> {
explicit ssl_ctx_st(const SSL_METHOD *ssl_method);
ssl_ctx_st(const ssl_ctx_st &) = delete;
ssl_ctx_st &operator=(const ssl_ctx_st &) = delete;
@@ -3516,8 +3558,6 @@
SSL_SESSION *(*get_session_cb)(SSL *ssl, const uint8_t *data, int len,
int *copy) = nullptr;
- CRYPTO_refcount_t references = 1;
-
// if defined, these override the X509_verify_cert() calls
int (*app_verify_callback)(X509_STORE_CTX *store_ctx, void *arg) = nullptr;
void *app_verify_arg = nullptr;
@@ -3754,8 +3794,8 @@
bool aes_hw_override_value : 1;
private:
+ friend RefCounted;
~ssl_ctx_st();
- friend OPENSSL_EXPORT void SSL_CTX_free(SSL_CTX *);
};
struct ssl_st {
@@ -3847,13 +3887,11 @@
bool enable_early_data : 1;
};
-struct ssl_session_st {
+struct ssl_session_st : public bssl::RefCounted<ssl_session_st> {
explicit ssl_session_st(const bssl::SSL_X509_METHOD *method);
ssl_session_st(const ssl_session_st &) = delete;
ssl_session_st &operator=(const ssl_session_st &) = delete;
- CRYPTO_refcount_t references = 1;
-
// ssl_version is the (D)TLS version that established the session.
uint16_t ssl_version = 0;
@@ -3996,21 +4034,18 @@
bssl::Array<uint8_t> quic_early_data_context;
private:
+ friend RefCounted;
~ssl_session_st();
- friend OPENSSL_EXPORT void SSL_SESSION_free(SSL_SESSION *);
};
-struct ssl_ech_keys_st {
- ssl_ech_keys_st() = default;
- ssl_ech_keys_st(const ssl_ech_keys_st &) = delete;
- ssl_ech_keys_st &operator=(const ssl_ech_keys_st &) = delete;
+struct ssl_ech_keys_st : public bssl::RefCounted<ssl_ech_keys_st> {
+ ssl_ech_keys_st() : RefCounted(CheckSubClass()) {}
bssl::GrowableArray<bssl::UniquePtr<bssl::ECHServerConfig>> configs;
- CRYPTO_refcount_t references = 1;
private:
+ friend RefCounted;
~ssl_ech_keys_st() = default;
- friend OPENSSL_EXPORT void SSL_ECH_KEYS_free(SSL_ECH_KEYS *);
};
#endif // OPENSSL_HEADER_SSL_INTERNAL_H
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index 58b68e6..91741fd 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -523,7 +523,8 @@
}
ssl_ctx_st::ssl_ctx_st(const SSL_METHOD *ssl_method)
- : method(ssl_method->method),
+ : RefCounted(CheckSubClass()),
+ method(ssl_method->method),
x509_method(ssl_method->x509_method),
retain_only_sha256_of_client_certs(false),
quiet_shutdown(false),
@@ -589,18 +590,14 @@
}
int SSL_CTX_up_ref(SSL_CTX *ctx) {
- CRYPTO_refcount_inc(&ctx->references);
+ ctx->UpRefInternal();
return 1;
}
void SSL_CTX_free(SSL_CTX *ctx) {
- if (ctx == NULL ||
- !CRYPTO_refcount_dec_and_test_zero(&ctx->references)) {
- return;
+ if (ctx != nullptr) {
+ ctx->DecRefInternal();
}
-
- ctx->~ssl_ctx_st();
- OPENSSL_free(ctx);
}
ssl_st::ssl_st(SSL_CTX *ctx_arg)
diff --git a/ssl/ssl_session.cc b/ssl/ssl_session.cc
index 979ac59..5275b69 100644
--- a/ssl/ssl_session.cc
+++ b/ssl/ssl_session.cc
@@ -935,7 +935,8 @@
using namespace bssl;
ssl_session_st::ssl_session_st(const SSL_X509_METHOD *method)
- : x509_method(method),
+ : RefCounted(CheckSubClass()),
+ x509_method(method),
extended_master_secret(false),
peer_sha256_valid(false),
not_resumable(false),
@@ -957,18 +958,14 @@
}
int SSL_SESSION_up_ref(SSL_SESSION *session) {
- CRYPTO_refcount_inc(&session->references);
+ session->UpRefInternal();
return 1;
}
void SSL_SESSION_free(SSL_SESSION *session) {
- if (session == NULL ||
- !CRYPTO_refcount_dec_and_test_zero(&session->references)) {
- return;
+ if (session != nullptr) {
+ session->DecRefInternal();
}
-
- session->~ssl_session_st();
- OPENSSL_free(session);
}
const uint8_t *SSL_SESSION_get_id(const SSL_SESSION *session,