Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[try_put_and_wait] Part 11: Add implementation of try_put_and_wait feature for continue nodes #1453

7 changes: 7 additions & 0 deletions include/oneapi/tbb/detail/_flow_graph_body_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,14 @@ class threshold_regulator<T, continue_msg, void> : public continue_receiver, no_

T *my_node;

#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
// Intentionally ignore the metainformation
// If there are more items associated with passed metainfo to be processed
// They should be stored in the buffer before the limiter_node
graph_task* execute(const message_metainfo&) override {
#else
graph_task* execute() override {
#endif
return my_node->decrement_counter( 1 );
}

Expand Down
26 changes: 21 additions & 5 deletions include/oneapi/tbb/detail/_flow_graph_node_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -753,18 +753,25 @@ class continue_input : public continue_receiver {
virtual broadcast_cache<output_type > &successors() = 0;

friend class apply_body_task_bypass< class_type, continue_msg >;
#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
friend class apply_body_task_bypass< class_type, continue_msg, trackable_messages_graph_task >;
#endif

//! Applies the body to the provided input
graph_task* apply_body_bypass( input_type __TBB_FLOW_GRAPH_METAINFO_ARG(const message_metainfo&) ) {
graph_task* apply_body_bypass( input_type __TBB_FLOW_GRAPH_METAINFO_ARG(const message_metainfo& metainfo) ) {
// There is an extra copied needed to capture the
// body execution without the try_put
fgt_begin_body( my_body );
output_type v = (*my_body)( continue_msg() );
fgt_end_body( my_body );
return successors().try_put_task( v );
return successors().try_put_task( v __TBB_FLOW_GRAPH_METAINFO_ARG(metainfo) );
}

#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
graph_task* execute(const message_metainfo& metainfo) override {
#else
graph_task* execute() override {
#endif
if(!is_graph_active(my_graph_ref)) {
return nullptr;
}
Expand All @@ -776,12 +783,21 @@ class continue_input : public continue_receiver {
#if _MSC_VER && !__INTEL_COMPILER
#pragma warning (pop)
#endif
return apply_body_bypass( continue_msg() __TBB_FLOW_GRAPH_METAINFO_ARG(message_metainfo{}) );
return apply_body_bypass( continue_msg() __TBB_FLOW_GRAPH_METAINFO_ARG(metainfo) );
}
else {
d1::small_object_allocator allocator{};
typedef apply_body_task_bypass<class_type, continue_msg> task_type;
graph_task* t = allocator.new_object<task_type>( graph_reference(), allocator, *this, continue_msg(), my_priority );
graph_task* t = nullptr;
#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
if (!metainfo.empty()) {
using task_type = apply_body_task_bypass<class_type, continue_msg, trackable_messages_graph_task>;
t = allocator.new_object<task_type>( graph_reference(), allocator, *this, continue_msg(), my_priority, metainfo );
} else
#endif
{
using task_type = apply_body_task_bypass<class_type, continue_msg>;
t = allocator.new_object<task_type>( graph_reference(), allocator, *this, continue_msg(), my_priority );
}
return t;
}
}
Expand Down
45 changes: 40 additions & 5 deletions include/oneapi/tbb/flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,30 +401,61 @@ class continue_receiver : public receiver< continue_msg > {
template< typename R, typename B > friend class run_and_put_task;
template<typename X, typename Y> friend class broadcast_cache;
template<typename X, typename Y> friend class round_robin_cache;

private:
// execute body is supposed to be too small to create a task for.
graph_task* try_put_task( const input_type & ) override {
graph_task* try_put_task_impl( const input_type& __TBB_FLOW_GRAPH_METAINFO_ARG(const message_metainfo& metainfo) ) {
#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
message_metainfo predecessor_metainfo;
#endif
{
spin_mutex::scoped_lock l(my_mutex);
#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
// Prolong the wait and store the metainfo until receiving signals from all the predecessors
for (auto waiter : metainfo.waiters()) {
waiter->reserve(1);
}
my_current_metainfo.merge(metainfo);
#endif
if ( ++my_current_count < my_predecessor_count )
return SUCCESSFULLY_ENQUEUED;
kboyarinov marked this conversation as resolved.
Show resolved Hide resolved
else
else {
my_current_count = 0;
#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
predecessor_metainfo = my_current_metainfo;
my_current_metainfo = message_metainfo{};
#endif
}
}
#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
graph_task* res = execute(predecessor_metainfo);
for (auto waiter : predecessor_metainfo.waiters()) {
waiter->release(1);
}
#else
graph_task* res = execute();
#endif
return res? res : SUCCESSFULLY_ENQUEUED;
}

protected:
graph_task* try_put_task( const input_type& input ) override {
return try_put_task_impl(input __TBB_FLOW_GRAPH_METAINFO_ARG(message_metainfo{}));
}

#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
// TODO: add metainfo support for continue_receiver
graph_task* try_put_task( const input_type& input, const message_metainfo& ) override {
return try_put_task(input);
graph_task* try_put_task( const input_type& input, const message_metainfo& metainfo ) override {
return try_put_task_impl(input, metainfo);
}
#endif

spin_mutex my_mutex;
int my_predecessor_count;
int my_current_count;
int my_initial_predecessor_count;
#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
message_metainfo my_current_metainfo;
#endif
node_priority_t my_priority;
// the friend declaration in the base class did not eliminate the "protected class"
// error in gcc 4.1.2
Expand All @@ -440,7 +471,11 @@ class continue_receiver : public receiver< continue_msg > {
//! Does whatever should happen when the threshold is reached
/** This should be very fast or else spawn a task. This is
called while the sender is blocked in the try_put(). */
#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
virtual graph_task* execute(const message_metainfo& metainfo) = 0;
#else
virtual graph_task* execute() = 0;
#endif
template<typename TT, typename M> friend class successor_cache;
bool is_continue_receiver() override { return true; }

Expand Down
177 changes: 177 additions & 0 deletions test/tbb/test_continue_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,176 @@ void test_successor_cache_specialization() {
"Wrong number of messages is passed via continue_node");
}

#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
void test_try_put_and_wait_default() {
tbb::task_arena arena(1);

arena.execute([&]{
tbb::flow::graph g;

int processed_items = 0;

tbb::flow::continue_node<tbb::flow::continue_msg>* start_node = nullptr;

tbb::flow::continue_node<tbb::flow::continue_msg> cont(g,
[&](tbb::flow::continue_msg) noexcept {
static bool put_ten_msgs = true;
if (put_ten_msgs) {
for (std::size_t i = 0; i < 10; ++i) {
start_node->try_put(tbb::flow::continue_msg{});
}
put_ten_msgs = false;
}
});

start_node = &cont;

tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight> writer(g,
[&](tbb::flow::continue_msg) noexcept {
++processed_items;
});

tbb::flow::make_edge(cont, writer);

cont.try_put_and_wait(tbb::flow::continue_msg{});

// Only 1 item should be processed, with the additional 10 items having been spawned
CHECK_MESSAGE(processed_items == 1, "Unexpected items processing");

g.wait_for_all();

// The additional 10 items should be processed
CHECK_MESSAGE(processed_items == 11, "Unexpected items processing");
});
}

void test_try_put_and_wait_lightweight() {
tbb::task_arena arena(1);

arena.execute([&]{
tbb::flow::graph g;

std::vector<int> start_work_items;
std::vector<int> processed_items;
std::vector<int> new_work_items;

int wait_message = 10;

for (int i = 0; i < wait_message; ++i) {
start_work_items.emplace_back(i);
new_work_items.emplace_back(i + 1 + wait_message);
}

tbb::flow::continue_node<int, tbb::flow::lightweight>* start_node = nullptr;

tbb::flow::continue_node<int, tbb::flow::lightweight> cont(g,
[&](tbb::flow::continue_msg) noexcept {
static int counter = 0;
int i = counter++;
if (i == wait_message) {
for (auto item : new_work_items) {
(void)item;
start_node->try_put(tbb::flow::continue_msg{});
}
}
return i;
});

start_node = &cont;

tbb::flow::function_node<int, int, tbb::flow::lightweight> writer(g, tbb::flow::unlimited,
[&](int input) noexcept {
processed_items.emplace_back(input);
return 0;
});

tbb::flow::make_edge(cont, writer);

for (auto item : start_work_items) {
(void)item;
cont.try_put(tbb::flow::continue_msg{});
}

cont.try_put_and_wait(tbb::flow::continue_msg{});

CHECK_MESSAGE(processed_items.size() == start_work_items.size() + new_work_items.size() + 1,
"Unexpected number of elements processed");

std::size_t check_index = 0;

// For lightweight continue_node, start_work_items are expected to be processed first
// while putting items into the first node
for (auto item : start_work_items) {
CHECK_MESSAGE(processed_items[check_index++] == item, "Unexpected items processing");
}

for (auto item : new_work_items) {
CHECK_MESSAGE(processed_items[check_index++] == item, "Unexpected items processing");
}
// wait_message would be processed only after new_work_items
CHECK_MESSAGE(processed_items[check_index++] == wait_message, "Unexpected items processing");

g.wait_for_all();

CHECK(check_index == processed_items.size());
});
}

void test_metainfo_buffering() {
tbb::task_arena arena(1);

arena.execute([&] {
tbb::flow::graph g;

std::vector<char> call_order;

tbb::flow::continue_node<tbb::flow::continue_msg>* b_ptr = nullptr;

tbb::flow::continue_node<tbb::flow::continue_msg> a(g,
[&](tbb::flow::continue_msg) noexcept {
call_order.push_back('A');
static std::once_flag flag; // Send a signal to B only in the first call
std::call_once(flag, [&]{ b_ptr->try_put(tbb::flow::continue_msg{}); });
});

tbb::flow::continue_node<tbb::flow::continue_msg> b(g,
[&](tbb::flow::continue_msg) noexcept {
call_order.push_back('B');
a.try_put(tbb::flow::continue_msg{});
});

b_ptr = &b;

tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight> c(g,
[&](tbb::flow::continue_msg) noexcept {
call_order.push_back('C');
});

tbb::flow::make_edge(a, c);
tbb::flow::make_edge(b, c);

a.try_put_and_wait(tbb::flow::continue_msg{});

// Inside the first call of A, we send a signal to B.
// Both of them send signals to C. Since C lightweight, it is processed immediately
// upon receiving signals from both predecessors. This completes the wait.
CHECK(call_order == std::vector<char>{'A', 'B', 'C'});

g.wait_for_all();

// B previously sent a signal to A, which has now been processed.
// A sends a signal to C, which is not processed because no signal is received from B this time.
CHECK(call_order == std::vector<char>{'A', 'B', 'C', 'A'});
});
}

void test_try_put_and_wait() {
test_try_put_and_wait_default();
test_try_put_and_wait_lightweight();
test_metainfo_buffering();
}
#endif // __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT

//! Test concurrent continue_node for correctness
//! \brief \ref error_guessing
TEST_CASE("Concurrency testing") {
Expand Down Expand Up @@ -418,3 +588,10 @@ TEST_CASE("constraints for continue_node body") {
static_assert(!can_call_continue_node_ctor<output_type, WrongReturnOperatorRoundBrackets<output_type>>);
}
#endif // __TBB_CPP20_CONCEPTS_PRESENT

#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT
//! \brief \ref error_guessing
TEST_CASE("test continue_node try_put_and_wait") {
test_try_put_and_wait();
}
#endif
Loading