Skip to content

Commit

Permalink
Added try_wait options
Browse files Browse the repository at this point in the history
  • Loading branch information
ogiroux authored and wmaxey committed Jul 27, 2021
1 parent 69e7f77 commit 6e72dc6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 25 deletions.
19 changes: 12 additions & 7 deletions include/cuda/std/barrier
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ class barrier : public std::__barrier_base<_CompletionF, _Sco> {
template<thread_scope>
friend class pipeline;

using std::__barrier_base<_CompletionF, _Sco>::__try_wait;

public:
barrier() = default;

Expand Down Expand Up @@ -77,6 +75,13 @@ _LIBCUDACXX_END_NAMESPACE_CUDA_DEVICE

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

template<class __Barrier>
inline _LIBCUDACXX_INLINE_VISIBILITY
bool barrier_try_wait_parity(__Barrier const* __this, bool __parity)
{
return __this->__try_wait_parity(__parity);
}

template<class __Barrier>
struct __barrier_poll_tester_parity {
__Barrier const* __this;
Expand All @@ -91,15 +96,15 @@ struct __barrier_poll_tester_parity {
inline _LIBCUDACXX_INLINE_VISIBILITY
bool operator()() const
{
return __this->__try_wait_parity(__parity);
return barrier_try_wait_parity(__this, __parity);
}
};

template<class __Barrier>
inline _LIBCUDACXX_INLINE_VISIBILITY
void barrier_wait_for_parity(__Barrier const* __self, bool __parity)
void barrier_wait_parity(__Barrier const* __this, bool __parity)
{
_CUDA_VSTD::__libcpp_thread_poll_with_backoff(__barrier_poll_tester_parity<__Barrier>(__self, __parity));
_CUDA_VSTD::__libcpp_thread_poll_with_backoff(__barrier_poll_tester_parity<__Barrier>(__this, __parity));
}

template<>
Expand All @@ -114,7 +119,7 @@ public:
using arrival_token = typename __barrier_base::arrival_token;

_LIBCUDACXX_INLINE_VISIBILITY
bool __try_wait(arrival_token __phase) const {
bool try_wait(arrival_token __phase) const {
#if __CUDA_ARCH__ >= 800
if (__isShared(&__barrier)) {
int __ready = 0;
Expand All @@ -131,7 +136,7 @@ public:
else
#endif
{
return __barrier.__try_wait(std::move(__phase));
return __barrier.try_wait(std::move(__phase));
}
}

Expand Down
46 changes: 28 additions & 18 deletions libcxx/include/barrier
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ class __barrier_base {
_LIBCUDACXX_BARRIER_ALIGNMENTS __atomic_base<ptrdiff_t, _Sco> __expected, __arrived;
_LIBCUDACXX_BARRIER_ALIGNMENTS _CompletionF __completion;
_LIBCUDACXX_BARRIER_ALIGNMENTS __atomic_base<bool, _Sco> __phase;

_LIBCUDACXX_INLINE_VISIBILITY
bool __try_wait_phase(bool __old_phase) const
{
return __phase.load(memory_order_acquire) != __old_phase;
}
public:
using arrival_token = bool;

Expand Down Expand Up @@ -241,11 +247,15 @@ public:
return __old_phase;
}
_LIBCUDACXX_INLINE_VISIBILITY
bool __try_wait(arrival_token __old_phase) const
bool try_wait(arrival_token __old) const
{
return __phase != __old_phase;
return __try_wait_phase(__old);
}
_LIBCUDACXX_INLINE_VISIBILITY
bool __try_wait_parity(bool __parity) const
{
return __try_wait_phase(__parity);
}

_LIBCUDACXX_INLINE_VISIBILITY
void wait(arrival_token&& __old_phase) const
{
Expand Down Expand Up @@ -281,10 +291,10 @@ struct __barrier_poll_tester {
, __phase(_CUDA_VSTD::move(__phase_))
{}

inline _LIBCUDACXX_INLINE_VISIBILITY
_LIBCUDACXX_INLINE_VISIBILITY
bool operator()() const
{
return __this->__try_wait(__phase);
return __this->try_wait(__phase);
}
};

Expand All @@ -303,12 +313,18 @@ public:
using arrival_token = uint64_t;

private:
static inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR
uint64_t __init(ptrdiff_t __count) _NOEXCEPT
{
return (((1u << 31) - __count) << 32)
| ((1u << 31) - __count);
}
_LIBCUDACXX_INLINE_VISIBILITY
bool __try_wait_phase(uint64_t __phase) const
{
uint64_t const __current = __phase_arrived_expected.load(memory_order_acquire);
return ((__current & __phase_bit) != __phase);
}

public:
__barrier_base() = default;
Expand All @@ -323,19 +339,13 @@ public:
__barrier_base(__barrier_base const&) = delete;
__barrier_base& operator=(__barrier_base const&) = delete;

inline _LIBCUDACXX_INLINE_VISIBILITY
bool __try_wait_phase(uint64_t __phase) const
{
uint64_t const __current = __phase_arrived_expected.load(memory_order_acquire);
return ((__current & __phase_bit) != __phase);
}
inline _LIBCUDACXX_INLINE_VISIBILITY
_LIBCUDACXX_INLINE_VISIBILITY
bool __try_wait_parity(bool __parity) const
{
return __try_wait_phase(__parity ? __phase_bit : 0);
}
inline _LIBCUDACXX_INLINE_VISIBILITY
bool __try_wait(arrival_token __old) const
_LIBCUDACXX_INLINE_VISIBILITY
bool try_wait(arrival_token __old) const
{
return __try_wait_phase(__old & __phase_bit);
}
Expand All @@ -351,17 +361,17 @@ public:
}
return __old & __phase_bit;
}
inline _LIBCUDACXX_INLINE_VISIBILITY
_LIBCUDACXX_INLINE_VISIBILITY
void wait(arrival_token&& __phase) const
{
__libcpp_thread_poll_with_backoff(__barrier_poll_tester<__barrier_base<__empty_completion, _Sco>>(this, _CUDA_VSTD::move(__phase)));
}
inline _LIBCUDACXX_INLINE_VISIBILITY
_LIBCUDACXX_INLINE_VISIBILITY
void arrive_and_wait()
{
wait(arrive());
}
inline _LIBCUDACXX_INLINE_VISIBILITY
_LIBCUDACXX_INLINE_VISIBILITY
void arrive_and_drop()
{
__phase_arrived_expected.fetch_add(__expected_unit, memory_order_relaxed);
Expand Down

0 comments on commit 6e72dc6

Please sign in to comment.