Skip to content

Commit

Permalink
[#1024]: Implement collective bcast for collection (with member funct…
Browse files Browse the repository at this point in the history
…ion hanlder)
  • Loading branch information
JacobDomagala committed Sep 19, 2020
1 parent 14b814c commit b46c1cf
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/vt/vrt/collection/broadcast/broadcastable.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ struct Broadcastable : BaseProxyT {
typename MsgT, ActiveColMemberTypedFnType<MsgT, ColT> f, typename... Args
>
messaging::PendingSend broadcast(Args&&... args) const;

template <typename MsgT, ActiveColMemberTypedFnType<MsgT, ColT> f>
messaging::PendingSend broadcastCollective(MsgT* msg) const;
template <typename MsgT, ActiveColMemberTypedFnType<MsgT, ColT> f>
messaging::PendingSend broadcastCollective(MsgSharedPtr<MsgT> msg) const;
template <
typename MsgT, ActiveColMemberTypedFnType<MsgT, ColT> f, typename... Args
>
messaging::PendingSend broadcastCollective(Args&&... args) const;
};

}}} /* end namespace vt::vrt::collection */
Expand Down
23 changes: 23 additions & 0 deletions src/vt/vrt/collection/broadcast/broadcastable.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,29 @@ messaging::PendingSend Broadcastable<ColT,IndexT,BaseProxyT>::broadcast(MsgT* ms
return theCollection()->broadcastMsg<MsgT, f>(proxy,msg);
}

template <typename ColT, typename IndexT, typename BaseProxyT>
template <typename MsgT, ActiveColMemberTypedFnType<MsgT, ColT> f>
messaging::PendingSend
Broadcastable<ColT, IndexT, BaseProxyT>::broadcastCollective(MsgSharedPtr<MsgT> msg) const {
return broadcastCollective<MsgT, f>(msg.get());
}

template <typename ColT, typename IndexT, typename BaseProxyT>
template <
typename MsgT, ActiveColMemberTypedFnType<MsgT, ColT> f, typename... Args>
messaging::PendingSend
Broadcastable<ColT, IndexT, BaseProxyT>::broadcastCollective(Args&&... args) const {
return broadcastCollective<MsgT, f>(makeMessage<MsgT>(std::forward<Args>(args)...));
}

template <typename ColT, typename IndexT, typename BaseProxyT>
template <typename MsgT, ActiveColMemberTypedFnType<MsgT, ColT> f>
messaging::PendingSend
Broadcastable<ColT, IndexT, BaseProxyT>::broadcastCollective(MsgT* msg) const {
auto proxy = this->getProxy();
return theCollection()->broadcastMsgCollective<MsgT, f>(proxy, msg);
}

}}} /* end namespace vt::vrt::collection */

#endif /*INCLUDED_VRT_COLLECTION_BROADCAST_BROADCASTABLE_IMPL_H*/
7 changes: 7 additions & 0 deletions src/vt/vrt/collection/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,13 @@ struct CollectionManager
bool instrument
);

template <
typename MsgT,
ActiveColMemberTypedFnType<MsgT, typename MsgT::CollectionType> f>
messaging::PendingSend broadcastMsgCollective(
CollectionProxyWrapType<typename MsgT::CollectionType> const& proxy,
MsgT* msg, bool instrument = true);

/**
* \brief Broadcast a message with action function handler
*
Expand Down
41 changes: 41 additions & 0 deletions src/vt/vrt/collection/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,47 @@ messaging::PendingSend CollectionManager::broadcastFromRoot(MsgT* raw_msg) {
return ret;
}


template <
typename MsgT,
ActiveColMemberTypedFnType<MsgT, typename MsgT::CollectionType> f>
messaging::PendingSend CollectionManager::broadcastMsgCollective(
CollectionProxyWrapType<typename MsgT::CollectionType> const& proxy,
MsgT* msg, bool instrument) {

using ColT = typename MsgT::CollectionType;
using IndexT = typename ColT::IndexType;

auto promoMsg = promoteMsg(msg);

return messaging::PendingSend(
promoMsg, [proxy](MsgSharedPtr<BaseMsgType>& msgIn) {
auto elm_holder = theCollection()->findElmHolder<ColT, IndexT>(proxy);
auto const node = theContext()->getNode();

auto col_msg = reinterpret_cast<MsgT*>(msgIn.get());
auto handler =
auto_registry::makeAutoHandlerCollectionMem<ColT, MsgT, f>();
col_msg->setVrtHandler(handler);

theMsg()->markAsCollectionMessage(col_msg);

if (elm_holder) {
elm_holder->foreach (
[node, msgIn, col_msg,
elm_holder](IndexT const& idx, CollectionBase<ColT, IndexT>* base) {
auto const from = col_msg->getFromNode();
auto trace_event = trace::no_trace_event;
auto const hand = col_msg->getVrtHandler();

collectionAutoMsgDeliver<
ColT, IndexT, MsgT, typename MsgT::UserMsgType>(
col_msg, base, hand, true, from, trace_event);
});
}
});
}

template <
typename MsgT,
ActiveColMemberTypedFnType<MsgT,typename MsgT::CollectionType> f
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/collection/test_collection_group.extended.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@

namespace vt { namespace tests { namespace unit {

static int32_t elemCounter = 0;

struct MyReduceMsg : collective::ReduceTMsg<int> {
explicit MyReduceMsg(int const in_num)
: collective::ReduceTMsg<int>(in_num)
Expand All @@ -63,6 +65,12 @@ struct MyReduceMsg : collective::ReduceTMsg<int> {
struct ColA : Collection<ColA,Index1D> {
struct TestMsg : CollectionMessage<ColA> { };

struct TestDataMsg : CollectionMessage<ColA> {
TestDataMsg(int32_t value) : value_(value) {}

int32_t value_ = -1;
};

void finishedReduce(MyReduceMsg* m) {
fmt::print("at root: final num={}\n", m->getVal());
finished = true;
Expand All @@ -75,6 +83,12 @@ struct ColA : Collection<ColA,Index1D> {
proxy.reduce<collective::PlusOp<int>>(reduce_msg.get(),cb);
}

void memberHanlder(TestDataMsg* msg) {
EXPECT_EQ(msg->value_, theContext()->getNode());
--elemCounter;
finished = true;
}

virtual ~ColA() {
EXPECT_TRUE(finished);
}
Expand All @@ -96,4 +110,20 @@ TEST_F(TestCollectionGroup, test_collection_group_1) {
}
}

TEST_F(TestCollectionGroup, test_collection_group_2){
auto const my_node = theContext()->getNode();

auto const range = Index1D(8);
auto const proxy = theCollection()->constructCollective<ColA>(
range, [](vt::Index1D idx) {
++elemCounter;
return std::make_unique<ColA>();
});

auto msg = ::vt::makeMessage<ColA::TestDataMsg>(my_node);
proxy.broadcastCollective<ColA::TestDataMsg, &ColA::memberHanlder>(msg.get());

EXPECT_EQ(elemCounter, 0);
}

}}} // end namespace vt::tests::unit

0 comments on commit b46c1cf

Please sign in to comment.