diff --git a/src/vt/termination/termination.h b/src/vt/termination/termination.h index 42128bfdf1..4baf0ca292 100644 --- a/src/vt/termination/termination.h +++ b/src/vt/termination/termination.h @@ -71,6 +71,18 @@ namespace vt { namespace term { using DijkstraScholtenTerm = term::ds::StateDS; +struct EpochStack { + using DataType = epoch::detail::EpochImplType; + + void push(DataType in) { stack_[cur_++] = in; } + DataType top() const { return stack_[cur_-1]; } + void pop() { cur_--; } + int size() const { return cur_; } + + int cur_ = 0; + std::array stack_; +}; + /** * \struct TerminationDetector * @@ -102,6 +114,7 @@ struct TerminationDetector : using SuccessorBagType = EpochDependency::SuccessorBagType; using EpochGraph = termination::graph::EpochGraph; using EpochGraphMsg = termination::graph::EpochGraphMsg; + using EpochStackType = EpochStack; /** * \internal \brief Construct a termination detector @@ -407,11 +420,9 @@ struct TerminationDetector : * \param[in] state the epoch state * \param[in] num_units number of units * \param[in] produce whether its a produce or consume - * \param[in] node the node producing to or consuming from */ - void produceConsumeState( - TermStateType& state, TermCounterType const num_units, bool produce, - NodeType node + inline void produceConsumeState( + TermStateType& state, TermCounterType const num_units, bool produce ); /** @@ -422,7 +433,7 @@ struct TerminationDetector : * \param[in] produce whether its a produce or consume * \param[in] node the node producing to or consuming from */ - void produceConsume( + inline void produceConsume( EpochType epoch = any_epoch_sentinel, TermCounterType num_units = 1, bool produce = true, NodeType node = uninitialized_destination ); @@ -772,11 +783,26 @@ struct TerminationDetector : */ static void epochContinueHandler(TermMsg* msg); -private: +public: + inline EpochType getEpoch() const; + inline EpochType getGlobalEpoch() const; + inline void pushEpoch(EpochType const& epoch); + inline EpochType popEpoch(EpochType const& epoch = no_epoch); + + inline void pushEpochFast(EpochType epoch) { + epoch_stack_.push(epoch.get()); + } + inline void popEpochFast(EpochType epoch) { + epoch_stack_.pop(); + } + + inline EpochStackType& getEpochStack() { return epoch_stack_; } + // global termination state TermStateType any_epoch_state_; // hang detector termination state TermStateType hang_; +private: // epoch termination state EpochContainerType epoch_state_ = {}; // ready epoch list (misnomer: finishedEpoch was invoked) @@ -785,6 +811,8 @@ struct TerminationDetector : std::unordered_set epoch_wait_status_ = {}; // has printed epoch graph during abort bool has_printed_epoch_graph = false; + NodeType this_node_ = uninitialized_destination; + EpochStackType epoch_stack_; }; }} // end namespace vt::term