From e1528a71e2ff9ad023f1ba37667960ee7dcd9a3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Steven=20Dee=20=28J=C5=8Dshin=29?= Date: Sat, 31 Aug 2024 14:00:56 -0400 Subject: [PATCH] Basic CTL shared_ptr implementation (#1267) --- ctl/conditional.h | 3 + ctl/is_void.h | 3 + ctl/shared_ptr.h | 454 ++++++++++++++++++++++++++++++++++++ test/ctl/shared_ptr_test.cc | 248 ++++++++++++++++++++ 4 files changed, 708 insertions(+) create mode 100644 ctl/shared_ptr.h create mode 100644 test/ctl/shared_ptr_test.cc diff --git a/ctl/conditional.h b/ctl/conditional.h index 976143a1d04..5b63eaa8561 100644 --- a/ctl/conditional.h +++ b/ctl/conditional.h @@ -17,6 +17,9 @@ struct conditional typedef F type; }; +template +using conditional_t = typename conditional::type; + } // namespace ctl #endif // CTL_CONDITIONAL_H_ diff --git a/ctl/is_void.h b/ctl/is_void.h index 04c33145cb6..275848d81ac 100644 --- a/ctl/is_void.h +++ b/ctl/is_void.h @@ -19,6 +19,9 @@ template struct is_void : public is_void_::type>::type {}; +template +inline constexpr bool is_void_v = is_void::value; + } // namespace ctl #endif // CTL_IS_VOID_H_ diff --git a/ctl/shared_ptr.h b/ctl/shared_ptr.h new file mode 100644 index 00000000000..40e7a1a7d81 --- /dev/null +++ b/ctl/shared_ptr.h @@ -0,0 +1,454 @@ +// -*-mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8-*- +// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi +#ifndef CTL_SHARED_PTR_H_ +#define CTL_SHARED_PTR_H_ + +#include "exception.h" +#include "is_convertible.h" +#include "remove_extent.h" +#include "unique_ptr.h" + +// XXX currently needed to use placement-new syntax (move to cxx.inc?) +void* +operator new(size_t, void*) noexcept; + +namespace ctl { + +class bad_weak_ptr : public exception +{ + public: + const char* what() const noexcept override + { + return "ctl::bad_weak_ptr"; + } +}; + +namespace __ { + +template +struct ptr_ref +{ + using type = T&; +}; + +template<> +struct ptr_ref +{ + using type = void; +}; + +static inline __attribute__((always_inline)) void +incref(size_t* r) noexcept +{ +#ifdef NDEBUG + __atomic_fetch_add(r, 1, __ATOMIC_RELAXED); +#else + size_t refs = __atomic_fetch_add(r, 1, __ATOMIC_RELAXED); + if (refs > ((size_t)-1) >> 1) + __builtin_trap(); +#endif +} + +static inline __attribute__((always_inline)) bool +decref(size_t* r) noexcept +{ + if (!__atomic_fetch_sub(r, 1, __ATOMIC_RELEASE)) { + __atomic_thread_fence(__ATOMIC_ACQUIRE); + return true; + } + return false; +} + +class shared_ref +{ + public: + constexpr shared_ref() noexcept = default; + shared_ref(const shared_ref&) = delete; + shared_ref& operator=(const shared_ref&) = delete; + + virtual ~shared_ref() = default; + + void keep_shared() noexcept + { + incref(&shared); + } + + void drop_shared() noexcept + { + if (decref(&shared)) { + dispose(); + drop_weak(); + } + } + + void keep_weak() noexcept + { + incref(&weak); + } + + void drop_weak() noexcept + { + if (decref(&weak)) { + delete this; + } + } + + size_t use_count() const noexcept + { + return shared + 1; + } + + size_t weak_count() const noexcept + { + return weak; + } + + private: + virtual void dispose() noexcept = 0; + + size_t shared = 0; + size_t weak = 0; +}; + +template +class shared_pointer : public shared_ref +{ + public: + static shared_pointer* make(T* const p, D d) + { + return make(unique_ptr(p, move(d))); + } + + static shared_pointer* make(unique_ptr p) + { + return new shared_pointer(p.release(), move(p.get_deleter())); + } + + private: + shared_pointer(T* const p, D d) noexcept : p(p), d(move(d)) + { + } + + void dispose() noexcept override + { + move(d)(p); + } + + T* const p; + [[no_unique_address]] D d; +}; + +template +class shared_emplace : public shared_ref +{ + public: + union + { + T t; + }; + + template + void construct(Args&&... args) + { + ::new (&t) T(forward(args)...); + } + + static unique_ptr make() + { + return unique_ptr(new shared_emplace()); + } + + private: + explicit constexpr shared_emplace() noexcept = default; + + void dispose() noexcept override + { + t.~T(); + } +}; + +template +concept shared_ptr_compatible = is_convertible_v; + +} // namespace __ + +template +class weak_ptr; + +template +class shared_ptr +{ + public: + using element_type = remove_extent_t; + using weak_type = weak_ptr; + + constexpr shared_ptr() noexcept = default; + constexpr shared_ptr(nullptr_t) noexcept + { + } + + template + requires __::shared_ptr_compatible + explicit shared_ptr(U* const p) : shared_ptr(p, default_delete()) + { + } + + template + requires __::shared_ptr_compatible + shared_ptr(U* const p, D d) + : p(p), rc(__::shared_pointer::make(p, move(d))) + { + } + + template + shared_ptr(const shared_ptr& r, element_type* p) noexcept + : p(p), rc(r.rc) + { + if (rc) + rc->keep_shared(); + } + + template + shared_ptr(shared_ptr&& r, element_type* p) noexcept : p(p), rc(r.rc) + { + r.p = nullptr; + r.rc = nullptr; + } + + template + requires __::shared_ptr_compatible + shared_ptr(const shared_ptr& r) noexcept : p(r.p), rc(r.rc) + { + if (rc) + rc->keep_shared(); + } + + template + requires __::shared_ptr_compatible + shared_ptr(shared_ptr&& r) noexcept : p(r.p), rc(r.rc) + { + r.p = nullptr; + r.rc = nullptr; + } + + shared_ptr(const shared_ptr& r) noexcept : p(r.p), rc(r.rc) + { + if (rc) + rc->keep_shared(); + } + + shared_ptr(shared_ptr&& r) noexcept : p(r.p), rc(r.rc) + { + r.p = nullptr; + r.rc = nullptr; + } + + template + requires __::shared_ptr_compatible + explicit shared_ptr(const weak_ptr& r) : p(r.p), rc(r.rc) + { + if (r.expired()) { + throw bad_weak_ptr(); + } + rc->keep_shared(); + } + + template + requires __::shared_ptr_compatible + shared_ptr(unique_ptr&& r) + : p(r.p), rc(__::shared_pointer::make(move(r))) + { + } + + ~shared_ptr() + { + if (rc) + rc->drop_shared(); + } + + shared_ptr& operator=(shared_ptr r) noexcept + { + swap(r); + return *this; + } + + template + requires __::shared_ptr_compatible + shared_ptr& operator=(shared_ptr r) noexcept + { + shared_ptr(move(r)).swap(*this); + return *this; + } + + void reset() noexcept + { + shared_ptr().swap(*this); + } + + template + requires __::shared_ptr_compatible + void reset(U* const p2) + { + shared_ptr(p2).swap(*this); + } + + template + requires __::shared_ptr_compatible + void reset(U* const p2, D d) + { + shared_ptr(p2, d).swap(*this); + } + + void swap(shared_ptr& r) noexcept + { + using ctl::swap; + swap(p, r.p); + swap(rc, r.rc); + } + + element_type* get() const noexcept + { + return p; + } + + typename __::ptr_ref::type operator*() const noexcept + { + if (!p) + __builtin_trap(); + return *p; + } + + T* operator->() const noexcept + { + if (!p) + __builtin_trap(); + return p; + } + + long use_count() const noexcept + { + return rc ? rc->use_count() : 0; + } + + explicit operator bool() const noexcept + { + return p; + } + + template + bool owner_before(const shared_ptr& r) const noexcept + { + return p < r.p; + } + + template + bool owner_before(const weak_ptr& r) const noexcept + { + return !r.owner_before(*this); + } + + private: + template + friend class weak_ptr; + + template + friend class shared_ptr; + + template + friend shared_ptr make_shared(Args&&... args); + + element_type* p = nullptr; + __::shared_ref* rc = nullptr; +}; + +template +class weak_ptr +{ + public: + using element_type = remove_extent_t; + + constexpr weak_ptr() noexcept = default; + + template + requires __::shared_ptr_compatible + weak_ptr(const shared_ptr& r) noexcept : p(r.p), rc(r.rc) + { + if (rc) + rc->keep_weak(); + } + + ~weak_ptr() + { + if (rc) + rc->drop_weak(); + } + + long use_count() const noexcept + { + return rc ? rc->use_count() : 0; + } + + bool expired() const noexcept + { + return !use_count(); + } + + void reset() noexcept + { + weak_ptr().swap(*this); + } + + void swap(weak_ptr& r) noexcept + { + using ctl::swap; + swap(p, r.p); + swap(rc, r.rc); + } + + shared_ptr lock() const noexcept + { + if (expired()) + return nullptr; + shared_ptr r; + r.p = p; + r.rc = rc; + if (rc) + rc->keep_shared(); + return r; + } + + template + bool owner_before(const weak_ptr& r) const noexcept + { + return p < r.p; + } + + template + bool owner_before(const shared_ptr& r) const noexcept + { + return p < r.p; + } + + private: + template + friend class shared_ptr; + + element_type* p = nullptr; + __::shared_ref* rc = nullptr; +}; + +template +shared_ptr +make_shared(Args&&... args) +{ + auto rc = __::shared_emplace::make(); + rc->construct(forward(args)...); + shared_ptr r; + r.p = &rc->t; + r.rc = rc.release(); + return r; +} + +} // namespace ctl + +#endif // CTL_SHARED_PTR_H_ diff --git a/test/ctl/shared_ptr_test.cc b/test/ctl/shared_ptr_test.cc new file mode 100644 index 00000000000..c9f9f0516a0 --- /dev/null +++ b/test/ctl/shared_ptr_test.cc @@ -0,0 +1,248 @@ +// -*- mode:c++; indent-tabs-mode:nil; c-basic-offset:4; coding:utf-8 -*- +// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi +// +// Copyright 2024 Justine Alexandra Roberts Tunney +// +// Permission to use, copy, modify, and/or distribute this software for +// any purpose with or without fee is hereby granted, provided that the +// above copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL +// WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE +// AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL +// DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR +// PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER +// TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +// PERFORMANCE OF THIS SOFTWARE. + +#include "ctl/shared_ptr.h" +#include "ctl/vector.h" +#include "libc/mem/leaks.h" + +// #include +// #include +// #define ctl std + +using ctl::bad_weak_ptr; +using ctl::make_shared; +using ctl::move; +using ctl::shared_ptr; +using ctl::unique_ptr; +using ctl::vector; +using ctl::weak_ptr; + +#undef ctl + +static int g = 0; + +struct ConstructG +{ + ConstructG() + { + ++g; + } +}; + +struct DestructG +{ + ~DestructG() + { + ++g; + } +}; + +struct CallG +{ + void operator()(auto*) const noexcept + { + ++g; + } +}; + +struct Base +{}; + +struct Derived : Base +{}; + +int +main() +{ + int a, b; + + { + // Shouldn't cause memory leaks. + shared_ptr x(new int(5)); + } + + { + // Objects get destroyed when the last shared_ptr is reset. + shared_ptr x(&a, CallG()); + shared_ptr y(x); + x.reset(); + if (g) + return 1; + y.reset(); + if (g != 1) + return 2; + } + + { + g = 0; + // Weak pointers don't prevent object destruction. + shared_ptr x(&a, CallG()); + weak_ptr y(x); + x.reset(); + if (g != 1) + return 3; + } + + { + g = 0; + // Weak pointers can be promoted to shared pointers. + shared_ptr x(&a, CallG()); + weak_ptr y(x); + auto z = y.lock(); + x.reset(); + if (g) + return 4; + y.reset(); + if (g) + return 5; + z.reset(); + if (g != 1) + return 6; + } + + { + // Shared null pointers are falsey. + shared_ptr x; + if (x) + return 7; + x.reset(new int); + if (!x) + return 8; + } + + { + // You can cast a shared pointer validly. + shared_ptr x(new Derived); + shared_ptr y(x); + // But not invalidly: + // shared_ptr x(new Derived); + // shared_ptr y(x); + } + + { + // You can cast a shared pointer to void to retain a reference. + shared_ptr x(new int); + shared_ptr y(x); + } + + { + // You can also create a shared pointer to void in the first place. + shared_ptr x(new int); + } + + { + // You can take a shared pointer to a subobject, and it will free the + // base object. + shared_ptr> x(new vector); + x->push_back(5); + shared_ptr y(x, &x->at(0)); + x.reset(); + if (*y != 5) + return 9; + } + + { + g = 0; + // You can create a shared_ptr from a unique_ptr. + unique_ptr x(&a, CallG()); + shared_ptr y(move(x)); + if (x) + return 10; + y.reset(); + if (g != 1) + return 11; + } + + { + g = 0; + // You can reassign shared_ptrs. + shared_ptr x(&a, CallG()); + shared_ptr y; + y = x; + x.reset(); + if (g) + return 12; + y.reset(); + if (g != 1) + return 13; + } + + { + // owner_before works across shared and weak pointers. + shared_ptr x(&a, CallG()); + shared_ptr y(&b, CallG()); + if (!x.owner_before(y)) + return 14; + if (!x.owner_before(weak_ptr(y))) + return 15; + } + + { + // Use counts work like you'd expect + shared_ptr x(new int); + if (x.use_count() != 1) + return 16; + shared_ptr y(x); + if (x.use_count() != 2 || y.use_count() != 2) + return 17; + x.reset(); + if (x.use_count() != 0 || y.use_count() != 1) + return 18; + } + + { + // There is a make_shared that will allocate an object for you safely. + auto x = make_shared(5); + if (!x) + return 19; + if (*x != 5) + return 20; + } + + { + // Expired weak pointers lock to nullptr, and throw when promoted to + // shared pointer by constructor. + auto x = make_shared(); + weak_ptr y(x); + x.reset(); + if (y.lock()) + return 21; + int caught = 0; + try { + shared_ptr z(y); + } catch (bad_weak_ptr& e) { + caught = 1; + } + if (!caught) + return 22; + } + + { + // nullptr is always expired. + shared_ptr x(nullptr); + weak_ptr y(x); + if (!y.expired()) + return 23; + } + + // TODO(mrdomino): exercise threads / races. The reference count should be + // atomically maintained. + + CheckForMemoryLeaks(); + return 0; +}