Skip to content

Commit

Permalink
Merge pull request #1961 from DARMA-tasking/1941-move-epoch-stack-to-…
Browse files Browse the repository at this point in the history
…terminationdetector

1941 move EpochStack to TerminationDetector
  • Loading branch information
lifflander authored and cz4rs committed Sep 28, 2022
2 parents 77a3e12 + e45e989 commit 0ee2365
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 182 deletions.
48 changes: 12 additions & 36 deletions src/vt/context/runnable_context/td.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,77 +62,53 @@ TD::TD(EpochType in_ep)
}

void TD::begin() {
theMsg()->pushEpoch(ep_);
theTerm()->pushEpoch(ep_);

auto& epoch_stack = theMsg()->getEpochStack();
auto& epoch_stack = theTerm()->getEpochStack();

vt_debug_print(
verbose, context,
"TD::begin: top={:x}, size={}\n",
epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch,
epoch_stack.size()
);

base_epoch_stack_size_ = epoch_stack.size();
}

void TD::end() {
auto& epoch_stack = theMsg()->getEpochStack();
auto& epoch_stack = theTerm()->getEpochStack();

vt_debug_print(
verbose, context,
"TD::end: top={:x}, size={}, base_size={}\n",
epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch,
epoch_stack.size(), base_epoch_stack_size_
);

vtAssert(
base_epoch_stack_size_ <= epoch_stack.size(),
"Epoch stack popped below preceding push size in handler"
);

while (epoch_stack.size() > base_epoch_stack_size_) {
theMsg()->popEpoch();
theTerm()->popEpoch();
}

theMsg()->popEpoch(ep_);
theTerm()->popEpoch(ep_);
}

void TD::suspend() {
auto& epoch_stack = theMsg()->getEpochStack();
auto& epoch_stack = theTerm()->getEpochStack();

vt_debug_print(
verbose, context,
"TD::suspend: top={:x}, size={}, base_size={}\n",
epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch,
epoch_stack.size(), base_epoch_stack_size_
);

while (epoch_stack.size() > base_epoch_stack_size_) {
suspended_epochs_.push_back(theMsg()->getEpoch());
theMsg()->popEpoch();
suspended_epochs_.push_back(theTerm()->getEpoch());
theTerm()->popEpoch();
}

theMsg()->popEpoch(ep_);
theTerm()->popEpoch(ep_);
}

