From a9a551c048d7ecdba9518e78181bab3ed33930f2 Mon Sep 17 00:00:00 2001 From: Jonathan Lifflander Date: Tue, 13 Oct 2020 16:57:27 -0700 Subject: [PATCH] #1111: active: implment fixes for large message sends --- src/vt/configs/arguments/args.cc | 12 + src/vt/configs/arguments/args.h | 1 + src/vt/configs/types/types_type.h | 2 +- src/vt/messaging/active.cc | 488 ++++++++++++------ src/vt/messaging/active.h | 212 ++++++-- src/vt/messaging/irecv_holder.h | 6 +- src/vt/messaging/send_info.h | 103 ++++ src/vt/rdma/collection/rdma_collection.cc | 3 +- src/vt/rdma/rdma.cc | 17 +- src/vt/rdma/rdma_msg.h | 9 +- src/vt/runtime/runtime.cc | 17 + .../messaging/serialized_data_msg.h | 1 + .../messaging/serialized_messenger.impl.h | 12 +- src/vt/utils/memory/memory_units.cc | 29 ++ src/vt/utils/memory/memory_units.h | 1 + tests/unit/active/test_active_send_large.cc | 173 +++++++ 16 files changed, 879 insertions(+), 207 deletions(-) create mode 100644 src/vt/messaging/send_info.h create mode 100644 tests/unit/active/test_active_send_large.cc diff --git a/src/vt/configs/arguments/args.cc b/src/vt/configs/arguments/args.cc index 598e93099a..f51f19fa77 100644 --- a/src/vt/configs/arguments/args.cc +++ b/src/vt/configs/arguments/args.cc @@ -165,6 +165,8 @@ namespace vt { namespace arguments { /*static*/ bool ArgConfig::parsed = false; +/*static*/ std::size_t ArgConfig::vt_max_mpi_send_size = 1ull << 30; + static std::unique_ptr new_argv = nullptr; /*static*/ int ArgConfig::parse(int& argc, char**& argv) { @@ -549,6 +551,16 @@ static std::unique_ptr new_argv = nullptr; hca->group(schedulerGroup); kca->group(schedulerGroup); + /* + * Options for configuring the runtime + */ + auto a1x = app.add_option( + "--vt_max_mpi_send_size", config_.vt_max_mpi_send_size, max_size, true + ); + + auto configRuntime = "Runtime"; + a1x->group(configRuntime); + /* * Run the parser! */ diff --git a/src/vt/configs/arguments/args.h b/src/vt/configs/arguments/args.h index 5b209facc3..eebe228d90 100644 --- a/src/vt/configs/arguments/args.h +++ b/src/vt/configs/arguments/args.h @@ -154,6 +154,7 @@ struct ArgConfig { static bool vt_debug_objgroup; static bool vt_debug_print_flush; + std::size_t vt_max_mpi_send_size; static bool vt_user_1; static bool vt_user_2; diff --git a/src/vt/configs/types/types_type.h b/src/vt/configs/types/types_type.h index eb589d7738..2670eaa7cd 100644 --- a/src/vt/configs/types/types_type.h +++ b/src/vt/configs/types/types_type.h @@ -77,7 +77,7 @@ using VirtualElmOnlyProxyType = uint64_t; using VirtualElmCountType = int64_t; using UniqueIndexBitType = uint64_t; using GroupType = uint64_t; -using MsgSizeType = int32_t; +using MsgSizeType = int64_t; using PhaseType = uint64_t; using PipeType = uint64_t; using ObjGroupProxyType = uint64_t; diff --git a/src/vt/messaging/active.cc b/src/vt/messaging/active.cc index 7cce284871..7f413acf66 100644 --- a/src/vt/messaging/active.cc +++ b/src/vt/messaging/active.cc @@ -159,9 +159,9 @@ EventType ActiveMessenger::sendMsgBytesWithPut( } else { auto const& env_tag = envelopeGetPutTag(msg->env); auto const& ret = sendData( - RDMA_GetType{put_ptr,put_size}, dest, env_tag + PtrLenPairType{put_ptr,put_size}, dest, env_tag ); - auto const& ret_tag = std::get<1>(ret); + auto const& ret_tag = ret.getTag(); if (ret_tag != env_tag) { envelopeSetPutTag(msg->env, ret_tag); } @@ -177,6 +177,127 @@ EventType ActiveMessenger::sendMsgBytesWithPut( return no_event; } +struct MultiMsg : vt::Message { + MultiMsg() = default; + MultiMsg(SendInfo in_info, NodeType in_from, MsgSizeType in_size) + : info(in_info), + from(in_from), + size(in_size) + { } + + SendInfo getInfo() const { return info; } + NodeType getFrom() const { return from; } + MsgSizeType getSize() const { return size; } + +private: + SendInfo info; + NodeType from = uninitialized_destination; + MsgSizeType size = 0; +}; + +/*static*/ void ActiveMessenger::chunkedMultiMsg(MultiMsg* msg) { + theMsg()->handleChunkedMultiMsg(msg); +} + +void ActiveMessenger::handleChunkedMultiMsg(MultiMsg* msg) { + auto buf = +#if vt_check_enabled(memory_pool) + static_cast(thePool()->alloc(msg->getSize())); +#else + static_cast(std::malloc(msg->getSize())); +#endif + + auto const size = msg->getSize(); + auto const info = msg->getInfo(); + auto const sender = msg->getFrom(); + auto const nchunks = info.getNumChunks(); + auto const tag = info.getTag(); + + auto fn = [buf,sender,size,tag,this](PtrLenPairType,ActionType){ + vt_debug_print( + active, node, + "handleChunkedMultiMsg: all chunks arrived tag={}, size={}, from={}\n", + tag, size, sender + ); + InProgressIRecv irecv(buf, size, sender); + finishPendingActiveMsgAsyncRecv(&irecv); + }; + + recvDataDirect(nchunks, buf, tag, sender, size, 0, nullptr, fn, false); +} + +EventType ActiveMessenger::sendMsgMPI( + NodeType const& dest, MsgSharedPtr const& base, + MsgSizeType const& msg_size, TagType const& send_tag +) { + BaseMsgType* base_typed_msg = base.get(); + + char* untyped_msg = reinterpret_cast(base_typed_msg); + + vt_debug_print( + active, node, + "sendMsgMPI: dest={}, msg_size={}, send_tag={}\n", + dest, msg_size, send_tag + ); + + auto const max_per_send = theConfig()->vt_max_mpi_send_size; + if (static_cast(msg_size) < max_per_send) { + auto const event_id = theEvent()->createMPIEvent(this_node_); + auto& holder = theEvent()->getEventHolder(event_id); + auto mpi_event = holder.get_event(); + + mpi_event->setManagedMessage(base.to()); + + int small_msg_size = static_cast(msg_size); + { + VT_ALLOW_MPI_CALLS; + #if vt_check_enabled(trace_enabled) + double tr_begin = 0; + if (theConfig()->vt_trace_mpi) { + tr_begin = vt::timing::Timing::getCurrentTime(); + } + #endif + int const ret = MPI_Isend( + untyped_msg, small_msg_size, MPI_BYTE, dest, send_tag, + theContext()->getComm(), mpi_event->getRequest() + ); + vtAssertMPISuccess(ret, "MPI_Isend"); + + #if vt_check_enabled(trace_enabled) + if (theConfig()->vt_trace_mpi) { + auto tr_end = vt::timing::Timing::getCurrentTime(); + auto tr_note = fmt::format("Isend(AM): dest={}, bytes={}", dest, msg_size); + trace::addUserBracketedNote(tr_begin, tr_end, tr_note, trace_isend); + } + #endif + } + + return event_id; + } else { + vt_debug_print( + active, node, + "sendMsgMPI: (multi): size={}\n", msg_size + ); + auto tag = allocateNewTag(); + auto this_node = theContext()->getNode(); + + // Send the actual data in multiple chunks + PtrLenPairType tup = std::make_tuple(untyped_msg, msg_size); + SendInfo info = sendData(tup, dest, tag); + + auto event_id = info.getEvent(); + auto& holder = theEvent()->getEventHolder(event_id); + auto mpi_event = holder.get_event(); + mpi_event->setManagedMessage(base.to()); + + // Send the control message to receive the multiple chunks of data + auto m = makeMessage(info, this_node, msg_size); + sendMsg(dest, m); + + return event_id; + } +} + EventType ActiveMessenger::sendMsgBytes( NodeType const& dest, MsgSharedPtr const& base, MsgSizeType const& msg_size, TagType const& send_tag @@ -189,10 +310,6 @@ EventType ActiveMessenger::sendMsgBytes( auto const is_term = envelopeIsTerm(msg->env); auto const is_bcast = envelopeIsBcast(msg->env); - auto const event_id = theEvent()->createMPIEvent(this_node_); - auto& holder = theEvent()->getEventHolder(event_id); - auto mpi_event = holder.get_event(); - if (!is_term || backend_check_enabled(print_term_msgs)) { debug_print( active, node, @@ -200,10 +317,6 @@ EventType ActiveMessenger::sendMsgBytes( ); } - if (is_shared) { - mpi_event->setManagedMessage(base.to()); - } - vtWarnIf( !(dest != theContext()->getNode() || is_bcast), "Destination {} should != this node" @@ -212,29 +325,13 @@ EventType ActiveMessenger::sendMsgBytes( dest >= theContext()->getNumNodes() || dest < 0, "Invalid destination: {}" ); - { - VT_ALLOW_MPI_CALLS; - #if backend_check_enabled(trace_enabled) - double tr_begin = 0; - if (ArgType::vt_trace_mpi) { - tr_begin = vt::timing::Timing::getCurrentTime(); - } - #endif - - const int ret = MPI_Isend( - msg, msg_size, MPI_BYTE, dest, send_tag, theContext()->getComm(), - mpi_event->getRequest() - ); - vtAssertMPISuccess(ret, "MPI_Isend"); - - #if backend_check_enabled(trace_enabled) - if (ArgType::vt_trace_mpi) { - auto tr_end = vt::timing::Timing::getCurrentTime(); - auto tr_note = fmt::format("Isend(AM): dest={}, bytes={}", dest, msg_size); - trace::addUserBracketedNote(tr_begin, tr_end, tr_note, trace_isend); - } - #endif } + if (is_term) { + tdSentCount.increment(1); + } + amSentCounterGauge.incrementUpdate(msg_size, 1); + + EventType const event_id = sendMsgMPI(dest, base, msg_size, send_tag); if (not is_term) { theTerm()->produce(epoch,1,dest); @@ -311,8 +408,19 @@ EventType ActiveMessenger::sendMsgSized( return ret_event; } -ActiveMessenger::SendDataRetType ActiveMessenger::sendData( - RDMA_GetType const& ptr, NodeType const& dest, TagType const& tag +MPI_TagType ActiveMessenger::allocateNewTag() { + auto const max_tag = util::MPI_Attr::getMaxTag(); + + if (cur_direct_buffer_tag_ == max_tag) { + cur_direct_buffer_tag_ = starting_direct_buffer_tag; + } + auto const ret_tag = cur_direct_buffer_tag_++; + + return ret_tag; +} + +SendInfo ActiveMessenger::sendData( + PtrLenPairType const& ptr, NodeType const& dest, TagType const& tag ) { auto const& data_ptr = std::get<0>(ptr); auto const& num_bytes = std::get<1>(ptr); @@ -321,19 +429,9 @@ ActiveMessenger::SendDataRetType ActiveMessenger::sendData( if (tag != no_tag) { send_tag = tag; } else { - auto const max_tag = util::MPI_Attr::getMaxTag(); - - if (cur_direct_buffer_tag_ == max_tag) { - cur_direct_buffer_tag_ = starting_direct_buffer_tag; - } - send_tag = cur_direct_buffer_tag_++; + send_tag = allocateNewTag(); } - auto const event_id = theEvent()->createMPIEvent(this_node_); - auto& holder = theEvent()->getEventHolder(event_id); - - auto mpi_event = holder.get_event(); - debug_print( active, node, "sendData: ptr={}, num_bytes={} dest={}, tag={}, send_tag={}\n", @@ -349,29 +447,9 @@ ActiveMessenger::SendDataRetType ActiveMessenger::sendData( "Invalid destination: {}" ); - { - VT_ALLOW_MPI_CALLS; - #if backend_check_enabled(trace_enabled) - double tr_begin = 0; - if (ArgType::vt_trace_mpi) { - tr_begin = vt::timing::Timing::getCurrentTime(); - } - #endif - - const int ret = MPI_Isend( - data_ptr, num_bytes, MPI_BYTE, dest, send_tag, theContext()->getComm(), - mpi_event->getRequest() - ); - vtAssertMPISuccess(ret, "MPI_Isend"); - - #if backend_check_enabled(trace_enabled) - if (ArgType::vt_trace_mpi) { - auto tr_end = vt::timing::Timing::getCurrentTime(); - auto tr_note = fmt::format("Isend(Data): dest={}, bytes={}", dest, num_bytes); - trace::addUserBracketedNote(tr_begin, tr_end, tr_note, trace_isend); - } - #endif - } + auto ret = sendDataMPI(ptr, dest, send_tag); + EventType event_id = std::get<0>(ret); + int num = std::get<1>(ret); // Assume that any raw data send/recv is paired with a message with an epoch // if required to inhibit early termination of that epoch @@ -382,20 +460,87 @@ ActiveMessenger::SendDataRetType ActiveMessenger::sendData( l->send(dest, num_bytes, false); } - return SendDataRetType{event_id,send_tag}; + return SendInfo{event_id, send_tag, num}; +} + +std::tuple ActiveMessenger::sendDataMPI( + PtrLenPairType const& payload, NodeType const& dest, TagType const& tag +) { + auto ptr = static_cast(std::get<0>(payload)); + auto remainder = std::get<1>(payload); + int num_sends = 0; + std::vector events; + EventType ret_event = no_event; + auto const max_per_send = theConfig()->vt_max_mpi_send_size; + while (remainder > 0) { + auto const event_id = theEvent()->createMPIEvent(this_node_); + auto& holder = theEvent()->getEventHolder(event_id); + auto mpi_event = holder.get_event(); + auto subsize = static_cast( + std::min(static_cast(remainder), max_per_send) + ); + { + #if vt_check_enabled(trace_enabled) + double tr_begin = 0; + if (theConfig()->vt_trace_mpi) { + tr_begin = vt::timing::Timing::getCurrentTime(); + } + #endif + + vt_debug_print( + active, node, + "sendDataMPI: remainder={}, node={}, tag={}, num_sends={}, subsize={}," + "total size={}\n", + remainder, dest, tag, num_sends, subsize, std::get<1>(payload) + ); + + VT_ALLOW_MPI_CALLS; + int const ret = MPI_Isend( + ptr, subsize, MPI_BYTE, dest, tag, theContext()->getComm(), + mpi_event->getRequest() + ); + vtAssertMPISuccess(ret, "MPI_Isend"); + + #if vt_check_enabled(trace_enabled) + if (theConfig()->vt_trace_mpi) { + auto tr_end = vt::timing::Timing::getCurrentTime(); + auto tr_note = fmt::format("Isend(Data): dest={}, bytes={}", dest, subsize); + trace::addUserBracketedNote(tr_begin, tr_end, tr_note, trace_isend); + } + #endif + } + ptr += subsize; + remainder -= subsize; + num_sends++; + events.push_back(event_id); + } + + if (events.size() > 1) { + ret_event = theEvent()->createParentEvent(theContext()->getNode()); + auto& holder = theEvent()->getEventHolder(ret_event); + for (auto&& child_event : events) { + holder.get_event()->addEventToList(child_event); + } + } else { + vtAssert(events.size() > 0, "Must contain at least one event"); + ret_event = events.back(); + } + + return std::make_tuple(ret_event, num_sends); } bool ActiveMessenger::recvDataMsgPriority( - PriorityType priority, TagType const& tag, NodeType const& node, - RDMA_ContinuationDeleteType next + int nchunks, PriorityType priority, TagType const& tag, NodeType const& node, + ContinuationDeleterType next ) { - return recvDataMsg(priority, tag, node, true, next); + return recvDataMsg(nchunks, priority, tag, node, true, next); } bool ActiveMessenger::recvDataMsg( - TagType const& tag, NodeType const& node, RDMA_ContinuationDeleteType next + int nchunks, TagType const& tag, NodeType const& node, + ContinuationDeleterType next ) { - return recvDataMsg(default_priority, tag, node, true, next); + return recvDataMsg(nchunks, default_priority, tag, node, true, next); } bool ActiveMessenger::tryProcessDataMsgRecv() { @@ -403,10 +548,10 @@ bool ActiveMessenger::tryProcessDataMsgRecv() { auto iter = pending_recvs_.begin(); for (; iter != pending_recvs_.end(); ++iter) { + auto& elm = iter->second; auto const done = recvDataMsgBuffer( - iter->second.user_buf, iter->second.priority, iter->first, - iter->second.recv_node, false, iter->second.dealloc_user_buf, - iter->second.cont + elm.nchunks, elm.user_buf, elm.priority, iter->first, elm.sender, false, + elm.dealloc_user_buf, elm.cont, elm.is_user_buf ); if (done) { erase = true; @@ -423,17 +568,20 @@ bool ActiveMessenger::tryProcessDataMsgRecv() { } bool ActiveMessenger::recvDataMsgBuffer( - void* const user_buf, TagType const& tag, + int nchunks, void* const user_buf, TagType const& tag, NodeType const& node, bool const& enqueue, ActionType dealloc, - RDMA_ContinuationDeleteType next + ContinuationDeleterType next, bool is_user_buf ) { - return recvDataMsgBuffer(user_buf, no_priority, tag, node, enqueue, dealloc, next); + return recvDataMsgBuffer( + nchunks, user_buf, no_priority, tag, node, enqueue, dealloc, next, + is_user_buf + ); } bool ActiveMessenger::recvDataMsgBuffer( - void* const user_buf, PriorityType priority, TagType const& tag, + int nchunks, void* const user_buf, PriorityType priority, TagType const& tag, NodeType const& node, bool const& enqueue, ActionType dealloc_user_buf, - RDMA_ContinuationDeleteType next + ContinuationDeleterType next, bool is_user_buf ) { if (not enqueue) { CountType num_probe_bytes; @@ -465,53 +613,10 @@ bool ActiveMessenger::recvDataMsgBuffer( NodeType const sender = stat.MPI_SOURCE; - MPI_Request req; - - { - VT_ALLOW_MPI_CALLS; - - #if backend_check_enabled(trace_enabled) - double tr_begin = 0; - if (ArgType::vt_trace_mpi) { - tr_begin = vt::timing::Timing::getCurrentTime(); - } - #endif - - const int recv_ret = MPI_Irecv( - buf, num_probe_bytes, MPI_BYTE, stat.MPI_SOURCE, stat.MPI_TAG, - theContext()->getComm(), &req - ); - vtAssertMPISuccess(recv_ret, "MPI_Irecv"); - - #if backend_check_enabled(trace_enabled) - if (ArgType::vt_trace_mpi) { - auto tr_end = vt::timing::Timing::getCurrentTime(); - auto tr_note = fmt::format( - "Irecv(Data): from={}, bytes={}", - stat.MPI_SOURCE, num_probe_bytes - ); - trace::addUserBracketedNote(tr_begin, tr_end, tr_note, trace_irecv); - } - #endif - } - - InProgressDataIRecv recv_holder{ - buf, num_probe_bytes, sender, req, user_buf, dealloc_user_buf, next, - priority - }; - - int recv_flag = 0; - { - VT_ALLOW_MPI_CALLS; - MPI_Status recv_stat; - MPI_Test(&recv_holder.req, &recv_flag, &recv_stat); - } - - if (recv_flag == 1) { - finishPendingDataMsgAsyncRecv(&recv_holder); - } else { - in_progress_data_irecv.emplace(std::move(recv_holder)); - } + recvDataDirect( + nchunks, buf, tag, sender, num_probe_bytes, priority, dealloc_user_buf, + next, is_user_buf + ); return true; } else { @@ -520,21 +625,105 @@ bool ActiveMessenger::recvDataMsgBuffer( } else { debug_print( active, node, - "recvDataMsgBuffer: node={}, tag={}, enqueue={}, priority={:x}\n", - node, tag, print_bool(enqueue), priority + "recvDataMsgBuffer: nchunks={}, node={}, tag={}, enqueue={}, " + "priority={:x} buffering, is_user_buf={}\n", + nchunks, node, tag, print_bool(enqueue), priority, is_user_buf ); pending_recvs_.emplace( std::piecewise_construct, std::forward_as_tuple(tag), std::forward_as_tuple( - PendingRecvType{user_buf,next,dealloc_user_buf,node,priority} + PendingRecvType{ + nchunks, user_buf, next, dealloc_user_buf, node, priority, + is_user_buf + } ) ); return false; } } +void ActiveMessenger::recvDataDirect( + int nchunks, TagType const tag, NodeType const from, MsgSizeType len, + ContinuationDeleterType next +) { + char* buf = + #if vt_check_enabled(memory_pool) + static_cast(thePool()->alloc(len)); + #else + static_cast(std::malloc(len)); + #endif + + recvDataDirect( + nchunks, buf, tag, from, len, default_priority, nullptr, next, false + ); +} + +void ActiveMessenger::recvDataDirect( + int nchunks, void* const buf, TagType const tag, NodeType const from, + MsgSizeType len, PriorityType prio, ActionType dealloc, + ContinuationDeleterType next, bool is_user_buf +) { + vtAssert(nchunks > 0, "Must have at least one chunk"); + + std::vector reqs; + reqs.resize(nchunks); + + char* cbuf = static_cast(buf); + MsgSizeType remainder = len; + auto const max_per_send = theConfig()->vt_max_mpi_send_size; + for (int i = 0; i < nchunks; i++) { + auto sublen = static_cast( + std::min(static_cast(remainder), max_per_send) + ); + + #if vt_check_enabled(trace_enabled) + double tr_begin = 0; + if (theConfig()->vt_trace_mpi) { + tr_begin = vt::timing::Timing::getCurrentTime(); + } + #endif + + { + VT_ALLOW_MPI_CALLS; + int const ret = MPI_Irecv( + cbuf+(i*max_per_send), sublen, MPI_BYTE, from, tag, + theContext()->getComm(), &reqs[i] + ); + vtAssertMPISuccess(ret, "MPI_Irecv"); + } + + dmPostedCounterGauge.incrementUpdate(len, 1); + + #if vt_check_enabled(trace_enabled) + if (theConfig()->vt_trace_mpi) { + auto tr_end = vt::timing::Timing::getCurrentTime(); + auto tr_note = fmt::format( + "Irecv(Data): from={}, bytes={}", + from, sublen + ); + trace::addUserBracketedNote(tr_begin, tr_end, tr_note, trace_irecv); + } + #endif + + remainder -= sublen; + } + + InProgressDataIRecv recv{ + cbuf, len, from, std::move(reqs), is_user_buf ? buf : nullptr, dealloc, + next, prio + }; + + bool done = recv.test(); + + if (done) { + finishPendingDataMsgAsyncRecv(&recv); + } else { + in_progress_data_irecv.emplace(std::move(recv)); + } +} + void ActiveMessenger::finishPendingDataMsgAsyncRecv(InProgressDataIRecv* irecv) { auto buf = irecv->buf; auto num_probe_bytes = irecv->probe_bytes; @@ -575,7 +764,7 @@ void ActiveMessenger::finishPendingDataMsgAsyncRecv(InProgressDataIRecv* irecv) } else { // If we have a continuation, schedule to run later auto run = [=]{ - next(RDMA_GetType{buf,num_probe_bytes}, dealloc_buf); + next(PtrLenPairType{buf,num_probe_bytes}, dealloc_buf); theTerm()->consume(term::any_epoch_sentinel,1,sender); theTerm()->hangDetectRecv(); }; @@ -584,11 +773,12 @@ void ActiveMessenger::finishPendingDataMsgAsyncRecv(InProgressDataIRecv* irecv) } bool ActiveMessenger::recvDataMsg( - PriorityType priority, TagType const& tag, NodeType const& recv_node, - bool const& enqueue, RDMA_ContinuationDeleteType next + int nchunks, PriorityType priority, TagType const& tag, + NodeType const& sender, bool const& enqueue, + ContinuationDeleterType next ) { return recvDataMsgBuffer( - nullptr, priority, tag, recv_node, enqueue, nullptr, next + nchunks, nullptr, priority, tag, sender, enqueue, nullptr, next ); } @@ -776,16 +966,18 @@ bool ActiveMessenger::deliverActiveMsg( } bool ActiveMessenger::tryProcessIncomingActiveMsg() { - VT_ALLOW_MPI_CALLS; // MPI_Iprove, MPI_Irecv, MPI_Test - CountType num_probe_bytes; MPI_Status stat; int flag; - MPI_Iprobe( - MPI_ANY_SOURCE, static_cast(MPITag::ActiveMsgTag), - theContext()->getComm(), &flag, &stat - ); + { + VT_ALLOW_MPI_CALLS; + + MPI_Iprobe( + MPI_ANY_SOURCE, static_cast(MPITag::ActiveMsgTag), + theContext()->getComm(), &flag, &stat + ); + } if (flag == 1) { MPI_Get_count(&stat, MPI_BYTE, &num_probe_bytes); @@ -808,6 +1000,7 @@ bool ActiveMessenger::tryProcessIncomingActiveMsg() { } #endif + VT_ALLOW_MPI_CALLS; MPI_Irecv( buf, num_probe_bytes, MPI_BYTE, sender, stat.MPI_TAG, theContext()->getComm(), &req @@ -827,10 +1020,9 @@ bool ActiveMessenger::tryProcessIncomingActiveMsg() { InProgressIRecv recv_holder{buf, num_probe_bytes, sender, req}; - int recv_flag = 0; - MPI_Status recv_stat; - MPI_Test(&recv_holder.req, &recv_flag, &recv_stat); - if (recv_flag == 1) { + auto done = recv_holder.test(); + + if (done) { finishPendingActiveMsgAsyncRecv(&recv_holder); } else { in_progress_active_msg_irecv.emplace(std::move(recv_holder)); @@ -896,8 +1088,8 @@ void ActiveMessenger::finishPendingActiveMsgAsyncRecv(InProgressIRecv* irecv) { put_finished = true; } else { /*bool const put_delivered = */recvDataMsg( - put_tag, sender, - [=](RDMA_GetType ptr, ActionType deleter){ + 1, put_tag, sender, + [=](PtrLenPairType ptr, ActionType deleter){ envelopeSetPutPtr(base->env, std::get<0>(ptr), std::get<1>(ptr)); scheduleActiveMsg(base, sender, num_probe_bytes, true, deleter); } diff --git a/src/vt/messaging/active.h b/src/vt/messaging/active.h index 02fadf97b4..d8c8d8b865 100644 --- a/src/vt/messaging/active.h +++ b/src/vt/messaging/active.h @@ -57,6 +57,7 @@ #include "vt/messaging/pending_send.h" #include "vt/messaging/listener.h" #include "vt/messaging/irecv_holder.h" +#include "vt/messaging/send_info.h" #include "vt/event/event.h" #include "vt/registry/registry.h" #include "vt/registry/auto/auto_registry_interface.h" @@ -74,6 +75,17 @@ #include #include +namespace vt { + +/// A pair of a void* and number of bytes (length) for sending data +using PtrLenPairType = std::tuple; + +/// A continuation function with an allocated pointer with a deleter function +using ContinuationDeleterType = + std::function; + +} /* end namespace vt */ + namespace vt { namespace messaging { using MPI_TagType = int; @@ -90,55 +102,96 @@ static constexpr TagType const starting_direct_buffer_tag = 1000; static constexpr MsgSizeType const max_pack_direct_size = 512; struct PendingRecv { + int nchunks = 0; void* user_buf = nullptr; - RDMA_ContinuationDeleteType cont = nullptr; + ContinuationDeleterType cont = nullptr; ActionType dealloc_user_buf = nullptr; - NodeType recv_node = uninitialized_destination; + NodeType sender = uninitialized_destination; PriorityType priority = no_priority; + bool is_user_buf = false; PendingRecv( - void* in_user_buf, RDMA_ContinuationDeleteType in_cont, + int in_nchunks, void* in_user_buf, ContinuationDeleterType in_cont, ActionType in_dealloc_user_buf, NodeType node, - PriorityType in_priority - ) : user_buf(in_user_buf), cont(in_cont), - dealloc_user_buf(in_dealloc_user_buf), recv_node(node), - priority(in_priority) + PriorityType in_priority, bool in_is_user_buf + ) : nchunks(in_nchunks), user_buf(in_user_buf), cont(in_cont), + dealloc_user_buf(in_dealloc_user_buf), sender(node), + priority(in_priority), is_user_buf(in_is_user_buf) { } }; -struct InProgressIRecv { - using CountType = int32_t; - - InProgressIRecv( - char* in_buf, CountType in_probe_bytes, NodeType in_sender, - MPI_Request in_req +struct InProgressBase { + InProgressBase( + char* in_buf, MsgSizeType in_probe_bytes, NodeType in_sender ) : buf(in_buf), probe_bytes(in_probe_bytes), sender(in_sender), - req(in_req), valid(true) + valid(true) { } char* buf = nullptr; - CountType probe_bytes = 0; + MsgSizeType probe_bytes = 0; NodeType sender = uninitialized_destination; - MPI_Request req; bool valid = false; }; -struct InProgressDataIRecv : public InProgressIRecv { +struct InProgressIRecv : InProgressBase { + + InProgressIRecv( + char* in_buf, MsgSizeType in_probe_bytes, NodeType in_sender, + MPI_Request in_req = MPI_REQUEST_NULL + ) : InProgressBase(in_buf, in_probe_bytes, in_sender), + req(in_req) + { } + + bool test() { + VT_ALLOW_MPI_CALLS; // MPI_Test + + int flag = 0; + MPI_Status stat; + MPI_Test(&req, &flag, &stat); + return flag; + } + +private: + MPI_Request req = MPI_REQUEST_NULL; +}; + +struct InProgressDataIRecv : InProgressBase { InProgressDataIRecv( - char* in_buf, CountType in_probe_bytes, NodeType in_sender, - MPI_Request in_req, void* const in_user_buf, + char* in_buf, MsgSizeType in_probe_bytes, NodeType in_sender, + std::vector in_reqs, void* const in_user_buf, ActionType in_dealloc_user_buf, - RDMA_ContinuationDeleteType in_next, + ContinuationDeleterType in_next, PriorityType in_priority - ) : InProgressIRecv{in_buf, in_probe_bytes, in_sender, in_req}, + ) : InProgressBase{in_buf, in_probe_bytes, in_sender}, user_buf(in_user_buf), dealloc_user_buf(in_dealloc_user_buf), - next(in_next), priority(in_priority) + next(in_next), priority(in_priority), reqs(std::move(in_reqs)) { } + bool test() { + int flag = 0; + MPI_Status stat; + for ( ; cur < reqs.size(); cur++) { + VT_ALLOW_MPI_CALLS; // MPI_Test + + MPI_Test(&reqs[cur], &flag, &stat); + + if (flag == 0) { + return false; + } + } + + return true; + } + +public: void* user_buf = nullptr; ActionType dealloc_user_buf = nullptr; - RDMA_ContinuationDeleteType next = nullptr; + ContinuationDeleterType next = nullptr; PriorityType priority = no_priority; + +private: + std::size_t cur = 0; + std::vector reqs; }; struct BufferedActiveMsg { @@ -155,16 +208,16 @@ struct BufferedActiveMsg { { } }; +// forward-declare for header +struct MultiMsg; + struct ActiveMessenger { using BufferedMsgType = BufferedActiveMsg; using MessageType = ShortMessage*; using CountType = int32_t; using PendingRecvType = PendingRecv; using EventRecordType = event::AsyncEvent::EventRecordType; - using SendDataRetType = std::tuple; - using SendFnType = std::function< - SendDataRetType(RDMA_GetType,NodeType,TagType) - >; + using SendFnType = std::function; using UserSendFnType = std::function; using ContainerPendingType = std::unordered_map; using MsgContType = std::list; @@ -489,37 +542,75 @@ struct ActiveMessenger { MsgSizeType const& ptr_bytes ); - SendDataRetType sendData( - RDMA_GetType const& ptr, NodeType const& dest, TagType const& tag + SendInfo sendData( + PtrLenPairType const& ptr, NodeType const& dest, TagType const& tag + ); + + std::tuple sendDataMPI( + PtrLenPairType const& ptr, NodeType const& dest, TagType const& tag ); bool recvDataMsgPriority( - PriorityType priority, TagType const& tag, NodeType const& node, - RDMA_ContinuationDeleteType next = nullptr + int nchunks, PriorityType priority, TagType const& tag, + NodeType const& node, ContinuationDeleterType next = nullptr ); bool recvDataMsg( - TagType const& tag, NodeType const& node, - RDMA_ContinuationDeleteType next = nullptr + int nchunks, TagType const& tag, NodeType const& node, + ContinuationDeleterType next = nullptr ); bool recvDataMsg( - PriorityType priority, TagType const& tag, NodeType const& recv_node, - bool const& enqueue, RDMA_ContinuationDeleteType next = nullptr + int nchunks, PriorityType priority, TagType const& tag, + NodeType const& sender, bool const& enqueue, + ContinuationDeleterType next = nullptr ); bool recvDataMsgBuffer( - void* const user_buf, PriorityType priority, TagType const& tag, + int nchunks, void* const user_buf, PriorityType priority, TagType const& tag, NodeType const& node = uninitialized_destination, bool const& enqueue = true, ActionType dealloc_user_buf = nullptr, - RDMA_ContinuationDeleteType next = nullptr + ContinuationDeleterType next = nullptr, bool is_user_buf = false ); bool recvDataMsgBuffer( - void* const user_buf, TagType const& tag, + int nchunks, void* const user_buf, TagType const& tag, NodeType const& node = uninitialized_destination, bool const& enqueue = true, ActionType dealloc_user_buf = nullptr, - RDMA_ContinuationDeleteType next = nullptr + ContinuationDeleterType next = nullptr, bool is_user_buf = false + ); + + /** + * \brief Receive data from MPI in multiple chunks + * + * \param[in] nchunks the number of chunks to receive + * \param[in] buf the receive buffer + * \param[in] tag the MPI tag + * \param[in] from the sender + * \param[in] len the total length + * \param[in] prio the priority for the continuation + * \param[in] dealloc the action to deallocate the buffer + * \param[in] next the continuation that gets passed the data when ready + * \param[in] is_user_buf is a user buffer that require user deallocation + */ + void recvDataDirect( + int nchunks, void* const buf, TagType const tag, NodeType const from, + MsgSizeType len, PriorityType prio, ActionType dealloc = nullptr, + ContinuationDeleterType next = nullptr, bool is_user_buf = false + ); + + /** + * \brief Receive data from MPI in multiple chunks + * + * \param[in] nchunks the number of chunks to receive + * \param[in] tag the MPI tag + * \param[in] from the sender + * \param[in] len the total length + * \param[in] next the continuation that gets passed the data when ready + */ + void recvDataDirect( + int nchunks, TagType const tag, NodeType const from, + MsgSizeType len, ContinuationDeleterType next ); EventType sendMsgSized( @@ -585,6 +676,23 @@ struct ActiveMessenger { MsgSizeType const& msg_size, TagType const& send_tag ); + /** + * \internal + * \brief Send already-packed message bytes with MPI using multiple + * sends if necessary + * + * \param[in] dest the destination of the message + * \param[in] base the message base pointer + * \param[in] msg_size the size of the message + * \param[in] send_tag the send tag on the message + * + * \return the event to test/wait for completion + */ + EventType sendMsgMPI( + NodeType const& dest, MsgSharedPtr const& base, + MsgSizeType const& msg_size, TagType const& send_tag + ); + /* * setGlobalEpoch() is a shortcut for both pushing and popping epochs on the * stack depending on the value of the `epoch' passed as an argument. @@ -642,6 +750,32 @@ struct ActiveMessenger { } private: + /** + * \internal \brief Allocate a new, unused tag. + * + * \note Wraps around when reaching max tag, determined by the MPI + * implementation. + * + * \return a new MPI tag + */ + MPI_TagType allocateNewTag(); + + /** + * \internal \brief Handle a control message that coordinates multiple + * payloads arriving that constitute a contiguous payload + * + * \param[in] msg the message with control data + */ + void handleChunkedMultiMsg(MultiMsg* msg); + + /** + * \internal \brief Handle a control message; immediately calls + * \c handleChunkedMultiMsg + * + * \param[in] msg the message with control data + */ + static void chunkedMultiMsg(MultiMsg* msg); + bool testPendingActiveMsgAsyncRecv(); bool testPendingDataMsgAsyncRecv(); void finishPendingActiveMsgAsyncRecv(InProgressIRecv* irecv); diff --git a/src/vt/messaging/irecv_holder.h b/src/vt/messaging/irecv_holder.h index 679ff4ac5c..226fcda791 100644 --- a/src/vt/messaging/irecv_holder.h +++ b/src/vt/messaging/irecv_holder.h @@ -94,11 +94,9 @@ struct IRecvHolder { auto& e = holder_[i]; vtAssert(e.valid, "Must be valid"); - int flag = 0; - MPI_Status stat; - MPI_Test(&e.req, &flag, &stat); + auto done = e.test(); - if (flag == 0) { + if (not done) { ++i; continue; } diff --git a/src/vt/messaging/send_info.h b/src/vt/messaging/send_info.h new file mode 100644 index 0000000000..4aaf6d2399 --- /dev/null +++ b/src/vt/messaging/send_info.h @@ -0,0 +1,103 @@ +/* +//@HEADER +// ***************************************************************************** +// +// send_info.h +// DARMA Toolkit v. 1.0.0 +// DARMA/vt => Virtual Transport +// +// Copyright 2019 National Technology & Engineering Solutions of Sandia, LLC +// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S. +// Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact darma@sandia.gov +// +// ***************************************************************************** +//@HEADER +*/ + +#if !defined INCLUDED_VT_MESSAGING_SEND_INFO_H +#define INCLUDED_VT_MESSAGING_SEND_INFO_H + +#include "vt/config.h" + +namespace vt { namespace messaging { + +/** + * \struct SendInfo + * + * \brief Returned from a data send to be used to receive the data + */ +struct SendInfo { + + /** + * \internal + * \brief Construct a SendInfo + * + * \param[in] in_event the send event (parent event if multiple sends) + * \param[in] in_tag the MPI tag it was sent with + * \param[in] in_nchunks the number of send chunks for the entire payload + */ + SendInfo(EventType in_event, TagType in_tag, int in_nchunks) + : event(in_event), + tag(in_tag), + nchunks(in_nchunks) + { } + + /** + * \brief Get the send event (either an MPI event or a parent event containing + * multiple MPI events) + * + * \return the send event + */ + EventType getEvent() const { return event; } + + /** + * \brief The MPI tag that it was sent with + * + * \return the tag + */ + TagType getTag() const { return tag; } + + /** + * \brief The number of Isend chunks that make up the entire payload + * + * \return the number of chunks + */ + int getNumChunks() const { return nchunks; } + +private: + EventType const event = no_event; /**< The event for the send */ + TagType const tag = no_tag; /**< The MPI tag for the send */ + int const nchunks = 0; /**< The number of send chunks to receive */ +}; + +}} /* end namespace vt::messaging */ + +#endif /*INCLUDED_VT_MESSAGING_SEND_INFO_H*/ diff --git a/src/vt/rdma/collection/rdma_collection.cc b/src/vt/rdma/collection/rdma_collection.cc index 15dea098f6..66f6660911 100644 --- a/src/vt/rdma/collection/rdma_collection.cc +++ b/src/vt/rdma/collection/rdma_collection.cc @@ -197,7 +197,8 @@ namespace vt { namespace rdma { auto send_payload = [&](Active::SendFnType send){ auto ret = send(put_ret, put_node, no_tag); - msg->mpi_tag_to_recv = std::get<1>(ret); + msg->mpi_tag_to_recv = ret.getTag(); + msg->nchunks = ret.getNumChunks(); }; if (tag != no_tag) { diff --git a/src/vt/rdma/rdma.cc b/src/vt/rdma/rdma.cc index dde4cc6e6c..99f2afae89 100644 --- a/src/vt/rdma/rdma.cc +++ b/src/vt/rdma/rdma.cc @@ -87,7 +87,8 @@ RDMAManager::RDMAManager() auto send_payload = [&](Active::SendFnType send){ auto ret = send(data, recv_node, no_tag); - new_msg->mpi_tag_to_recv = std::get<1>(ret); + new_msg->mpi_tag_to_recv = ret.getTag(); + new_msg->nchunks = ret.getNumChunks(); debug_print( rdma, node, "data is sending: tag={}, node={}\n", @@ -125,7 +126,7 @@ RDMAManager::RDMAManager() if (get_ptr == nullptr) { theMsg()->recvDataMsg( - msg->mpi_tag_to_recv, msg->send_back, + msg->nchunks, msg->mpi_tag_to_recv, msg->send_back, [=](RDMA_GetType ptr, ActionType deleter){ theRDMA()->triggerGetRecvData( op_id, msg_tag, std::get<0>(ptr), std::get<1>(ptr), deleter @@ -134,7 +135,8 @@ RDMAManager::RDMAManager() ); } else { theMsg()->recvDataMsgBuffer( - get_ptr, msg->mpi_tag_to_recv, msg->send_back, true, [get_ptr_action]{ + msg->nchunks, get_ptr, msg->mpi_tag_to_recv, msg->send_back, true, + [get_ptr_action]{ debug_print( rdma, node, "recv_data_msg_buffer finished\n" @@ -142,7 +144,8 @@ RDMAManager::RDMAManager() if (get_ptr_action) { get_ptr_action(); } - } + }, + nullptr, true ); } } @@ -208,6 +211,7 @@ RDMAManager::RDMAManager() if (put_ptr == nullptr) { theMsg()->recvDataMsg( + msg->nchunks, recv_tag, recv_node, [=](RDMA_GetType ptr, ActionType deleter){ debug_print( rdma, node, @@ -237,7 +241,7 @@ RDMAManager::RDMAManager() msg->offset != no_byte ? static_cast(put_ptr) + msg->offset : put_ptr; // do a direct recv into the user buffer theMsg()->recvDataMsgBuffer( - put_ptr_offset, recv_tag, recv_node, true, []{}, + msg->nchunks, put_ptr_offset, recv_tag, recv_node, true, []{}, [=](RDMA_GetType ptr, ActionType deleter){ debug_print( rdma, node, @@ -685,7 +689,8 @@ void RDMAManager::putData( auto send_payload = [&](Active::SendFnType send){ auto ret = send(RDMA_GetType{ptr, num_bytes}, put_node, no_tag); - msg->mpi_tag_to_recv = std::get<1>(ret); + msg->mpi_tag_to_recv = ret.getTag(); + msg->nchunks = ret.getNumChunks(); }; if (tag != no_tag) { diff --git a/src/vt/rdma/rdma_msg.h b/src/vt/rdma/rdma_msg.h index 8e46310d18..2a985abe6d 100644 --- a/src/vt/rdma/rdma_msg.h +++ b/src/vt/rdma/rdma_msg.h @@ -82,12 +82,13 @@ struct SendDataMessage : ActiveMessage { NodeType const& back = uninitialized_destination, NodeType const& in_recv_node = uninitialized_destination, bool const in_packed_direct = false - ) : ActiveMessage(), - rdma_handle(in_han), send_back(back), recv_node(in_recv_node), - mpi_tag_to_recv(in_mpi_tag), op_id(in_op), num_bytes(in_num_bytes), - offset(in_offset), packed_direct(in_packed_direct) + ) : rdma_handle(in_han), send_back(back), + recv_node(in_recv_node), mpi_tag_to_recv(in_mpi_tag), op_id(in_op), + num_bytes(in_num_bytes), offset(in_offset), + packed_direct(in_packed_direct) { } + int nchunks = 0; RDMA_HandleType rdma_handle = no_rdma_handle; NodeType send_back = uninitialized_destination; NodeType recv_node = uninitialized_destination; diff --git a/src/vt/runtime/runtime.cc b/src/vt/runtime/runtime.cc index fda9441086..44f8805f3c 100644 --- a/src/vt/runtime/runtime.cc +++ b/src/vt/runtime/runtime.cc @@ -874,6 +874,23 @@ void Runtime::printStartupBanner() { } } + // Limit to between 256 B and 1 GiB. If its too small a VT envelope won't fit; + // if its too large we overflow an integer passed to MPI. + if (ArgType::vt_max_mpi_send_size < 256) { + vtAbort("Max size for MPI send must be greater than 256 B"); + } else if (ArgType::vt_max_mpi_send_size > 1ull << 30) { + vtAbort("Max size for MPI send must not be greater than 1 GiB (overflow)"); + } else { + auto const bytes = ArgType::vt_max_mpi_send_size; + auto const ret = util::memory::getBestMemoryUnit(bytes); + auto f_max = fmt::format( + "Splitting messages greater than {} {}", + std::get<1>(ret), std::get<0>(ret) + ); + auto f_max_arg = opt_on("--vt_max_mpi_send_size", f_max); + fmt::print("{}\t{}{}", vt_pre, f_max_arg, reset); + } + if (ArgType::vt_debug_all) { auto f11 = fmt::format("All debug prints are on (if enabled compile-time)"); auto f12 = opt_on("--vt_debug_all", f11); diff --git a/src/vt/serialization/messaging/serialized_data_msg.h b/src/vt/serialization/messaging/serialized_data_msg.h index c3dd8800eb..54abc563ef 100644 --- a/src/vt/serialization/messaging/serialized_data_msg.h +++ b/src/vt/serialization/messaging/serialized_data_msg.h @@ -64,6 +64,7 @@ struct SerializedDataMsgAny : MessageT { HandlerType handler = uninitialized_handler; TagType data_recv_tag = no_tag; NodeType from_node = uninitialized_destination; + int nchunks = 0; }; template diff --git a/src/vt/serialization/messaging/serialized_messenger.impl.h b/src/vt/serialization/messaging/serialized_messenger.impl.h index ad8ebb9b48..dbd5c2d581 100644 --- a/src/vt/serialization/messaging/serialized_messenger.impl.h +++ b/src/vt/serialization/messaging/serialized_messenger.impl.h @@ -91,6 +91,8 @@ template ) { auto const handler = sys_msg->handler; auto const& recv_tag = sys_msg->data_recv_tag; + auto const& nchunks = sys_msg->nchunks; + auto const& len = sys_msg->ptr_size; auto const epoch = envelopeGetEpoch(sys_msg->env); debug_print( @@ -107,8 +109,8 @@ template } auto node = sys_msg->from_node; - theMsg()->recvDataMsg( - recv_tag, sys_msg->from_node, + theMsg()->recvDataDirect( + nchunks, recv_tag, sys_msg->from_node, len, [handler,recv_tag,node,epoch,is_valid_epoch] (RDMA_GetType ptr, ActionType action){ // be careful here not to use "msg", it is no longer valid @@ -352,9 +354,11 @@ template auto sys_msg = makeMessage>(); auto send_serialized = [=](Active::SendFnType send){ auto ret = send(RDMA_GetType{ptr, ptr_size}, dest, no_tag); - EventType event = std::get<0>(ret); + EventType event = ret.getEvent(); theEvent()->attachAction(event, [=]{ std::free(ptr); }); - sys_msg->data_recv_tag = std::get<1>(ret); + sys_msg->data_recv_tag = ret.getTag(); + sys_msg->nchunks = ret.getNumChunks(); + sys_msg->ptr_size = ptr_size; }; auto cur_ref = envelopeGetRef(sys_msg->env); sys_msg->env = msg->env; diff --git a/src/vt/utils/memory/memory_units.cc b/src/vt/utils/memory/memory_units.cc index d952d79755..391814e7c0 100644 --- a/src/vt/utils/memory/memory_units.cc +++ b/src/vt/utils/memory/memory_units.cc @@ -69,4 +69,33 @@ MemoryUnitEnum getUnitFromString(std::string const& unit) { return MemoryUnitEnum::Bytes; } +std::tuple getBestMemoryUnit(std::size_t bytes) { + auto multiplier = static_cast(MemoryUnitEnum::Yottabytes); + for ( ; multiplier > 0; multiplier--) { + auto value_tmp = static_cast(bytes); + for (int8_t i = 0; i < static_cast(multiplier); i++) { + value_tmp /= 1024.0; + } + if (value_tmp >= 1.) { + break; + } + } + + // We found a multiplier that results in a value over 1.0, use it + vtAssert( + multiplier <= static_cast(MemoryUnitEnum::Yottabytes) and + multiplier >= 0, + "Must be a valid memory unit" + ); + + auto unit_name = getMemoryUnitName(static_cast(multiplier)); + + auto new_value = static_cast(bytes); + for (int8_t i = 0; i < static_cast(multiplier); i++) { + new_value /= 1024.0; + } + + return std::make_tuple(unit_name, new_value); +} + }}} /* end namespace vt::util::memory */ diff --git a/src/vt/utils/memory/memory_units.h b/src/vt/utils/memory/memory_units.h index c1699c5b7f..d542a14637 100644 --- a/src/vt/utils/memory/memory_units.h +++ b/src/vt/utils/memory/memory_units.h @@ -60,6 +60,7 @@ enum struct MemoryUnitEnum : int8_t { std::string getMemoryUnitName(MemoryUnitEnum unit); MemoryUnitEnum getUnitFromString(std::string const& unit); +std::tuple getBestMemoryUnit(std::size_t bytes); }}} /* end namespace vt::util::memory */ diff --git a/tests/unit/active/test_active_send_large.cc b/tests/unit/active/test_active_send_large.cc new file mode 100644 index 0000000000..37b6d6b2df --- /dev/null +++ b/tests/unit/active/test_active_send_large.cc @@ -0,0 +1,173 @@ +/* +//@HEADER +// ***************************************************************************** +// +// test_active_send_large.cc +// DARMA Toolkit v. 1.0.0 +// DARMA/vt => Virtual Transport +// +// Copyright 2019 National Technology & Engineering Solutions of Sandia, LLC +// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S. +// Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact darma@sandia.gov +// +// ***************************************************************************** +//@HEADER +*/ + +#include +#include + +#include "test_parallel_harness.h" +#include "data_message.h" + +namespace vt { namespace tests { namespace unit { namespace large { + +struct SerializedTag {}; +struct NonSerializedTag {}; + +using RecvMsg = vt::Message; + +struct CallbackMsg : vt::Message { + vt::Callback cb_; +}; + +template +struct LargeMsg; + +template +struct LargeMsg< + nbytes, + T, + typename std::enable_if_t::value> +> : TestStaticSerialBytesMsg { }; + +template +struct LargeMsg< + nbytes, + T, + typename std::enable_if_t::value> +> : TestStaticBytesMsg { }; + +template +void fillMsg(T msg) { + auto arr = reinterpret_cast(&msg->payload[0]); + for (std::size_t i = 0; i < msg->bytes / sizeof(int64_t); i++) { + arr[i] = i; + } +} + +template +void checkMsg(T msg) { + auto arr = reinterpret_cast(&msg->payload[0]); + for (std::size_t i = 0; i < msg->bytes / sizeof(int64_t); i++) { + EXPECT_EQ(arr[i], i); + } +} + +template +void myHandler(MsgT* m) { + checkMsg(m); + auto msg = makeMessage(); + m->cb_.send(msg.get()); +} + +template +struct TestActiveSendLarge : TestParallelHarness { + using TagType = typename std::tuple_element<1,T>::type; + + // Set max size to 16 KiB for testing + void addAdditionalArgs() override { + new_arg = fmt::format("--vt_max_mpi_send_size=16384"); + addArgs(new_arg); + } + +private: + std::string new_arg; +}; + +TYPED_TEST_SUITE_P(TestActiveSendLarge); + +TYPED_TEST_P(TestActiveSendLarge, test_large_bytes_msg) { + using IntegralType = typename std::tuple_element<0,TypeParam>::type; + using TagType = typename std::tuple_element<1,TypeParam>::type; + + static constexpr NumBytesType const nbytes = 1ll << IntegralType::value; + + using LargeMsgType = LargeMsg; + + NodeType const this_node = theContext()->getNode(); + NodeType const num_nodes = theContext()->getNumNodes(); + + // over two nodes will allocate a lot of memory for the run + if (num_nodes != 2) { + return; + } + + int counter = 0; + auto e = pipe::LifetimeEnum::Once; + auto cb = theCB()->makeFunc(e, [&counter](RecvMsg*){ counter++; }); + + vt::runInEpochCollective([&]{ + NodeType next_node = (this_node + 1) % num_nodes; + auto msg = makeMessage(); + fillMsg(msg); + msg->cb_ = cb; + theMsg()->sendMsg>(next_node, msg); + }); + + EXPECT_EQ(counter, 1); +} + +REGISTER_TYPED_TEST_SUITE_P(TestActiveSendLarge, test_large_bytes_msg); + +using NonSerTestTypes = testing::Types< + std::tuple, NonSerializedTag>, + std::tuple, NonSerializedTag> + // std::tuple, NonSerializedTag> +>; + +using SerTestTypes = testing::Types< + std::tuple, SerializedTag>, + std::tuple, SerializedTag> + // std::tuple, SerializedTag> +>; + +INSTANTIATE_TYPED_TEST_SUITE_P( + test_large_bytes_serialized, TestActiveSendLarge, NonSerTestTypes, + DEFAULT_NAME_GEN +); + +INSTANTIATE_TYPED_TEST_SUITE_P( + test_large_bytes_nonserialized, TestActiveSendLarge, SerTestTypes, + DEFAULT_NAME_GEN +); + +}}}} // end namespace vt::tests::unit::large