diff --git a/include/oneapi/tbb/detail/_flow_graph_body_impl.h b/include/oneapi/tbb/detail/_flow_graph_body_impl.h index cd4d81f94e..21da06ce03 100644 --- a/include/oneapi/tbb/detail/_flow_graph_body_impl.h +++ b/include/oneapi/tbb/detail/_flow_graph_body_impl.h @@ -398,7 +398,14 @@ class threshold_regulator : public continue_receiver, no_ T *my_node; +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + // Intentionally ignore the metainformation + // If there are more items associated with passed metainfo to be processed + // They should be stored in the buffer before the limiter_node + graph_task* execute(const message_metainfo&) override { +#else graph_task* execute() override { +#endif return my_node->decrement_counter( 1 ); } diff --git a/include/oneapi/tbb/detail/_flow_graph_node_impl.h b/include/oneapi/tbb/detail/_flow_graph_node_impl.h index 86af80c667..336cb069c6 100644 --- a/include/oneapi/tbb/detail/_flow_graph_node_impl.h +++ b/include/oneapi/tbb/detail/_flow_graph_node_impl.h @@ -753,18 +753,25 @@ class continue_input : public continue_receiver { virtual broadcast_cache &successors() = 0; friend class apply_body_task_bypass< class_type, continue_msg >; +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + friend class apply_body_task_bypass< class_type, continue_msg, trackable_messages_graph_task >; +#endif //! Applies the body to the provided input - graph_task* apply_body_bypass( input_type __TBB_FLOW_GRAPH_METAINFO_ARG(const message_metainfo&) ) { + graph_task* apply_body_bypass( input_type __TBB_FLOW_GRAPH_METAINFO_ARG(const message_metainfo& metainfo) ) { // There is an extra copied needed to capture the // body execution without the try_put fgt_begin_body( my_body ); output_type v = (*my_body)( continue_msg() ); fgt_end_body( my_body ); - return successors().try_put_task( v ); + return successors().try_put_task( v __TBB_FLOW_GRAPH_METAINFO_ARG(metainfo) ); } +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + graph_task* execute(const message_metainfo& metainfo) override { +#else graph_task* execute() override { +#endif if(!is_graph_active(my_graph_ref)) { return nullptr; } @@ -776,12 +783,21 @@ class continue_input : public continue_receiver { #if _MSC_VER && !__INTEL_COMPILER #pragma warning (pop) #endif - return apply_body_bypass( continue_msg() __TBB_FLOW_GRAPH_METAINFO_ARG(message_metainfo{}) ); + return apply_body_bypass( continue_msg() __TBB_FLOW_GRAPH_METAINFO_ARG(metainfo) ); } else { d1::small_object_allocator allocator{}; - typedef apply_body_task_bypass task_type; - graph_task* t = allocator.new_object( graph_reference(), allocator, *this, continue_msg(), my_priority ); + graph_task* t = nullptr; +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + if (!metainfo.empty()) { + using task_type = apply_body_task_bypass; + t = allocator.new_object( graph_reference(), allocator, *this, continue_msg(), my_priority, metainfo ); + } else +#endif + { + using task_type = apply_body_task_bypass; + t = allocator.new_object( graph_reference(), allocator, *this, continue_msg(), my_priority ); + } return t; } } diff --git a/include/oneapi/tbb/flow_graph.h b/include/oneapi/tbb/flow_graph.h index 0e56aaf6ae..c2b3731253 100644 --- a/include/oneapi/tbb/flow_graph.h +++ b/include/oneapi/tbb/flow_graph.h @@ -401,23 +401,51 @@ class continue_receiver : public receiver< continue_msg > { template< typename R, typename B > friend class run_and_put_task; template friend class broadcast_cache; template friend class round_robin_cache; + +private: // execute body is supposed to be too small to create a task for. - graph_task* try_put_task( const input_type & ) override { + graph_task* try_put_task_impl( const input_type& __TBB_FLOW_GRAPH_METAINFO_ARG(const message_metainfo& metainfo) ) { +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + message_metainfo predecessor_metainfo; +#endif { spin_mutex::scoped_lock l(my_mutex); +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + // Prolong the wait and store the metainfo until receiving signals from all the predecessors + for (auto waiter : metainfo.waiters()) { + waiter->reserve(1); + } + my_current_metainfo.merge(metainfo); +#endif if ( ++my_current_count < my_predecessor_count ) return SUCCESSFULLY_ENQUEUED; - else + else { my_current_count = 0; +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + predecessor_metainfo = my_current_metainfo; + my_current_metainfo = message_metainfo{}; +#endif + } } +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + graph_task* res = execute(predecessor_metainfo); + for (auto waiter : predecessor_metainfo.waiters()) { + waiter->release(1); + } +#else graph_task* res = execute(); +#endif return res? res : SUCCESSFULLY_ENQUEUED; } +protected: + graph_task* try_put_task( const input_type& input ) override { + return try_put_task_impl(input __TBB_FLOW_GRAPH_METAINFO_ARG(message_metainfo{})); + } + #if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT - // TODO: add metainfo support for continue_receiver - graph_task* try_put_task( const input_type& input, const message_metainfo& ) override { - return try_put_task(input); + graph_task* try_put_task( const input_type& input, const message_metainfo& metainfo ) override { + return try_put_task_impl(input, metainfo); } #endif @@ -425,6 +453,9 @@ class continue_receiver : public receiver< continue_msg > { int my_predecessor_count; int my_current_count; int my_initial_predecessor_count; +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + message_metainfo my_current_metainfo; +#endif node_priority_t my_priority; // the friend declaration in the base class did not eliminate the "protected class" // error in gcc 4.1.2 @@ -440,7 +471,11 @@ class continue_receiver : public receiver< continue_msg > { //! Does whatever should happen when the threshold is reached /** This should be very fast or else spawn a task. This is called while the sender is blocked in the try_put(). */ +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + virtual graph_task* execute(const message_metainfo& metainfo) = 0; +#else virtual graph_task* execute() = 0; +#endif template friend class successor_cache; bool is_continue_receiver() override { return true; } diff --git a/test/tbb/test_continue_node.cpp b/test/tbb/test_continue_node.cpp index 4b81c8ee94..1cfea3df43 100644 --- a/test/tbb/test_continue_node.cpp +++ b/test/tbb/test_continue_node.cpp @@ -354,6 +354,176 @@ void test_successor_cache_specialization() { "Wrong number of messages is passed via continue_node"); } +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT +void test_try_put_and_wait_default() { + tbb::task_arena arena(1); + + arena.execute([&]{ + tbb::flow::graph g; + + int processed_items = 0; + + tbb::flow::continue_node* start_node = nullptr; + + tbb::flow::continue_node cont(g, + [&](tbb::flow::continue_msg) noexcept { + static bool put_ten_msgs = true; + if (put_ten_msgs) { + for (std::size_t i = 0; i < 10; ++i) { + start_node->try_put(tbb::flow::continue_msg{}); + } + put_ten_msgs = false; + } + }); + + start_node = &cont; + + tbb::flow::continue_node writer(g, + [&](tbb::flow::continue_msg) noexcept { + ++processed_items; + }); + + tbb::flow::make_edge(cont, writer); + + cont.try_put_and_wait(tbb::flow::continue_msg{}); + + // Only 1 item should be processed, with the additional 10 items having been spawned + CHECK_MESSAGE(processed_items == 1, "Unexpected items processing"); + + g.wait_for_all(); + + // The additional 10 items should be processed + CHECK_MESSAGE(processed_items == 11, "Unexpected items processing"); + }); +} + +void test_try_put_and_wait_lightweight() { + tbb::task_arena arena(1); + + arena.execute([&]{ + tbb::flow::graph g; + + std::vector start_work_items; + std::vector processed_items; + std::vector new_work_items; + + int wait_message = 10; + + for (int i = 0; i < wait_message; ++i) { + start_work_items.emplace_back(i); + new_work_items.emplace_back(i + 1 + wait_message); + } + + tbb::flow::continue_node* start_node = nullptr; + + tbb::flow::continue_node cont(g, + [&](tbb::flow::continue_msg) noexcept { + static int counter = 0; + int i = counter++; + if (i == wait_message) { + for (auto item : new_work_items) { + (void)item; + start_node->try_put(tbb::flow::continue_msg{}); + } + } + return i; + }); + + start_node = &cont; + + tbb::flow::function_node writer(g, tbb::flow::unlimited, + [&](int input) noexcept { + processed_items.emplace_back(input); + return 0; + }); + + tbb::flow::make_edge(cont, writer); + + for (auto item : start_work_items) { + (void)item; + cont.try_put(tbb::flow::continue_msg{}); + } + + cont.try_put_and_wait(tbb::flow::continue_msg{}); + + CHECK_MESSAGE(processed_items.size() == start_work_items.size() + new_work_items.size() + 1, + "Unexpected number of elements processed"); + + std::size_t check_index = 0; + + // For lightweight continue_node, start_work_items are expected to be processed first + // while putting items into the first node + for (auto item : start_work_items) { + CHECK_MESSAGE(processed_items[check_index++] == item, "Unexpected items processing"); + } + + for (auto item : new_work_items) { + CHECK_MESSAGE(processed_items[check_index++] == item, "Unexpected items processing"); + } + // wait_message would be processed only after new_work_items + CHECK_MESSAGE(processed_items[check_index++] == wait_message, "Unexpected items processing"); + + g.wait_for_all(); + + CHECK(check_index == processed_items.size()); + }); +} + +void test_metainfo_buffering() { + tbb::task_arena arena(1); + + arena.execute([&] { + tbb::flow::graph g; + + std::vector call_order; + + tbb::flow::continue_node* b_ptr = nullptr; + + tbb::flow::continue_node a(g, + [&](tbb::flow::continue_msg) noexcept { + call_order.push_back('A'); + static std::once_flag flag; // Send a signal to B only in the first call + std::call_once(flag, [&]{ b_ptr->try_put(tbb::flow::continue_msg{}); }); + }); + + tbb::flow::continue_node b(g, + [&](tbb::flow::continue_msg) noexcept { + call_order.push_back('B'); + a.try_put(tbb::flow::continue_msg{}); + }); + + b_ptr = &b; + + tbb::flow::continue_node c(g, + [&](tbb::flow::continue_msg) noexcept { + call_order.push_back('C'); + }); + + tbb::flow::make_edge(a, c); + tbb::flow::make_edge(b, c); + + a.try_put_and_wait(tbb::flow::continue_msg{}); + + // Inside the first call of A, we send a signal to B. + // Both of them send signals to C. Since C lightweight, it is processed immediately + // upon receiving signals from both predecessors. This completes the wait. + CHECK(call_order == std::vector{'A', 'B', 'C'}); + + g.wait_for_all(); + + // B previously sent a signal to A, which has now been processed. + // A sends a signal to C, which is not processed because no signal is received from B this time. + CHECK(call_order == std::vector{'A', 'B', 'C', 'A'}); + }); +} + +void test_try_put_and_wait() { + test_try_put_and_wait_default(); + test_try_put_and_wait_lightweight(); + test_metainfo_buffering(); +} +#endif // __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT + //! Test concurrent continue_node for correctness //! \brief \ref error_guessing TEST_CASE("Concurrency testing") { @@ -418,3 +588,10 @@ TEST_CASE("constraints for continue_node body") { static_assert(!can_call_continue_node_ctor>); } #endif // __TBB_CPP20_CONCEPTS_PRESENT + +#if __TBB_PREVIEW_FLOW_GRAPH_TRY_PUT_AND_WAIT +//! \brief \ref error_guessing +TEST_CASE("test continue_node try_put_and_wait") { + test_try_put_and_wait(); +} +#endif