Skip to content

Commit

Permalink
Add context to as_kind.
Browse files Browse the repository at this point in the history
Signed-off-by: Michał Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Oct 27, 2021
1 parent adc26ee commit 3ae680e
Showing 1 changed file with 32 additions and 14 deletions.
46 changes: 32 additions & 14 deletions include/cuda/memory_resource
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ struct any_context{};

namespace detail {

template <typename...>
struct __type_pack {};

namespace __fallback_typeid {

template <class _Tp>
Expand Down Expand Up @@ -222,21 +225,23 @@ public:
/*!
* \brief Tries to cast the resource to a resource of given kind
*/
template <typename _Kind>
template <typename _Kind, typename _Context = any_context>
memory_resource<_Kind> *as_kind() noexcept {
using __tag = detail::__type_pack<_Kind, _Context>;
return static_cast<memory_resource<_Kind> *>(
__do_as_kind(detail::__fallback_typeid::__get_typeid<_Kind*>(),
detail::__fallback_typeid::__get_fallback_typeid<_Kind*>()));
__do_as_kind(detail::__fallback_typeid::__get_typeid<__tag>(),
detail::__fallback_typeid::__get_fallback_typeid<__tag>()));
}

/*!
* \brief Tries to cast the resource to a resource of given kind
*/
template <typename _Kind>
const memory_resource<_Kind> *as_kind() const noexcept {
return static_cast<const memory_resource<_Kind> *>(
__do_as_kind(detail::__fallback_typeid::__get_typeid<_Kind*>(),
detail::__fallback_typeid::__get_fallback_typeid<_Kind*>()));
template <typename _Kind, typename _Context = any_context>
const memory_resource<_Kind, _Context> *as_kind() const noexcept {
using __tag = detail::__type_pack<_Kind, _Context>;
return static_cast<const memory_resource<_Kind, _Context> *>(
__do_as_kind(detail::__fallback_typeid::__get_typeid<__tag>(),
detail::__fallback_typeid::__get_fallback_typeid<__tag>()));
}

protected:
Expand All @@ -252,7 +257,7 @@ protected:
template <typename _ResourcePointer, typename... _Properties>
friend class cuda::basic_resource_view;

virtual void *__do_as_kind(const ::std::type_info *__kind_type_id, const void *__kind_type_fallback_id) const noexcept = 0;
virtual void *__do_as_kind(const ::std::type_info *__tag_type_id, const void *__tag_type_fallback_id) const noexcept = 0;
};

class stream_ordered_memory_resource_base : public virtual memory_resource_base {
Expand Down Expand Up @@ -469,21 +474,27 @@ private:
return this == &__other;
}

void *__do_as_kind(const ::std::type_info *__kind_type_id, const void *__kind_type_fallback_id) const noexcept final {
return detail::__fallback_typeid::__is_type<memory_kind*>(__kind_type_id, __kind_type_fallback_id)
void *__do_as_kind(const ::std::type_info *__tag_type_id, const void *__tag_type_fallback_id) const noexcept final {
using __tag = detail::__type_pack<memory_kind, context>;
return detail::__fallback_typeid::__is_type<__tag>(__tag_type_id, __tag_type_fallback_id)
? const_cast<memory_resource*>(this) : nullptr;
}

bool is_equal_base(const detail::memory_resource_base &__other) const noexcept final {
if (auto *__other_res = __other.as_kind<memory_kind>()) {
if (auto *__other_res = __other.as_kind<memory_kind, context>()) {
return do_is_equal(*__other_res);
} else {
return false;
}
}

};

template <typename _Kind, typename _Context>
inline _LIBCUDACXX_INLINE_VISIBILITY
bool operator==(const memory_resource<_Kind, _Context> &__a, const memory_resource<_Kind, _Context> &__b) {
return __a.is_equal(__b);
}

/*!
* \brief Abstract interface for CUDA stream-ordered memory allocation.
*
Expand Down Expand Up @@ -793,6 +804,8 @@ public:

basic_resource_view() = default;

basic_resource_view(int) = delete;

basic_resource_view(std::nullptr_t) {}

/*!
Expand Down Expand Up @@ -832,7 +845,6 @@ public:
*/
_ResourcePointer operator->() const noexcept { return __pointer; }


template <typename _Ptr2, typename... _Props2>
bool operator==(const cuda::basic_resource_view<_Ptr2, _Props2...> &__v2) const noexcept {
using __view1_t = basic_resource_view;
Expand All @@ -847,6 +859,12 @@ public:
return !(*this == __v2);
}

/*!
* \brief Returns true if the underlying pointer is not null.
*/
constexpr explicit operator bool() const noexcept {
return !!__pointer;
}

private:
template <typename, typename...>
Expand Down

0 comments on commit 3ae680e

Please sign in to comment.