Skip to content

Commit

Permalink
#883: active: Adapt codebase after changes to brodcast
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Feb 23, 2021
1 parent e2663cb commit ae558ea
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 83 deletions.
6 changes: 3 additions & 3 deletions examples/collection/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ struct Block : vt::Collection<Block, vt::Index1D> {
auto proxy_msg = vt::makeMessage<ProxyMsg>(proxy.getProxy());
vt::theMsg()->broadcastMsg<SetupGroup,ProxyMsg>(proxy_msg);
// Invoke it locally: broadcast sends to all other nodes
auto proxy_msg_local = vt::makeMessage<ProxyMsg>(proxy.getProxy());
SetupGroup()(proxy_msg_local.get());
// auto proxy_msg_local = vt::makeMessage<ProxyMsg>(proxy.getProxy());
// SetupGroup()(proxy_msg_local.get());
}
}

Expand Down Expand Up @@ -314,7 +314,7 @@ static void solveGroupSetup(vt::NodeType this_node, vt::VirtualProxyType coll_pr

vt::theGroup()->newGroupCollective(
is_even_node, [=](vt::GroupType group_id){
fmt::print("Group is created: id={:x}\n", group_id);
fmt::print("{}: Group is created: id={:x}\n", this_node, group_id);
if (this_node == 1) {
auto msg = vt::makeMessage<SubSolveMsg>(coll_proxy);
vt::envelopeSetGroup(msg->env, group_id);
Expand Down
1 change: 0 additions & 1 deletion src/vt/collective/barrier/barrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ void Barrier::barrierUp(
"barrierDown: barrier={}\n", barrier
);
theMsg()->broadcastMsg<BarrierMsg, barrierDown>(msg);
barrierDown(is_named, is_wait, barrier);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/vt/group/global/group_default.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ namespace vt { namespace group { namespace global {
});
}

if (is_root) {
*deliver = true;
if (is_root && !envelopeIsTerm(msg->env)) {
*deliver = true;
}

// If not the root of the spanning tree, send to the root to propagate to
Expand Down
6 changes: 0 additions & 6 deletions src/vt/objgroup/manager.static.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,7 @@ void invoke(messaging::MsgPtrThief<MsgT> msg, HandlerType han, NodeType dest_nod

template <typename MsgT>
void broadcast(MsgSharedPtr<MsgT> msg, HandlerType han) {
// Get the current epoch for the message
auto const cur_epoch = theMsg()->setupEpochMsg(msg);
// Broadcast the message
auto msg_hold = promoteMsg(msg.get()); // for scheduling
theMsg()->broadcastMsg<MsgT>(han, msg, no_tag);
// Schedule delivery on this node for the objgroup
scheduleMsg(msg_hold.template toVirtual<ShortMessage>(), han, cur_epoch);
}

}} /* end namespace vt::objgroup */
Expand Down
4 changes: 3 additions & 1 deletion src/vt/pipe/callback/callback_handler_bcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ struct CallbackBcast : CallbackBase<signal::Signal<MsgT>> {

private:
void trigger_(SignalDataType* data) override {
fmt::print("{}: trigger_ include_root_={}\n", include_root_)
theMsg()->broadcastMsg<MsgT,f>(data);
assert(include_root_);
if (include_root_) {
auto nmsg = makeMessage<SignalDataType*>(*data);
f(nmsg.get());
//f(nmsg.get());
}
}

Expand Down
10 changes: 0 additions & 10 deletions src/vt/pipe/callback/handler_bcast/callback_bcast.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,6 @@ CallbackBcast<MsgT>::triggerDispatch(SignalDataType* data, PipeType const& pid)
);
auto msg = makeMessage<CallbackMsg>(pid);
theMsg()->broadcastMsg<CallbackMsg>(handler_, msg);
if (include_sender_) {
runnable::RunnableVoid::run(handler_,this_node);
}
}

template <typename MsgT>
Expand All @@ -109,13 +106,6 @@ CallbackBcast<MsgT>::triggerDispatch(SignalDataType* data, PipeType const& pid)
this_node, include_sender_
);
theMsg()->broadcastMsg<SignalDataType>(handler_, data);
auto msg_group = envelopeGetGroup(data->env);
bool const is_default = msg_group == default_group;
if (include_sender_ and is_default) {
auto nmsg = makeMessage<SignalDataType>(*data);
auto short_msg = nmsg.template to<ShortMessage>.get();
runnable::Runnable<ShortMessage>::run(handler_,nullptr,short_msg,this_node);
}
}

}}} /* end namespace vt::pipe::callback */
Expand Down
6 changes: 3 additions & 3 deletions src/vt/pipe/callback/handler_bcast/callback_bcast_tl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ CallbackBcastTypeless::CallbackBcastTypeless(

void CallbackBcastTypeless::triggerVoid(PipeType const& pipe) {
auto const& this_node = theContext()->getNode();
vt_debug_print(
pipe, node,
fmt::print(
"CallbackBcast: (void) trigger_: pipe={:x}, this_node={}, "
"include_sender_={}\n",
pipe, this_node, include_sender_
);
auto msg = makeMessage<CallbackMsg>(pipe);
theMsg()->broadcastMsg<CallbackMsg>(handler_, msg);
assert(include_sender_);
if (include_sender_) {
runnable::RunnableVoid::run(handler_,this_node);
//runnable::RunnableVoid::run(handler_,this_node);
}
}

Expand Down
7 changes: 0 additions & 7 deletions src/vt/pipe/callback/handler_bcast/callback_bcast_tl.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,6 @@ void CallbackBcastTypeless::trigger(MsgT* msg, PipeType const& pipe) {

auto pmsg = promoteMsg(msg);
theMsg()->broadcastMsg<MsgT>(handler_, pmsg);

auto msg_group = envelopeGetGroup(msg->env);
bool const is_default = msg_group == default_group;
if (include_sender_ and is_default) {
auto nmsg = makeMessage<MsgT>(*msg); // create copy (?)
runnable::Runnable<MsgT>::run(handler_, nullptr, nmsg.get(), this_node);
}
}

}}} /* end namespace vt::pipe::callback */
Expand Down
109 changes: 63 additions & 46 deletions src/vt/vrt/collection/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ template <typename SysMsgT>
using IndexT = typename SysMsgT::IndexType;
using BaseIdxType = vt::index::BaseIndex;

