diff --git a/src/vt/messaging/active.cc b/src/vt/messaging/active.cc index 8dbbc47f21..71506793b2 100644 --- a/src/vt/messaging/active.cc +++ b/src/vt/messaging/active.cc @@ -299,7 +299,7 @@ void ActiveMessenger::handleChunkedMultiMsg(MultiMsg* msg) { recvDataDirect(nchunks, buf, tag, sender, size, 0, nullptr, fn); } -void ActiveMessenger::sendMsgMPI( +EventType ActiveMessenger::sendMsgMPI( NodeType const& dest, MsgSharedPtr const& base, MsgSizeType const& msg_size, TagType const& send_tag ) { @@ -307,11 +307,19 @@ void ActiveMessenger::sendMsgMPI( char* untyped_msg = reinterpret_cast(base_typed_msg); + vt_debug_print( + active, node, + "sendMsgMPI: dest={}, msg_size={}, send_tag={}\n", + dest, msg_size, send_tag + ); + if (static_cast(msg_size) < max_per_send) { auto const event_id = theEvent()->createMPIEvent(this_node_); auto& holder = theEvent()->getEventHolder(event_id); auto mpi_event = holder.get_event(); + mpi_event->setManagedMessage(base.to()); + int small_msg_size = static_cast(msg_size); { VT_ALLOW_MPI_CALLS; @@ -335,6 +343,8 @@ void ActiveMessenger::sendMsgMPI( } #endif } + + return event_id; } else { vt_debug_print( active, node, @@ -352,6 +362,8 @@ void ActiveMessenger::sendMsgMPI( auto m = makeMessage(info, this_node, msg_size); sendMsg(dest, m); + + return event_id; } } @@ -381,17 +393,15 @@ EventType ActiveMessenger::sendMsgBytes( dest >= theContext()->getNumNodes() || dest < 0, "Invalid destination: {}" ); - { - if (is_bcast) { - bcastsSentCount.increment(1); - } - if (is_term) { - tdSentCount.increment(1); - } - amSentCounterGauge.incrementUpdate(msg_size, 1); - - sendMsgMPI(dest, base, msg_size, send_tag); + if (is_bcast) { + bcastsSentCount.increment(1); + } + if (is_term) { + tdSentCount.increment(1); } + amSentCounterGauge.incrementUpdate(msg_size, 1); + + EventType const event_id = sendMsgMPI(dest, base, msg_size, send_tag); if (not is_term) { theTerm()->produce(epoch,1,dest); @@ -402,7 +412,7 @@ EventType ActiveMessenger::sendMsgBytes( l->send(dest, msg_size, is_bcast); } - return no_event; + return event_id; } #if vt_check_enabled(trace_enabled) diff --git a/src/vt/messaging/active.h b/src/vt/messaging/active.h index d5b659a191..1eb487f9b4 100644 --- a/src/vt/messaging/active.h +++ b/src/vt/messaging/active.h @@ -1496,8 +1496,10 @@ struct ActiveMessenger : runtime::component::PollableComponent * \param[in] base the message base pointer * \param[in] msg_size the size of the message * \param[in] send_tag the send tag on the message + * + * \return the event to test/wait for completion */ - void sendMsgMPI( + EventType sendMsgMPI( NodeType const& dest, MsgSharedPtr const& base, MsgSizeType const& msg_size, TagType const& send_tag );