Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1941 move EpochStack to TerminationDetector #1961

Merged
merged 14 commits into from
Sep 22, 2022
Merged
48 changes: 12 additions & 36 deletions src/vt/context/runnable_context/td.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,77 +62,53 @@ TD::TD(EpochType in_ep)
}

void TD::begin() {
theMsg()->pushEpoch(ep_);
theTerm()->pushEpoch(ep_);

auto& epoch_stack = theMsg()->getEpochStack();
auto& epoch_stack = theTerm()->getEpochStack();

vt_debug_print(
verbose, context,
"TD::begin: top={:x}, size={}\n",
epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch,
epoch_stack.size()
);

base_epoch_stack_size_ = epoch_stack.size();
}

void TD::end() {
auto& epoch_stack = theMsg()->getEpochStack();
auto& epoch_stack = theTerm()->getEpochStack();

vt_debug_print(
verbose, context,
"TD::end: top={:x}, size={}, base_size={}\n",
epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch,
epoch_stack.size(), base_epoch_stack_size_
);

vtAssert(
base_epoch_stack_size_ <= epoch_stack.size(),
"Epoch stack popped below preceding push size in handler"
);

while (epoch_stack.size() > base_epoch_stack_size_) {
theMsg()->popEpoch();
theTerm()->popEpoch();
}

theMsg()->popEpoch(ep_);
theTerm()->popEpoch(ep_);
}

void TD::suspend() {
auto& epoch_stack = theMsg()->getEpochStack();
auto& epoch_stack = theTerm()->getEpochStack();

vt_debug_print(
verbose, context,
"TD::suspend: top={:x}, size={}, base_size={}\n",
epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch,
epoch_stack.size(), base_epoch_stack_size_
);

while (epoch_stack.size() > base_epoch_stack_size_) {
suspended_epochs_.push_back(theMsg()->getEpoch());
theMsg()->popEpoch();
suspended_epochs_.push_back(theTerm()->getEpoch());
theTerm()->popEpoch();
}

theMsg()->popEpoch(ep_);
theTerm()->popEpoch(ep_);
}

