This repository has been archived by the owner on Mar 21, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 187
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
mdspan
reference implementation (#299)
* Implement `std::mdspan` Co-authored-by: Michael Schellenberger Costa <[email protected]>
- Loading branch information
Showing
224 changed files
with
17,025 additions
and
8 deletions.
There are no files selected for viewing
228 changes: 228 additions & 0 deletions
228
.upstream-tests/test/std/containers/views/mdspan/foo_customizations.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
#ifndef _FOO_CUSTOMIZATIONS_HPP | ||
#define _FOO_CUSTOMIZATIONS_HPP | ||
|
||
// Taken from the reference implementation repo | ||
|
||
namespace Foo { | ||
template<class T> | ||
struct foo_ptr { | ||
T* data; | ||
__MDSPAN_HOST_DEVICE | ||
constexpr foo_ptr(T* ptr):data(ptr) {} | ||
}; | ||
|
||
template<class T> | ||
struct foo_accessor { | ||
using offset_policy = foo_accessor; | ||
using element_type = T; | ||
using reference = T&; | ||
using data_handle_type = foo_ptr<T>; | ||
|
||
__MDSPAN_INLINE_FUNCTION | ||
constexpr foo_accessor(int* ptr = nullptr) noexcept { flag = ptr; } | ||
|
||
template<class OtherElementType> | ||
__MDSPAN_INLINE_FUNCTION | ||
constexpr foo_accessor(cuda::std::default_accessor<OtherElementType>) noexcept { flag = nullptr; } | ||
|
||
template<class OtherElementType> | ||
__MDSPAN_INLINE_FUNCTION | ||
constexpr foo_accessor(foo_accessor<OtherElementType> other) noexcept { flag = other.flag; } | ||
|
||
|
||
constexpr reference access(data_handle_type p, size_t i) const noexcept { | ||
return p.data[i]; | ||
} | ||
|
||
constexpr data_handle_type offset(data_handle_type p, size_t i) const noexcept { | ||
return data_handle_type(p.data+i); | ||
} | ||
int* flag; | ||
|
||
friend constexpr void swap(foo_accessor& x, foo_accessor& y) { | ||
x.flag[0] = 99; | ||
y.flag[0] = 77; | ||
cuda::std::swap(x.flag, y.flag); | ||
} | ||
}; | ||
|
||
struct layout_foo { | ||
template<class Extents> | ||
class mapping; | ||
}; | ||
|
||
template <class Extents> | ||
class layout_foo::mapping { | ||
public: | ||
using extents_type = Extents; | ||
using index_type = typename extents_type::index_type; | ||
using size_type = typename extents_type::size_type; | ||
using rank_type = typename extents_type::rank_type; | ||
using layout_type = layout_foo; | ||
private: | ||
|
||
static_assert(cuda::std::__detail::__is_extents_v<extents_type>, | ||
"layout_foo::mapping must be instantiated with a specialization of cuda::std::extents."); | ||
static_assert(extents_type::rank() < 3, "layout_foo only supports 0D, 1D and 2D"); | ||
|
||
template <class> | ||
friend class mapping; | ||
|
||
public: | ||
|
||
//-------------------------------------------------------------------------------- | ||
|
||
__MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mapping() noexcept = default; | ||
__MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mapping(mapping const&) noexcept = default; | ||
|
||
__MDSPAN_HOST_DEVICE | ||
constexpr mapping(extents_type const& __exts) noexcept | ||
:__extents(__exts) | ||
{ } | ||
|
||
__MDSPAN_TEMPLATE_REQUIRES( | ||
class OtherExtents, | ||
/* requires */ ( | ||
__MDSPAN_TRAIT(cuda::std::is_constructible, extents_type, OtherExtents) | ||
) | ||
) | ||
__MDSPAN_CONDITIONAL_EXPLICIT((!cuda::std::is_convertible<OtherExtents, extents_type>::value)) // needs two () due to comma | ||
__MDSPAN_INLINE_FUNCTION constexpr | ||
mapping(mapping<OtherExtents> const& other) noexcept // NOLINT(google-explicit-constructor) | ||
:__extents(other.extents()) | ||
{ | ||
/* | ||
* TODO: check precondition | ||
* other.required_span_size() is a representable value of type index_type | ||
*/ | ||
} | ||
|
||
__MDSPAN_TEMPLATE_REQUIRES( | ||
class OtherExtents, | ||
/* requires */ ( | ||
__MDSPAN_TRAIT(cuda::std::is_constructible, extents_type, OtherExtents) | ||
) | ||
) | ||
__MDSPAN_CONDITIONAL_EXPLICIT((!cuda::std::is_convertible<OtherExtents, extents_type>::value)) // needs two () due to comma | ||
__MDSPAN_INLINE_FUNCTION constexpr | ||
mapping(cuda::std::layout_right::mapping<OtherExtents> const& other) noexcept // NOLINT(google-explicit-constructor) | ||
:__extents(other.extents()) | ||
{} | ||
|
||
__MDSPAN_TEMPLATE_REQUIRES( | ||
class OtherExtents, | ||
/* requires */ ( | ||
__MDSPAN_TRAIT(cuda::std::is_constructible, extents_type, OtherExtents) && | ||
(extents_type::rank() <= 1) | ||
) | ||
) | ||
__MDSPAN_CONDITIONAL_EXPLICIT((!cuda::std::is_convertible<OtherExtents, extents_type>::value)) // needs two () due to comma | ||
__MDSPAN_INLINE_FUNCTION constexpr | ||
mapping(cuda::std::layout_left::mapping<OtherExtents> const& other) noexcept // NOLINT(google-explicit-constructor) | ||
:__extents(other.extents()) | ||
{} | ||
|
||
__MDSPAN_TEMPLATE_REQUIRES( | ||
class OtherExtents, | ||
/* requires */ ( | ||
__MDSPAN_TRAIT(cuda::std::is_constructible, extents_type, OtherExtents) | ||
) | ||
) | ||
__MDSPAN_CONDITIONAL_EXPLICIT((extents_type::rank() > 0)) | ||
__MDSPAN_INLINE_FUNCTION constexpr | ||
mapping(cuda::std::layout_stride::mapping<OtherExtents> const& other) // NOLINT(google-explicit-constructor) | ||
:__extents(other.extents()) | ||
{ | ||
/* | ||
* TODO: check precondition | ||
* other.required_span_size() is a representable value of type index_type | ||
*/ | ||
#ifndef __CUDA_ARCH__ | ||
size_t stride = 1; | ||
for(rank_type r=__extents.rank(); r>0; r--) { | ||
assert(stride == other.stride(r-1)); | ||
//if(stride != other.stride(r-1)) | ||
// throw std::runtime_error("Assigning layout_stride to layout_foo with invalid strides."); | ||
stride *= __extents.extent(r-1); | ||
} | ||
#endif | ||
} | ||
|
||
__MDSPAN_INLINE_FUNCTION_DEFAULTED __MDSPAN_CONSTEXPR_14_DEFAULTED mapping& operator=(mapping const&) noexcept = default; | ||
|
||
__MDSPAN_INLINE_FUNCTION | ||
constexpr const extents_type& extents() const noexcept { | ||
return __extents; | ||
} | ||
|
||
__MDSPAN_INLINE_FUNCTION | ||
constexpr index_type required_span_size() const noexcept { | ||
index_type value = 1; | ||
for(rank_type r=0; r != extents_type::rank(); ++r) value*=__extents.extent(r); | ||
return value; | ||
} | ||
|
||
//-------------------------------------------------------------------------------- | ||
|
||
__MDSPAN_INLINE_FUNCTION | ||
constexpr index_type operator() () const noexcept { return index_type(0); } | ||
|
||
template<class Indx0> | ||
__MDSPAN_INLINE_FUNCTION | ||
constexpr index_type operator()(Indx0 idx0) const noexcept { | ||
return static_cast<index_type>(idx0); | ||
} | ||
|
||
template<class Indx0, class Indx1> | ||
__MDSPAN_INLINE_FUNCTION | ||
constexpr index_type operator()(Indx0 idx0, Indx1 idx1) const noexcept { | ||
return static_cast<index_type>(idx0 * __extents.extent(0) + idx1); | ||
} | ||
|
||
__MDSPAN_INLINE_FUNCTION static constexpr bool is_always_unique() noexcept { return true; } | ||
__MDSPAN_INLINE_FUNCTION static constexpr bool is_always_exhaustive() noexcept { return true; } | ||
__MDSPAN_INLINE_FUNCTION static constexpr bool is_always_strided() noexcept { return true; } | ||
__MDSPAN_INLINE_FUNCTION constexpr bool is_unique() const noexcept { return true; } | ||
__MDSPAN_INLINE_FUNCTION constexpr bool is_exhaustive() const noexcept { return true; } | ||
__MDSPAN_INLINE_FUNCTION constexpr bool is_strided() const noexcept { return true; } | ||
|
||
__MDSPAN_INLINE_FUNCTION | ||
constexpr index_type stride(rank_type i) const noexcept { | ||
index_type value = 1; | ||
for(rank_type r=extents_type::rank()-1; r>i; r--) value*=__extents.extent(r); | ||
return value; | ||
} | ||
|
||
template<class OtherExtents> | ||
__MDSPAN_INLINE_FUNCTION | ||
friend constexpr bool operator==(mapping const& lhs, mapping<OtherExtents> const& rhs) noexcept { | ||
return lhs.extents() == rhs.extents(); | ||
} | ||
|
||
// In C++ 20 the not equal exists if equal is found | ||
#if !(__MDSPAN_HAS_CXX_20) | ||
template<class OtherExtents> | ||
__MDSPAN_INLINE_FUNCTION | ||
friend constexpr bool operator!=(mapping const& lhs, mapping<OtherExtents> const& rhs) noexcept { | ||
return lhs.extents() != rhs.extents(); | ||
} | ||
#endif | ||
|
||
// Not really public, but currently needed to implement fully constexpr useable submdspan: | ||
template<size_t N, class SizeType, size_t ... E, size_t ... Idx> | ||
constexpr index_type __get_stride(cuda::std::extents<SizeType, E...>, cuda::std::integer_sequence<size_t, Idx...>) const { | ||
return __MDSPAN_FOLD_TIMES_RIGHT((Idx>N? __extents.template __extent<Idx>():1),1); | ||
} | ||
template<size_t N> | ||
constexpr index_type __stride() const noexcept { | ||
return __get_stride<N>(__extents, std::make_index_sequence<extents_type::rank()>()); | ||
} | ||
|
||
private: | ||
__MDSPAN_NO_UNIQUE_ADDRESS extents_type __extents{}; | ||
|
||
}; | ||
|
||
} | ||
#endif | ||
|
27 changes: 27 additions & 0 deletions
27
...am-tests/test/std/containers/views/mdspan/mdspan.accessor.default.members/access.pass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
//UNSUPPORTED: c++11, nvrtc && nvcc-12.0, nvrtc && nvcc-12.1 | ||
|
||
#include <cuda/std/mdspan> | ||
#include <cuda/std/cassert> | ||
|
||
int main(int, char**) | ||
{ | ||
{ | ||
using element_t = int; | ||
cuda::std::array<element_t, 2> d{42,43}; | ||
cuda::std::default_accessor<element_t> a; | ||
|
||
assert( a.access( d.data(), 0 ) == 42 ); | ||
assert( a.access( d.data(), 1 ) == 43 ); | ||
} | ||
|
||
return 0; | ||
} |
30 changes: 30 additions & 0 deletions
30
...ream-tests/test/std/containers/views/mdspan/mdspan.accessor.default.members/copy.pass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
//UNSUPPORTED: c++11, nvrtc && nvcc-12.0, nvrtc && nvcc-12.1 | ||
|
||
#include <cuda/std/mdspan> | ||
#include <cuda/std/cassert> | ||
|
||
int main(int, char**) | ||
{ | ||
{ | ||
using element_t = int; | ||
cuda::std::array<element_t, 2> d{42,43}; | ||
cuda::std::default_accessor<element_t> a0; | ||
cuda::std::default_accessor<element_t> a(a0); | ||
|
||
assert( a.access( d.data(), 0 ) == 42 ); | ||
assert( a.access( d.data(), 1 ) == 43 ); | ||
assert( a.offset( d.data(), 0 ) == d.data() ); | ||
assert( a.offset( d.data(), 1 ) == d.data() + 1 ); | ||
} | ||
|
||
return 0; | ||
} |
27 changes: 27 additions & 0 deletions
27
...am-tests/test/std/containers/views/mdspan/mdspan.accessor.default.members/offset.pass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
//UNSUPPORTED: c++11, nvrtc && nvcc-12.0, nvrtc && nvcc-12.1 | ||
|
||
#include <cuda/std/mdspan> | ||
#include <cuda/std/cassert> | ||
|
||
int main(int, char**) | ||
{ | ||
{ | ||
using element_t = int; | ||
cuda::std::array<element_t, 2> d{42,43}; | ||
cuda::std::default_accessor<element_t> a; | ||
|
||
assert( a.offset( d.data(), 0 ) == d.data() ); | ||
assert( a.offset( d.data(), 1 ) == d.data() + 1 ); | ||
} | ||
|
||
return 0; | ||
} |
58 changes: 58 additions & 0 deletions
58
.upstream-tests/test/std/containers/views/mdspan/mdspan.extents.cmp/compare.pass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
//UNSUPPORTED: c++11, nvrtc && nvcc-12.0, nvrtc && nvcc-12.1 | ||
|
||
#include <cuda/std/mdspan> | ||
#include <cuda/std/cassert> | ||
|
||
constexpr auto dyn = cuda::std::dynamic_extent; | ||
|
||
int main(int, char**) | ||
{ | ||
{ | ||
using index_t = size_t; | ||
|
||
cuda::std::extents< index_t, 10 > e0; | ||
cuda::std::extents< index_t, 10 > e1; | ||
|
||
assert( e0 == e1 ); | ||
} | ||
|
||
{ | ||
using index_t = size_t; | ||
|
||
cuda::std::extents< index_t, 10 > e0; | ||
cuda::std::extents< index_t, dyn > e1{ 10 }; | ||
|
||
assert( e0 == e1 ); | ||
} | ||
|
||
{ | ||
using index_t = size_t; | ||
|
||
cuda::std::extents< index_t, 10 > e0; | ||
cuda::std::extents< index_t, 10, 10 > e1; | ||
|
||
assert( e0 != e1 ); | ||
} | ||
|
||
{ | ||
using index0_t = size_t; | ||
using index1_t = uint8_t; | ||
|
||
cuda::std::extents< index0_t, 10 > e0; | ||
cuda::std::extents< index1_t, 10 > e1; | ||
|
||
assert( e0 == e1 ); | ||
} | ||
|
||
|
||
return 0; | ||
} |
Oops, something went wrong.