Skip to content

Commit

Permalink
Merge pull request #271 from darma-mpi-backend/246-callable
Browse files Browse the repository at this point in the history
  • Loading branch information
lifflander authored Feb 27, 2019
2 parents 82fb410 + 9ab8b92 commit 48fcf37
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 20 deletions.
49 changes: 33 additions & 16 deletions src/vt/termination/term_action.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ void TermAction::addAction(EpochType const& epoch, ActionType action) {
return addActionEpoch(epoch,action);
}

void TermAction::afterAddEpochAction(EpochType const& epoch) {
/*
* Produce a unit of any epoch type to inhibit global termination when
* local termination of a specific epoch is waiting for detection
*/
theTerm()->produce(term::any_epoch_sentinel);

auto const& status = testEpochFinished(epoch);
if (status == TermStatusEnum::Finished) {
triggerAllEpochActions(epoch);
}
}

void TermAction::addActionEpoch(EpochType const& epoch, ActionType action) {
if (epoch == term::any_epoch_sentinel) {
return addAction(action);
Expand All @@ -85,17 +98,8 @@ void TermAction::addActionEpoch(EpochType const& epoch, ActionType action) {
} else {
epoch_iter->second.emplace_back(action);
}
/*
* Produce a unit of any epoch type to inhibit global termination when
* local termination of a specific epoch is waiting for detection
*/
theTerm()->produce(term::any_epoch_sentinel);

auto const& status = testEpochFinished(epoch);
if (status == TermStatusEnum::Finished) {
triggerAllEpochActions(epoch);
}
}
afterAddEpochAction(epoch);
}

void TermAction::clearActions() {
Expand Down Expand Up @@ -142,19 +146,32 @@ void TermAction::triggerAllActions(
}

void TermAction::triggerAllEpochActions(EpochType const& epoch) {
// Run through the normal ActionType elements associated with this epoch
std::size_t epoch_actions_count = 0;
auto iter = epoch_actions_.find(epoch);
if (iter != epoch_actions_.end()) {
auto const& epoch_actions_count = iter->second.size();
epoch_actions_count += iter->second.size();
for (auto&& action : iter->second) {
action();
}
epoch_actions_.erase(iter);
/*
* Consume `size' units of any epoch type to match the production in
* addActionEpoch() so global termination can now be detected
*/
theTerm()->consume(term::any_epoch_sentinel, epoch_actions_count);
}
// Run through the callables associated with this epoch
auto iter2 = epoch_callable_actions_.find(epoch);
if (iter2 != epoch_callable_actions_.end()) {
epoch_actions_count += iter2->second.size();

for (auto&& action : iter2->second) {
action->invoke();
}

epoch_callable_actions_.erase(iter2);
}
/*
* Consume number of action units of any epoch type to match the production
* in addActionEpoch() so global termination can now be detected
*/
theTerm()->consume(term::any_epoch_sentinel, epoch_actions_count);
}

}} /* end namespace vt::term */
55 changes: 51 additions & 4 deletions src/vt/termination/term_action.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,53 @@

#include <vector>
#include <unordered_map>
#include <memory>

