Skip to content

Commit

Permalink
FIX-1541: fuse_operations
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisYaroshevskiy committed Feb 7, 2023
1 parent 98fc1a1 commit 5336cc4
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 13 deletions.
16 changes: 16 additions & 0 deletions include/eve/module/algo/algo/traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,22 @@ namespace eve::algo
//================================================================================================
inline constexpr auto expensive_callable = ::rbr::flag( expensive_callable_tag{} );

struct fuse_operations_tag {};
//================================================================================================
//! @addtogroup algorithms
//! @{
//! @var fuse_operations
//!
//! @brief Some algorithms (for example `transform_reduce`) can be implemented more efficient
//! if you fuse multiple operations provided in a single function.
//!
//! Example: if you want to use `fma` instead of doing `multiply` + `add`.
//!
//! This flag replaces the functions with their more parameter equivalents.
//! @}
//================================================================================================
inline constexpr auto fuse_operations = ::rbr::flag( fuse_operations_tag{} );

// getters -------------------

template <typename Traits>
Expand Down
22 changes: 16 additions & 6 deletions include/eve/module/algo/algo/transform_reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,19 @@ template<typename TraitsSupport> struct transform_reduce_ : TraitsSupport
sums[0] = add_op(sums[0], init);
}

EVE_FORCEINLINE bool step(auto it, eve::relative_conditional_expr auto ignore, auto idx)
template <typename I>
EVE_FORCEINLINE bool step(I it, eve::relative_conditional_expr auto ignore, auto idx)
{
auto loaded = eve::load[ignore](it);
auto mapped = map_op(loaded);
auto cvt = eve::convert(mapped, eve::as<element_type_t<SumWide>> {});
sums[idx()] = add_op(sums[idx()], if_else(ignore, cvt, zero));
auto zero_ = as_value(zero, as<wide_value_type_t<I>>{});
auto loaded = eve::load[ignore.else_(zero_)](it);

if constexpr (traits_type::contains(fuse_operations)) {
sums[idx()] = map_op(loaded, sums[idx()]);
} else {
auto mapped = map_op(loaded);
auto cvt = eve::convert(mapped, eve::as<element_type_t<SumWide>> {});
sums[idx()] = add_op(sums[idx()], cvt);
}
return false;
}

Expand Down Expand Up @@ -107,7 +114,10 @@ template<typename TraitsSupport> struct transform_reduce_ : TraitsSupport
//! Due to the nature of how SIMD algorithms work, the reduce operation has to be paired with its,
//! neutral element. For example, for add you pass `{add, zero}` as zero is the identity for add.
//! Instead of zero it can be beneficial to pass eve's constants like `eve::zero`, `eve::one`
//! because sometimes the implementation can be improved
//! because sometimes the implementation can be improved.
//!
//! Supports `fuse_operations` trait => `map_op` starts to take an extra parameter: current sum.
//! Allows to use fma. Note that `add` operations and `zero` are still used and have to be correct.
//!
//! @note
//! * The interface differs from the standard because we felt this better matches our use case:
Expand Down
17 changes: 10 additions & 7 deletions test/doc/algo/transform_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@ int main()
{
std::vector<float> v = {1.0f, 2.0f, 3.0f, 4.0f};

std::cout << " -> v = "
std::cout << " -> v = "
<< tts::as_string(v)
<< "\n";

std::cout << " -> eve::algo::transform_reduce(v, [](auto x) { return x + x }, 0.) = "
std::cout << " -> eve::algo::transform_reduce(v, [](auto x) { return x + x }, 0.f) = "
<< eve::algo::transform_reduce(v, [](auto x) { return x + x; }, 0.) << "\n";

std::cout << " -> std::transform_reduce(v.begin(), v.end(), std::plus<>{}, 0., [](auto x) { return x + x }) = "
std::cout << " -> std::transform_reduce(v.begin(), v.end(), std::plus<>{}, 0.f, [](auto x) { return x + x }) = "
<< std::transform_reduce(v.begin(), v.end(), 0., std::plus<>{}, [](auto x) { return x + x; }) << "\n";

std::cout << " -> eve::algo::reduce(eve::views::map(v, [](auto x) { return x + x }), 0.) = "
<< eve::algo::reduce(eve::views::map(v, [](auto x) { return x + x; }), 0.) << "\n";
std::cout << " -> eve::algo::reduce(eve::views::map(v, [](auto x) { return x + x }), 0.f) = "
<< eve::algo::reduce(eve::views::map(v, [](auto x) { return x + x; }), 0.f) << "\n";

std::cout << " -> eve::algo::transform_reduce(v, [](auto x) { return x + x }, std::pair{eve::mul, eve::one}, 0.) = "
<< eve::algo::transform_reduce(v, [](auto x) { return x + x; }, std::pair{eve::mul, eve::one}, 1.) << "\n";
std::cout << " -> eve::algo::transform_reduce(v, [](auto x) { return x + x }, std::pair{eve::mul, eve::one}, 1.f) = "
<< eve::algo::transform_reduce(v, [](auto x) { return x + x; }, std::pair{eve::mul, eve::one}, 1.f) << "\n";

std::cout << " -> eve::algo::transform_reduce[eve::algo::fuse_operations](v, [](auto x, auto sum) { return eve::fma(x, 2.f, sum); }, 0.f) = "
<< eve::algo::transform_reduce[eve::algo::fuse_operations](v, [](auto x, auto sum) { return eve::fma(x, .5f, sum); }, 0.f) << "\n";
}
15 changes: 15 additions & 0 deletions test/unit/module/algo/algorithm/transform_reduce_generic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,23 @@ template<typename TraitsSupport> struct transfrom_reduce_2_reduce_ : TraitsSuppo
inline constexpr auto transform_reduce_2_reduce = eve::algo::function_with_traits<
transfrom_reduce_2_reduce_>[eve::algo::transform_reduce.get_traits()];

template<typename TraitsSupport> struct transfrom_reduce_2_reduce_fuse_ : TraitsSupport
{
EVE_FORCEINLINE auto operator()(auto&& rng, auto init) const
{
return eve::algo::transform_reduce[TraitsSupport::get_traits()][eve::algo::fuse_operations](
EVE_FWD(rng), []<typename Sum, typename N>(auto x, eve::wide<Sum, N> sum) {
return eve::convert(x, eve::as<Sum>{}) + sum; }, init);
}
};

inline constexpr auto transfrom_reduce_2_reduce_fuse = eve::algo::function_with_traits<
transfrom_reduce_2_reduce_fuse_>[eve::algo::transform_reduce.get_traits()];


TTS_CASE_TPL("Check transform reduce generic", algo_test::selected_types)
<typename T>(tts::type<T>)
{
reduce_generic_all_test_cases(eve::as<T> {}, transform_reduce_2_reduce);
reduce_generic_all_test_cases(eve::as<T> {}, transfrom_reduce_2_reduce_fuse);
};

0 comments on commit 5336cc4

Please sign in to comment.