Skip to content

Commit

Permalink
#1570: LB: combine all stats into a single handler/reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
lifflander committed Sep 30, 2021
1 parent 2d05393 commit 3326d8b
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 125 deletions.
2 changes: 1 addition & 1 deletion src/vt/collective/reduce/operators/functors/plus_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct PlusOp< std::vector<T> > {
void operator()(std::vector<T>& v1, std::vector<T> const& v2) {
vtAssert(v1.size() == v2.size(), "Sizes of vectors in reduce must be equal");
for (size_t ii = 0; ii < v1.size(); ++ii)
v1[ii] += v2[ii];
v1[ii] = v1[ii] + v2[ii];
}
};

Expand Down
182 changes: 78 additions & 104 deletions src/vt/vrt/collection/balance/lb_invoke/lb_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,66 +298,58 @@ void LBManager::finishedLB(PhaseType phase) {
}
}

void LBManager::statsHandler(StatsMsgType* msg) {
auto in_stat_vec = msg->getConstVal();

for (auto&& st : in_stat_vec) {
auto stat = st.stat_;
auto max = st.max();
auto min = st.min();
auto avg = st.avg();
auto sum = st.sum();
auto npr = st.npr();
auto car = st.N_;
auto imb = st.I();
auto var = st.var();
auto stdv = st.stdv();
auto skew = st.skew();
auto krte = st.krte();

stats[stat][lb::StatisticQuantity::max] = max;
stats[stat][lb::StatisticQuantity::min] = min;
stats[stat][lb::StatisticQuantity::avg] = avg;
stats[stat][lb::StatisticQuantity::sum] = sum;
stats[stat][lb::StatisticQuantity::npr] = npr;
stats[stat][lb::StatisticQuantity::car] = car;
stats[stat][lb::StatisticQuantity::var] = var;
stats[stat][lb::StatisticQuantity::npr] = npr;
stats[stat][lb::StatisticQuantity::imb] = imb;
stats[stat][lb::StatisticQuantity::std] = stdv;
stats[stat][lb::StatisticQuantity::skw] = skew;
stats[stat][lb::StatisticQuantity::kur] = krte;

if (theContext()->getNode() == 0) {
vt_print(
lb,
"BaseLB: Statistic={}: "
" max={:.2f}, min={:.2f}, sum={:.2f}, avg={:.2f}, var={:.2f},"
" stdev={:.2f}, nproc={}, cardinality={} skewness={:.2f}, kurtosis={:.2f},"
" npr={}, imb={:.2f}, num_stats={}\n",
lb::lb_stat_name_[stat],
max, min, sum, avg, var, stdv, npr, car, skew, krte, npr, imb,
stats.size()
);
}
}
}

void LBManager::computeStatistics(PhaseType phase) {
vt_debug_print(
normal, lb,
"computeStatistics\n"
);

computeStatisticsOver(phase, lb::Statistic::P_l);
computeStatisticsOver(phase, lb::Statistic::O_l);

// if (comm_aware_) {
// computeStatisticsOver(lb::Statistic::P_c);
// computeStatisticsOver(lb::Statistic::O_c);
// }
// @todo: add P_c, P_t, O_c, O_t
}

void LBManager::statsHandler(StatsMsgType* msg) {
auto in = msg->getConstVal();
auto max = in.max();
auto min = in.min();
auto avg = in.avg();
auto sum = in.sum();
auto npr = in.npr();
auto car = in.N_;
auto imb = in.I();
auto var = in.var();
auto stdv = in.stdv();
auto skew = in.skew();
auto krte = in.krte();
auto the_stat = msg->stat_;

stats[the_stat][lb::StatisticQuantity::max] = max;
stats[the_stat][lb::StatisticQuantity::min] = min;
stats[the_stat][lb::StatisticQuantity::avg] = avg;
stats[the_stat][lb::StatisticQuantity::sum] = sum;
stats[the_stat][lb::StatisticQuantity::npr] = npr;
stats[the_stat][lb::StatisticQuantity::car] = car;
stats[the_stat][lb::StatisticQuantity::var] = var;
stats[the_stat][lb::StatisticQuantity::npr] = npr;
stats[the_stat][lb::StatisticQuantity::imb] = imb;
stats[the_stat][lb::StatisticQuantity::std] = stdv;
stats[the_stat][lb::StatisticQuantity::skw] = skew;
stats[the_stat][lb::StatisticQuantity::kur] = krte;

if (theContext()->getNode() == 0) {
vt_print(
lb,
"BaseLB: Statistic={}: "
" max={:.2f}, min={:.2f}, sum={:.2f}, avg={:.2f}, var={:.2f},"
" stdev={:.2f}, nproc={}, cardinality={} skewness={:.2f}, kurtosis={:.2f},"
" npr={}, imb={:.2f}, num_stats={}\n",
lb::lb_stat_name_[the_stat],
max, min, sum, avg, var, stdv, npr, car, skew, krte, npr, imb,
stats.size()
);
}
}

