diff --git a/include/alpaka/exec/IndependentElements.hpp b/include/alpaka/exec/IndependentElements.hpp index 3af342bf249..cd489b2a4cf 100644 --- a/include/alpaka/exec/IndependentElements.hpp +++ b/include/alpaka/exec/IndependentElements.hpp @@ -146,9 +146,9 @@ namespace alpaka }; private: - const Idx first_; - const Idx stride_; - const Idx extent_; + Idx const first_; + Idx const stride_; + Idx const extent_; }; } // namespace detail @@ -311,11 +311,12 @@ namespace alpaka ALPAKA_FN_ACC inline const_iterator(Idx elements, Idx stride, Idx extent, Idx first) : elements_{elements} - , stride_{stride} + , + // we need to reduce the stride by on element range because index_ is later increased with each + // increment + stride_{stride - elements} , extent_{extent} - , first_{std::min(first, extent)} - , index_{first_} - , range_{std::min(first + elements, extent)} + , index_{std::min(first, extent)} { } @@ -328,22 +329,16 @@ namespace alpaka // pre-increment the iterator ALPAKA_FN_ACC inline const_iterator& operator++() { - // increment the index along the elements processed by the current thread + ++indexElem_; ++index_; - if(index_ < range_) - return *this; - - // increment the thread index with the block stride - first_ += stride_; - index_ = first_; - range_ = std::min(first_ + elements_, extent_); - if(index_ < extent_) - return *this; + if(indexElem_ >= elements_) + { + indexElem_ = Idx{0}; + index_ += stride_; + } + if(index_ >= extent_) + index_ = extent_; - // the iterator has reached or passed the end of the extent, clamp it to the extent - first_ = extent_; - index_ = extent_; - range_ = extent_; return *this; } @@ -357,7 +352,7 @@ namespace alpaka ALPAKA_FN_ACC inline bool operator==(const_iterator const& other) const { - return (index_ == other.index_) and (first_ == other.first_); + return (*(*this) == *other); } ALPAKA_FN_ACC inline bool operator!=(const_iterator const& other) const @@ -371,16 +366,15 @@ namespace alpaka Idx stride_; Idx extent_; // modified by the pre/post-increment operator - Idx first_; Idx index_; - Idx range_; + Idx indexElem_ = {0}; }; private: - const Idx elements_; - const Idx thread_; - const Idx stride_; - const Idx extent_; + Idx const elements_; + Idx const thread_; + Idx const stride_; + Idx const extent_; }; } // namespace detail