Skip to content

Commit

Permalink
Update mdspan corresponding to mdspan PR $360
Browse files Browse the repository at this point in the history
  • Loading branch information
crtrott committed Oct 4, 2024
1 parent d11e42f commit 04882f8
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 55 deletions.
3 changes: 3 additions & 0 deletions tpls/mdspan/include/experimental/__p0009_bits/layout_left.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,12 @@ class layout_left::mapping {

// Not really public, but currently needed to implement fully constexpr useable submdspan:
template<size_t N, class SizeType, size_t ... E, size_t ... Idx>
MDSPAN_INLINE_FUNCTION
constexpr index_type __get_stride(MDSPAN_IMPL_STANDARD_NAMESPACE::extents<SizeType, E...>,std::integer_sequence<size_t, Idx...>) const {
return _MDSPAN_FOLD_TIMES_RIGHT((Idx<N? __extents.template __extent<Idx>():1),1);
}
template<size_t N>
MDSPAN_INLINE_FUNCTION
constexpr index_type __stride() const noexcept {
return __get_stride<N>(__extents, std::make_index_sequence<extents_type::rank()>());
}
Expand All @@ -255,6 +257,7 @@ class layout_left::mapping {
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,12 @@ class layout_right::mapping {

// Not really public, but currently needed to implement fully constexpr useable submdspan:
template<size_t N, class SizeType, size_t ... E, size_t ... Idx>
MDSPAN_INLINE_FUNCTION
constexpr index_type __get_stride(MDSPAN_IMPL_STANDARD_NAMESPACE::extents<SizeType, E...>,std::integer_sequence<size_t, Idx...>) const {
return _MDSPAN_FOLD_TIMES_RIGHT((Idx>N? __extents.template __extent<Idx>():1),1);
}
template<size_t N>
MDSPAN_INLINE_FUNCTION
constexpr index_type __stride() const noexcept {
return __get_stride<N>(__extents, std::make_index_sequence<extents_type::rank()>());
}
Expand All @@ -252,6 +254,7 @@ class layout_right::mapping {
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
Expand Down
36 changes: 18 additions & 18 deletions tpls/mdspan/include/experimental/__p0009_bits/layout_stride.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,28 +197,22 @@ struct layout_stride {
}

template<class IntegralType>
MDSPAN_INLINE_FUNCTION
static constexpr const __strides_storage_t fill_strides(const std::array<IntegralType,extents_type::rank()>& s) {
return __strides_storage_t{static_cast<index_type>(s[Idxs])...};
}

MDSPAN_TEMPLATE_REQUIRES(
class IntegralType,
// The is_convertible condition is added to make sfinae valid
// the extents_type::rank() > 0 is added to avoid use of non-standard zero length c-array
(std::is_convertible<IntegralType, typename extents_type::index_type>::value && (extents_type::rank() > 0))
(std::is_convertible<IntegralType, typename extents_type::index_type>::value)
)
MDSPAN_INLINE_FUNCTION
// despite the requirement some compilers still complain about zero length array during parsing
// making it length 1 now, but since the thing can't be instantiated due to requirement the actual
// instantiation of strides_storage will not fail despite mismatching length
// Need to avoid zero length c-array
static constexpr const __strides_storage_t fill_strides(mdspan_non_standard_tag, const IntegralType (&s)[extents_type::rank()>0?extents_type::rank():1]) {
return __strides_storage_t{static_cast<index_type>(s[Idxs])...};
}

#ifdef __cpp_lib_span
template<class IntegralType>
MDSPAN_INLINE_FUNCTION
static constexpr const __strides_storage_t fill_strides(const std::span<IntegralType,extents_type::rank()>& s) {
return __strides_storage_t{static_cast<index_type>(s[Idxs])...};
}
Expand All @@ -242,10 +236,13 @@ struct layout_stride {
// Can't use defaulted parameter in the __deduction_workaround template because of a bug in MSVC warning C4348.
using __impl = __deduction_workaround<std::make_index_sequence<Extents::rank()>>;

MDSPAN_FUNCTION
static constexpr __strides_storage_t strides_storage(detail::with_rank<0>) {
return {};
}

template <std::size_t N>
MDSPAN_FUNCTION
static constexpr __strides_storage_t strides_storage(detail::with_rank<N>) {
__strides_storage_t s{};

Expand Down Expand Up @@ -273,7 +270,7 @@ struct layout_stride {

//--------------------------------------------------------------------------------

MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mapping() noexcept
MDSPAN_INLINE_FUNCTION constexpr mapping() noexcept
#if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
: __members{
#else
Expand All @@ -299,7 +296,6 @@ struct layout_stride {
_MDSPAN_TRAIT(std::is_nothrow_constructible, typename Extents::index_type, const std::remove_const_t<IntegralTypes>&)
)
)
MDSPAN_INLINE_FUNCTION
constexpr
mapping(
extents_type const& e,
Expand Down Expand Up @@ -333,19 +329,16 @@ struct layout_stride {
// MSVC 19.32 does not like using index_type here, requires the typename Extents::index_type
// error C2641: cannot deduce template arguments for 'MDSPAN_IMPL_STANDARD_NAMESPACE::layout_stride::mapping'
_MDSPAN_TRAIT(std::is_convertible, const std::remove_const_t<IntegralTypes>&, typename Extents::index_type) &&
_MDSPAN_TRAIT(std::is_nothrow_constructible, typename Extents::index_type, const std::remove_const_t<IntegralTypes>&) &&
(Extents::rank() > 0)
_MDSPAN_TRAIT(std::is_nothrow_constructible, typename Extents::index_type, const std::remove_const_t<IntegralTypes>&)
)
)
MDSPAN_INLINE_FUNCTION
constexpr
mapping(
mdspan_non_standard_tag,
extents_type const& e,
// despite the requirement some compilers still complain about zero length array during parsing
// making it length 1 now, but since the thing can't be instantiated due to requirement the actual
// instantiation of strides_storage will not fail despite mismatching length
IntegralTypes (&s)[extents_type::rank()>0?extents_type::rank():1]
// Need to avoid zero-length c-array
const IntegralTypes (&s)[extents_type::rank()>0?extents_type::rank():1]
) noexcept
#if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
: __members{
Expand Down Expand Up @@ -379,7 +372,6 @@ struct layout_stride {
_MDSPAN_TRAIT(std::is_nothrow_constructible, typename Extents::index_type, const std::remove_const_t<IntegralTypes>&)
)
)
MDSPAN_INLINE_FUNCTION
constexpr
mapping(
extents_type const& e,
Expand Down Expand Up @@ -476,7 +468,8 @@ struct layout_stride {
MDSPAN_INLINE_FUNCTION
constexpr index_type required_span_size() const noexcept {
index_type span_size = 1;
for(unsigned r = 0; r < extents_type::rank(); r++) {
// using int here to avoid warning about pointless comparison to 0
for(int r = 0; r < static_cast<int>(extents_type::rank()); r++) {
// Return early if any of the extents are zero
if(extents().extent(r)==0) return 0;
span_size += ( static_cast<index_type>(extents().extent(r) - 1 ) * __strides_storage()[r]);
Expand Down Expand Up @@ -509,15 +502,18 @@ struct layout_stride {
MDSPAN_INLINE_FUNCTION static constexpr bool is_unique() noexcept { return true; }

private:
MDSPAN_INLINE_FUNCTION
constexpr bool exhaustive_for_nonzero_span_size() const
{
return required_span_size() == __get_size(extents(), std::make_index_sequence<extents_type::rank()>());
}

MDSPAN_INLINE_FUNCTION
constexpr bool is_exhaustive_impl(detail::with_rank<0>) const
{
return true;
}
MDSPAN_INLINE_FUNCTION
constexpr bool is_exhaustive_impl(detail::with_rank<1>) const
{
if (required_span_size() != static_cast<index_type>(0)) {
Expand All @@ -526,6 +522,7 @@ struct layout_stride {
return stride(0) == 1;
}
template <std::size_t N>
MDSPAN_INLINE_FUNCTION
constexpr bool is_exhaustive_impl(detail::with_rank<N>) const
{
if (required_span_size() != static_cast<index_type>(0)) {
Expand Down Expand Up @@ -627,6 +624,7 @@ struct layout_stride {
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
Expand All @@ -637,10 +635,12 @@ struct layout_stride {
namespace detail {

template <class Layout, class Extents, class Mapping>
MDSPAN_INLINE_FUNCTION
constexpr void validate_strides(with_rank<0>, Layout, const Extents&, const Mapping&)
{}

template <std::size_t N, class Layout, class Extents, class Mapping>
MDSPAN_INLINE_FUNCTION
constexpr void validate_strides(with_rank<N>, Layout, const Extents& ext, const Mapping& other)
{
static_assert(std::is_same<typename Mapping::layout_type, layout_stride>::value &&
Expand Down
100 changes: 100 additions & 0 deletions tpls/mdspan/include/experimental/__p0009_bits/utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <cstddef>
#include <type_traits>
#include <array>
#include <utility>

namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
namespace detail {
Expand Down Expand Up @@ -64,6 +66,104 @@ constexpr struct
}
} stride;

// same as std::integral_constant but with __host__ __device__ annotations on
// the implicit conversion function and the call operator
template <class T, T v>
struct integral_constant {
using value_type = T;
using type = integral_constant<T, v>;

static constexpr T value = v;

MDSPAN_INLINE_FUNCTION_DEFAULTED
constexpr integral_constant() = default;

// These interop functions work, because other than the value_type operator
// everything of std::integral_constant works on device (defaulted functions)
MDSPAN_FUNCTION
constexpr integral_constant(std::integral_constant<T,v>) {};

MDSPAN_FUNCTION constexpr operator std::integral_constant<T,v>() const noexcept {
return std::integral_constant<T,v>{};
}

MDSPAN_FUNCTION constexpr operator value_type() const noexcept {
return value;
}

MDSPAN_FUNCTION constexpr value_type operator()() const noexcept {
return value;
}
};

// The tuple implementation only comes in play when using capabilities
// such as submdspan which require C++17 anyway
#if MDSPAN_HAS_CXX_17
template<class T, size_t Idx>
struct tuple_member {
using type = T;
static constexpr size_t idx = Idx;
T val;
MDSPAN_FUNCTION constexpr T& get() { return val; }
MDSPAN_FUNCTION constexpr const T& get() const { return val; }
};

// A helper class which will be used via a fold expression to
// select the type with the correct Idx in a pack of tuple_member
template<size_t SearchIdx, size_t Idx, class T>
struct tuple_idx_matcher {
using type = tuple_member<T, Idx>;
template<class Other>
MDSPAN_FUNCTION
constexpr auto operator | (Other v) const {
if constexpr (Idx == SearchIdx) { return *this; }
else { return v; }
}
};

template<class IdxSeq, class ... Elements>
struct tuple_impl;

template<size_t ... Idx, class ... Elements>
struct tuple_impl<std::index_sequence<Idx...>, Elements...>: public tuple_member<Elements, Idx> ... {

MDSPAN_FUNCTION
constexpr tuple_impl(Elements ... vals):tuple_member<Elements, Idx>{vals}... {}

template<size_t N>
MDSPAN_FUNCTION
constexpr auto& get() {
using base_t = decltype((tuple_idx_matcher<N, Idx, Elements>() | ...) );
return base_t::type::get();
}
template<size_t N>
MDSPAN_FUNCTION
constexpr const auto& get() const {
using base_t = decltype((tuple_idx_matcher<N, Idx, Elements>() | ...) );
return base_t::type::get();
}
};

// A simple tuple-like class for representing slices internally and is compatible with device code
// This doesn't support type access since we don't need it
// This is not meant as an external API
template<class ... Elements>
struct tuple: public tuple_impl<decltype(std::make_index_sequence<sizeof...(Elements)>()), Elements...> {
MDSPAN_FUNCTION
constexpr tuple(Elements ... vals):tuple_impl<decltype(std::make_index_sequence<sizeof...(Elements)>()), Elements ...>(vals ...) {}
};

template<size_t Idx, class ... Args>
MDSPAN_FUNCTION
constexpr auto& get(tuple<Args...>& vals) { return vals.template get<Idx>(); }

template<size_t Idx, class ... Args>
MDSPAN_FUNCTION
constexpr const auto& get(const tuple<Args...>& vals) { return vals.template get<Idx>(); }

template<class ... Elements>
tuple(Elements ...) -> tuple<Elements...>;
#endif
} // namespace detail

constexpr struct mdspan_non_standard_tag {
Expand Down
Loading

0 comments on commit 04882f8

Please sign in to comment.