Make an internal RefCounted base class for libssl

This is still a bit more tedious than I'd like, but we've got three of
these and I'm about to add a fourth. Add something like Chromium's base
class. But where Chromium integrates the base class directly with
scoped_refptr (giving a place for a static_assert that you did the
subclassing right), we don't quite have that since we need to integrate
with the external C API.

Instead, use the "passkey" pattern and have RefCounted<T>'s protected
constructor take a struct that only T can construct. The passkey ensures
that only T can construct RefCounted<T>, and the protectedness ensures
that T subclassed RefCounted<T>. (I think the latter already comes from
the static_cast in DecRef, but may as well.)

Change-Id: Icf4cbc7d4168010ee46dfa3a7b0a2e7c20aaf383
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/66369
Reviewed-by: Bob Beck <bbe@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/encrypted_client_hello.cc b/ssl/encrypted_client_hello.cc
index a5492e9..8c4a42c 100644
--- a/ssl/encrypted_client_hello.cc
+++ b/ssl/encrypted_client_hello.cc
@@ -1012,18 +1012,12 @@
 
 SSL_ECH_KEYS *SSL_ECH_KEYS_new() { return New<SSL_ECH_KEYS>(); }
 
-void SSL_ECH_KEYS_up_ref(SSL_ECH_KEYS *keys) {
-  CRYPTO_refcount_inc(&keys->references);
-}
+void SSL_ECH_KEYS_up_ref(SSL_ECH_KEYS *keys) { keys->UpRefInternal(); }
 
 void SSL_ECH_KEYS_free(SSL_ECH_KEYS *keys) {
-  if (keys == nullptr ||
-      !CRYPTO_refcount_dec_and_test_zero(&keys->references)) {
-    return;
+  if (keys != nullptr) {
+    keys->DecRefInternal();
   }
-
-  keys->~ssl_ech_keys_st();
-  OPENSSL_free(keys);
 }
 
 int SSL_ECH_KEYS_add(SSL_ECH_KEYS *configs, int is_retry_config,
diff --git a/ssl/internal.h b/ssl/internal.h
index f1d02a0..dcc546b 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -460,6 +460,48 @@
   return fixed_names.size() + objects.size();
 }
 
+// RefCounted is a common base for ref-counted types. This is an instance of the
+// C++ curiously-recurring template pattern, so a type Foo must subclass
+// RefCounted<Foo>. It additionally must friend RefCounted<Foo> to allow calling
+// the destructor.
+template <typename Derived>
+class RefCounted {
+ public:
+  RefCounted(const RefCounted &) = delete;
+  RefCounted &operator=(const RefCounted &) = delete;
+
+  // These methods are intentionally named differently from `bssl::UpRef` to
+  // avoid a collision. Only the implementations of `FOO_up_ref` and `FOO_free`
+  // should call these.
+  void UpRefInternal() { CRYPTO_refcount_inc(&references_); }
+  void DecRefInternal() {
+    if (CRYPTO_refcount_dec_and_test_zero(&references_)) {
+      Derived *d = static_cast<Derived *>(this);
+      d->~Derived();
+      OPENSSL_free(d);
+    }
+  }
+
+ protected:
+  // Ensure that only `Derived`, which must inherit from `RefCounted<Derived>`,
+  // can call the constructor. This catches bugs where someone inherited from
+  // the wrong base.
+  class CheckSubClass {
+   private:
+    friend Derived;
+    CheckSubClass() = default;
+  };
+  RefCounted(CheckSubClass) {
+    static_assert(std::is_base_of<RefCounted, Derived>::value,
+                  "Derived must subclass RefCounted<Derived>");
+  }
+
+  ~RefCounted() = default;
+
+ private:
+  CRYPTO_refcount_t references_ = 1;
+};
+
 
 // Protocol versions.
 //
@@ -3446,7 +3488,7 @@
   const bssl::SSL_X509_METHOD *x509_method;
 };
 
