diff --git a/src/vt/rdma/collection/rdma_collection.cc b/src/vt/rdma/collection/rdma_collection.cc index 57a7cf78fb..a4770f8408 100644 --- a/src/vt/rdma/collection/rdma_collection.cc +++ b/src/vt/rdma/collection/rdma_collection.cc @@ -197,7 +197,8 @@ namespace vt { namespace rdma { auto send_payload = [&](Active::SendFnType send){ auto ret = send(put_ret, put_node, no_tag); - msg->mpi_tag_to_recv = std::get<1>(ret); + msg->mpi_tag_to_recv = ret.getTag(); + msg->nchunks = ret.getNumChunks(); }; if (tag != no_tag) { diff --git a/src/vt/rdma/rdma.cc b/src/vt/rdma/rdma.cc index 5227d764e1..5df85f3e29 100644 --- a/src/vt/rdma/rdma.cc +++ b/src/vt/rdma/rdma.cc @@ -87,7 +87,8 @@ RDMAManager::RDMAManager() auto send_payload = [&](Active::SendFnType send){ auto ret = send(data, recv_node, no_tag); - new_msg->mpi_tag_to_recv = std::get<1>(ret); + new_msg->mpi_tag_to_recv = ret.getTag(); + new_msg->nchunks = ret.getNumChunks(); vt_debug_print( rdma, node, "data is sending: tag={}, node={}\n", @@ -126,7 +127,7 @@ RDMAManager::RDMAManager() if (get_ptr == nullptr) { theMsg()->recvDataMsg( - msg->mpi_tag_to_recv, msg->send_back, + msg->nchunks, msg->mpi_tag_to_recv, msg->send_back, [=](RDMA_GetType ptr, ActionType deleter){ theRDMA()->triggerGetRecvData( op_id, msg_tag, std::get<0>(ptr), std::get<1>(ptr), deleter @@ -135,7 +136,8 @@ RDMAManager::RDMAManager() ); } else { theMsg()->recvDataMsgBuffer( - get_ptr, msg->mpi_tag_to_recv, msg->send_back, true, [get_ptr_action]{ + msg->nchunks, get_ptr, msg->mpi_tag_to_recv, msg->send_back, true, + [get_ptr_action]{ vt_debug_print( rdma, node, "recv_data_msg_buffer finished\n" @@ -209,6 +211,7 @@ RDMAManager::RDMAManager() if (put_ptr == nullptr) { theMsg()->recvDataMsg( + msg->nchunks, recv_tag, recv_node, [=](RDMA_GetType ptr, ActionType deleter){ vt_debug_print( rdma, node, @@ -238,7 +241,7 @@ RDMAManager::RDMAManager() msg->offset != no_byte ? static_cast(put_ptr) + msg->offset : put_ptr; // do a direct recv into the user buffer theMsg()->recvDataMsgBuffer( - put_ptr_offset, recv_tag, recv_node, true, []{}, + msg->nchunks, put_ptr_offset, recv_tag, recv_node, true, []{}, [=](RDMA_GetType ptr, ActionType deleter){ vt_debug_print( rdma, node, @@ -686,7 +689,8 @@ void RDMAManager::putData( auto send_payload = [&](Active::SendFnType send){ auto ret = send(RDMA_GetType{ptr, num_bytes}, put_node, no_tag); - msg->mpi_tag_to_recv = std::get<1>(ret); + msg->mpi_tag_to_recv = ret.getTag(); + msg->nchunks = ret.getNumChunks(); }; if (tag != no_tag) { diff --git a/src/vt/rdma/rdma_msg.h b/src/vt/rdma/rdma_msg.h index 8e46310d18..2a985abe6d 100644 --- a/src/vt/rdma/rdma_msg.h +++ b/src/vt/rdma/rdma_msg.h @@ -82,12 +82,13 @@ struct SendDataMessage : ActiveMessage { NodeType const& back = uninitialized_destination, NodeType const& in_recv_node = uninitialized_destination, bool const in_packed_direct = false - ) : ActiveMessage(), - rdma_handle(in_han), send_back(back), recv_node(in_recv_node), - mpi_tag_to_recv(in_mpi_tag), op_id(in_op), num_bytes(in_num_bytes), - offset(in_offset), packed_direct(in_packed_direct) + ) : rdma_handle(in_han), send_back(back), + recv_node(in_recv_node), mpi_tag_to_recv(in_mpi_tag), op_id(in_op), + num_bytes(in_num_bytes), offset(in_offset), + packed_direct(in_packed_direct) { } + int nchunks = 0; RDMA_HandleType rdma_handle = no_rdma_handle; NodeType send_back = uninitialized_destination; NodeType recv_node = uninitialized_destination; diff --git a/src/vt/serialization/messaging/serialized_data_msg.h b/src/vt/serialization/messaging/serialized_data_msg.h index c3dd8800eb..54abc563ef 100644 --- a/src/vt/serialization/messaging/serialized_data_msg.h +++ b/src/vt/serialization/messaging/serialized_data_msg.h @@ -64,6 +64,7 @@ struct SerializedDataMsgAny : MessageT { HandlerType handler = uninitialized_handler; TagType data_recv_tag = no_tag; NodeType from_node = uninitialized_destination; + int nchunks = 0; }; template diff --git a/src/vt/serialization/messaging/serialized_messenger.impl.h b/src/vt/serialization/messaging/serialized_messenger.impl.h index c2868ce2e3..b3d52ba9fb 100644 --- a/src/vt/serialization/messaging/serialized_messenger.impl.h +++ b/src/vt/serialization/messaging/serialized_messenger.impl.h @@ -100,6 +100,8 @@ template ) { auto const handler = sys_msg->handler; auto const& recv_tag = sys_msg->data_recv_tag; + auto const& nchunks = sys_msg->nchunks; + auto const& len = sys_msg->ptr_size; auto const epoch = envelopeGetEpoch(sys_msg->env); vt_debug_print( @@ -116,8 +118,8 @@ template } auto node = sys_msg->from_node; - theMsg()->recvDataMsg( - recv_tag, sys_msg->from_node, + theMsg()->recvDataDirect( + nchunks, recv_tag, sys_msg->from_node, len, [handler,recv_tag,node,epoch,is_valid_epoch] (RDMA_GetType ptr, ActionType action){ // be careful here not to use "sys_msg", it is no longer valid @@ -360,9 +362,11 @@ template auto sys_msg = makeMessage>(); auto send_serialized = [=](Active::SendFnType send){ auto ret = send(RDMA_GetType{ptr, ptr_size}, dest, no_tag); - EventType event = std::get<0>(ret); + EventType event = ret.getEvent(); theEvent()->attachAction(event, [=]{ std::free(ptr); }); - sys_msg->data_recv_tag = std::get<1>(ret); + sys_msg->data_recv_tag = ret.getTag(); + sys_msg->nchunks = ret.getNumChunks(); + sys_msg->ptr_size = ptr_size; }; // wrap metadata