void TD::resume() {
theMsg()->pushEpoch(ep_);
theTerm()->pushEpoch(ep_);

auto& epoch_stack = theMsg()->getEpochStack();
auto& epoch_stack = theTerm()->getEpochStack();
base_epoch_stack_size_ = epoch_stack.size();

vt_debug_print(
verbose, context,
"TD::resume: top={:x}, size={}, base_size={}\n",
epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch,
epoch_stack.size(), base_epoch_stack_size_
);

for (auto it = suspended_epochs_.rbegin();
it != suspended_epochs_.rend();
++it) {
theMsg()->pushEpoch(*it);
theTerm()->pushEpoch(*it);
}

suspended_epochs_.clear();
Expand Down
21 changes: 1 addition & 20 deletions src/vt/messaging/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,6 @@ ActiveMessenger::ActiveMessenger()
# endif
this_node_(theContext()->getNode())
{
/*
* Push the default epoch into the stack so it is always at the bottom of the
* stack during execution until the AM's destructor is invoked
*/
pushEpoch(term::any_epoch_sentinel);

// Register counters for AM/DM message sends and number of bytes
amSentCounterGauge = diagnostic::CounterGauge{
registerCounter("AM_sent", "active messages sent"),
Expand Down Expand Up @@ -171,20 +165,7 @@ void ActiveMessenger::startup() {
#endif
}

/*virtual*/ ActiveMessenger::~ActiveMessenger() {
// Pop all extraneous epochs off the stack greater than 1
auto stack_size = epoch_stack_.size();
while (stack_size > 1) {
stack_size = (epoch_stack_.pop(), epoch_stack_.size());
}
// Pop off the last epoch: term::any_epoch_sentinel
auto const ret_epoch = popEpoch(term::any_epoch_sentinel);
vtAssertInfo(
ret_epoch == term::any_epoch_sentinel, "Last pop must be any epoch",
ret_epoch, term::any_epoch_sentinel, epoch_stack_.size()
);
vtAssertExpr(epoch_stack_.size() == 0);
}
/*virtual*/ ActiveMessenger::~ActiveMessenger() {}

trace::TraceEventIDType ActiveMessenger::makeTraceCreationSend(
HandlerType const handler, ByteType serialized_msg_size, bool is_bcast
Expand Down
20 changes: 0 additions & 20 deletions src/vt/messaging/active.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
using ReadyHanTagType = std::tuple<HandlerType, TagType>;
using MaybeReadyType = std::vector<ReadyHanTagType>;
using HandlerManagerType = HandlerManager;
using EpochStackType = std::stack<EpochType>;
using PendingSendType = PendingSend;

/**
Expand Down Expand Up @@ -1516,17 +1515,6 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
MsgSizeType const& msg_size, TagType const& send_tag
);

/**
* \internal
* \brief Get the current global epoch
*
* \c Returns the top epoch on the stack iff \c epoch_stack.size() > 0, else it
* returns \c vt::no_epoch
*
* \return the current global epoch
*/
inline EpochType getGlobalEpoch() const;

/**
* \internal
* \brief Push an epoch on the stack
Expand Down Expand Up @@ -1563,12 +1551,6 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
*/
inline EpochType getEpoch() const;

/**
* \internal
* \brief Access the epoch stack
*/
inline EpochStackType& getEpochStack() { return epoch_stack_; }

/**
* \internal
* \brief Get the epoch for a message based on the current context so an
Expand Down Expand Up @@ -1644,7 +1626,6 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
| pending_handler_msgs_
| pending_recvs_
| cur_direct_buffer_tag_
| epoch_stack_
| in_progress_active_msg_irecv
| in_progress_data_irecv
| in_progress_ops
Expand Down Expand Up @@ -1755,7 +1736,6 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
ContWaitType pending_handler_msgs_ = {};
ContainerPendingType pending_recvs_ = {};
TagType cur_direct_buffer_tag_ = starting_direct_buffer_tag;
EpochStackType epoch_stack_;
RequestHolder<InProgressIRecv> in_progress_active_msg_irecv;
RequestHolder<InProgressDataIRecv> in_progress_data_irecv;
RequestHolder<AsyncOpWrapper> in_progress_ops;
Expand Down
44 changes: 3 additions & 41 deletions src/vt/messaging/active.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,54 +445,16 @@ ActiveMessenger::PendingSendType ActiveMessenger::broadcastMsgAuto(
);
}

inline EpochType ActiveMessenger::getGlobalEpoch() const {
vtAssertInfo(
epoch_stack_.size() > 0, "Epoch stack size must be greater than zero",
epoch_stack_.size()
);
return epoch_stack_.size() ? epoch_stack_.top() : term::any_epoch_sentinel;
}

inline void ActiveMessenger::pushEpoch(EpochType const& epoch) {
/*
* pushEpoch(epoch) pushes any epoch onto the local stack iff epoch !=
* no_epoch; the epoch stack includes all locally pushed epochs and the
* current contexts pushed, transitively causally related active message
* handlers.
*/
vtAssertInfo(
epoch != no_epoch, "Do not push no_epoch onto the epoch stack",
epoch, no_epoch, epoch_stack_.size(),
epoch_stack_.size() > 0 ? epoch_stack_.top() : no_epoch
);
if (epoch != no_epoch) {
epoch_stack_.push(epoch);
}
return theTerm()->pushEpoch(epoch);
}

inline EpochType ActiveMessenger::popEpoch(EpochType const& epoch) {
/*
* popEpoch(epoch) shall remove the top entry from epoch_size_, iif the size
* is non-zero and the `epoch' passed, if `epoch != no_epoch', is equal to the
* top of the `epoch_stack_.top()'; else, it shall remove any entry from the
* top of the stack.
*/
auto const& non_zero = epoch_stack_.size() > 0;
vtAssertExprInfo(
non_zero and (epoch_stack_.top() == epoch or epoch == no_epoch),
epoch, non_zero, epoch_stack_.top()
);
if (epoch == no_epoch) {
return non_zero ? epoch_stack_.pop(),epoch_stack_.top() : no_epoch;
} else {
return non_zero && epoch == epoch_stack_.top() ?
epoch_stack_.pop(),epoch :
no_epoch;
}
return theTerm()->popEpoch(epoch);
}

inline EpochType ActiveMessenger::getEpoch() const {
return getGlobalEpoch();
return theTerm()->getEpoch();
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like keeping these wrappers here, but I guess it spares us rewriting all the call sites at the same time. It does also allow for caching theTerm() in ActiveMessenger for internal usage

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've not changed this. Do you think it's worth an issue to remove them or to go with the caching, @PhilMiller @lifflander ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave it be for now, and think about caching if that lookup ever turns up as a hot spot in profiles

template <typename MsgT>
Expand Down
63 changes: 7 additions & 56 deletions src/vt/termination/termination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,12 @@ namespace vt { namespace term {

TerminationDetector::TerminationDetector()
: collective::tree::Tree(collective::tree::tree_cons_tag_t),
any_epoch_state_(any_epoch_sentinel, false, true, getNumChildren()),
hang_(no_epoch, true, false, getNumChildren())
{ }
any_epoch_state_(any_epoch_sentinel, false, true, getNumChildren()),
hang_(no_epoch, true, false, getNumChildren()),
this_node_(theContext()->getNode())
{
pushEpoch(term::any_epoch_sentinel);
}

/*static*/ void TerminationDetector::makeRootedHandler(TermMsg* msg) {
theTerm()->makeRootedHan(msg->new_epoch, false);
Expand Down Expand Up @@ -112,7 +115,7 @@ void TerminationDetector::setLocalTerminated(
any_epoch_state_.notifyLocalTerminated(local_terminated);

if (local_terminated && !no_propagate) {
theTerm()->maybePropagate();
maybePropagate();
}
}

Expand Down Expand Up @@ -141,26 +144,6 @@ TerminationDetector::findOrCreateState(EpochType const& epoch, bool is_ready) {
return epoch_iter->second;
}

void TerminationDetector::produceConsumeState(
TermStateType& state, TermCounterType const num_units, bool produce,
NodeType node
) {
auto& counter = produce ? state.l_prod : state.l_cons;
counter += num_units;

vt_debug_print(
verbose, term,
"produceConsumeState: epoch={:x}, event_count={}, l_prod={}, l_cons={}, "
"num_units={}, produce={}, node={}\n",
state.getEpoch(), state.getRecvChildCount(), state.l_prod, state.l_cons, num_units,
print_bool(produce), node
);

if (state.readySubmitParent()) {
propagateEpoch(state);
}
}

TerminationDetector::TermStateDSType*
TerminationDetector::getDSTerm(EpochType epoch, bool is_root) {
vt_debug_print(
Expand Down Expand Up @@ -188,38 +171,6 @@ TerminationDetector::getDSTerm(EpochType epoch, bool is_root) {
}
}

void TerminationDetector::produceConsume(
EpochType epoch, TermCounterType num_units, bool produce, NodeType node
) {
vt_debug_print(
normal, term,
"produceConsume: epoch={:x}, rooted={}, ds={}, count={}, produce={}, "
"node={}\n",
epoch, isRooted(epoch), isDS(epoch), num_units, produce, node
);

// If a node is not passed, use the current node (self-prod/cons)
if (node == uninitialized_destination) {
node = theContext()->getNode();
}

produceConsumeState(any_epoch_state_, num_units, produce, node);

if (epoch != any_epoch_sentinel) {
if (isDS(epoch)) {
auto ds_term = getDSTerm(epoch);
if (produce) {
ds_term->msgSent(node,num_units);
} else {
ds_term->msgProcessed(node,num_units);
}
} else {
auto& state = findOrCreateState(epoch, false);
produceConsumeState(state, num_units, produce, node);
}
}
}

void TerminationDetector::maybePropagate() {
if (any_epoch_state_.readySubmitParent()) {
propagateEpoch(any_epoch_state_);
Expand Down
Loading