-struct ssl_ctx_st {
+struct ssl_ctx_st : public bssl::RefCounted<ssl_ctx_st> {
   explicit ssl_ctx_st(const SSL_METHOD *ssl_method);
   ssl_ctx_st(const ssl_ctx_st &) = delete;
   ssl_ctx_st &operator=(const ssl_ctx_st &) = delete;
@@ -3516,8 +3558,6 @@
   SSL_SESSION *(*get_session_cb)(SSL *ssl, const uint8_t *data, int len,
                                  int *copy) = nullptr;
 
-  CRYPTO_refcount_t references = 1;
-
   // if defined, these override the X509_verify_cert() calls
   int (*app_verify_callback)(X509_STORE_CTX *store_ctx, void *arg) = nullptr;
   void *app_verify_arg = nullptr;
@@ -3754,8 +3794,8 @@
   bool aes_hw_override_value : 1;
 
  private:
+  friend RefCounted;
   ~ssl_ctx_st();
-  friend OPENSSL_EXPORT void SSL_CTX_free(SSL_CTX *);
 };
 
 struct ssl_st {
@@ -3847,13 +3887,11 @@
   bool enable_early_data : 1;
 };
 
-struct ssl_session_st {
+struct ssl_session_st : public bssl::RefCounted<ssl_session_st> {
   explicit ssl_session_st(const bssl::SSL_X509_METHOD *method);
   ssl_session_st(const ssl_session_st &) = delete;
   ssl_session_st &operator=(const ssl_session_st &) = delete;
 
-  CRYPTO_refcount_t references = 1;
-
   // ssl_version is the (D)TLS version that established the session.
   uint16_t ssl_version = 0;
 
@@ -3996,21 +4034,18 @@
   bssl::Array<uint8_t> quic_early_data_context;
 
  private:
+  friend RefCounted;
   ~ssl_session_st();
-  friend OPENSSL_EXPORT void SSL_SESSION_free(SSL_SESSION *);
 };
 
-struct ssl_ech_keys_st {
-  ssl_ech_keys_st() = default;
-  ssl_ech_keys_st(const ssl_ech_keys_st &) = delete;
-  ssl_ech_keys_st &operator=(const ssl_ech_keys_st &) = delete;
+struct ssl_ech_keys_st : public bssl::RefCounted<ssl_ech_keys_st> {
+  ssl_ech_keys_st() : RefCounted(CheckSubClass()) {}
 
   bssl::GrowableArray<bssl::UniquePtr<bssl::ECHServerConfig>> configs;
-  CRYPTO_refcount_t references = 1;
 
  private:
+  friend RefCounted;
   ~ssl_ech_keys_st() = default;
-  friend OPENSSL_EXPORT void SSL_ECH_KEYS_free(SSL_ECH_KEYS *);
 };
 
 #endif  // OPENSSL_HEADER_SSL_INTERNAL_H
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index 58b68e6..91741fd 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -523,7 +523,8 @@
 }
 
 ssl_ctx_st::ssl_ctx_st(const SSL_METHOD *ssl_method)
-    : method(ssl_method->method),
+    : RefCounted(CheckSubClass()),
+      method(ssl_method->method),
       x509_method(ssl_method->x509_method),
       retain_only_sha256_of_client_certs(false),
       quiet_shutdown(false),
@@ -589,18 +590,14 @@
 }
 
 int SSL_CTX_up_ref(SSL_CTX *ctx) {
-  CRYPTO_refcount_inc(&ctx->references);
+  ctx->UpRefInternal();
   return 1;
 }
 
 void SSL_CTX_free(SSL_CTX *ctx) {
-  if (ctx == NULL ||
-      !CRYPTO_refcount_dec_and_test_zero(&ctx->references)) {
-    return;
+  if (ctx != nullptr) {
+    ctx->DecRefInternal();
   }
-
-  ctx->~ssl_ctx_st();
-  OPENSSL_free(ctx);
 }
 
 ssl_st::ssl_st(SSL_CTX *ctx_arg)
diff --git a/ssl/ssl_session.cc b/ssl/ssl_session.cc
index 979ac59..5275b69 100644
--- a/ssl/ssl_session.cc
+++ b/ssl/ssl_session.cc
@@ -935,7 +935,8 @@
 using namespace bssl;
 
 ssl_session_st::ssl_session_st(const SSL_X509_METHOD *method)
-    : x509_method(method),
+    : RefCounted(CheckSubClass()),
+      x509_method(method),
       extended_master_secret(false),
       peer_sha256_valid(false),
       not_resumable(false),
@@ -957,18 +958,14 @@
 }
 
 int SSL_SESSION_up_ref(SSL_SESSION *session) {
-  CRYPTO_refcount_inc(&session->references);
+  session->UpRefInternal();
   return 1;
 }
 
 void SSL_SESSION_free(SSL_SESSION *session) {
-  if (session == NULL ||
-      !CRYPTO_refcount_dec_and_test_zero(&session->references)) {
-    return;
+  if (session != nullptr) {
+    session->DecRefInternal();
   }
-
-  session->~ssl_session_st();
-  OPENSSL_free(session);
 }
 
 const uint8_t *SSL_SESSION_get_id(const SSL_SESSION *session,