void LBManager::computeStatisticsOver(PhaseType phase, lb::Statistic stat) {
using ReduceOp = collective::PlusOp<balance::LoadData>;
using ReduceOp = collective::PlusOp<std::vector<balance::LoadData>>;

bool comm_collectives_ = false;

Expand All @@ -366,12 +358,12 @@ void LBManager::computeStatisticsOver(PhaseType phase, lb::Statistic stat) {
>(proxy_);

TimeType total_load = 0;
std::vector<balance::LoadData> lds;
std::vector<balance::LoadData> P_c;
for (auto elm : *model_) {
auto work = model_->getWork(
elm, {balance::PhaseOffset::NEXT_PHASE, balance::PhaseOffset::WHOLE_PHASE}
);
lds.emplace_back(work);
P_c.emplace_back(LoadData{lb::Statistic::O_l, work});
total_load += work;
}

Expand All @@ -382,61 +374,43 @@ void LBManager::computeStatisticsOver(PhaseType phase, lb::Statistic stat) {
comm_data = &iter->second;
}

switch (stat) {
case lb::Statistic::P_l: {
// Perform the reduction for P_l -> processor load only
auto msg = makeMessage<StatsMsgType>(lb::Statistic::P_l, total_load);
proxy_.template reduce<ReduceOp>(msg,cb);
}
break;
case lb::Statistic::O_l: {
// Perform the reduction for O_l -> object load only
auto msg = makeMessage<StatsMsgType>(
lb::Statistic::O_l, reduceVec(std::move(lds))
);
proxy_.template reduce<ReduceOp>(msg,cb);
}
break;
case lb::Statistic::P_c: {
// Perform the reduction for P_c -> processor comm only
double comm_load = 0.0;
for (auto&& elm : *comm_data) {
if (not comm_collectives_ and isCollectiveComm(elm.first.cat_)) {
continue;
}
if (elm.first.onNode() or elm.first.selfEdge()) {
continue;
}
//vt_print(lb, "comm_load={}, elm={}\n", comm_load, elm.second.bytes);
comm_load += elm.second.bytes;
std::vector<LoadData> lstats;
lstats.emplace_back(LoadData{lb::Statistic::P_l, total_load});
lstats.emplace_back(reduceVec(lb::Statistic::P_c, std::move(P_c)));

double comm_load = 0.0;
for (auto&& elm : *comm_data) {
if (not comm_collectives_ and isCollectiveComm(elm.first.cat_)) {
continue;
}
auto msg = makeMessage<StatsMsgType>(lb::Statistic::P_c, comm_load);
proxy_.template reduce<ReduceOp>(msg,cb);
}
break;
case lb::Statistic::O_c: {
// Perform the reduction for O_c -> object comm only
std::vector<balance::LoadData> lds2;
for (auto&& elm : *comm_data) {
// Only count object-to-object direct edges in the O_c statistics
if (elm.first.cat_ == balance::CommCategory::SendRecv and not elm.first.selfEdge()) {
lds2.emplace_back(balance::LoadData(elm.second.bytes));
}
if (elm.first.onNode() or elm.first.selfEdge()) {
continue;
}
auto msg = makeMessage<StatsMsgType>(
lb::Statistic::O_c, reduceVec(std::move(lds2)
));
proxy_.template reduce<ReduceOp>(msg,cb);
//vt_print(lb, "comm_load={}, elm={}\n", comm_load, elm.second.bytes);
comm_load += elm.second.bytes;
}
break;
default:
break;

lstats.emplace_back(LoadData{lb::Statistic::P_c, comm_load});

std::vector<balance::LoadData> O_c;
for (auto&& elm : *comm_data) {
// Only count object-to-object direct edges in the O_c statistics
if (elm.first.cat_ == balance::CommCategory::SendRecv and not elm.first.selfEdge()) {
O_c.emplace_back(LoadData{lb::Statistic::O_c, elm.second.bytes});
}
}

lstats.emplace_back(reduceVec(lb::Statistic::O_c, std::move(O_c)));

auto msg = makeMessage<StatsMsgType>(std::move(lstats));
proxy_.template reduce<ReduceOp>(msg,cb);
}

balance::LoadData
LBManager::reduceVec(std::vector<balance::LoadData>&& vec) const {
balance::LoadData reduce_ld(0.0f);
LBManager::reduceVec(
lb::Statistic stat, std::vector<balance::LoadData>&& vec
) const {
balance::LoadData reduce_ld(stat, 0.0f);
if (vec.size() == 0) {
return reduce_ld;
} else {
Expand Down
5 changes: 3 additions & 2 deletions src/vt/vrt/collection/balance/lb_invoke/lb_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,10 @@ struct LBManager : runtime::component::Component<LBManager> {

private:
void computeStatistics(PhaseType phase);
void computeStatisticsOver(PhaseType phase, lb::Statistic stats);
void statsHandler(StatsMsgType* msg);
balance::LoadData reduceVec(std::vector<balance::LoadData>&& vec) const;
balance::LoadData reduceVec(
lb::Statistic stat, std::vector<balance::LoadData>&& vec
) const;
bool isCollectiveComm(balance::CommCategory cat) const;

private:
Expand Down
33 changes: 18 additions & 15 deletions src/vt/vrt/collection/balance/stats_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,25 @@
#include <algorithm>
#include <limits>
#include <type_traits>
#include <vector>

namespace vt { namespace vrt { namespace collection { namespace balance {

struct LoadData {
using isByteCopyable = std::true_type;

LoadData() = default;
LoadData(TimeType const y)
LoadData(lb::Statistic in_stat, TimeType const y)
: max_(y), sum_(y), min_(y), avg_(y), M2_(0.0f), M3_(0.0f), M4_(0.0f),
N_(1), P_(y not_eq 0.0f)
N_(1), P_(y not_eq 0.0f), stat_(in_stat)
{
vt_debug_print(
verbose, lb,
"LoadData: in={}, N_={}\n", y, N_
);
}

friend LoadData operator+(LoadData a1, LoadData const& a2) {
friend LoadData operator+(LoadData& a1, LoadData const& a2) {
vt_debug_print(
verbose, lb,
"operator+: a1.N_={}, a2.N_={}\n", a1.N_, a2.N_
Expand Down Expand Up @@ -110,6 +113,7 @@ struct LoadData {
a1.max_ = std::max(a1.max_, a2.max_);
a1.sum_ += a2.sum_;
a1.P_ += a2.P_;
a1.stat_ = a2.stat_;

return a1;
}
Expand Down Expand Up @@ -160,34 +164,33 @@ struct LoadData {
TimeType M4_ = 0.0;
int32_t N_ = 0;
int32_t P_ = 0;
lb::Statistic stat_ = lb::Statistic::P_l;
};

static_assert(
vt::messaging::is_byte_copyable_t<LoadData>::value,
"Must be trivially copyable to avoid serialization."
);

struct NodeStatsMsg : NonSerialized<
collective::ReduceTMsg<LoadData>,
struct NodeStatsMsg : SerializeRequired<
collective::ReduceTMsg<std::vector<LoadData>>,
NodeStatsMsg
>
{
using MessageParentType = NonSerialized<
collective::ReduceTMsg<LoadData>,
using MessageParentType = SerializeRequired<
collective::ReduceTMsg<std::vector<LoadData>>,
NodeStatsMsg
>;

NodeStatsMsg() = default;
NodeStatsMsg(lb::Statistic in_stat, TimeType const in_total_load)
: MessageParentType(LoadData(in_total_load)),
stat_(in_stat)
{ }
NodeStatsMsg(lb::Statistic in_stat, LoadData&& ld)
: MessageParentType(std::move(ld)),
stat_(in_stat)
explicit NodeStatsMsg(std::vector<LoadData> ld)
: MessageParentType(std::move(ld))
{ }

lb::Statistic stat_ = lb::Statistic::P_l;
template <typename SerializerT>
void serialize(SerializerT& s) {
MessageParentType::serialize(s);
}
};

}}}} /* end namespace vt::vrt::collection::balance */
Expand Down
10 changes: 7 additions & 3 deletions src/vt/vrt/collection/balance/temperedlb/temperedlb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,16 @@ void TemperedLB::doLBStages(TimeType start_imb) {

if (rollback_ || theConfig()->vt_debug_temperedlb || (iter_ == num_iters_ - 1)) {
runInEpochCollective("TemperedLB::doLBStages -> P_l reduce", [=] {
using ReduceOp = collective::PlusOp<balance::LoadData>;
using ReduceOp = collective::PlusOp<std::vector<balance::LoadData>>;
auto cb = vt::theCB()->makeBcast<
TemperedLB, StatsMsgType, &TemperedLB::loadStatsHandler
>(this->proxy_);
// Perform the reduction for P_l -> processor load only
auto msg = makeMessage<StatsMsgType>(Statistic::P_l, this_new_load_);
auto msg = makeMessage<StatsMsgType>(
std::vector<balance::LoadData>{
{balance::LoadData{Statistic::P_l, this_new_load_}}
}
);
this->proxy_.template reduce<ReduceOp>(msg,cb);
});
}
Expand Down Expand Up @@ -405,7 +409,7 @@ void TemperedLB::doLBStages(TimeType start_imb) {
}

void TemperedLB::loadStatsHandler(StatsMsgType* msg) {
auto in = msg->getConstVal();
auto in = msg->getConstVal()[0];
new_imbalance_ = in.I();

auto this_node = theContext()->getNode();
Expand Down

0 comments on commit 3326d8b

Please sign in to comment.