Skip to content

Commit

Permalink
Add forward iterator to async_for_each
Browse files Browse the repository at this point in the history
async_for_each originally assumed random access iterators, but we
can also support forward iterators at the cost of some complexity.

The main limitation with forward iterators is that we cannot do
arithmetic on the iterators to determine the size of the range and
to pre-calculate end iterators. So we use counted loops instead when
those iterators are used (we dispatch to a separate path for random
access iterators).
  • Loading branch information
travisdowns committed Feb 22, 2024
1 parent e602960 commit c667749
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 48 deletions.
69 changes: 53 additions & 16 deletions src/v/ssx/async_algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,62 @@ static ssize_t remaining(const C& c) {
*/
constexpr ssize_t FIXED_COST = 1;

template<typename I>
struct iter_size {
I iter;
ssize_t count;
};

/**
* A mix of for_each and for_each_n: iterates from begin to end, or until
* limit elements have been visited, whichever comes first, applying f to
* each element.
*
* Returns the number of elements visited as well as the iterator to the first
* unvisited element. This can be implemented more efficiently with random
* access iterators since we can calculate the exact end iterator up front and
* so do an efficient loop with a single sentinel. The forward iterator version
* must increment an count in the loop and check both end iterator and counter
* as the termination condition.
*/
template<std::random_access_iterator I, typename Fn>
iter_size<I> for_each_limit(const I begin, const I end, ssize_t limit, Fn f) {
auto chunk_size = std::min(limit, end - begin);
I chunk_end = begin + chunk_size;
std::for_each(begin, chunk_end, std::move(f));
return {chunk_end, chunk_size};
}

template<std::forward_iterator I, typename Fn>
iter_size<I> for_each_limit(const I begin, const I end, ssize_t limit, Fn f) {
ssize_t count = 0;
auto i = begin;
while (i != end && count < limit) {
f(*i);
++i;
++count;
}
return {i, count};
}

template<
typename Traits,
typename Counter,
typename Fn,
std::random_access_iterator Iterator>
std::forward_iterator Iterator>
ss::future<>
async_for_each_coro(Counter counter, Iterator begin, Iterator end, Fn f) {
do {
auto chunk_size = std::min(remaining<Traits>(counter), end - begin);
Iterator chunk_end = begin + chunk_size;
std::for_each(begin, chunk_end, f);
begin = chunk_end;
counter.count += chunk_size;
auto new_begin = for_each_limit(
begin, end, remaining<Traits>(counter), f);
begin = new_begin.iter;
counter.count += new_begin.count;
if (counter.count >= Traits::interval) {
co_await ss::coroutine::maybe_yield();
counter.count = 0;
Traits::yield_called();
}
} while (begin != end);

counter.count += FIXED_COST;
}

/**
Expand All @@ -112,21 +147,23 @@ template<
typename Traits = async_algo_traits,
typename Counter,
typename Fn,
std::random_access_iterator Iterator>
std::forward_iterator Iterator>
ss::future<>
async_for_each_fast(Counter counter, Iterator begin, Iterator end, Fn f) {
// This first part is an important optimization: if the input range is small
// enough, we don't want to create a coroutine frame as that's costly, so
// this function is not coroutine and we do the whole iteration here (as we
// won't yield), otherwise we defer to the coroutine-based helper.
if (auto total_size = (end - begin) + FIXED_COST;
total_size <= detail::remaining<Traits>(counter)) {
std::for_each(begin, end, std::move(f));
counter.count += total_size;

ssize_t limit = detail::remaining<Traits>(counter);
auto new_begin = for_each_limit(begin, end, limit, f);
counter.count += new_begin.count + FIXED_COST;
if (new_begin.iter == end && counter.count < Traits::interval) [[likely]] {
return ss::make_ready_future();
}

return async_for_each_coro<Traits>(counter, begin, end, std::move(f));
return async_for_each_coro<Traits>(
counter, new_begin.iter, end, std::move(f));
}

} // namespace detail
Expand All @@ -152,7 +189,7 @@ async_for_each_fast(Counter counter, Iterator begin, Iterator end, Fn f) {
template<
typename Traits = async_algo_traits,
typename Fn,
std::random_access_iterator Iterator>
std::forward_iterator Iterator>
ss::future<> async_for_each(Iterator begin, Iterator end, Fn f) {
return async_for_each_fast<Traits>(
detail::internal_counter{}, begin, end, std::move(f));
Expand Down Expand Up @@ -208,7 +245,7 @@ ss::future<> async_for_each(Iterator begin, Iterator end, Fn f) {
template<
typename Traits = async_algo_traits,
typename Fn,
std::random_access_iterator Iterator>
std::forward_iterator Iterator>
ss::future<> async_for_each_counter(
async_counter& counter, Iterator begin, Iterator end, Fn f) {
return detail::async_for_each_fast<Traits>(
Expand Down
112 changes: 80 additions & 32 deletions src/v/ssx/tests/async_algorithm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <algorithm>
#include <cstdint>
#include <iterator>
#include <unordered_map>

namespace ssx {

Expand All @@ -36,8 +37,12 @@ async_push_back(auto& container, ssize_t n, const auto& val) {
}
}

const auto add_one = [](int& x) { x++; };
[[maybe_unused]] const auto add_one_slow = [](int& x) {
struct add_one {
void operator()(int& x) const { x++; };
void operator()(std::pair<const int, int>& x) const { x.second++; };
};

const auto add_one_slow = [](int& x) {
static volatile int sink = 0;
for (int i = 0; i < 100; i++) {
sink = sink + 1;
Expand Down Expand Up @@ -98,75 +103,110 @@ struct task_counter {

} // namespace

TEST(AsyncAlgo, async_for_each_same_result) {
std::vector<int> make_container(size_t elems, std::vector<int>) {
std::vector<int> ret;
for (int i = 0; i < elems; i++) {
ret.push_back(i);
}
return ret;
}

std::unordered_map<int, int>
make_container(size_t elems, std::unordered_map<int, int>) {
std::unordered_map<int, int> ret;
for (int i = 0; i < elems; i++) {
ret[i] = i;
}
return ret;
}

template<typename T>
struct AsyncAlgo : public testing::Test {
using container = T;

static T make(size_t size) { return make_container(size, T{}); }
};

// using container_types = ::testing::Types<std::vector<int>>;
using container_types
= ::testing::Types<std::vector<int>, std::unordered_map<int, int>>;
TYPED_TEST_SUITE(AsyncAlgo, container_types);

TYPED_TEST(AsyncAlgo, make_container) {
auto c = this->make(2);
ASSERT_EQ(2, c.size());
ASSERT_EQ(c[0], 0);
ASSERT_EQ(c[1], 1);
}

TYPED_TEST(AsyncAlgo, async_for_each_same_result) {
// basic checks
check_same_result(std::vector<int>{}, add_one);
check_same_result(std::vector<int>{1, 2, 3, 4}, add_one);
check_same_result(this->make(0), add_one{});
check_same_result(this->make(4), add_one{});
}

TEST(AsyncAlgo, yield_count) {
TYPED_TEST(AsyncAlgo, yield_count) {
// helper to check that async_for_each results in the same final state
// as std::for_each

std::vector<int> v{1, 2, 3, 4, 5};
auto v = this->make(5);

task_counter c;
ssx::async_for_each<test_traits<1>>(v.begin(), v.end(), add_one).get();
ssx::async_for_each<test_traits<1>>(v.begin(), v.end(), add_one{}).get();
EXPECT_EQ(5, c.yield_delta());

c = {};
ssx::async_for_each<test_traits<2>>(v.begin(), v.end(), add_one).get();
ssx::async_for_each<test_traits<2>>(v.begin(), v.end(), add_one{}).get();
// floor(5/2), as we don't yield on partial intervals
EXPECT_EQ(2, c.yield_delta());
}

TEST(AsyncAlgo, yield_count_counter) {
TYPED_TEST(AsyncAlgo, yield_count_counter) {
async_counter a_counter;

std::vector<int> v{1, 2};
auto v = this->make(2);

task_counter t_counter;
ssx::async_for_each_counter<test_traits<3>>(
a_counter, v.begin(), v.end(), add_one)
ssx::async_for_each_counter<test_traits<4>>(
a_counter, v.begin(), v.end(), add_one{})
.get();
EXPECT_EQ(0, t_counter.yield_delta());
EXPECT_EQ(3, a_counter.count);

// now we should get a yield since we carry over the 2 ops
// from above
t_counter = {};
ssx::async_for_each_counter<test_traits<3>>(
a_counter, v.begin(), v.end(), add_one)
ssx::async_for_each_counter<test_traits<4>>(
a_counter, v.begin(), v.end(), add_one{})
.get();
EXPECT_EQ(1, t_counter.yield_delta());

v = {1, 2, 3};
v = this->make(3);
t_counter = {};
a_counter = {};
ssx::async_for_each_counter<test_traits<2>>(
a_counter, v.begin(), v.end(), add_one)
a_counter, v.begin(), v.end(), add_one{})
.get();
// 3 elems - 2 interval, overflow by 1 + FIXED_COST = 2
EXPECT_EQ(2, a_counter.count);
EXPECT_EQ(1, a_counter.count);
EXPECT_EQ(1, t_counter.yield_delta());

t_counter = {};
ssx::async_for_each_counter<test_traits<2>>(
a_counter, v.begin(), v.end(), add_one)
a_counter, v.begin(), v.end(), add_one{})
.get();
EXPECT_EQ(2, a_counter.count);
EXPECT_EQ(0, a_counter.count);
EXPECT_EQ(2, t_counter.yield_delta());
}

TEST(AsyncAlgo, yield_count_counter_empty) {
TYPED_TEST(AsyncAlgo, yield_count_counter_empty) {
async_counter a_counter;
task_counter t_counter;

std::vector<int> empty;
TypeParam empty;

auto call = [&] {
ssx::async_for_each_counter<test_traits<2>>(
a_counter, empty.begin(), empty.end(), add_one)
a_counter, empty.begin(), empty.end(), add_one{})
.get();
};

Expand All @@ -175,8 +215,8 @@ TEST(AsyncAlgo, yield_count_counter_empty) {
EXPECT_EQ(0, t_counter.yield_delta());

call();
EXPECT_EQ(2, a_counter.count);
EXPECT_EQ(0, t_counter.yield_delta());
EXPECT_EQ(0, a_counter.count);
EXPECT_EQ(1, t_counter.yield_delta());

call();
EXPECT_EQ(1, a_counter.count);
Expand Down Expand Up @@ -209,20 +249,28 @@ TEST(AsyncAlgo, async_for_each_large_container) {
EXPECT_GT(tasks.task_delta(), 2); // in practice it's > 100
}

TEST(AsyncAlgo, async_for_each_move_correctness) {
std::deque<int> v;
async_push_back(v, 10, 0).get();
int value(int i) { return i; }
int value(std::pair<const int, int> p) { return p.second; }

auto func = [canary = move_canary{}](const int& i) {
if (i != -1) {
TYPED_TEST(AsyncAlgo, async_for_each_move_correctness) {
auto v = this->make(10);

auto func = [canary = move_canary{}](const auto& i) {
if (value(i) != -1) {
EXPECT_FALSE(canary.is_moved_from());
}
return canary.is_moved_from();
};

ASSERT_FALSE(func(0));
ASSERT_FALSE(func(-1));

async_for_each<test_traits<2>>(v.begin(), v.end(), func).get();

// test that the canary is working
auto other_func = std::move(func);
// NOLINTNEXTLINE(bugprone-use-after-move)
ASSERT_TRUE(func(-1));
ASSERT_FALSE(other_func(-1));
}

} // namespace ssx

0 comments on commit c667749

Please sign in to comment.