Skip to content

Commit

Permalink
Workaround dynamic cast for private bases.
Browse files Browse the repository at this point in the history
Signed-off-by: Michał Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Jul 21, 2021
1 parent 3856f12 commit 9110638
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class resource : public cuda::memory_resource<cuda::memory_kind::pinned> {
void do_deallocate(void *, size_t, size_t) {
}

#ifndef _LIBCUDACXX_NO_RTTI
#ifdef _LIBCUDACXX_EXT_RTTI_ENABLED
bool do_is_equal(const cuda::memory_resource<cuda::memory_kind::pinned> &other) const noexcept override {
fprintf(stderr, "Comparison start: %p %p\n", this, &other);
if (auto *other_ptr = dynamic_cast<const resource *>(&other)) {
Expand All @@ -48,7 +48,7 @@ struct tag2;


int main(int argc, char **argv) {
#if !defined(__CUDA_ARCH__) && !defined(_LIBCUDACXX_NO_RTTI)
#if !defined(__CUDA_ARCH__) && defined(_LIBCUDACXX_EXT_RTTI_ENABLED)
resource<tag1> r1, r2, r3;
resource<tag2> r4;
r1.value = 42;
Expand All @@ -68,7 +68,7 @@ int main(int argc, char **argv) {
assert(v1 == v2);
assert(v1 != v3);
assert(v1 != v4);
// assert(v2 != v3); - cannot compare
// assert(v2 != v3); - cannot compare - incompatible views
assert(v2 != v4);
assert(v3 != v4);
assert(v4 == v4);
Expand Down
89 changes: 79 additions & 10 deletions include/cuda/memory_resource
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
#include "std/version"
#include "stream_view"

#ifdef _LIBCUDACXX_EXT_RTTI_ENABLED
#include <typeinfo>
#endif

#if _LIBCUDACXX_STD_VER > 14
#include <memory_resource>
#endif // _LIBCUDACXX_STD_VER > 14
Expand Down Expand Up @@ -64,6 +68,44 @@ namespace memory_kind {
struct any_context{};

namespace detail {

namespace __fallback_typeid {

template <class _Tp>
struct _LIBCUDACXX_TEMPLATE_VIS __unique_typeinfo { static constexpr int __id = 0; };
template <class _Tp> constexpr int __unique_typeinfo<_Tp>::__id;

template <class _Tp>
inline _LIBCUDACXX_INLINE_VISIBILITY
constexpr const void* __get_fallback_typeid() {
return &__unique_typeinfo<std::decay_t<_Tp>>::__id;
}

template <typename _Tp>
const ::std::type_info *__get_typeid() {
#ifdef _LIBCUDACXX_EXT_RTTI_ENABLED
return &typeid(_Tp);
#else
return nullptr;
#endif
}

bool __compare_type(const ::std::type_info *__ti1, const void *__fallback_ti1,
const ::std::type_info *__ti2, const void *__fallback_ti2) {
#ifdef _LIBCUDACXX_EXT_RTTI_ENABLED
if (__ti1 && __ti2 && *__ti1 == *__ti2)
return true;
#endif
return __fallback_ti1 == __fallback_ti2;
}

template <typename _Tp>
bool __is_type(const ::std::type_info *__ti1, const void *__fallback_ti1) {
return __compare_type(__ti1, __fallback_ti1, __get_typeid<_Tp>(), __get_fallback_typeid<_Tp>());
}

} // namespace __fallback_typeid

template <typename _Context>
class __get_context_impl {
protected:
Expand Down Expand Up @@ -121,6 +163,9 @@ namespace memory_location {
struct host;
}

template <typename _MemoryKind, typename _Context = any_context>
class memory_resource;

namespace detail {

class memory_resource_base {
Expand Down Expand Up @@ -168,6 +213,26 @@ public:
do_deallocate(__mem, __bytes, __alignment);
}

/*!
* \brief Tries to cast the resource to a resource of given kind
*/
template <typename _Kind>
memory_resource<_Kind> *as_kind() noexcept {
return static_cast<memory_resource<_Kind> *>(
__do_as_kind(detail::__fallback_typeid::__get_typeid<_Kind*>(),
detail::__fallback_typeid::__get_fallback_typeid<_Kind*>()));
}

/*!
* \brief Tries to cast the resource to a resource of given kind
*/
template <typename _Kind>
const memory_resource<_Kind> *as_kind() const noexcept {
return static_cast<const memory_resource<_Kind> *>(
__do_as_kind(detail::__fallback_typeid::__get_typeid<_Kind*>(),
detail::__fallback_typeid::__get_fallback_typeid<_Kind*>()));
}

protected:
virtual void *do_allocate(size_t __bytes, size_t __alignment) = 0;
virtual void do_deallocate(void *__mem, size_t __bytes, size_t __alignment) = 0;
Expand All @@ -180,6 +245,8 @@ protected:

template <typename _ResourcePointer, typename... _Properties>
friend class cuda::basic_resource_view;

virtual void *__do_as_kind(const ::std::type_info *__kind_type_id, const void *__kind_type_fallback_id) const noexcept = 0;
};

class stream_ordered_memory_resource_base : public virtual memory_resource_base {
Expand Down Expand Up @@ -309,12 +376,12 @@ protected:
* \tparam _Context The execution context on which the storage may be used
* without synchronization
*/
template <typename _MemoryKind, typename _Context = any_context>
template <typename _MemoryKind, typename _Context>
class memory_resource : private virtual detail::memory_resource_base, private detail::__get_context_impl<_Context> {
public:
using memory_kind = _MemoryKind;
using context = _Context;
static constexpr _CUDA_VSTD::size_t default_alignment = memory_resource_base::default_alignment;
static constexpr std::size_t default_alignment = memory_resource_base::default_alignment;

virtual ~memory_resource() = default;

Expand Down Expand Up @@ -396,15 +463,17 @@ private:
return this == &__other;
}

virtual void *__do_as_kind(const ::std::type_info *__kind_type_id, const void *__kind_type_fallback_id) const noexcept {
return detail::__fallback_typeid::__is_type<memory_kind*>(__kind_type_id, __kind_type_fallback_id)
? const_cast<memory_resource*>(this) : nullptr;
}

bool is_equal_base(const detail::memory_resource_base &__other) const noexcept final override {
#ifdef _LIBCUDACXX_NO_RTTI
return this == &__other;
#else
if (auto *__other_res = dynamic_cast<const memory_resource *>(&__other))
return do_is_equal(*__other_res);
else
return false;
#endif
if (auto *__other_res = __other.as_kind<memory_kind>()) {
return do_is_equal(*__other_res);
} else {
return false;
}
}

};
Expand Down

0 comments on commit 9110638

Please sign in to comment.