diff --git a/src/vt/context/runnable_context/td.cc b/src/vt/context/runnable_context/td.cc index 00958539db..1337dc37d4 100644 --- a/src/vt/context/runnable_context/td.cc +++ b/src/vt/context/runnable_context/td.cc @@ -62,29 +62,17 @@ TD::TD(EpochType in_ep) } void TD::begin() { - theMsg()->pushEpoch(ep_); + theTerm()->pushEpoch(ep_); - auto& epoch_stack = theMsg()->getEpochStack(); + auto& epoch_stack = theTerm()->getEpochStack(); - vt_debug_print( - verbose, context, - "TD::begin: top={:x}, size={}\n", - epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch, - epoch_stack.size() - ); base_epoch_stack_size_ = epoch_stack.size(); } void TD::end() { - auto& epoch_stack = theMsg()->getEpochStack(); + auto& epoch_stack = theTerm()->getEpochStack(); - vt_debug_print( - verbose, context, - "TD::end: top={:x}, size={}, base_size={}\n", - epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch, - epoch_stack.size(), base_epoch_stack_size_ - ); vtAssert( base_epoch_stack_size_ <= epoch_stack.size(), @@ -92,47 +80,35 @@ void TD::end() { ); while (epoch_stack.size() > base_epoch_stack_size_) { - theMsg()->popEpoch(); + theTerm()->popEpoch(); } - theMsg()->popEpoch(ep_); + theTerm()->popEpoch(ep_); } void TD::suspend() { - auto& epoch_stack = theMsg()->getEpochStack(); + auto& epoch_stack = theTerm()->getEpochStack(); - vt_debug_print( - verbose, context, - "TD::suspend: top={:x}, size={}, base_size={}\n", - epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch, - epoch_stack.size(), base_epoch_stack_size_ - ); while (epoch_stack.size() > base_epoch_stack_size_) { - suspended_epochs_.push_back(theMsg()->getEpoch()); - theMsg()->popEpoch(); + suspended_epochs_.push_back(theTerm()->getEpoch()); + theTerm()->popEpoch(); } - theMsg()->popEpoch(ep_); + theTerm()->popEpoch(ep_); } void TD::resume() { - theMsg()->pushEpoch(ep_); + theTerm()->pushEpoch(ep_); - auto& epoch_stack = theMsg()->getEpochStack(); + auto& epoch_stack = theTerm()->getEpochStack(); base_epoch_stack_size_ = epoch_stack.size(); - vt_debug_print( - verbose, context, - "TD::resume: top={:x}, size={}, base_size={}\n", - epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch, - epoch_stack.size(), base_epoch_stack_size_ - ); for (auto it = suspended_epochs_.rbegin(); it != suspended_epochs_.rend(); ++it) { - theMsg()->pushEpoch(*it); + theTerm()->pushEpoch(*it); } suspended_epochs_.clear(); diff --git a/src/vt/messaging/active.cc b/src/vt/messaging/active.cc index 25db27fabf..b5c7de42dc 100644 --- a/src/vt/messaging/active.cc +++ b/src/vt/messaging/active.cc @@ -75,12 +75,6 @@ ActiveMessenger::ActiveMessenger() # endif this_node_(theContext()->getNode()) { - /* - * Push the default epoch into the stack so it is always at the bottom of the - * stack during execution until the AM's destructor is invoked - */ - pushEpoch(term::any_epoch_sentinel); - // Register counters for AM/DM message sends and number of bytes amSentCounterGauge = diagnostic::CounterGauge{ registerCounter("AM_sent", "active messages sent"), @@ -171,20 +165,7 @@ void ActiveMessenger::startup() { #endif } -/*virtual*/ ActiveMessenger::~ActiveMessenger() { - // Pop all extraneous epochs off the stack greater than 1 - auto stack_size = epoch_stack_.size(); - while (stack_size > 1) { - stack_size = (epoch_stack_.pop(), epoch_stack_.size()); - } - // Pop off the last epoch: term::any_epoch_sentinel - auto const ret_epoch = popEpoch(term::any_epoch_sentinel); - vtAssertInfo( - ret_epoch == term::any_epoch_sentinel, "Last pop must be any epoch", - ret_epoch, term::any_epoch_sentinel, epoch_stack_.size() - ); - vtAssertExpr(epoch_stack_.size() == 0); -} +/*virtual*/ ActiveMessenger::~ActiveMessenger() {} trace::TraceEventIDType ActiveMessenger::makeTraceCreationSend( HandlerType const handler, ByteType serialized_msg_size, bool is_bcast diff --git a/src/vt/messaging/active.h b/src/vt/messaging/active.h index 3b50f71b47..2b44025004 100644 --- a/src/vt/messaging/active.h +++ b/src/vt/messaging/active.h @@ -325,7 +325,6 @@ struct ActiveMessenger : runtime::component::PollableComponent using ReadyHanTagType = std::tuple; using MaybeReadyType = std::vector; using HandlerManagerType = HandlerManager; - using EpochStackType = std::stack; using PendingSendType = PendingSend; /** @@ -1516,17 +1515,6 @@ struct ActiveMessenger : runtime::component::PollableComponent MsgSizeType const& msg_size, TagType const& send_tag ); - /** - * \internal - * \brief Get the current global epoch - * - * \c Returns the top epoch on the stack iff \c epoch_stack.size() > 0, else it - * returns \c vt::no_epoch - * - * \return the current global epoch - */ - inline EpochType getGlobalEpoch() const; - /** * \internal * \brief Push an epoch on the stack @@ -1563,12 +1551,6 @@ struct ActiveMessenger : runtime::component::PollableComponent */ inline EpochType getEpoch() const; - /** - * \internal - * \brief Access the epoch stack - */ - inline EpochStackType& getEpochStack() { return epoch_stack_; } - /** * \internal * \brief Get the epoch for a message based on the current context so an @@ -1644,7 +1626,6 @@ struct ActiveMessenger : runtime::component::PollableComponent | pending_handler_msgs_ | pending_recvs_ | cur_direct_buffer_tag_ - | epoch_stack_ | in_progress_active_msg_irecv | in_progress_data_irecv | in_progress_ops @@ -1755,7 +1736,6 @@ struct ActiveMessenger : runtime::component::PollableComponent ContWaitType pending_handler_msgs_ = {}; ContainerPendingType pending_recvs_ = {}; TagType cur_direct_buffer_tag_ = starting_direct_buffer_tag; - EpochStackType epoch_stack_; RequestHolder in_progress_active_msg_irecv; RequestHolder in_progress_data_irecv; RequestHolder in_progress_ops; diff --git a/src/vt/messaging/active.impl.h b/src/vt/messaging/active.impl.h index 6161b73fe5..705dc17a4e 100644 --- a/src/vt/messaging/active.impl.h +++ b/src/vt/messaging/active.impl.h @@ -445,54 +445,16 @@ ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsgAuto( ); } -inline EpochType ActiveMessenger::getGlobalEpoch() const { - vtAssertInfo( - epoch_stack_.size() > 0, "Epoch stack size must be greater than zero", - epoch_stack_.size() - ); - return epoch_stack_.size() ? epoch_stack_.top() : term::any_epoch_sentinel; -} - inline void ActiveMessenger::pushEpoch(EpochType const& epoch) { - /* - * pushEpoch(epoch) pushes any epoch onto the local stack iff epoch != - * no_epoch; the epoch stack includes all locally pushed epochs and the - * current contexts pushed, transitively causally related active message - * handlers. - */ - vtAssertInfo( - epoch != no_epoch, "Do not push no_epoch onto the epoch stack", - epoch, no_epoch, epoch_stack_.size(), - epoch_stack_.size() > 0 ? epoch_stack_.top() : no_epoch - ); - if (epoch != no_epoch) { - epoch_stack_.push(epoch); - } + return theTerm()->pushEpoch(epoch); } inline EpochType ActiveMessenger::popEpoch(EpochType const& epoch) { - /* - * popEpoch(epoch) shall remove the top entry from epoch_size_, iif the size - * is non-zero and the `epoch' passed, if `epoch != no_epoch', is equal to the - * top of the `epoch_stack_.top()'; else, it shall remove any entry from the - * top of the stack. - */ - auto const& non_zero = epoch_stack_.size() > 0; - vtAssertExprInfo( - non_zero and (epoch_stack_.top() == epoch or epoch == no_epoch), - epoch, non_zero, epoch_stack_.top() - ); - if (epoch == no_epoch) { - return non_zero ? epoch_stack_.pop(),epoch_stack_.top() : no_epoch; - } else { - return non_zero && epoch == epoch_stack_.top() ? - epoch_stack_.pop(),epoch : - no_epoch; - } + return theTerm()->popEpoch(epoch); } inline EpochType ActiveMessenger::getEpoch() const { - return getGlobalEpoch(); + return theTerm()->getEpoch(); } template diff --git a/src/vt/termination/termination.cc b/src/vt/termination/termination.cc index 96c0a1cfa9..bcd1dd6912 100644 --- a/src/vt/termination/termination.cc +++ b/src/vt/termination/termination.cc @@ -63,9 +63,12 @@ namespace vt { namespace term { TerminationDetector::TerminationDetector() : collective::tree::Tree(collective::tree::tree_cons_tag_t), - any_epoch_state_(any_epoch_sentinel, false, true, getNumChildren()), - hang_(no_epoch, true, false, getNumChildren()) -{ } + any_epoch_state_(any_epoch_sentinel, false, true, getNumChildren()), + hang_(no_epoch, true, false, getNumChildren()), + this_node_(theContext()->getNode()) +{ + pushEpoch(term::any_epoch_sentinel); +} /*static*/ void TerminationDetector::makeRootedHandler(TermMsg* msg) { theTerm()->makeRootedHan(msg->new_epoch, false); @@ -112,7 +115,7 @@ void TerminationDetector::setLocalTerminated( any_epoch_state_.notifyLocalTerminated(local_terminated); if (local_terminated && !no_propagate) { - theTerm()->maybePropagate(); + maybePropagate(); } } @@ -141,26 +144,6 @@ TerminationDetector::findOrCreateState(EpochType const& epoch, bool is_ready) { return epoch_iter->second; } -void TerminationDetector::produceConsumeState( - TermStateType& state, TermCounterType const num_units, bool produce, - NodeType node -) { - auto& counter = produce ? state.l_prod : state.l_cons; - counter += num_units; - - vt_debug_print( - verbose, term, - "produceConsumeState: epoch={:x}, event_count={}, l_prod={}, l_cons={}, " - "num_units={}, produce={}, node={}\n", - state.getEpoch(), state.getRecvChildCount(), state.l_prod, state.l_cons, num_units, - print_bool(produce), node - ); - - if (state.readySubmitParent()) { - propagateEpoch(state); - } -} - TerminationDetector::TermStateDSType* TerminationDetector::getDSTerm(EpochType epoch, bool is_root) { vt_debug_print( @@ -188,38 +171,6 @@ TerminationDetector::getDSTerm(EpochType epoch, bool is_root) { } } -void TerminationDetector::produceConsume( - EpochType epoch, TermCounterType num_units, bool produce, NodeType node -) { - vt_debug_print( - normal, term, - "produceConsume: epoch={:x}, rooted={}, ds={}, count={}, produce={}, " - "node={}\n", - epoch, isRooted(epoch), isDS(epoch), num_units, produce, node - ); - - // If a node is not passed, use the current node (self-prod/cons) - if (node == uninitialized_destination) { - node = theContext()->getNode(); - } - - produceConsumeState(any_epoch_state_, num_units, produce, node); - - if (epoch != any_epoch_sentinel) { - if (isDS(epoch)) { - auto ds_term = getDSTerm(epoch); - if (produce) { - ds_term->msgSent(node,num_units); - } else { - ds_term->msgProcessed(node,num_units); - } - } else { - auto& state = findOrCreateState(epoch, false); - produceConsumeState(state, num_units, produce, node); - } - } -} - void TerminationDetector::maybePropagate() { if (any_epoch_state_.readySubmitParent()) { propagateEpoch(any_epoch_state_); diff --git a/src/vt/termination/termination.h b/src/vt/termination/termination.h index 42128bfdf1..8e99a02b33 100644 --- a/src/vt/termination/termination.h +++ b/src/vt/termination/termination.h @@ -71,6 +71,18 @@ namespace vt { namespace term { using DijkstraScholtenTerm = term::ds::StateDS; +struct EpochStack { + using DataType = epoch::detail::EpochImplType; + + void push(DataType in) { stack_[cur_++] = in; } + DataType top() const { return stack_[cur_-1]; } + void pop() { cur_--; } + unsigned int size() const { return cur_; } + + int cur_ = 0; + std::array stack_; +}; + /** * \struct TerminationDetector * @@ -102,13 +114,26 @@ struct TerminationDetector : using SuccessorBagType = EpochDependency::SuccessorBagType; using EpochGraph = termination::graph::EpochGraph; using EpochGraphMsg = termination::graph::EpochGraphMsg; + using EpochStackType = EpochStack; /** * \internal \brief Construct a termination detector */ TerminationDetector(); - virtual ~TerminationDetector() {} + virtual ~TerminationDetector() { + //Pop all extraneous epochs off the stack greater than 1 + while (epoch_stack_.size() > 1) { + epoch_stack_.pop(); + } + // Pop off the last epoch: term::any_epoch_sentinel + auto const ret_epoch = popEpoch(term::any_epoch_sentinel); + vtAssertInfo( + ret_epoch == term::any_epoch_sentinel, "Last pop must be any epoch", + ret_epoch, term::any_epoch_sentinel, epoch_stack_.size() + ); + vtAssertExpr(epoch_stack_.size() == 0); + } std::string name() override { return "TerminationDetector"; } @@ -407,11 +432,9 @@ struct TerminationDetector : * \param[in] state the epoch state * \param[in] num_units number of units * \param[in] produce whether its a produce or consume - * \param[in] node the node producing to or consuming from */ - void produceConsumeState( - TermStateType& state, TermCounterType const num_units, bool produce, - NodeType node + inline void produceConsumeState( + TermStateType& state, TermCounterType const num_units, bool produce ); /** @@ -422,7 +445,7 @@ struct TerminationDetector : * \param[in] produce whether its a produce or consume * \param[in] node the node producing to or consuming from */ - void produceConsume( + inline void produceConsume( EpochType epoch = any_epoch_sentinel, TermCounterType num_units = 1, bool produce = true, NodeType node = uninitialized_destination ); @@ -772,11 +795,25 @@ struct TerminationDetector : */ static void epochContinueHandler(TermMsg* msg); -private: +public: + inline EpochType getEpoch() const; + inline void pushEpoch(EpochType epoch); + inline EpochType popEpoch(EpochType epoch = no_epoch); + + inline void pushEpochFast(EpochType epoch) { + epoch_stack_.push(epoch.get()); + } + inline void popEpochFast() { + epoch_stack_.pop(); + } + + inline EpochStackType& getEpochStack() { return epoch_stack_; } + // global termination state TermStateType any_epoch_state_; // hang detector termination state TermStateType hang_; +private: // epoch termination state EpochContainerType epoch_state_ = {}; // ready epoch list (misnomer: finishedEpoch was invoked) @@ -785,6 +822,8 @@ struct TerminationDetector : std::unordered_set epoch_wait_status_ = {}; // has printed epoch graph during abort bool has_printed_epoch_graph = false; + NodeType this_node_ = uninitialized_destination; + EpochStackType epoch_stack_; }; }} // end namespace vt::term diff --git a/src/vt/termination/termination.impl.h b/src/vt/termination/termination.impl.h index a69ea91dd8..6b61c87e6f 100644 --- a/src/vt/termination/termination.impl.h +++ b/src/vt/termination/termination.impl.h @@ -83,6 +83,104 @@ inline bool TerminationDetector::isDS(EpochType epoch) { } } +inline void TerminationDetector::produceConsumeState( + TermStateType& state, TermCounterType const num_units, bool produce +) { + auto& counter = produce ? state.l_prod : state.l_cons; + counter += num_units; + + vt_debug_print( + verbose, term, + "produceConsumeState: epoch={:x}, event_count={}, l_prod={}, l_cons={}, " + "num_units={}, produce={}\n", + state.getEpoch(), state.getRecvChildCount(), state.l_prod, state.l_cons, num_units, + print_bool(produce) + ); + + if (state.readySubmitParent()) { + propagateEpoch(state); + } +} + +inline void TerminationDetector::produceConsume( + EpochType epoch, TermCounterType num_units, bool produce, NodeType node +) { + vt_debug_print( + normal, term, + "produceConsume: epoch={:x}, rooted={}, ds={}, count={}, produce={}, " + "node={}\n", + epoch, isRooted(epoch), isDS(epoch), num_units, produce, node + ); + + produceConsumeState(any_epoch_state_, num_units, produce); + + if (epoch != any_epoch_sentinel) { + if (isDS(epoch)) { + auto ds_term = getDSTerm(epoch); + + // If a node is not passed, use the current node (self-prod/cons) + if (node == uninitialized_destination) { + node = this_node_; + } + + if (produce) { + ds_term->msgSent(node,num_units); + } else { + ds_term->msgProcessed(node,num_units); + } + } else { + auto& state = findOrCreateState(epoch, false); + produceConsumeState(state, num_units, produce); + } + } +} + +inline EpochType TerminationDetector::getEpoch() const { + vtAssertInfo( + epoch_stack_.size() > 0, "Epoch stack size must be greater than zero", + epoch_stack_.size() + ); + return epoch_stack_.size() ? EpochType{epoch_stack_.top()} : term::any_epoch_sentinel; +} + inline void TerminationDetector::pushEpoch(EpochType epoch) { + /* + * pushEpoch(epoch) pushes any epoch onto the local stack iff epoch != + * no_epoch; the epoch stack includes all locally pushed epochs and the + * current contexts pushed, transitively causally related active message + * handlers. + */ + vtAssertInfo( + epoch != no_epoch, "Do not push no_epoch onto the epoch stack", + epoch, no_epoch, epoch_stack_.size(), + epoch_stack_.size() > 0 ? EpochType{epoch_stack_.top()} : no_epoch + ); + if (epoch != no_epoch) { + epoch_stack_.push(epoch.get()); + } +} + +inline EpochType TerminationDetector::popEpoch(EpochType epoch) { + /* + * popEpoch(epoch) shall remove the top entry from epoch_size_, iif the size + * is non-zero and the `epoch' passed, if `epoch != no_epoch', is equal to the + * top of the `epoch_stack_.top()'; else, it shall remove any entry from the + * top of the stack. + */ + auto const& non_zero = epoch_stack_.size() > 0; + vtAssertExprInfo( + non_zero and (epoch_stack_.top() == epoch.get() or epoch == no_epoch), + epoch, non_zero, non_zero ? EpochType{epoch_stack_.top()} : no_epoch + ); + if (epoch == no_epoch) { + return non_zero ? epoch_stack_.pop(),EpochType{epoch_stack_.top()} : no_epoch; + } else { + return non_zero && epoch == EpochType{epoch_stack_.top()} ? + epoch_stack_.pop(),epoch : + no_epoch; + } +} + + }} /* end namespace vt::term */ #endif /*INCLUDED_VT_TERMINATION_TERMINATION_IMPL_H*/ diff --git a/tests/unit/active/test_async_op_threads.cc b/tests/unit/active/test_async_op_threads.cc index 5fa85a3dde..7eb8770cad 100644 --- a/tests/unit/active/test_async_op_threads.cc +++ b/tests/unit/active/test_async_op_threads.cc @@ -70,7 +70,7 @@ struct MyObjGroup { } // get the epoch stack and store the original size - auto& epoch_stack = theMsg()->getEpochStack(); + auto& epoch_stack = theTerm()->getEpochStack(); std::size_t original_epoch_size = epoch_stack.size(); auto comm = theContext()->getComm(); @@ -95,7 +95,7 @@ struct MyObjGroup { done_ = true; // stack should be the size before running this method since we haven't // resumed the thread yet! - EXPECT_EQ(theMsg()->getEpochStack().size(), original_epoch_size - 2); + EXPECT_EQ(theTerm()->getEpochStack().size(), original_epoch_size - 2); } );