Skip to content

Commit

Permalink
#1672: move total work logic into TemperedWMin
Browse files Browse the repository at this point in the history
  • Loading branch information
cz4rs committed Mar 16, 2022
1 parent 7293993 commit 53549af
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 30 deletions.
19 changes: 0 additions & 19 deletions src/vt/vrt/collection/balance/model/load_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,25 +227,6 @@ struct LoadModel
return {};
}

/**
* \brief Provide an estimate of the total work for a given object during
* a specified interval
*
* \param[in] object The object whose total work is desired
* \param[in] when The interval in which the work takes place
*
* \return Estimated total time of work for the object
*
* The `updateLoads` method must have been called before any call to
* this.
*/
TimeType getTotalWork(
ElementIDStruct object, PhaseOffset when,
double alpha, double beta, double gamma
) {
return alpha * getLoad(object, when) + beta * getComm(object, when) + gamma;
}

/**
* \brief Compute how many phases of past load statistics need to be
* kept availble for the model to use
Expand Down
15 changes: 7 additions & 8 deletions src/vt/vrt/collection/balance/temperedlb/temperedlb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,6 @@ void TemperedLB::inputParams(balance::SpecEntry* spec) {
num_iters_ = spec->getOrDefault<int32_t>("iters", num_iters_);
num_trials_ = spec->getOrDefault<int32_t>("trials", num_trials_);

alpha_ = spec->getOrDefault<double>("alpha", alpha_);
beta_ = spec->getOrDefault<double>("beta", beta_);
gamma_ = spec->getOrDefault<double>("gamma", gamma_);

deterministic_ = spec->getOrDefault<bool>("deterministic", deterministic_);
rollback_ = spec->getOrDefault<bool>("rollback", rollback_);
target_pole_ = spec->getOrDefault<bool>("targetpole", target_pole_);
Expand Down Expand Up @@ -488,10 +484,7 @@ void TemperedLB::doLBStages(TimeType start_imb) {
cur_objs_.clear();
for (auto obj : *load_model_) {
if (obj.isMigratable()) {
// TODO: `beta_ * communication` component is still missing here
cur_objs_[obj] = alpha_ * load_model_->getLoad(
obj, {balance::PhaseOffset::NEXT_PHASE, balance::PhaseOffset::WHOLE_PHASE}
) + gamma_;
cur_objs_[obj] = getTotalWork(obj);
}
}
this_new_load_ = this_load;
Expand Down Expand Up @@ -1352,4 +1345,10 @@ void TemperedLB::migrate() {
vtAssertExpr(false);
}

TimeType TemperedLB::getTotalWork(const elm::ElementIDStruct& obj) {
return load_model_->getLoad(
obj, {balance::PhaseOffset::NEXT_PHASE, balance::PhaseOffset::WHOLE_PHASE}
);
}

}}}} /* end namespace vt::vrt::collection::lb */
4 changes: 1 addition & 3 deletions src/vt/vrt/collection/balance/temperedlb/temperedlb.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ struct TemperedLB : BaseLB {
ElementLoadType::iterator selectObject(
LoadType size, ElementLoadType& load, std::set<ObjIDType> const& available
);
virtual TimeType getTotalWork(const elm::ElementIDStruct& obj);

void lazyMigrateObjsTo(EpochType epoch, NodeType node, ObjsType const& objs);
void inLazyMigrations(balance::LazyMigrationMsg* msg);
Expand All @@ -126,9 +127,6 @@ struct TemperedLB : BaseLB {
uint8_t k_cur_ = 0;
uint16_t iter_ = 0;
uint16_t trial_ = 0;
double alpha_ = 1.0;
double beta_ = 0.0;
double gamma_ = 0.0;
uint16_t num_iters_ = 4;
/**
* \brief Number of trials
Expand Down
19 changes: 19 additions & 0 deletions src/vt/vrt/collection/balance/temperedwmin/temperedwmin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@

#include "vt/vrt/collection/balance/temperedwmin/temperedwmin.h"

#include "vt/vrt/collection/balance/lb_common.h"
#include "vt/vrt/collection/balance/model/load_model.h"

namespace vt { namespace vrt { namespace collection { namespace lb {

/*static*/ std::unordered_map<std::string, std::string>
Expand Down Expand Up @@ -72,4 +75,20 @@ Default: 0.0
return map;
}

void TemperedWMin::inputParams(balance::SpecEntry* spec) {
TemperedLB::inputParams(spec);

alpha_ = spec->getOrDefault<double>("alpha", alpha_);
beta_ = spec->getOrDefault<double>("beta", beta_);
gamma_ = spec->getOrDefault<double>("gamma", gamma_);
}

TimeType TemperedWMin::getTotalWork(const elm::ElementIDStruct& obj) {
balance::PhaseOffset when =
{balance::PhaseOffset::NEXT_PHASE, balance::PhaseOffset::WHOLE_PHASE};

return alpha_ * load_model_->getLoad(obj, when)
+ beta_ * load_model_->getComm(obj, when) + gamma_;
}

}}}} // namespace vt::vrt::collection::lb
10 changes: 10 additions & 0 deletions src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ struct TemperedWMin : TemperedLB {

public:
static std::unordered_map<std::string, std::string> getInputKeysWithHelp();

void inputParams(balance::SpecEntry* spec) override;

protected:
TimeType getTotalWork(const elm::ElementIDStruct& obj) override;

private:
double alpha_ = 1.0;
double beta_ = 0.0;
double gamma_ = 0.0;
};

}}}} /* end namespace vt::vrt::collection::lb */
Expand Down

0 comments on commit 53549af

Please sign in to comment.