From b46c1cf593ae9c2614840c5acf05ef5372890100 Mon Sep 17 00:00:00 2001 From: Jakub Domagala Date: Fri, 18 Sep 2020 20:10:11 +0200 Subject: [PATCH] [#1024]: Implement collective bcast for collection (with member function hanlder) --- .../vrt/collection/broadcast/broadcastable.h | 9 ++++ .../collection/broadcast/broadcastable.impl.h | 23 +++++++++++ src/vt/vrt/collection/manager.h | 7 ++++ src/vt/vrt/collection/manager.impl.h | 41 +++++++++++++++++++ .../test_collection_group.extended.cc | 30 ++++++++++++++ 5 files changed, 110 insertions(+) diff --git a/src/vt/vrt/collection/broadcast/broadcastable.h b/src/vt/vrt/collection/broadcast/broadcastable.h index a93dbb39de..60d3002108 100644 --- a/src/vt/vrt/collection/broadcast/broadcastable.h +++ b/src/vt/vrt/collection/broadcast/broadcastable.h @@ -79,6 +79,15 @@ struct Broadcastable : BaseProxyT { typename MsgT, ActiveColMemberTypedFnType f, typename... Args > messaging::PendingSend broadcast(Args&&... args) const; + + template f> + messaging::PendingSend broadcastCollective(MsgT* msg) const; + template f> + messaging::PendingSend broadcastCollective(MsgSharedPtr msg) const; + template < + typename MsgT, ActiveColMemberTypedFnType f, typename... Args + > + messaging::PendingSend broadcastCollective(Args&&... args) const; }; }}} /* end namespace vt::vrt::collection */ diff --git a/src/vt/vrt/collection/broadcast/broadcastable.impl.h b/src/vt/vrt/collection/broadcast/broadcastable.impl.h index 0b5c9e0bcd..db03e40144 100644 --- a/src/vt/vrt/collection/broadcast/broadcastable.impl.h +++ b/src/vt/vrt/collection/broadcast/broadcastable.impl.h @@ -105,6 +105,29 @@ messaging::PendingSend Broadcastable::broadcast(MsgT* ms return theCollection()->broadcastMsg(proxy,msg); } +template +template f> +messaging::PendingSend +Broadcastable::broadcastCollective(MsgSharedPtr msg) const { + return broadcastCollective(msg.get()); +} + +template +template < + typename MsgT, ActiveColMemberTypedFnType f, typename... Args> +messaging::PendingSend +Broadcastable::broadcastCollective(Args&&... args) const { + return broadcastCollective(makeMessage(std::forward(args)...)); +} + +template +template f> +messaging::PendingSend +Broadcastable::broadcastCollective(MsgT* msg) const { + auto proxy = this->getProxy(); + return theCollection()->broadcastMsgCollective(proxy, msg); +} + }}} /* end namespace vt::vrt::collection */ #endif /*INCLUDED_VRT_COLLECTION_BROADCAST_BROADCASTABLE_IMPL_H*/ diff --git a/src/vt/vrt/collection/manager.h b/src/vt/vrt/collection/manager.h index 2872ba411d..c899f66d48 100644 --- a/src/vt/vrt/collection/manager.h +++ b/src/vt/vrt/collection/manager.h @@ -904,6 +904,13 @@ struct CollectionManager bool instrument ); + template < + typename MsgT, + ActiveColMemberTypedFnType f> + messaging::PendingSend broadcastMsgCollective( + CollectionProxyWrapType const& proxy, + MsgT* msg, bool instrument = true); + /** * \brief Broadcast a message with action function handler * diff --git a/src/vt/vrt/collection/manager.impl.h b/src/vt/vrt/collection/manager.impl.h index e4c9fecc9c..10c2ed0052 100644 --- a/src/vt/vrt/collection/manager.impl.h +++ b/src/vt/vrt/collection/manager.impl.h @@ -848,6 +848,47 @@ messaging::PendingSend CollectionManager::broadcastFromRoot(MsgT* raw_msg) { return ret; } + +template < + typename MsgT, + ActiveColMemberTypedFnType f> +messaging::PendingSend CollectionManager::broadcastMsgCollective( + CollectionProxyWrapType const& proxy, + MsgT* msg, bool instrument) { + + using ColT = typename MsgT::CollectionType; + using IndexT = typename ColT::IndexType; + + auto promoMsg = promoteMsg(msg); + + return messaging::PendingSend( + promoMsg, [proxy](MsgSharedPtr& msgIn) { + auto elm_holder = theCollection()->findElmHolder(proxy); + auto const node = theContext()->getNode(); + + auto col_msg = reinterpret_cast(msgIn.get()); + auto handler = + auto_registry::makeAutoHandlerCollectionMem(); + col_msg->setVrtHandler(handler); + + theMsg()->markAsCollectionMessage(col_msg); + + if (elm_holder) { + elm_holder->foreach ( + [node, msgIn, col_msg, + elm_holder](IndexT const& idx, CollectionBase* base) { + auto const from = col_msg->getFromNode(); + auto trace_event = trace::no_trace_event; + auto const hand = col_msg->getVrtHandler(); + + collectionAutoMsgDeliver< + ColT, IndexT, MsgT, typename MsgT::UserMsgType>( + col_msg, base, hand, true, from, trace_event); + }); + } + }); +} + template < typename MsgT, ActiveColMemberTypedFnType f diff --git a/tests/unit/collection/test_collection_group.extended.cc b/tests/unit/collection/test_collection_group.extended.cc index 021c499d5b..2ab78a5df6 100644 --- a/tests/unit/collection/test_collection_group.extended.cc +++ b/tests/unit/collection/test_collection_group.extended.cc @@ -54,6 +54,8 @@ namespace vt { namespace tests { namespace unit { +static int32_t elemCounter = 0; + struct MyReduceMsg : collective::ReduceTMsg { explicit MyReduceMsg(int const in_num) : collective::ReduceTMsg(in_num) @@ -63,6 +65,12 @@ struct MyReduceMsg : collective::ReduceTMsg { struct ColA : Collection { struct TestMsg : CollectionMessage { }; + struct TestDataMsg : CollectionMessage { + TestDataMsg(int32_t value) : value_(value) {} + + int32_t value_ = -1; + }; + void finishedReduce(MyReduceMsg* m) { fmt::print("at root: final num={}\n", m->getVal()); finished = true; @@ -75,6 +83,12 @@ struct ColA : Collection { proxy.reduce>(reduce_msg.get(),cb); } + void memberHanlder(TestDataMsg* msg) { + EXPECT_EQ(msg->value_, theContext()->getNode()); + --elemCounter; + finished = true; + } + virtual ~ColA() { EXPECT_TRUE(finished); } @@ -96,4 +110,20 @@ TEST_F(TestCollectionGroup, test_collection_group_1) { } } +TEST_F(TestCollectionGroup, test_collection_group_2){ + auto const my_node = theContext()->getNode(); + + auto const range = Index1D(8); + auto const proxy = theCollection()->constructCollective( + range, [](vt::Index1D idx) { + ++elemCounter; + return std::make_unique(); + }); + + auto msg = ::vt::makeMessage(my_node); + proxy.broadcastCollective(msg.get()); + + EXPECT_EQ(elemCounter, 0); +} + }}} // end namespace vt::tests::unit