From 167400c4c782b78dbb6a559e14696fe10e36d6c8 Mon Sep 17 00:00:00 2001 From: lzhangzz Date: Fri, 18 Mar 2022 00:35:22 +0800 Subject: [PATCH] [Enhancement] Switch to statically typed Value::Any (#209) * replace std::any with StaticAny * fix __compare_typeid * remove fallback id support * constraint on traits::TypeId::value * fix includes --- csrc/core/mpl/static_any.h | 489 ++++++++++++++++++++++++++++ csrc/core/value.h | 20 +- tests/test_csrc/core/test_value.cpp | 10 + 3 files changed, 513 insertions(+), 6 deletions(-) create mode 100644 csrc/core/mpl/static_any.h diff --git a/csrc/core/mpl/static_any.h b/csrc/core/mpl/static_any.h new file mode 100644 index 0000000000..a027fd63d3 --- /dev/null +++ b/csrc/core/mpl/static_any.h @@ -0,0 +1,489 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_CSRC_CORE_MPL_STATIC_ANY_H_ +#define MMDEPLOY_CSRC_CORE_MPL_STATIC_ANY_H_ + +#include +#include +#include +#include +#include + +// re-implementation of std::any, relies on static type id instead of RTTI. +// adjusted from libc++-10 + +namespace mmdeploy { + +namespace traits { + +using type_id_t = uint64_t; + +template +struct TypeId { + static constexpr type_id_t value = 0; +}; + +template <> +struct TypeId { + static constexpr auto value = static_cast(-1); +}; + +// ! This only works when calling inside mmdeploy namespace +#define MMDEPLOY_REGISTER_TYPE_ID(type, id) \ + namespace traits { \ + template <> \ + struct TypeId { \ + static constexpr type_id_t value = id; \ + }; \ + } + +} // namespace traits + +namespace detail { + +template +struct is_in_place_type_impl : std::false_type {}; + +template +struct is_in_place_type_impl> : std::true_type {}; + +template +struct is_in_place_type : public is_in_place_type_impl {}; + +} // namespace detail + +class BadAnyCast : public std::bad_cast { + public: + const char* what() const noexcept override { return "BadAnyCast"; } +}; + +[[noreturn]] inline void ThrowBadAnyCast() { +#if __cpp_exceptions + throw BadAnyCast{}; +#else + std::abort(); +#endif +} + +// Forward declarations +class StaticAny; + +template +std::add_pointer_t> static_any_cast(const StaticAny*) noexcept; + +template +std::add_pointer_t static_any_cast(StaticAny*) noexcept; + +namespace __static_any_impl { + +using _Buffer = std::aligned_storage_t<3 * sizeof(void*), std::alignment_of_v>; + +template +using _IsSmallObject = + std::integral_constant % std::alignment_of_v == 0 && + std::is_nothrow_move_constructible_v>; + +enum class _Action { _Destroy, _Copy, _Move, _Get, _TypeInfo }; + +union _Ret { + void* ptr_; + traits::type_id_t type_id_; +}; + +template +struct _SmallHandler; +template +struct _LargeHandler; + +template +inline bool __compare_typeid(traits::type_id_t __id) { + if (__id && __id == traits::TypeId::value) { + return true; + } + return false; +} + +template +using _Handler = std::conditional_t<_IsSmallObject::value, _SmallHandler, _LargeHandler>; + +} // namespace __static_any_impl + +class StaticAny { + public: + constexpr StaticAny() noexcept : h_(nullptr) {} + + StaticAny(const StaticAny& other) : h_(nullptr) { + if (other.h_) { + other.__call(_Action::_Copy, this); + } + } + + StaticAny(StaticAny&& other) noexcept : h_(nullptr) { + if (other.h_) { + other.__call(_Action::_Move, this); + } + } + + template , + class = std::enable_if_t< + !std::is_same::value && !detail::is_in_place_type::value && + std::is_copy_constructible::value && traits::TypeId::value>> + explicit StaticAny(ValueType&& value); + + template < + class ValueType, class... Args, class T = std::decay_t, + class = std::enable_if_t::value && + std::is_copy_constructible::value && traits::TypeId::value>> + explicit StaticAny(std::in_place_type_t, Args&&... args); + + template , + class = std::enable_if_t< + std::is_constructible&, Args...>::value && + std::is_copy_constructible::value && traits::TypeId::value>> + explicit StaticAny(std::in_place_type_t, std::initializer_list, Args&&... args); + + ~StaticAny() { this->reset(); } + + StaticAny& operator=(const StaticAny& rhs) { + StaticAny(rhs).swap(*this); + return *this; + } + + StaticAny& operator=(StaticAny&& rhs) noexcept { + StaticAny(std::move(rhs)).swap(*this); + return *this; + } + + template < + class ValueType, class T = std::decay_t, + class = std::enable_if_t::value && + std::is_copy_constructible::value && traits::TypeId::value>> + StaticAny& operator=(ValueType&& v); + + template < + class ValueType, class... Args, class T = std::decay_t, + class = std::enable_if_t::value && + std::is_copy_constructible::value && traits::TypeId::value>> + T& emplace(Args&&... args); + + template , + class = std::enable_if_t< + std::is_constructible&, Args...>::value && + std::is_copy_constructible::value && traits::TypeId::value>> + T& emplace(std::initializer_list, Args&&...); + + void reset() noexcept { + if (h_) { + this->__call(_Action::_Destroy); + } + } + + void swap(StaticAny& rhs) noexcept; + + bool has_value() const noexcept { return h_ != nullptr; } + + traits::type_id_t type() const noexcept { + if (h_) { + return this->__call(_Action::_TypeInfo).type_id_; + } else { + return traits::TypeId::value; + } + } + + private: + using _Action = __static_any_impl::_Action; + using _Ret = __static_any_impl::_Ret; + using _HandleFuncPtr = _Ret (*)(_Action, const StaticAny*, StaticAny*, traits::type_id_t info); + + union _Storage { + constexpr _Storage() : ptr_(nullptr) {} + void* ptr_; + __static_any_impl::_Buffer buf_; + }; + + _Ret __call(_Action a, StaticAny* other = nullptr, traits::type_id_t info = 0) const { + return h_(a, this, other, info); + } + + _Ret __call(_Action a, StaticAny* other = nullptr, traits::type_id_t info = 0) { + return h_(a, this, other, info); + } + + template + friend struct __static_any_impl::_SmallHandler; + + template + friend struct __static_any_impl::_LargeHandler; + + template + friend std::add_pointer_t> static_any_cast(const StaticAny*) noexcept; + + template + friend std::add_pointer_t static_any_cast(StaticAny*) noexcept; + + _HandleFuncPtr h_ = nullptr; + _Storage s_; +}; + +namespace __static_any_impl { + +template +struct _SmallHandler { + static _Ret __handle(_Action action, const StaticAny* self, StaticAny* other, + traits::type_id_t info) { + _Ret ret; + ret.ptr_ = nullptr; + switch (action) { + case _Action::_Destroy: + __destroy(const_cast(*self)); + break; + case _Action::_Copy: + __copy(*self, *other); + break; + case _Action::_Move: + __move(const_cast(*self), *other); + break; + case _Action::_Get: + ret.ptr_ = __get(const_cast(*self), info); + break; + case _Action::_TypeInfo: + ret.type_id_ = __type_info(); + break; + } + return ret; + } + + template + static T& __create(StaticAny& dest, Args&&... args) { + T* ret = ::new (static_cast(&dest.s_.buf_)) T(std::forward(args)...); + dest.h_ = &_SmallHandler::__handle; + return *ret; + } + + private: + template + static void __destroy(StaticAny& self) { + T& value = *static_cast(static_cast(&self.s_.buf_)); + value.~T(); + self.h_ = nullptr; + } + + template + static void __copy(const StaticAny& self, StaticAny& dest) { + _SmallHandler::__create(dest, *static_cast(static_cast(&self.s_.buf_))); + } + + static void __move(StaticAny& self, StaticAny& dest) { + _SmallHandler::__create(dest, std::move(*static_cast(static_cast(&self.s_.buf_)))); + __destroy(self); + } + + static void* __get(StaticAny& self, traits::type_id_t info) { + if (__static_any_impl::__compare_typeid(info)) { + return static_cast(&self.s_.buf_); + } + return nullptr; + } + + static traits::type_id_t __type_info() { return traits::TypeId::value; } +}; + +template +struct _LargeHandler { + static _Ret __handle(_Action action, const StaticAny* self, StaticAny* other, + traits::type_id_t info) { + _Ret ret; + ret.ptr_ = nullptr; + switch (action) { + case _Action::_Destroy: + __destroy(const_cast(*self)); + break; + case _Action::_Copy: + __copy(*self, *other); + break; + case _Action::_Move: + __move(const_cast(*self), *other); + break; + case _Action::_Get: + ret.ptr_ = __get(const_cast(*self), info); + break; + case _Action::_TypeInfo: + ret.type_id_ = __type_info(); + break; + } + return ret; + } + + template + static T& __create(StaticAny& dest, Args&&... args) { + using _Alloc = std::allocator; + _Alloc alloc; + auto dealloc = [&](T* p) { alloc.deallocate(p, 1); }; + std::unique_ptr hold(alloc.allocate(1), dealloc); + T* ret = ::new ((void*)hold.get()) T(std::forward(args)...); + dest.s_.ptr_ = hold.release(); + dest.h_ = &_LargeHandler::__handle; + return *ret; + } + + private: + static void __destroy(StaticAny& self) { + delete static_cast(self.s_.ptr_); + self.h_ = nullptr; + } + + static void __copy(const StaticAny& self, StaticAny& dest) { + _LargeHandler::__create(dest, *static_cast(self.s_.ptr_)); + } + + static void __move(StaticAny& self, StaticAny& dest) { + dest.s_.ptr_ = self.s_.ptr_; + dest.h_ = &_LargeHandler::__handle; + self.h_ = nullptr; + } + + static void* __get(StaticAny& self, traits::type_id_t info) { + if (__static_any_impl::__compare_typeid(info)) { + return static_cast(self.s_.ptr_); + } + return nullptr; + } + + static traits::type_id_t __type_info() { return traits::TypeId::value; } +}; + +} // namespace __static_any_impl + +template +StaticAny::StaticAny(ValueType&& v) : h_(nullptr) { + __static_any_impl::_Handler::__create(*this, std::forward(v)); +} + +template +StaticAny::StaticAny(std::in_place_type_t, Args&&... args) { + __static_any_impl::_Handler::__create(*this, std::forward(args)...); +} + +template +StaticAny::StaticAny(std::in_place_type_t, std::initializer_list il, Args&&... args) { + __static_any_impl::_Handler::__create(*this, il, std::forward(args)...); +} + +template +inline StaticAny& StaticAny::operator=(ValueType&& v) { + StaticAny(std::forward(v)).swap(*this); + return *this; +} + +template +inline T& StaticAny::emplace(Args&&... args) { + reset(); + return __static_any_impl::_Handler::__create(*this, std::forward(args)...); +} + +template +inline T& StaticAny::emplace(std::initializer_list il, Args&&... args) { + reset(); + return __static_any_impl::_Handler::_create(*this, il, std::forward(args)...); +} + +inline void StaticAny::swap(StaticAny& rhs) noexcept { + if (this == &rhs) { + return; + } + if (h_ && rhs.h_) { + StaticAny tmp; + rhs.__call(_Action::_Move, &tmp); + this->__call(_Action::_Move, &rhs); + tmp.__call(_Action::_Move, this); + } else if (h_) { + this->__call(_Action::_Move, &rhs); + } else if (rhs.h_) { + rhs.__call(_Action::_Move, this); + } +} + +inline void swap(StaticAny& lhs, StaticAny& rhs) noexcept { lhs.swap(rhs); } + +template +inline StaticAny make_static_any(Args&&... args) { + return StaticAny(std::in_place_type, std::forward(args)...); +} + +template +StaticAny make_static_any(std::initializer_list il, Args&&... args) { + return StaticAny(std::in_place_type, il, std::forward(args)...); +} + +template +ValueType static_any_cast(const StaticAny& v) { + using _RawValueType = std::remove_cv_t>; + static_assert(std::is_constructible::value, + "ValueType is required to be a const lvalue reference " + "or a CopyConstructible type"); + auto tmp = static_any_cast>(&v); + if (tmp == nullptr) { + ThrowBadAnyCast(); + } + return static_cast(*tmp); +} + +template +inline ValueType static_any_cast(StaticAny& v) { + using _RawValueType = std::remove_cv_t>; + static_assert(std::is_constructible::value, + "ValueType is required to be an lvalue reference " + "or a CopyConstructible type"); + auto tmp = static_any_cast<_RawValueType>(&v); + if (tmp == nullptr) { + ThrowBadAnyCast(); + } + return static_cast(*tmp); +} + +template +inline ValueType static_any_cast(StaticAny&& v) { + using _RawValueType = std::remove_cv_t>; + static_assert(std::is_constructible::value, + "ValueType is required to be an rvalue reference " + "or a CopyConstructible type"); + auto tmp = static_any_cast<_RawValueType>(&v); + if (tmp == nullptr) { + ThrowBadAnyCast(); + } + return static_cast(std::move(*tmp)); +} + +template +inline std::add_pointer_t> static_any_cast( + const StaticAny* __any) noexcept { + static_assert(!std::is_reference::value, "ValueType may not be a reference."); + return static_any_cast(const_cast(__any)); +} + +template +inline RetType __pointer_or_func_test(void* p, std::false_type) noexcept { + return static_cast(p); +} + +template +inline RetType __pointer_or_func_test(void*, std::true_type) noexcept { + return nullptr; +} + +template +std::add_pointer_t static_any_cast(StaticAny* any) noexcept { + using __static_any_impl::_Action; + static_assert(!std::is_reference::value, "ValueType may not be a reference."); + using ReturnType = std::add_pointer_t; + if (any && any->h_) { + void* p = any->__call(_Action::_Get, nullptr, traits::TypeId::value).ptr_; + return __pointer_or_func_test(p, std::is_function{}); + } + return nullptr; +} + +} // namespace mmdeploy + +#endif // MMDEPLOY_CSRC_CORE_MPL_STATIC_ANY_H_ diff --git a/csrc/core/value.h b/csrc/core/value.h index 3241330565..e716be5c81 100644 --- a/csrc/core/value.h +++ b/csrc/core/value.h @@ -3,7 +3,6 @@ #ifndef MMDEPLOY_TYPES_VALUE_H_ #define MMDEPLOY_TYPES_VALUE_H_ -#include #include #include #include @@ -16,6 +15,7 @@ #include "core/logger.h" #include "core/status_code.h" #include "mpl/priority_tag.h" +#include "mpl/static_any.h" #include "mpl/type_traits.h" namespace mmdeploy { @@ -164,6 +164,14 @@ struct is_cast_by_erasure : std::true_type {}; template <> struct is_cast_by_erasure : std::true_type {}; +MMDEPLOY_REGISTER_TYPE_ID(Device, 1); +MMDEPLOY_REGISTER_TYPE_ID(Buffer, 2); +MMDEPLOY_REGISTER_TYPE_ID(Stream, 3); +MMDEPLOY_REGISTER_TYPE_ID(Event, 4); +MMDEPLOY_REGISTER_TYPE_ID(Model, 5); +MMDEPLOY_REGISTER_TYPE_ID(Tensor, 6); +MMDEPLOY_REGISTER_TYPE_ID(Mat, 7); + template struct is_value : std::is_same {}; @@ -204,8 +212,8 @@ class Value { using Array = std::vector; using Object = std::map; using Pointer = std::shared_ptr; - using Dynamic = mmdeploy::Dynamic; - using Any = std::any; + using Dynamic = ::mmdeploy::Dynamic; + using Any = ::mmdeploy::StaticAny; using ValueRef = detail::ValueRef; static constexpr const auto kNull = ValueType::kNull; @@ -349,7 +357,7 @@ class Value { if constexpr (std::is_void_v) { return true; } else { - return typeid(T) == data_.any->type(); + return traits::TypeId::value == data_.any->type(); } } @@ -440,11 +448,11 @@ class Value { template T* get_erased_ptr(EraseType*) noexcept { - return _is_any() ? std::any_cast(data_.any) : nullptr; + return _is_any() ? static_any_cast(data_.any) : nullptr; } template const T* get_erased_ptr(const EraseType*) const noexcept { - return _is_any() ? std::any_cast(const_cast(data_.any)) : nullptr; + return _is_any() ? static_any_cast(const_cast(data_.any)) : nullptr; } template diff --git a/tests/test_csrc/core/test_value.cpp b/tests/test_csrc/core/test_value.cpp index 0ecc1c629b..f5cdf0075b 100644 --- a/tests/test_csrc/core/test_value.cpp +++ b/tests/test_csrc/core/test_value.cpp @@ -283,12 +283,22 @@ struct Doge { int value; }; +namespace mmdeploy { + +MMDEPLOY_REGISTER_TYPE_ID(Meow, 1234); +MMDEPLOY_REGISTER_TYPE_ID(Doge, 3456); + +} // namespace mmdeploy + template <> struct mmdeploy::is_cast_by_erasure : std::true_type {}; TEST_CASE("test dynamic interface for value", "[value]") { Value meow(Meow{100}); REQUIRE(meow.is_any()); + REQUIRE(meow.is_any()); + REQUIRE_FALSE(meow.is_any()); + REQUIRE_FALSE(meow.is_any()); REQUIRE(meow.get().value == 100); REQUIRE(meow.get_ref().value == 100); REQUIRE(meow.get_ptr() == &meow.get_ref());