namespace vt { namespace term {

struct CallableBase {
protected:
CallableBase() = default;
public:
CallableBase(CallableBase const&) = delete;
CallableBase(CallableBase&&) = default;
virtual ~CallableBase() = default;
virtual void invoke() = 0;
};

template <typename Callable>
struct CallableHolder : CallableBase {
explicit CallableHolder(Callable&& in_c) : c_(std::move(in_c)) { }
CallableHolder(CallableHolder const&) = delete;
CallableHolder(CallableHolder&&) = default;

protected:
constexpr Callable&& move() { return std::move(c_); }

template <typename... A>
auto operator()(A&&... a) -> decltype(auto) {
return c_(std::forward<A>(a)...);
}

public:
virtual void invoke() override {
// Skipping arguments for now (ActionType use case)
this->operator()();
}

private:
Callable c_;
};


struct TermAction : TermFinished {
using TermStateType = TermState;
using ActionContType = std::vector<ActionType>;
using EpochActionContType = std::unordered_map<EpochType,ActionContType>;
using EpochStateType = std::unordered_map<EpochType,TermStateType>;
using TermStateType = TermState;
using ActionContType = std::vector<ActionType>;
using CallableActionType = std::unique_ptr<CallableBase>;
using CallableVecType = std::vector<CallableActionType>;
using CallableContType = std::unordered_map<EpochType,CallableVecType>;
using EpochActionContType = std::unordered_map<EpochType,ActionContType>;
using EpochStateType = std::unordered_map<EpochType,TermStateType>;

TermAction() = default;

Expand All @@ -71,6 +110,9 @@ struct TermAction : TermFinished {
void clearActions();
void clearActionsEpoch(EpochType const& epoch);

template <typename Callable>
void addActionUnique(EpochType const& epoch, Callable&& c);

public:
/*
* Deprecated methods for adding a termination action
Expand All @@ -83,14 +125,19 @@ struct TermAction : TermFinished {
protected:
void triggerAllActions(EpochType const& epoch, EpochStateType const& state);
void triggerAllEpochActions(EpochType const& epoch);
void afterAddEpochAction(EpochType const& epoch);

protected:
// Container for hold global termination actions
ActionContType global_term_actions_ = {};
// Container to hold actions to perform when an epoch has terminated
EpochActionContType epoch_actions_ = {};
// Container for "callables"; restricted in semantic wrt std::function
CallableContType epoch_callable_actions_ = {};
};

}} /* end namespace vt::term */

#include "vt/termination/term_action.impl.h"

#endif /*INCLUDED_TERMINATION_TERM_ACTION_H*/
24 changes: 24 additions & 0 deletions src/vt/termination/term_action.impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

#if !defined INCLUDED_VT_TERMINATION_TERM_ACTION_IMPL_H
#define INCLUDED_VT_TERMINATION_TERM_ACTION_IMPL_H

#include "vt/config.h"
#include "vt/termination/term_common.h"
#include "vt/termination/term_action.h"

#include <memory>
#include <unordered_map>

namespace vt { namespace term {

template <typename Callable>
void TermAction::addActionUnique(EpochType const& epoch, Callable&& c) {
std::unique_ptr<CallableBase> callable =
std::make_unique<CallableHolder<Callable>>(std::move(c));
epoch_callable_actions_[epoch].emplace_back(std::move(callable));
afterAddEpochAction(epoch);
}

}} /* end namespace vt::term */

#endif /*INCLUDED_VT_TERMINATION_TERM_ACTION_IMPL_H*/
98 changes: 98 additions & 0 deletions tests/unit/termination/test_termination_action_callable.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
//@HEADER
// ************************************************************************
//
// test_termination_action_callable.cc
// vt (Virtual Transport)
// Copyright (C) 2018 NTESS, LLC
//
// Under the terms of Contract DE-NA-0003525 with NTESS, LLC,
// the U.S. Government retains certain rights in this software.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the Corporation nor the names of the
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact [email protected]
//
// ************************************************************************
//@HEADER
*/

#include "test_termination_action_common.h"

#if !defined INCLUDED_TERMINATION_ACTION_CALLABLE_H
#define INCLUDED_TERMINATION_ACTION_CALLABLE_H

namespace vt { namespace tests { namespace unit {

static bool has_finished_1 = false;
static bool has_finished_2 = false;
static int num = 0;

// no need for a parameterized fixture
struct TestTermActionCallable : action::SimpleFixture {};

TEST_F(TestTermActionCallable, test_add_action_unique) {

has_finished_1 = has_finished_2 = false;
num = 0;

vt::theCollective()->barrier();

// create an epoch and a related termination flag
auto ep = ::vt::theTerm()->makeEpochCollective();

// assign an arbitrary action to be triggered at
// the end of the epoch and toggle the previous flag.
::vt::theTerm()->addActionEpoch(ep, [&]{
debug_print(term, node, "current epoch:{:x} finished\n", ep);
EXPECT_FALSE(has_finished_1);
EXPECT_TRUE(has_finished_2 == false or num == 1);
has_finished_1 = true;
num++;
});

// assign a callable to be triggered after
// the action submitted for the given epoch.
::vt::theTerm()->addActionUnique(ep, [&]{
debug_print(term, node, "trigger callable for epoch:{:x}\n", ep);
EXPECT_FALSE(has_finished_2);
EXPECT_TRUE(has_finished_1 == false or num == 1);
has_finished_2 = true;
num++;
});

if (channel::me == channel::root) {
action::compute(ep);
}

::vt::theTerm()->finishedEpoch(ep);
}

}}} // namespace vt::tests::unit::action

#endif /*INCLUDED_TERMINATION_ACTION_CALLABLE_H*/
12 changes: 12 additions & 0 deletions tests/unit/termination/test_termination_action_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ struct BaseFixture : Base {
int depth_ = 1;
};

struct SimpleFixture : TestParallelHarness {
virtual void SetUp() {
// explicit inheritance
TestParallelHarness::SetUp();
// set channel counting ranks
channel::root = 0;
channel::me = vt::theContext()->getNode();
channel::all = vt::theContext()->getNumNodes();
vtAssertExpr(channel::all > 1);
}
};

// epoch sequence creation
std::vector<vt::EpochType> newEpochSeq(
int nb=1, bool rooted=false, bool useDS=false
Expand Down

0 comments on commit 48fcf37

Please sign in to comment.