void TD::resume() {
theMsg()->pushEpoch(ep_);
theTerm()->pushEpoch(ep_);

auto& epoch_stack = theMsg()->getEpochStack();
auto& epoch_stack = theTerm()->getEpochStack();
base_epoch_stack_size_ = epoch_stack.size();

vt_debug_print(
verbose, context,
"TD::resume: top={:x}, size={}, base_size={}\n",
epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch,
epoch_stack.size(), base_epoch_stack_size_
);

for (auto it = suspended_epochs_.rbegin();
it != suspended_epochs_.rend();
++it) {
theMsg()->pushEpoch(*it);
theTerm()->pushEpoch(*it);
}

suspended_epochs_.clear();
Expand Down
21 changes: 1 addition & 20 deletions src/vt/messaging/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,6 @@ ActiveMessenger::ActiveMessenger()
# endif
this_node_(theContext()->getNode())
{
/*
* Push the default epoch into the stack so it is always at the bottom of the
* stack during execution until the AM's destructor is invoked
*/
pushEpoch(term::any_epoch_sentinel);

// Register counters for AM/DM message sends and number of bytes
amSentCounterGauge = diagnostic::CounterGauge{
registerCounter("AM_sent", "active messages sent"),
Expand Down Expand Up @@ -171,20 +165,7 @@ void ActiveMessenger::startup() {
#endif
}

/*virtual*/ ActiveMessenger::~ActiveMessenger() {
// Pop all extraneous epochs off the stack greater than 1
auto stack_size = epoch_stack_.size();
while (stack_size > 1) {
stack_size = (epoch_stack_.pop(), epoch_stack_.size());
}
// Pop off the last epoch: term::any_epoch_sentinel
auto const ret_epoch = popEpoch(term::any_epoch_sentinel);
vtAssertInfo(
ret_epoch == term::any_epoch_sentinel, "Last pop must be any epoch",
ret_epoch, term::any_epoch_sentinel, epoch_stack_.size()
);
vtAssertExpr(epoch_stack_.size() == 0);
}
/*virtual*/ ActiveMessenger::~ActiveMessenger() {}

trace::TraceEventIDType ActiveMessenger::makeTraceCreationSend(
HandlerType const handler, ByteType serialized_msg_size, bool is_bcast
Expand Down
20 changes: 0 additions & 20 deletions src/vt/messaging/active.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
using ReadyHanTagType = std::tuple<HandlerType, TagType>;
using MaybeReadyType = std::vector<ReadyHanTagType>;
using HandlerManagerType = HandlerManager;
using EpochStackType = std::stack<EpochType>;
using PendingSendType = PendingSend;

/**
Expand Down Expand Up @@ -1516,17 +1515,6 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
MsgSizeType const& msg_size, TagType const& send_tag
);

/**
* \internal
* \brief Get the current global epoch
*
* \c Returns the top epoch on the stack iff \c epoch_stack.size() > 0, else it
* returns \c vt::no_epoch
*
* \return the current global epoch
*/
inline EpochType getGlobalEpoch() const;

/**
* \internal
* \brief Push an epoch on the stack
Expand Down Expand Up @@ -1563,12 +1551,6 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
*/
inline EpochType getEpoch() const;

/**
* \internal
* \brief Access the epoch stack
*/
inline EpochStackType& getEpochStack() { return epoch_stack_; }

/**
* \internal
* \brief Get the epoch for a message based on the current context so an
Expand Down Expand Up @@ -1644,7 +1626,6 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
| pending_handler_msgs_
| pending_recvs_
| cur_direct_buffer_tag_
| epoch_stack_
| in_progress_active_msg_irecv
| in_progress_data_irecv
| in_progress_ops
Expand Down Expand Up @@ -1755,7 +1736,6 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
ContWaitType pending_handler_msgs_ = {};
ContainerPendingType pending_recvs_ = {};
TagType cur_direct_buffer_tag_ = starting_direct_buffer_tag;
EpochStackType epoch_stack_;
RequestHolder<InProgressIRecv> in_progress_active_msg_irecv;
RequestHolder<InProgressDataIRecv> in_progress_data_irecv;
RequestHolder<AsyncOpWrapper> in_progress_ops;
Expand Down
44 changes: 3 additions & 41 deletions src/vt/messaging/active.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,54 +445,16 @@ ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsgAuto(
);
}

inline EpochType ActiveMessenger::getGlobalEpoch() const {
vtAssertInfo(
epoch_stack_.size() > 0, "Epoch stack size must be greater than zero",
epoch_stack_.size()
);
return epoch_stack_.size() ? epoch_stack_.top() : term::any_epoch_sentinel;
}

inline void ActiveMessenger::pushEpoch(EpochType const& epoch) {
/*
* pushEpoch(epoch) pushes any epoch onto the local stack iff epoch !=
* no_epoch; the epoch stack includes all locally pushed epochs and the
* current contexts pushed, transitively causally related active message
* handlers.
*/
vtAssertInfo(
epoch != no_epoch, "Do not push no_epoch onto the epoch stack",
epoch, no_epoch, epoch_stack_.size(),
epoch_stack_.size() > 0 ? epoch_stack_.top() : no_epoch
);
if (epoch != no_epoch) {
epoch_stack_.push(epoch);
}
return theTerm()->pushEpoch(epoch);
}

inline EpochType ActiveMessenger::popEpoch(EpochType const& epoch) {
/*
* popEpoch(epoch) shall remove the top entry from epoch_size_, iif the size
* is non-zero and the `epoch' passed, if `epoch != no_epoch', is equal to the
* top of the `epoch_stack_.top()'; else, it shall remove any entry from the
* top of the stack.
*/
auto const& non_zero = epoch_stack_.size() > 0;
vtAssertExprInfo(
non_zero and (epoch_stack_.top() == epoch or epoch == no_epoch),
epoch, non_zero, epoch_stack_.top()
);
if (epoch == no_epoch) {
return non_zero ? epoch_stack_.pop(),epoch_stack_.top() : no_epoch;
} else {
return non_zero && epoch == epoch_stack_.top() ?
epoch_stack_.pop(),epoch :
no_epoch;
}
return theTerm()->popEpoch(epoch);
}

inline EpochType ActiveMessenger::getEpoch() const {
return getGlobalEpoch();
return theTerm()->getEpoch();
}

template <typename MsgT>
Expand Down
63 changes: 7 additions & 56 deletions src/vt/termination/termination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,12 @@ namespace vt { namespace term {

TerminationDetector::TerminationDetector()
: collective::tree::Tree(collective::tree::tree_cons_tag_t),
any_epoch_state_(any_epoch_sentinel, false, true, getNumChildren()),
hang_(no_epoch, true, false, getNumChildren())
{ }
any_epoch_state_(any_epoch_sentinel, false, true, getNumChildren()),
hang_(no_epoch, true, false, getNumChildren()),
this_node_(theContext()->getNode())
{
pushEpoch(term::any_epoch_sentinel);
}

/*static*/ void TerminationDetector::makeRootedHandler(TermMsg* msg) {
theTerm()->makeRootedHan(msg->new_epoch, false);
Expand Down Expand Up @@ -112,7 +115,7 @@ void TerminationDetector::setLocalTerminated(
any_epoch_state_.notifyLocalTerminated(local_terminated);

if (local_terminated && !no_propagate) {
theTerm()->maybePropagate();
maybePropagate();
}
}

Expand Down Expand Up @@ -141,26 +144,6 @@ TerminationDetector::findOrCreateState(EpochType const& epoch, bool is_ready) {
return epoch_iter->second;
}

void TerminationDetector::produceConsumeState(
TermStateType& state, TermCounterType const num_units, bool produce,
NodeType node
) {
auto& counter = produce ? state.l_prod : state.l_cons;
counter += num_units;

vt_debug_print(
verbose, term,
"produceConsumeState: epoch={:x}, event_count={}, l_prod={}, l_cons={}, "
"num_units={}, produce={}, node={}\n",
state.getEpoch(), state.getRecvChildCount(), state.l_prod, state.l_cons, num_units,
print_bool(produce), node
);

if (state.readySubmitParent()) {
propagateEpoch(state);
}
}

TerminationDetector::TermStateDSType*
TerminationDetector::getDSTerm(EpochType epoch, bool is_root) {
vt_debug_print(
Expand Down Expand Up @@ -188,38 +171,6 @@ TerminationDetector::getDSTerm(EpochType epoch, bool is_root) {
}
}

void TerminationDetector::produceConsume(
EpochType epoch, TermCounterType num_units, bool produce, NodeType node
) {
vt_debug_print(
normal, term,
"produceConsume: epoch={:x}, rooted={}, ds={}, count={}, produce={}, "
"node={}\n",
epoch, isRooted(epoch), isDS(epoch), num_units, produce, node
);

// If a node is not passed, use the current node (self-prod/cons)
if (node == uninitialized_destination) {
node = theContext()->getNode();
}

produceConsumeState(any_epoch_state_, num_units, produce, node);

if (epoch != any_epoch_sentinel) {
if (isDS(epoch)) {
auto ds_term = getDSTerm(epoch);
if (produce) {
ds_term->msgSent(node,num_units);
} else {
ds_term->msgProcessed(node,num_units);
}
} else {
auto& state = findOrCreateState(epoch, false);
produceConsumeState(state, num_units, produce, node);
}
}
}

void TerminationDetector::maybePropagate() {
if (any_epoch_state_.readySubmitParent()) {
propagateEpoch(any_epoch_state_);
Expand Down
Loading

0 comments on commit 0ee2365

Please sign in to comment.