auto const is_bcast = envelopeIsBcast(msg->env);
auto const dest = envelopeGetDest(msg->env);
auto const this_node = theContext()->getNode();
const auto is_root = is_bcast && (this_node == dest);

if(is_root){
return;
}

auto const num_nodes = theContext()->getNumNodes();

auto& info = msg->info;
Expand All @@ -186,11 +195,10 @@ template <typename SysMsgT>
// Total count across the statically sized collection
std::size_t num_elms = info.range_.getSize();

vt_debug_print(
vrt_coll, node,
"running foreach: size={}, range={}, map_han={}\n",
num_elms, range, map_han
);
// fmt::print(
// "{}: running foreach: size={}, range={}, map_han={}\n",
// theContext()->getNode(), num_elms, range, map_han
// );

range.foreach([&](IndexT cur_idx) mutable {
vt_debug_print_verbose(
Expand Down Expand Up @@ -606,6 +614,13 @@ template <typename>
/*static*/ void CollectionManager::collectionGroupFinishedHan(
CollectionGroupMsg* msg
) {
// THIS SEEMS TO BE CALLED ON NON-DEFAULT GRP
const auto is_root = envelopeIsBcast(msg->env) && (theContext()->getNode() == envelopeGetDest(msg->env));
fmt::print("{}: CollectionManager::collectionGroupFinishedHan bcast={}\n", theContext()->getNode(), is_root);
// if(is_root){
// return;
// }

auto const& proxy = msg->getProxy();
theCollection()->addToState(proxy, BufferReleaseEnum::AfterGroupReady);
theCollection()->triggerReadyOps(proxy, BufferTypeEnum::Reduce);
Expand All @@ -618,10 +633,9 @@ template <typename>
auto const& proxy = msg->proxy;
theCollection()->constructed_.insert(proxy);
theCollection()->addToState(proxy, BufferReleaseEnum::AfterFullyConstructed);
vt_debug_print(
vrt_coll, node,
"addToState: proxy={:x}, AfterCons\n", proxy
);
// fmt::print(
// "{}: addToState: proxy={:x}, AfterCons\n", theContext()->getNode(), proxy
// );
theCollection()->triggerReadyOps(proxy, BufferTypeEnum::Broadcast);
theCollection()->triggerReadyOps(proxy, BufferTypeEnum::Send);
theCollection()->triggerReadyOps(proxy, BufferTypeEnum::Reduce);
Expand All @@ -631,10 +645,9 @@ template <typename>
/*static*/ void CollectionManager::collectionGroupReduceHan(
CollectionGroupMsg* msg
) {
vt_debug_print(
vrt_coll, node,
"collectionGroupReduceHan: proxy={:x}, root={}, group={}\n",
msg->proxy, msg->isRoot(), msg->getGroup()
fmt::print(
"{}: collectionGroupReduceHan: proxy={:x}, root={}, group={}\n",
theContext()->getNode(), msg->proxy, msg->isRoot(), msg->getGroup()
);
if (msg->isRoot()) {
auto nmsg = makeMessage<CollectionGroupMsg>(*msg);
Expand Down Expand Up @@ -1035,10 +1048,6 @@ messaging::PendingSend CollectionManager::broadcastFromRoot(MsgT* raw_msg) {
msg
);

if (!send_group) {
collectionBcastHandler<ColT,IndexT,MsgT>(msg_hold.get());
}

theMsg()->popEpoch(cur_epoch);

return ret;
Expand Down Expand Up @@ -1328,12 +1337,15 @@ messaging::PendingSend CollectionManager::reduceMsgExpr(

auto msg = promoteMsg(raw_msg);



vt_debug_print(
vrt_coll, node,
"reduceMsg: msg={}\n", print_ptr(raw_msg)
);

auto const col_proxy = proxy.getProxy();
//fmt::print("reduceMsg: msg={}\n", print_ptr(raw_msg));
auto const cur_epoch = theMsg()->getEpochContextMsg(msg);

return bufferOpOrExecute<ColT>(
Expand Down Expand Up @@ -1382,6 +1394,10 @@ messaging::PendingSend CollectionManager::reduceMsgExpr(

auto ret_stamp = r->reduceImmediate<MsgT,f>(root_node, msg.get(), cur_stamp, num_elms);

// fmt::print("{}: reduceMsg: col_proxy={:x}, num_elms={}\n",
// theContext()->getNode(), col_proxy, num_elms
// );

vt_debug_print(
vrt_coll, node,
"reduceMsg: col_proxy={:x}, num_elms={}\n",
Expand Down Expand Up @@ -1611,21 +1627,21 @@ messaging::PendingSend CollectionManager::sendMsgUntypedHandler(
msg->setProxy(toProxy);

auto idx = elm_proxy.getIndex();
vt_debug_print(
vrt_coll, node,
"sendMsgUntypedHandler: col_proxy={:x}, cur_epoch={:x}, idx={}, "
"handler={}, imm_context={}\n",
col_proxy, cur_epoch, idx, handler, imm_context
);
// fmt::print(
// "{}: sendMsgUntypedHandler: col_proxy={:x}, cur_epoch={:x}, idx={}, "
// "handler={}, imm_context={}\n",
// theContext()->getNode(), col_proxy, cur_epoch, idx, handler, imm_context
// );

return schedule(
msg, !imm_context, cur_epoch, [=]{
bufferOpOrExecute<ColT>(
col_proxy,
BufferTypeEnum::Send,
BufferReleaseEnum::AfterMetaDataKnown,
static_cast<BufferReleaseEnum>(AfterFullyConstructed | AfterMetaDataKnown),
cur_epoch,
[=]() -> messaging::PendingSend {
fmt::print("{}: Executing PendingSend\n", theContext()->getNode());
auto home_node = getMapped<ColT>(col_proxy, idx);
// route the message to the destination using the location manager
auto lm = theLocMan()->getCollectionLM<ColT, IdxT>(col_proxy);
Expand All @@ -1650,11 +1666,10 @@ bool CollectionManager::insertCollectionElement(
) {
auto holder = findColHolder<ColT, IndexT>(proxy);

vt_debug_print(
vrt_coll, node,
"insertCollectionElement: proxy={:x}, map_han={}, idx={}, max_idx={}\n",
proxy, map_han, print_index(idx), print_index(max_idx)
);
// fmt::print(
// "{}: insertCollectionElement: proxy={:x}, map_han={}, idx={}, max_idx={}\n",
// theContext()->getNode(), proxy, map_han, print_index(idx), print_index(max_idx)
// );

if (holder == nullptr) {
insertMetaCollection<ColT>(proxy,map_han,max_idx,is_static);
Expand Down Expand Up @@ -1890,8 +1905,7 @@ CollectionManager::constructCollectiveMap(
// Construct a underlying group for the collection
groupConstruction<ColT>(proxy, is_static);

vt_debug_print(
vrt_coll, node,
fmt::print(
"constructCollectiveMap: entering wait for constructed_\n"
);

Expand Down Expand Up @@ -2167,10 +2181,10 @@ template <typename ColT, typename... Args>
);
};

vt_debug_print(
vrt_coll, node,
"addToState: proxy={:x}, AfterMeta\n", proxy
);
// fmt::print(
// "{}: addToState: proxy={:x}, AfterMeta\n", theContext()->getNode(), proxy
// );

theCollection()->addToState(proxy, BufferReleaseEnum::AfterMetaDataKnown);
theCollection()->triggerReadyOps(proxy, BufferTypeEnum::Send);
theCollection()->triggerReadyOps(proxy, BufferTypeEnum::Broadcast);
Expand All @@ -2189,10 +2203,9 @@ template <typename ColT>
auto msg = makeMessage<CollectionConsMsg>(proxy);
theMsg()->markAsCollectionMessage(msg);
auto const& root = 0;
vt_debug_print(
vrt_coll, node,
"reduceConstruction: invoke reduce: proxy={:x}\n", proxy
);
// fmt::print(
// "{}: reduceConstruction: invoke reduce: proxy={:x}\n", theContext()->getNode(), proxy
// );

using collective::reduce::makeStamp;
using collective::reduce::StrongUserID;
Expand Down Expand Up @@ -2288,11 +2301,6 @@ CollectionManager::constructMap(

create_msg->info = info;

vt_debug_print(
vrt_coll, node,
"construct_map: range={}\n", range.toString().c_str()
);

theMsg()->broadcastMsg<MsgType,distConstruct<MsgType>>(
create_msg
);
Expand Down Expand Up @@ -2338,6 +2346,13 @@ template <typename ColT, typename IndexT>
/*static*/ void CollectionManager::updateInsertEpochHandler(
UpdateInsertMsg<ColT,IndexT>* msg
) {
const auto is_root = envelopeIsBcast(msg->env) && (theContext()->getNode() == envelopeGetDest(msg->env));
fmt::print("{}: CollectionManager::updateInsertEpochHandler bcast={}\n", theContext()->getNode(), is_root);
if(is_root){
return;
}


auto const& untyped_proxy = msg->proxy_.getProxy();
UniversalIndexHolder<>::insertSetEpoch(untyped_proxy,msg->epoch_);

Expand Down Expand Up @@ -2974,8 +2989,6 @@ void CollectionManager::destroy(
theMsg()->markAsCollectionMessage(msg);
auto msg_hold = promoteMsg(msg.get()); // keep after bcast
theMsg()->broadcastMsg<DestroyMsgType, DestroyHandlers::destroyNow>(msg);

DestroyHandlers::destroyNow(msg_hold.get());
}

template <typename ColT, typename IndexT>
Expand Down Expand Up @@ -3284,12 +3297,15 @@ messaging::PendingSend CollectionManager::bufferOpOrExecute(
VirtualProxyType proxy, BufferTypeEnum type, BufferReleaseEnum release,
EpochType epoch, ActionPendingType action
) {

if (checkReady(proxy, release)) {
//fmt::print("{}: CollectionManager::bufferOpOrExecute action ready\n", theContext()->getNode());
theMsg()->pushEpoch(epoch);
auto ps = action();
theMsg()->popEpoch(epoch);
return ps;
} else {
//fmt::print("{}: CollectionManager::bufferOpOrExecute buffer action\n", theContext()->getNode());
return bufferOp<ColT>(proxy, type, release, epoch, action);
}
}
Expand Down Expand Up @@ -3320,6 +3336,7 @@ template <typename MsgT>
messaging::PendingSend CollectionManager::schedule(
MsgT msg, bool execute_now, EpochType cur_epoch, ActionType action
) {
fmt::print("{}: CollectionManager::schedule\n", theContext()->getNode());
theTerm()->produce(cur_epoch);
return messaging::PendingSend(msg, [=](MsgVirtualPtr<BaseMsgType> inner_msg){
auto fn = [=]{
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/memory/test_memory_lifetime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ TEST_F(TestMemoryLifetime, test_active_bcast_serial_lifetime) {
});

EXPECT_EQ(SerialTrackMsg::alloc_count, 0);
EXPECT_EQ(local_count, num_msgs_sent*(num_nodes-1));
EXPECT_EQ(local_count, num_msgs_sent*num_nodes);
}
}

Expand Down Expand Up @@ -210,7 +210,7 @@ TEST_F(TestMemoryLifetime, test_active_bcast_normal_lifetime_msgptr) {
}

theTerm()->addAction([=]{
EXPECT_EQ(local_count, num_msgs_sent*(num_nodes-1));
EXPECT_EQ(local_count, num_msgs_sent*num_nodes);
});
}
}
Expand Down
Loading

0 comments on commit ae558ea

Please sign in to comment.