diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a25f8aa103..a418b81709 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -20,7 +20,7 @@ set(TOP_LEVEL_SUBDIRS activefn # Add single-directory components context event handler parameterization sequence termination - scheduler worker standalone runtime trace timing demangle + scheduler worker standalone runtime trace timing demangle rdmahandle ) set( PROJECT_SUBDIRS_LIST diff --git a/src/vt/collective/reduce/reduce.h b/src/vt/collective/reduce/reduce.h index aeeebb3e54..3e3f46e3b2 100644 --- a/src/vt/collective/reduce/reduce.h +++ b/src/vt/collective/reduce/reduce.h @@ -59,6 +59,7 @@ #include "vt/messaging/message.h" #include "vt/collective/tree/tree.h" #include "vt/utils/hash/hash_tuple.h" +#include "vt/messaging/pending_send.h" #include #include @@ -80,6 +81,7 @@ namespace vt { namespace collective { namespace reduce { struct Reduce : virtual collective::tree::Tree { using ReduceStateType = ReduceState; using ReduceNumType = typename ReduceStateType::ReduceNumType; + using PendingSendType = messaging::PendingSend; /** * \internal \brief Construct a new reducer instance @@ -106,6 +108,23 @@ struct Reduce : virtual collective::tree::Tree { */ detail::ReduceStamp generateNextID(); + /** + * \brief Reduce a message up the tree, possibly delayed through a pending send + * + * \param[in] root the root node where the final handler provides the result + * \param[in] msg the message to reduce on this node + * \param[in] id the reduction stamp (optional), provided if out-of-order + * \param[in] num_contrib number of expected contributions from this node + * + * \return the pending send corresponding to the reduce + */ + template * f> + PendingSendType reduce( + NodeType root, MsgT* const msg, + detail::ReduceStamp id = detail::ReduceStamp{}, + ReduceNumType num_contrib = 1 + ); + /** * \brief Reduce a message up the tree * @@ -117,7 +136,7 @@ struct Reduce : virtual collective::tree::Tree { * \return the next reduction stamp */ template * f> - detail::ReduceStamp reduce( + detail::ReduceStamp reduceImmediate( NodeType root, MsgT* const msg, detail::ReduceStamp id = detail::ReduceStamp{}, ReduceNumType num_contrib = 1 @@ -141,7 +160,31 @@ struct Reduce : virtual collective::tree::Tree { MsgT, OpT, collective::reduce::operators::ReduceCallback > > - detail::ReduceStamp reduce( + PendingSendType reduce( + NodeType const& root, MsgT* msg, Callback cb, + detail::ReduceStamp id = detail::ReduceStamp{}, + ReduceNumType const& num_contrib = 1 + ); + + /** + * \brief Reduce a message up the tree + * + * \param[in] root the root node where the final handler provides the result + * \param[in] msg the message to reduce on this node + * \param[in] cb the callback to trigger on the root node + * \param[in] id the reduction stamp (optional), provided if out-of-order + * \param[in] num_contrib number of expected contributions from this node + * + * \return the next reduction stamp + */ + template < + typename OpT, + typename MsgT, + ActiveTypedFnType *f = MsgT::template msgHandler< + MsgT, OpT, collective::reduce::operators::ReduceCallback + > + > + detail::ReduceStamp reduceImmediate( NodeType const& root, MsgT* msg, Callback cb, detail::ReduceStamp id = detail::ReduceStamp{}, ReduceNumType const& num_contrib = 1 @@ -163,7 +206,29 @@ struct Reduce : virtual collective::tree::Tree { typename MsgT, ActiveTypedFnType *f = MsgT::template msgHandler > - detail::ReduceStamp reduce( + PendingSendType reduce( + NodeType const& root, MsgT* msg, + detail::ReduceStamp id = detail::ReduceStamp{}, + ReduceNumType const& num_contrib = 1 + ); + + /** + * \brief Reduce a message up the tree with a target function on the root node + * + * \param[in] root the root node where the final handler provides the result + * \param[in] msg the message to reduce on this node + * \param[in] id the reduction stamp (optional), provided if out-of-order + * \param[in] num_contrib number of expected contributions from this node + * + * \return the next reduction stamp + */ + template < + typename OpT, + typename FunctorT, + typename MsgT, + ActiveTypedFnType *f = MsgT::template msgHandler + > + detail::ReduceStamp reduceImmediate( NodeType const& root, MsgT* msg, detail::ReduceStamp id = detail::ReduceStamp{}, ReduceNumType const& num_contrib = 1 diff --git a/src/vt/collective/reduce/reduce.impl.h b/src/vt/collective/reduce/reduce.impl.h index cdf16080ac..1f4728df17 100644 --- a/src/vt/collective/reduce/reduce.impl.h +++ b/src/vt/collective/reduce/reduce.impl.h @@ -80,7 +80,7 @@ void Reduce::reduceRootRecv(MsgT* msg) { } template *f> -detail::ReduceStamp Reduce::reduce( +Reduce::PendingSendType Reduce::reduce( NodeType const& root, MsgT* msg, Callback cb, detail::ReduceStamp id, ReduceNumType const& num_contrib ) { @@ -88,18 +88,48 @@ detail::ReduceStamp Reduce::reduce( return reduce(root,msg,id,num_contrib); } +template *f> +detail::ReduceStamp Reduce::reduceImmediate( + NodeType const& root, MsgT* msg, Callback cb, detail::ReduceStamp id, + ReduceNumType const& num_contrib +) { + msg->setCallback(cb); + return reduceImmediate(root,msg,id,num_contrib); +} + template < typename OpT, typename FunctorT, typename MsgT, ActiveTypedFnType *f > -detail::ReduceStamp Reduce::reduce( +Reduce::PendingSendType Reduce::reduce( NodeType const& root, MsgT* msg, detail::ReduceStamp id, ReduceNumType const& num_contrib ) { return reduce(root,msg,id,num_contrib); } +template < + typename OpT, typename FunctorT, typename MsgT, ActiveTypedFnType *f +> +detail::ReduceStamp Reduce::reduceImmediate( + NodeType const& root, MsgT* msg, detail::ReduceStamp id, + ReduceNumType const& num_contrib +) { + return reduceImmediate(root,msg,id,num_contrib); +} + +template * f> +Reduce::PendingSendType Reduce::reduce( + NodeType root, MsgT* const msg, detail::ReduceStamp id, + ReduceNumType num_contrib +) { + auto msg_ptr = promoteMsg(msg); + return PendingSendType{theMsg()->getEpochContextMsg(msg_ptr), [=](){ + reduceImmediate(root, msg_ptr.get(), id, num_contrib); + } }; +} + template * f> -detail::ReduceStamp Reduce::reduce( +detail::ReduceStamp Reduce::reduceImmediate( NodeType root, MsgT* const msg, detail::ReduceStamp id, ReduceNumType num_contrib ) { diff --git a/src/vt/messaging/collection_chain_set.h b/src/vt/messaging/collection_chain_set.h index 175e1bed7e..33963612bf 100644 --- a/src/vt/messaging/collection_chain_set.h +++ b/src/vt/messaging/collection_chain_set.h @@ -68,7 +68,7 @@ namespace vt { namespace messaging { */ template class CollectionChainSet final { - public: + public: CollectionChainSet() = default; CollectionChainSet(const CollectionChainSet&) = delete; CollectionChainSet(CollectionChainSet&&) = delete; @@ -81,7 +81,9 @@ class CollectionChainSet final { * \param[in] idx the index to add */ void addIndex(Index idx) { - vtAssert(chains_.find(idx) == chains_.end(), "Cannot add an already-present chain"); + vtAssert( + chains_.find(idx) == chains_.end(), + "Cannot add an already-present chain"); chains_[idx] = DependentSendChain(); } @@ -98,16 +100,17 @@ class CollectionChainSet final { void removeIndex(Index idx) { auto iter = chains_.find(idx); vtAssert(iter != chains_.end(), "Cannot remove a non-present chain"); - vtAssert(iter->second.isTerminated(), "Cannot remove a chain with pending work"); + vtAssert( + iter->second.isTerminated(), "Cannot remove a chain with pending work"); chains_.erase(iter); } /** - * \brief The next step to execute on all the chains resident in this + * \brief The next step to execute on all the chain indices in this * collection chain set * - * Goes through every resident chain and enqueues the action at the end of the + * Goes through every chain index and enqueues the action at the end of the * current chain when the preceding steps terminate. Creates a new rooted * epoch for this step to contain/track completion of all the causally related * messages. @@ -117,15 +120,15 @@ class CollectionChainSet final { * \c PendingSend */ void nextStep( - std::string const& label, std::function step_action - ) { - for (auto &entry : chains_) { + std::string const& label, std::function step_action) { + for (auto& entry : chains_) { auto& idx = entry.first; auto& chain = entry.second; // The parameter `true` here tells VT to use an efficient rooted DS-epoch // by default. This can still be overridden by command-line flags - EpochType new_epoch = theTerm()->makeEpochRooted(label, term::UseDS{true}); + EpochType new_epoch = + theTerm()->makeEpochRooted(label, term::UseDS{true}); vt::theMsg()->pushEpoch(new_epoch); chain.add(new_epoch, step_action(idx)); @@ -136,9 +139,14 @@ class CollectionChainSet final { } /** - * \brief The next step to execute on all the chains resident in this + * \brief The next step to execute on all the chain indices in this * collection chain set * + * Goes through every chain index and enqueues the action at the end of the + * current chain when the preceding steps terminate. Creates a new rooted + * epoch for this step to contain/track completion of all the causally related + * messages. + * * \param[in] step_action The action to perform as a function that returns a * \c PendingSend */ @@ -146,45 +154,24 @@ class CollectionChainSet final { return nextStep("", step_action); } -#if 0 - void nextStep(std::function step_action) { - for (auto &entry : chains_) { - auto& idx = entry.first; - auto& chain = entry.second; - chain.add(step_action(idx)); - } - } - - void nextStepConcurrent(std::vector> step_actions) { - for (auto &entry : chains_) { - auto& idx = entry.first; - auto& chain = entry.second; - chain.add(step_actions[0](idx)); - for (int i = 1; i < step_actions.size(); ++i) - chain.addConcurrent(step_actions[i](idx)); - } - } -#endif - /** - * \brief The next collective step to execute across all resident elements - * across all nodes. + * \brief The next collective step to execute for each index that is added + * to the CollectionChainSet on each node. * * Should be used for steps with internal recursive communication and global * inter-dependence. Creates a global (on the communicator), collective epoch * to track all the casually related messages and collectively wait for - * termination of all of the recursive sends.. + * termination of all of the recursive sends. * * \param[in] label Label for the epoch created for debugging * \param[in] step_action the next step to execute, returning a \c PendingSend */ void nextStepCollective( - std::string const& label, std::function step_action - ) { + std::string const& label, std::function step_action) { auto epoch = theTerm()->makeEpochCollective(label); vt::theMsg()->pushEpoch(epoch); - for (auto &entry : chains_) { + for (auto& entry : chains_) { auto& idx = entry.first; auto& chain = entry.second; chain.add(epoch, step_action(idx)); @@ -195,8 +182,13 @@ class CollectionChainSet final { } /** - * \brief The next collective step to execute across all resident elements - * across all nodes. + * \brief The next collective step to execute for each index that is added + * to the CollectionChainSet on each node. + * + * Should be used for steps with internal recursive communication and global + * inter-dependence. Creates a global (on the communicator), collective epoch + * to track all the casually related messages and collectively wait for + * termination of all of the recursive sends. * * \param[in] step_action the next step to execute, returning a \c PendingSend */ @@ -204,12 +196,74 @@ class CollectionChainSet final { return nextStepCollective("", step_action); } + /** + * \brief The next collective step of both CollectionChainSets + * to execute over all shared indices of the CollectionChainSets over all + * nodes. + * + * This function ensures that the step is dependent on the previous step + * of both chainsets a and b. Additionally any additional steps in each + * chainset will occur after the merged step. + * + * \pre Each index in CollectionChainset a must exist in CollectionChainset b + * + * \param[in] a the first chainset + * \param[in] b the second chainset + * \param[in] step_action the next step to be executed, dependent on the + * previous step of chainsets a and b + */ + static void mergeStepCollective( + CollectionChainSet& a, CollectionChainSet& b, + std::function step_action) { + mergeStepCollective("", a, b, step_action); + } + + /** + * \brief The next collective step of both CollectionChainSets + * to execute over all shared indices of the CollectionChainSets over all + * nodes. + * + * This function ensures that the step is dependent on the previous step + * of both chainsets a and b. Additionally any additional steps in each + * chainset will occur after the merged step. + * + * \pre Each index in CollectionChainset a must exist in CollectionChainset b + * + * \param[in] label the label for the step + * \param[in] a the first chainset + * \param[in] b the second chainset + * \param[in] step_action the next step to be executed, dependent on the + * previous step of chainsets a and b + */ + static void mergeStepCollective( + std::string const& label, CollectionChainSet& a, CollectionChainSet& b, + std::function step_action) { + auto epoch = theTerm()->makeEpochCollective(label); + vt::theMsg()->pushEpoch(epoch); + + for (auto& entry : a.chains_) { + auto& idx = entry.first; + auto& chaina = entry.second; + auto chainb_pos = b.chains_.find(entry.first); + vtAssert( + chainb_pos != b.chains_.end(), + fmt::format("index {} must be present in chainset b", entry.first)); + + auto& chainb = chainb_pos->second; + DependentSendChain::mergeChainStep( + chaina, chainb, epoch, step_action(idx)); + } + + vt::theMsg()->popEpoch(epoch); + theTerm()->finishedEpoch(epoch); + } + /** * \brief Indicate that the current phase is complete. Resets the state on * each \c DependentSendChain */ void phaseDone() { - for (auto &entry : chains_) { + for (auto& entry : chains_) { entry.second.done(); } } @@ -219,7 +273,7 @@ class CollectionChainSet final { */ std::unordered_set getSet() { std::unordered_set index_set; - for (auto &entry : chains_) { + for (auto& entry : chains_) { index_set.emplace(entry.first); } return index_set; @@ -228,13 +282,13 @@ class CollectionChainSet final { /** * \brief Run a lambda immediately on each element in the index set */ - void foreach(std::function fn) { - for (auto &entry : chains_) { + void foreach (std::function fn) { + for (auto& entry : chains_) { fn(entry.first); } } - private: + private: std::unordered_map chains_; }; diff --git a/src/vt/messaging/dependent_send_chain.h b/src/vt/messaging/dependent_send_chain.h index 52a5fb6bf5..dfca1c88e9 100644 --- a/src/vt/messaging/dependent_send_chain.h +++ b/src/vt/messaging/dependent_send_chain.h @@ -84,6 +84,32 @@ struct PendingClosure { PendingSend pending_; /**< The \c PendingSend to be released */ }; +/** + * \struct MergedClosure dependent_send_chain.h vt/messaging/dependent_send_chain.h + * + * \brief A copyable closure that holds a \c PendingSend that will be released + * when all shared instances of this closure are destroyed. + */ +struct MergedClosure { + /** + * \brief Construct from a shared pointer to a \c PendingSend + * \param[in] shared_state the \c PendingSend that will be released + */ + explicit MergedClosure(std::shared_ptr shared_state) + : shared_state_(shared_state) + {} + MergedClosure(MergedClosure const&) = default; + MergedClosure(MergedClosure&& in) = default; + + void operator()() { + shared_state_.reset(); + } + +private: + + std::shared_ptr shared_state_; +}; + /** * \struct DependentSendChain dependent_send_chain.h vt/messaging/dependent_send_chain.h * @@ -115,6 +141,34 @@ class DependentSendChain final { last_epoch_ = new_epoch; } + /** + * \brief Add a task that is dependent on two DependentSendChain instances + * + * \param[in] a the first DependentSendChain + * \param[in] b the second DependentSendChain + * \param[in] new_epoch the epoch the task is being added to + * \param[in] link the \c PendingSend to release when complete + */ + static void mergeChainStep(DependentSendChain &a, DependentSendChain &b, + EpochType new_epoch, PendingSend&& link) { + a.checkInit(); + b.checkInit(); + + theTerm()->addDependency(a.last_epoch_, new_epoch); + theTerm()->addDependency(b.last_epoch_, new_epoch); + + auto c1 = MergedClosure(std::make_shared(std::move(link))); + auto c2 = c1; + + // closure is intentionally copied here; basically the ref count will go down + // when all actions are completed and execute the PendingSend + theTerm()->addActionUnique(a.last_epoch_, std::move(c1)); + theTerm()->addActionUnique(b.last_epoch_, std::move(c2)); + + a.last_epoch_ = new_epoch; + b.last_epoch_ = new_epoch; + } + /** * \brief Indicate that the chain is complete and should be reset */ diff --git a/src/vt/messaging/pending_send.cc b/src/vt/messaging/pending_send.cc index 885424fd63..0b268fd150 100644 --- a/src/vt/messaging/pending_send.cc +++ b/src/vt/messaging/pending_send.cc @@ -47,6 +47,22 @@ namespace vt { namespace messaging { +PendingSend::PendingSend(EpochType ep, EpochActionType const& in_action) + : epoch_action_{in_action}, epoch_produced_(ep) { + if (epoch_produced_ != no_epoch) { + theTerm()->produce(epoch_produced_, 1); + } +} + +PendingSend::PendingSend(PendingSend&& in) noexcept + : msg_size_(std::move(in.msg_size_)), + epoch_produced_(std::move(in.epoch_produced_)) +{ + std::swap(msg_, in.msg_); + std::swap(epoch_action_, in.epoch_action_); + std::swap(send_action_, in.send_action_); +} + void PendingSend::sendMsg() { if (send_action_ == nullptr) { theMsg()->doMessageSend(msg_, msg_size_); @@ -58,7 +74,7 @@ void PendingSend::sendMsg() { send_action_ = nullptr; } -EpochType PendingSend::getProduceEpoch() const { +EpochType PendingSend::getProduceEpochFromMsg() const { if (msg_ == nullptr or envelopeIsTerm(msg_->env) or not envelopeIsEpochType(msg_->env)) { return no_epoch; @@ -67,9 +83,8 @@ EpochType PendingSend::getProduceEpoch() const { return envelopeGetEpoch(msg_->env); } - void PendingSend::produceMsg() { - epoch_produced_ = getProduceEpoch(); + epoch_produced_ = getProduceEpochFromMsg(); if (epoch_produced_ != no_epoch) { theTerm()->produce(epoch_produced_, 1); } @@ -81,4 +96,16 @@ void PendingSend::consumeMsg() { } } +void PendingSend::release() { + bool send_msg = msg_ != nullptr || send_action_ != nullptr; + vtAssert(!send_msg || !epoch_action_, "cannot have both a message and epoch action"); + if (send_msg) { + sendMsg(); + } else if ( epoch_action_ ) { + epoch_action_(); + epoch_action_ = {}; + consumeMsg(); + } +} + }} diff --git a/src/vt/messaging/pending_send.h b/src/vt/messaging/pending_send.h index 474e45c306..01294e7d95 100644 --- a/src/vt/messaging/pending_send.h +++ b/src/vt/messaging/pending_send.h @@ -71,6 +71,7 @@ namespace vt { namespace messaging { struct PendingSend final { /// Function for complex action on send---takes a message to operate on using SendActionType = std::function&)>; + using EpochActionType = std::function; /** * \brief Construct a pending send. @@ -137,36 +138,11 @@ struct PendingSend final { produceMsg(); } - /** - * \brief Get the epoch produced when holder was created - * - * This is required because the epoch on the envelope can change in some cases - * in between when this is created and actually released. - * - * \return the produce epoch - */ - EpochType getProduceEpoch() const; - - /** - * \brief Produce on the messages epoch to inhibit early termination - */ - void produceMsg(); - - /** - * \brief Consume on the messages epoch to inhibit early termination - */ - void consumeMsg(); + PendingSend(EpochType ep, EpochActionType const& in_action); explicit PendingSend(std::nullptr_t) { } - PendingSend(PendingSend&& in) - : msg_(std::move(in.msg_)), - msg_size_(std::move(in.msg_size_)), - send_action_(std::move(in.send_action_)), - epoch_produced_(std::move(in.epoch_produced_)) - { - in.msg_ = nullptr; - in.send_action_ = nullptr; - } + PendingSend(PendingSend&& in) noexcept; + PendingSend(const PendingSend&) = delete; PendingSend& operator=(PendingSend&& in) = delete; PendingSend& operator=(PendingSend& in) = delete; @@ -189,13 +165,30 @@ struct PendingSend final { /** * \brief Release the message, run action if needed */ - void release() { - if (msg_ != nullptr || send_action_ != nullptr) { - sendMsg(); - } - } + void release(); private: + + /** + * \brief Get the epoch produced when holder was created + * + * This is required because the epoch on the envelope can change in some cases + * in between when this is created and actually released. + * + * \return the produce epoch + */ + EpochType getProduceEpochFromMsg() const; + + /** + * \brief Produce on the messages epoch to inhibit early termination + */ + void produceMsg(); + + /** + * \brief Consume on the messages epoch to inhibit early termination + */ + void consumeMsg(); + /// Send the message saved directly or trigger the lambda for /// specialized sends from the pending holder void sendMsg(); @@ -203,7 +196,8 @@ struct PendingSend final { private: MsgPtr msg_ = nullptr; ByteType msg_size_ = no_byte; - SendActionType send_action_ = nullptr; + SendActionType send_action_ = {}; + EpochActionType epoch_action_ = {}; EpochType epoch_produced_ = no_epoch; }; diff --git a/src/vt/objgroup/manager.h b/src/vt/objgroup/manager.h index 0214fb2293..0082dcd510 100644 --- a/src/vt/objgroup/manager.h +++ b/src/vt/objgroup/manager.h @@ -56,6 +56,7 @@ #include "vt/objgroup/dispatch/dispatch.h" #include "vt/messaging/message/message.h" #include "vt/messaging/message/smart_ptr.h" +#include "vt/messaging/pending_send.h" #include #include @@ -93,6 +94,7 @@ struct ObjGroupManager : runtime::component::Component { using DispatchBasePtrType = std::unique_ptr; using MsgContainerType = std::vector>; using BaseProxyListType = std::set; + using PendingSendType = messaging::PendingSend; /** * \internal \brief Construct the ObjGroupManager @@ -231,9 +233,11 @@ struct ObjGroupManager : runtime::component::Component { * \param[in] proxy proxy to the object group * \param[in] msg reduction message * \param[in] stamp stamp to identify reduction across nodes + * + * \return the PendingSend corresponding to the reduce */ template *f> - void reduce( + PendingSendType reduce( ProxyType proxy, MsgSharedPtr msg, collective::reduce::ReduceStamp const& stamp ); diff --git a/src/vt/objgroup/manager.impl.h b/src/vt/objgroup/manager.impl.h index 99e957805d..43c196519d 100644 --- a/src/vt/objgroup/manager.impl.h +++ b/src/vt/objgroup/manager.impl.h @@ -287,7 +287,7 @@ void ObjGroupManager::broadcast(MsgSharedPtr msg, HandlerType han) { } template *f> -void ObjGroupManager::reduce( +ObjGroupManager::PendingSendType ObjGroupManager::reduce( ProxyType proxy, MsgSharedPtr msg, collective::reduce::ReduceStamp const& stamp ) { @@ -295,7 +295,7 @@ void ObjGroupManager::reduce( auto const objgroup = proxy.getProxy(); auto r = theCollective()->getReducerObjGroup(objgroup); - r->template reduce(root, msg.get(), stamp); + return r->template reduce(root, msg.get(), stamp); } template diff --git a/src/vt/objgroup/proxy/proxy_objgroup.h b/src/vt/objgroup/proxy/proxy_objgroup.h index ea167ca232..7b8c2e36a4 100644 --- a/src/vt/objgroup/proxy/proxy_objgroup.h +++ b/src/vt/objgroup/proxy/proxy_objgroup.h @@ -58,6 +58,7 @@ #include "vt/utils/static_checks/msg_ptr.h" #include "vt/rdmahandle/handle.fwd.h" #include "vt/rdmahandle/handle_set.fwd.h" +#include "vt/messaging/pending_send.h" namespace vt { namespace objgroup { namespace proxy { @@ -76,6 +77,8 @@ template struct Proxy { using ReduceStamp = collective::reduce::ReduceStamp; + using PendingSendType = messaging::PendingSend; + Proxy() = default; Proxy(Proxy const&) = default; Proxy(Proxy&&) = default; @@ -126,6 +129,8 @@ struct Proxy { * \param[in] msg the reduction message * \param[in] cb the callback to trigger after the reduction is finished * \param[in] stamp the stamp to identify the reduction + * + * \return the PendingSend associated with the reduce */ template < typename OpT = collective::None, @@ -135,7 +140,7 @@ struct Proxy { MsgT, OpT, collective::reduce::operators::ReduceCallback > > - void reduce( + PendingSendType reduce( MsgPtrT msg, Callback cb, ReduceStamp stamp = ReduceStamp{} ) const; @@ -145,6 +150,8 @@ struct Proxy { * * \param[in] msg the reduction message * \param[in] stamp the stamp to identify the reduction + * + * \return the PendingSend associated with the reduce */ template < typename OpT = collective::None, @@ -153,7 +160,7 @@ struct Proxy { typename MsgT = typename util::MsgPtrType::MsgType, ActiveTypedFnType *f = MsgT::template msgHandler > - void reduce(MsgPtrT msg, ReduceStamp stamp = ReduceStamp{}) const; + PendingSendType reduce(MsgPtrT msg, ReduceStamp stamp = ReduceStamp{}) const; /** * \brief Reduce over the objgroup instance on each node with target specified @@ -161,13 +168,15 @@ struct Proxy { * * \param[in] msg the reduction message * \param[in] stamp the stamp to identify the reduction + * + * \return the PendingSend associated with the reduce */ template < typename MsgPtrT, typename MsgT = typename util::MsgPtrType::MsgType, ActiveTypedFnType *f > - void reduce(MsgPtrT msg, ReduceStamp stamp = ReduceStamp{}) const; + PendingSendType reduce(MsgPtrT msg, ReduceStamp stamp = ReduceStamp{}) const; /** * \brief Get raw pointer to the local object instance residing on the current diff --git a/src/vt/objgroup/proxy/proxy_objgroup.impl.h b/src/vt/objgroup/proxy/proxy_objgroup.impl.h index cd19d975f2..1a44d77887 100644 --- a/src/vt/objgroup/proxy/proxy_objgroup.impl.h +++ b/src/vt/objgroup/proxy/proxy_objgroup.impl.h @@ -80,7 +80,7 @@ template template < typename OpT, typename MsgPtrT, typename MsgT, ActiveTypedFnType *f > -void Proxy::reduce( +typename Proxy::PendingSendType Proxy::reduce( MsgPtrT inmsg, Callback cb, ReduceStamp stamp ) const { auto proxy = Proxy(*this); @@ -97,7 +97,7 @@ template < typename OpT, typename FunctorT, typename MsgPtrT, typename MsgT, ActiveTypedFnType *f > -void Proxy::reduce(MsgPtrT inmsg, ReduceStamp stamp) const { +typename Proxy::PendingSendType Proxy::reduce(MsgPtrT inmsg, ReduceStamp stamp) const { auto proxy = Proxy(*this); MsgPtr msg = promoteMsg(static_cast(inmsg)); return theObjGroup()->reduce(proxy,msg,stamp); @@ -105,7 +105,7 @@ void Proxy::reduce(MsgPtrT inmsg, ReduceStamp stamp) const { template template *f> -void Proxy::reduce(MsgPtrT inmsg, ReduceStamp stamp) const { +typename Proxy::PendingSendType Proxy::reduce(MsgPtrT inmsg, ReduceStamp stamp) const { auto proxy = Proxy(*this); MsgPtr msg = promoteMsg(inmsg); return theObjGroup()->reduce(proxy,msg,stamp); diff --git a/src/vt/vrt/collection/manager.h b/src/vt/vrt/collection/manager.h index 52c12d7ce6..718112013e 100644 --- a/src/vt/vrt/collection/manager.h +++ b/src/vt/vrt/collection/manager.h @@ -750,9 +750,11 @@ struct CollectionManager * the associated handler (if a callback is specified on a particular node, * the root will run the handler that triggers the callback at the appropriate * location) + * + * \return a PendingSend corresponding to the reduce */ template *f> - void reduceMsg( + messaging::PendingSend reduceMsg( CollectionProxyWrapType const& proxy, MsgT *const msg, ReduceStamp stamp = ReduceStamp{}, NodeType root_node = uninitialized_destination @@ -765,9 +767,11 @@ struct CollectionManager * \param[in] msg the reduce message * \param[in] stamp the reduce stamp * \param[in] idx the index of collection element being reduced + * + * \return a PendingSend corresponding to the reduce */ template *f> - void reduceMsg( + messaging::PendingSend reduceMsg( CollectionProxyWrapType const& proxy, MsgT *const msg, ReduceStamp stamp, typename ColT::IndexType const& idx ); @@ -784,9 +788,11 @@ struct CollectionManager * the associated handler (if a callback is specified on a particular node, * the root will run the handler that triggers the callback at the appropriate * location) + * + * \return a PendingSend corresponding to the reduce */ template *f> - void reduceMsgExpr( + messaging::PendingSend reduceMsgExpr( CollectionProxyWrapType const& proxy, MsgT *const msg, ReduceIdxFuncType expr_fn, ReduceStamp stamp = ReduceStamp{}, @@ -802,9 +808,11 @@ struct CollectionManager * \param[in] expr_fn expression function to pick indices * \param[in] stamp the reduce stamp * \param[in] idx the index of collection element being reduced + * + * \return a PendingSend corresponding to the reduce */ template *f> - void reduceMsgExpr( + messaging::PendingSend reduceMsgExpr( CollectionProxyWrapType const& proxy, MsgT *const msg, ReduceIdxFuncType expr_fn, ReduceStamp stamp, typename ColT::IndexType const& idx diff --git a/src/vt/vrt/collection/manager.impl.h b/src/vt/vrt/collection/manager.impl.h index 39c7e15e10..770d80f766 100644 --- a/src/vt/vrt/collection/manager.impl.h +++ b/src/vt/vrt/collection/manager.impl.h @@ -1054,7 +1054,7 @@ messaging::PendingSend CollectionManager::broadcastMsgUntypedHandler( } template *f> -void CollectionManager::reduceMsgExpr( +messaging::PendingSend CollectionManager::reduceMsgExpr( CollectionProxyWrapType const& proxy, MsgT *const raw_msg, ReduceIdxFuncType expr_fn, ReduceStamp stamp, NodeType root @@ -1071,7 +1071,7 @@ void CollectionManager::reduceMsgExpr( auto const col_proxy = proxy.getProxy(); auto const cur_epoch = theMsg()->getEpochContextMsg(msg); - bufferOpOrExecute( + return bufferOpOrExecute( col_proxy, BufferTypeEnum::Reduce, static_cast( @@ -1115,7 +1115,7 @@ void CollectionManager::reduceMsgExpr( r = theCollective()->getReducerVrtProxy(col_proxy); } - auto ret_stamp = r->reduce(root_node, msg.get(), cur_stamp, num_elms); + auto ret_stamp = r->reduceImmediate(root_node, msg.get(), cur_stamp, num_elms); vt_debug_print( vrt_coll, node, @@ -1143,7 +1143,7 @@ void CollectionManager::reduceMsgExpr( } template *f> -void CollectionManager::reduceMsg( +messaging::PendingSend CollectionManager::reduceMsg( CollectionProxyWrapType const& proxy, MsgT *const msg, ReduceStamp stamp, NodeType root ) { @@ -1151,7 +1151,7 @@ void CollectionManager::reduceMsg( } template *f> -void CollectionManager::reduceMsg( +messaging::PendingSend CollectionManager::reduceMsg( CollectionProxyWrapType const& proxy, MsgT *const msg, ReduceStamp stamp, typename ColT::IndexType const& idx ) { @@ -1159,7 +1159,7 @@ void CollectionManager::reduceMsg( } template *f> -void CollectionManager::reduceMsgExpr( +messaging::PendingSend CollectionManager::reduceMsgExpr( CollectionProxyWrapType const& proxy, MsgT *const msg, ReduceIdxFuncType expr_fn, ReduceStamp stamp, typename ColT::IndexType const& idx diff --git a/src/vt/vrt/collection/reducable/reducable.h b/src/vt/vrt/collection/reducable/reducable.h index ababfaa00e..40a2f7709f 100644 --- a/src/vt/vrt/collection/reducable/reducable.h +++ b/src/vt/vrt/collection/reducable/reducable.h @@ -75,7 +75,7 @@ struct Reducable : BaseProxyT { MsgT, OpT, collective::reduce::operators::ReduceCallback > > - void reduce( + messaging::PendingSend reduce( MsgT *const msg, Callback cb, ReduceStamp stamp = ReduceStamp{} ) const; @@ -85,25 +85,25 @@ struct Reducable : BaseProxyT { typename MsgT, ActiveTypedFnType *f = MsgT::template msgHandler > - void reduce(MsgT *const msg, ReduceStamp stamp = ReduceStamp{}) const; + messaging::PendingSend reduce(MsgT *const msg, ReduceStamp stamp = ReduceStamp{}) const; template *f> - void reduce( + messaging::PendingSend reduce( MsgT *const msg, ReduceStamp stamp = ReduceStamp{}, NodeType const& node = uninitialized_destination ) const; template *f> - void reduceExpr( + messaging::PendingSend reduceExpr( MsgT *const msg, ReduceIdxFuncType fn, ReduceStamp stamp = ReduceStamp{}, NodeType const& node = uninitialized_destination ) const; template *f> - void reduce(MsgT *const msg, ReduceStamp stamp, IndexT const& idx) const; + messaging::PendingSend reduce(MsgT *const msg, ReduceStamp stamp, IndexT const& idx) const; template *f> - void reduceExpr( + messaging::PendingSend reduceExpr( MsgT *const msg, ReduceIdxFuncType fn, ReduceStamp stamp, IndexT const& idx ) const; }; diff --git a/src/vt/vrt/collection/reducable/reducable.impl.h b/src/vt/vrt/collection/reducable/reducable.impl.h index 1b7a2d30e4..cb9b8c73a1 100644 --- a/src/vt/vrt/collection/reducable/reducable.impl.h +++ b/src/vt/vrt/collection/reducable/reducable.impl.h @@ -60,7 +60,7 @@ Reducable::Reducable(VirtualProxyType const in_proxy) template template *f> -void Reducable::reduce( +messaging::PendingSend Reducable::reduce( MsgT *const msg, Callback cb, ReduceStamp stamp ) const { auto const proxy = this->getProxy(); @@ -76,7 +76,7 @@ void Reducable::reduce( template template *f> -void Reducable::reduce( +messaging::PendingSend Reducable::reduce( MsgT *const msg, ReduceStamp stamp ) const { auto const proxy = this->getProxy(); @@ -86,7 +86,7 @@ void Reducable::reduce( template template *f> -void Reducable::reduce( +messaging::PendingSend Reducable::reduce( MsgT *const msg, ReduceStamp stamp, NodeType const& node ) const { auto const proxy = this->getProxy(); @@ -95,7 +95,7 @@ void Reducable::reduce( template template *f> -void Reducable::reduceExpr( +messaging::PendingSend Reducable::reduceExpr( MsgT *const msg, ReduceIdxFuncType fn, ReduceStamp stamp, NodeType const& node ) const { auto const proxy = this->getProxy(); @@ -104,7 +104,7 @@ void Reducable::reduceExpr( template template *f> -void Reducable::reduce( +messaging::PendingSend Reducable::reduce( MsgT *const msg, ReduceStamp stamp, IndexT const& idx ) const { auto const proxy = this->getProxy(); @@ -113,7 +113,7 @@ void Reducable::reduce( template template *f> -void Reducable::reduceExpr( +messaging::PendingSend Reducable::reduceExpr( MsgT *const msg, ReduceIdxFuncType fn, ReduceStamp stamp, IndexT const& idx ) const { auto const proxy = this->getProxy(); diff --git a/tests/unit/termination/test_term_chaining.cc b/tests/unit/termination/test_term_chaining.cc index 82e802edb0..9b5f329910 100644 --- a/tests/unit/termination/test_term_chaining.cc +++ b/tests/unit/termination/test_term_chaining.cc @@ -63,6 +63,14 @@ struct TestTermChaining : TestParallelHarness { static vt::messaging::DependentSendChain chain; static vt::EpochType epoch; + struct ChainReduceMsg : vt::collective::ReduceNoneMsg { + ChainReduceMsg(int in_num) + : num(in_num) + {} + + int num = 0; + }; + static void test_handler_reflector(TestMsg* msg) { fmt::print("reflector run\n"); @@ -98,6 +106,27 @@ struct TestTermChaining : TestParallelHarness { handler_count = 4; } + static void test_handler_set(TestMsg* msg) { + handler_count = 1; + } + + static void test_handler_reduce(ChainReduceMsg *msg) { + EXPECT_EQ(theContext()->getNode(), 0); + EXPECT_EQ(handler_count, 1); + auto n = theContext()->getNumNodes(); + EXPECT_EQ(msg->num, n * (n - 1)/2); + handler_count = 2; + } + + static void test_handler_bcast(TestMsg* msg) { + if (theContext()->getNode() == 0) { + EXPECT_EQ(handler_count, 2); + } else { + EXPECT_EQ(handler_count, 12); + } + handler_count = 3; + } + static void start_chain() { EpochType epoch1 = theTerm()->makeEpochRooted(); vt::theMsg()->pushEpoch(epoch1); @@ -119,6 +148,53 @@ struct TestTermChaining : TestParallelHarness { chain.done(); } + + static void chain_reduce() { + auto node = theContext()->getNode(); + + if (0 == node) { + EpochType epoch1 = theTerm()->makeEpochRooted(); + vt::theMsg()->pushEpoch(epoch1); + auto msg = makeMessage(); + chain.add( + epoch1, theMsg()->sendMsg(1, msg.get())); + vt::theMsg()->popEpoch(epoch1); + vt::theTerm()->finishedEpoch(epoch1); + } + + EpochType epoch2 = theTerm()->makeEpochCollective(); + vt::theMsg()->pushEpoch(epoch2); + auto msg2 = makeMessage(theContext()->getNode()); + auto cb = vt::theCB()->makeSend( 0 ); + chain.add(epoch2, theCollective()->global()->reduce< vt::collective::None >(0, msg2.get(), cb)); + vt::theMsg()->popEpoch(epoch2); + vt::theTerm()->finishedEpoch(epoch2); + + // Broadcast from both nodes, bcast wont send to itself + EpochType epoch3 = theTerm()->makeEpochRooted(); + vt::theMsg()->pushEpoch(epoch3); + auto msg3 = makeMessage(); + chain.add( + epoch3, theMsg()->broadcastMsg(msg3.get())); + vt::theMsg()->popEpoch(epoch3); + vt::theTerm()->finishedEpoch(epoch3); + + chain.done(); + } + + static void chain_reduce_single() { + handler_count = 1; + + EpochType epoch2 = theTerm()->makeEpochRooted(); + vt::theMsg()->pushEpoch(epoch2); + auto msg2 = makeMessage(theContext()->getNode()); + auto cb = vt::theCB()->makeSend( 0 ); + chain.add(epoch2, theCollective()->global()->reduce< vt::collective::None >(0, msg2.get(), cb)); + vt::theMsg()->popEpoch(epoch2); + vt::theTerm()->finishedEpoch(epoch2); + + chain.done(); + } }; /*static*/ int32_t TestTermChaining::handler_count = 0; @@ -134,6 +210,8 @@ TEST_F(TestTermChaining, test_termination_chaining_1) { epoch = theTerm()->makeEpochCollective(); + handler_count = 0; + fmt::print("global collective epoch {:x}\n", epoch); if (this_node == 0) { @@ -155,4 +233,19 @@ TEST_F(TestTermChaining, test_termination_chaining_1) { } } +TEST_F(TestTermChaining, test_termination_chaining_collective_1) { + auto const& num_nodes = theContext()->getNumNodes(); + + chain = vt::messaging::DependentSendChain{}; + handler_count = 0; + + if (num_nodes == 2) { + vt::runInEpochCollective( chain_reduce ); + EXPECT_EQ(handler_count, 3); + } else if (num_nodes == 1) { + vt::runInEpochCollective( chain_reduce_single ); + EXPECT_EQ(handler_count, 2); + } +} + }}} // end namespace vt::tests::unit diff --git a/tests/unit/termination/test_term_dep_send_chain.cc b/tests/unit/termination/test_term_dep_send_chain.cc index a8bdc30a2b..8066faa08c 100644 --- a/tests/unit/termination/test_term_dep_send_chain.cc +++ b/tests/unit/termination/test_term_dep_send_chain.cc @@ -508,6 +508,184 @@ struct PrintParam { } }; +struct MergeCol : vt::Collection { + MergeCol() = default; + MergeCol(NodeType num, double off) : offset( off ) { + idx = getIndex(); + } + + struct DataMsg : vt::CollectionMessage { + DataMsg() = default; + explicit DataMsg(double x_) : x(x_) { } + double x = 0.0; + }; + + struct GhostMsg : vt::CollectionMessage { + GhostMsg() = default; + explicit GhostMsg(vt::CollectionProxy proxy_) + : proxy(proxy_) + {} + vt::CollectionProxy proxy; + }; + + void initData(DataMsg* msg) { + EXPECT_EQ(msg->x, calcVal(1,idx)); + data = msg->x + offset; + } + + void ghost(GhostMsg* msg) { + msg->proxy(getIndex()).template send(data); + } + + void interact(DataMsg* msg ) { + data *= msg->x; + } + + void check(DataMsg *msg) { + EXPECT_EQ(msg->x, data); + } + + template + void serialize(SerializerT& s) { + vt::Collection::serialize(s); + s | idx | offset | data; + } + +private: + + vt::Index2D idx; + double offset = 0; + double data = 0.0; +}; + +struct MergeObjGroup +{ + MergeObjGroup() = default; + + void startup() { + // Produce on global epoch so on a single node it does not terminate early + vt::theTerm()->produce(vt::term::any_epoch_sentinel); + } + + void shutdown() { + // Consume on global epoch to match the startup produce + vt::theTerm()->consume(vt::term::any_epoch_sentinel); + } + + void makeVT() { + frontend_proxy = vt::theObjGroup()->makeCollective(this); + } + + void makeColl(NodeType num_nodes, int k, double offset) { + auto const node = theContext()->getNode(); + auto range = vt::Index2D(static_cast(num_nodes),k); + backend_proxy = vt::theCollection()->constructCollective( + range, [=](vt::Index2D idx) { + return std::make_unique(num_nodes, offset); + } + ); + + chains_ = std::make_unique>(); + + for (int i = 0; i < k; ++i) { + chains_->addIndex(vt::Index2D(static_cast(node), i)); + } + } + + void startUpdate() { + epoch_ = vt::theTerm()->makeEpochCollective(); + vt::theMsg()->pushEpoch(epoch_); + } + + void initData() { + chains_->nextStep("initData", [=](vt::Index2D idx) { + auto x = calcVal(1,idx); + return backend_proxy(idx).template send(x); + }); + } + + void interact( MergeObjGroup &other ) { + auto other_proxy = other.backend_proxy; + vt::messaging::CollectionChainSet::mergeStepCollective( "interact", + *chains_, + *other.chains_, + [=]( vt::Index2D idx) { + return backend_proxy(idx).template send(other_proxy); + }); + } + + void check( double offset, double other_offset, bool is_left ) { + chains_->nextStep("initData", [=](vt::Index2D idx) { + auto x = calcVal(1,idx) + offset; + if ( !is_left ) + x *= calcVal(1,idx) + other_offset; + return backend_proxy(idx).template send(x); + }); + } + + void finishUpdate() { + chains_->phaseDone(); + vt::theMsg()->popEpoch(epoch_); + vt::theTerm()->finishedEpoch(epoch_); + + vt::runSchedulerThrough(epoch_); + } + + private: + + // The current epoch for a given update + vt::EpochType epoch_ = vt::no_epoch; + // The backend collection proxy for managing the over decomposed workers + vt::CollectionProxy backend_proxy = {}; + // The proxy for this objgroup + vt::objgroup::proxy::Proxy frontend_proxy = {}; + // The current collection chains that are being managed here + std::unique_ptr> chains_ = nullptr; +}; + +TEST_P(TestTermDepSendChain, test_term_dep_send_chain_merge) { + auto const& num_nodes = theContext()->getNumNodes(); + auto const iter = 50; + auto const& tup = GetParam(); + auto const use_ds = std::get<0>(tup); + auto const k = std::get<1>(tup); + + vt::theConfig()->vt_term_rooted_use_wave = !use_ds; + vt::theConfig()->vt_term_rooted_use_ds = use_ds; + + auto obj_a = std::make_unique(); + obj_a->startup(); + obj_a->makeVT(); + obj_a->makeColl(num_nodes,k, 0.0); + + auto obj_b = std::make_unique(); + obj_b->startup(); + obj_b->makeVT(); + obj_b->makeColl(num_nodes,k, 1000.0); + + // Must have barrier here so op4Impl does not bounce early (invalid proxy)! + vt::theCollective()->barrier(); + + for (int t = 0; t < iter; t++) { + obj_a->startUpdate(); + obj_a->initData(); + + obj_b->startUpdate(); + obj_b->initData(); + + obj_a->interact( *obj_b ); + + obj_a->check(0.0, 1000.0, true); + obj_b->check(0.0, 1000.0, false); + + obj_b->finishUpdate(); + obj_a->finishUpdate(); + } + + obj_a->shutdown(); + obj_b->shutdown(); +} + // Test Wave-epoch with a narrower set of parameters since large k is very slow INSTANTIATE_TEST_SUITE_P( DepSendChainInputExplodeWave, TestTermDepSendChain,