Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce ldg_ptr to Enable __ldg in Data Stores and simple_ptr_holder #1802

Merged
merged 17 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions include/gridtools/common/ldg_ptr.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* GridTools
*
* Copyright (c) 2014-2023, ETH Zurich
* All rights reserved.
*
* Please, refer to the LICENSE file in the root directory.
* SPDX-License-Identifier: BSD-3-Clause
*/
#pragma once

#include <cstddef>
#include <type_traits>
#include <utility>

#include "defs.hpp"
#include "host_device.hpp"

#ifdef GT_CUDACC
#include "cuda_type_traits.hpp"
#endif

namespace gridtools {

#ifdef GT_CUDACC
namespace impl_ {

template <class T>
struct ldg_ptr {
T const *m_ptr;

static_assert(is_texture_type<T>::value);

GT_FUNCTION constexpr T operator*() const {
#ifdef GT_CUDA_ARCH
return __ldg(m_ptr);
#else
return *m_ptr;
#endif
}

GT_FUNCTION constexpr ldg_ptr &operator+=(std::ptrdiff_t diff) {
m_ptr += diff;
return *this;
}

GT_FUNCTION constexpr ldg_ptr &operator-=(std::ptrdiff_t diff) {
m_ptr -= diff;
return *this;
}

friend GT_FUNCTION constexpr bool operator==(ldg_ptr const &a, ldg_ptr const &b) {
return a.m_ptr == b.m_ptr;
}
friend GT_FUNCTION constexpr bool operator==(ldg_ptr const &a, T const *b) { return a.m_ptr == b; }
friend GT_FUNCTION constexpr bool operator==(T const *a, ldg_ptr const &b) { return a == b.m_ptr; }

friend GT_FUNCTION constexpr bool operator!=(ldg_ptr const &a, ldg_ptr const &b) {
return a.m_ptr != b.m_ptr;
}
friend GT_FUNCTION constexpr bool operator!=(ldg_ptr const &a, T const *b) { return a.m_ptr != b; }
friend GT_FUNCTION constexpr bool operator!=(T const *a, ldg_ptr const &b) { return a != b.m_ptr; }

friend GT_FUNCTION constexpr ldg_ptr &operator++(ldg_ptr &ptr) {
++ptr.m_ptr;
return ptr;
}

friend GT_FUNCTION constexpr ldg_ptr &operator--(ldg_ptr &ptr) {
--ptr.m_ptr;
return ptr;
}

friend GT_FUNCTION constexpr ldg_ptr operator++(ldg_ptr &ptr, int) {
ldg_ptr p = ptr;
++ptr.m_ptr;
return p;
}

friend GT_FUNCTION constexpr ldg_ptr operator--(ldg_ptr &ptr, int) {
ldg_ptr p = ptr;
--ptr.m_ptr;
return p;
}

friend GT_FUNCTION constexpr ldg_ptr operator+(ldg_ptr const &ptr, std::ptrdiff_t diff) {
return {ptr.m_ptr + diff};
}

friend GT_FUNCTION constexpr ldg_ptr operator-(ldg_ptr const &ptr, std::ptrdiff_t diff) {
return {ptr.m_ptr - diff};
}

friend GT_FUNCTION constexpr std::ptrdiff_t operator-(ldg_ptr const &ptr, ldg_ptr const &other) {
return ptr.m_ptr - other.m_ptr;
}
};
} // namespace impl_

template <class T>
GT_FUNCTION constexpr std::enable_if_t<is_texture_type<T>::value, impl_::ldg_ptr<T>> as_ldg_ptr(T const *ptr) {
return {ptr};
}

#endif

template <class T>
GT_FUNCTION constexpr T &&as_ldg_ptr(T &&value) {
return std::forward<T>(value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I like this fallback. Doesn't it mean that if you wrap any pointer ad "ldg" pointer, which is not "ldg"-capable, it will silently do that. If this is the intent, then at least I don't like the naming.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn’t wrap anything, it just passes unsupported types through as is. We just call this function wherever we would like to use LDG when available. The name might be improvable though, so let me know if you have a better one …

}

} // namespace gridtools
3 changes: 2 additions & 1 deletion include/gridtools/fn/cartesian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <functional>

#include "../common/ldg_ptr.hpp"
#include "../common/tuple_util.hpp"
#include "../sid/concept.hpp"
#include "./common_interface.hpp"
Expand Down Expand Up @@ -44,7 +45,7 @@ namespace gridtools::fn {

template <class Tag, class Ptr, class Strides>
GT_FUNCTION auto deref(iterator<Tag, Ptr, Strides> const &it) {
return *it.m_ptr;
return *as_ldg_ptr(it.m_ptr);
}

template <class Tag, class Ptr, class Strides, class Dim, class Offset, class... Offsets>
Expand Down
3 changes: 2 additions & 1 deletion include/gridtools/fn/neighbor_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <type_traits>

#include "../common/ldg_ptr.hpp"
#include "../common/tuple_util.hpp"
#include "../meta/logical.hpp"

Expand Down Expand Up @@ -56,7 +57,7 @@ namespace gridtools::fn::neighbor_table {

template <class T, std::enable_if_t<is_neighbor_list<T>::value, int> = 0>
GT_FUNCTION T const &neighbor_table_neighbors(T const *table, int index) {
return table[index];
return *as_ldg_ptr(&table[index]);
}

template <class NeighborTable>
Expand Down
3 changes: 2 additions & 1 deletion include/gridtools/fn/sid_neighbor_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <type_traits>

#include "../common/array.hpp"
#include "../common/ldg_ptr.hpp"
#include "../fn/unstructured.hpp"
#include "../sid/concept.hpp"

Expand Down Expand Up @@ -46,7 +47,7 @@ namespace gridtools::fn::sid_neighbor_table {

sid::shift(ptr, sid::get_stride<IndexDimension>(table.strides), index);
for (std::size_t element_idx = 0; element_idx < MaxNumNeighbors; ++element_idx) {
neighbors[element_idx] = *ptr;
neighbors[element_idx] = *as_ldg_ptr(ptr);
sid::shift(ptr, sid::get_stride<NeighborDimension>(table.strides), 1_c);
}
return neighbors;
Expand Down
3 changes: 2 additions & 1 deletion include/gridtools/fn/unstructured.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "../common/defs.hpp"
#include "../common/hymap.hpp"
#include "../common/ldg_ptr.hpp"
#include "../meta/logical.hpp"
#include "../sid/concept.hpp"
#include "../stencil/positional.hpp"
Expand Down Expand Up @@ -80,7 +81,7 @@ namespace gridtools::fn {
GT_FUNCTION constexpr auto deref(iterator<Tag, Ptr, Strides, Domain> const &it) {
GT_PROMISE(can_deref(it));
decltype(auto) stride = host_device::at_key<Tag>(sid::get_stride<dim::horizontal>(it.m_strides));
return *sid::shifted(it.m_ptr, stride, it.m_index);
return *as_ldg_ptr(sid::shifted(it.m_ptr, stride, it.m_index));
}

template <class Tag, class Ptr, class Strides, class Domain, class Conn, class Offset>
Expand Down
3 changes: 2 additions & 1 deletion include/gridtools/sid/simple_ptr_holder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "../common/defs.hpp"
#include "../common/host_device.hpp"
#include "../common/ldg_ptr.hpp"

#define GT_FILENAME <gridtools/sid/simple_ptr_holder.hpp>
#include GT_ITERATE_ON_TARGETS()
Expand All @@ -38,7 +39,7 @@ namespace gridtools {
simple_ptr_holder() = default;
GT_TARGET GT_FORCE_INLINE constexpr simple_ptr_holder(T const &ptr) : m_val{ptr} {}
#endif
GT_TARGET GT_FORCE_INLINE constexpr T const &operator()() const { return m_val; }
GT_TARGET GT_FORCE_INLINE constexpr decltype(auto) operator()() const { return as_ldg_ptr(m_val); }
};

template <class T>
Expand Down
9 changes: 4 additions & 5 deletions include/gridtools/stencil/gpu/entry_point.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "../../common/defs.hpp"
#include "../../common/hymap.hpp"
#include "../../common/integral_constant.hpp"
#include "../../common/ldg_ptr.hpp"
#include "../../common/tuple_util.hpp"
#include "../../meta.hpp"
#include "../../sid/allocator.hpp"
Expand Down Expand Up @@ -132,13 +133,11 @@ namespace gridtools {

template <class Keys>
struct deref_f {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
template <class Key, class T>
GT_FUNCTION std::enable_if_t<is_texture_type<T>::value && meta::st_contains<Keys, Key>::value, T>
operator()(Key, T const *ptr) const {
return __ldg(ptr);
GT_FUNCTION std::enable_if_t<meta::st_contains<Keys, Key>::value, T> operator()(
Key, T const *ptr) const {
return *as_ldg_ptr(ptr);
}
#endif
template <class Key, class Ptr>
GT_FUNCTION decltype(auto) operator()(Key, Ptr ptr) const {
return *ptr;
Expand Down
9 changes: 4 additions & 5 deletions include/gridtools/stencil/gpu_horizontal/entry_point.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "../../common/host_device.hpp"
#include "../../common/hymap.hpp"
#include "../../common/integral_constant.hpp"
#include "../../common/ldg_ptr.hpp"
#include "../../common/tuple_util.hpp"
#include "../../meta.hpp"
#include "../../sid/as_const.hpp"
Expand All @@ -41,13 +42,11 @@ namespace gridtools {
namespace gpu_horizontal_backend {
template <class Keys>
struct deref_f {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
template <class Key, class T>
GT_FUNCTION std::enable_if_t<is_texture_type<T>::value && meta::st_contains<Keys, Key>::value, T>
operator()(Key, T const *ptr) const {
return __ldg(ptr);
GT_FUNCTION std::enable_if_t<meta::st_contains<Keys, Key>::value, T> operator()(
Key, T const *ptr) const {
return *as_ldg_ptr(ptr);
}
#endif
template <class Key, class Ptr>
GT_FUNCTION decltype(auto) operator()(Key, Ptr ptr) const {
return *ptr;
Expand Down
3 changes: 2 additions & 1 deletion include/gridtools/storage/sid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "../common/hymap.hpp"
#include "../common/integral_constant.hpp"
#include "../common/layout_map.hpp"
#include "../common/ldg_ptr.hpp"
#include "../common/tuple.hpp"
#include "../common/tuple_util.hpp"
#include "../meta.hpp"
Expand All @@ -36,7 +37,7 @@ namespace gridtools {
template <class T>
struct ptr_holder {
T *m_val;
GT_FUNCTION constexpr T *operator()() const { return m_val; }
GT_FUNCTION constexpr auto operator()() const { return as_ldg_ptr(m_val); }

friend GT_FORCE_INLINE constexpr ptr_holder operator+(ptr_holder obj, int_t arg) {
return {obj.m_val + arg};
Expand Down
Loading
Loading