Skip to content

Commit

Permalink
feat: add elapsed time to training metrics
Browse files Browse the repository at this point in the history
also, when a measure history is subsampled, the sampling rate is saved in metrics.json
  • Loading branch information
Bycob authored and mergify[bot] committed Oct 14, 2021
1 parent 93815d7 commit fe5fc41
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 7 deletions.
14 changes: 14 additions & 0 deletions src/backends/caffe/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1839,6 +1839,12 @@ namespace dd
if (!ad_mllib.has("resume") || !ad_mllib.get("resume").get<bool>())
this->clear_all_meas_per_iter();
float smoothed_loss = 0.0;

auto training_start = std::chrono::steady_clock::now();
double prev_elapsed_time_ms = this->get_meas("elapsed_time_ms");
if (std::isnan(prev_elapsed_time_ms))
prev_elapsed_time_ms = 0;

while (solver->iter_ < solver->param_.max_iter()
&& this->_tjob_running.load())
{
Expand Down Expand Up @@ -1962,7 +1968,15 @@ namespace dd
this->add_meas("train_loss", smoothed_loss);
this->add_meas_per_iter("train_loss", smoothed_loss);
this->add_meas("iter_time", avg_fb_time);
this->add_meas("iteration_duration_ms", avg_fb_time);
this->add_meas("remain_time", est_remain_time);
int64_t elapsed_time_ms
= prev_elapsed_time_ms
+ std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - training_start)
.count();
this->add_meas("elapsed_time_ms", elapsed_time_ms);
this->add_meas_per_iter("elapsed_time_ms", elapsed_time_ms);

caffe::SGDSolver<float> *sgd_solver
= static_cast<caffe::SGDSolver<float> *>(solver.get());
Expand Down
15 changes: 15 additions & 0 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,10 @@ namespace dd
this->_logger->info("Training for {} iterations", iterations - it);
}

bool resume = it > 0;
if (!resume)
this->clear_all_meas_per_iter();

// [multigpu] initialize all module
typedef struct
{
Expand Down Expand Up @@ -826,6 +830,11 @@ namespace dd
{
throw MLLibBadParamException("No data found in training dataset");
}
auto training_start = steady_clock::now();

double prev_elapsed_time_ms = this->get_meas("elapsed_time_ms");
if (std::isnan(prev_elapsed_time_ms))
prev_elapsed_time_ms = 0;

// `it` is the iteration count (not epoch)
while (it < iterations)
Expand Down Expand Up @@ -1026,6 +1035,12 @@ namespace dd
}
this->add_meas("remain_time", remain_time_ms / 1000.0);
this->add_meas("train_loss", train_loss);
int64_t elapsed_time_ms
= prev_elapsed_time_ms
+ duration_cast<milliseconds>(tstop - training_start)
.count();
this->add_meas("elapsed_time_ms", elapsed_time_ms);
this->add_meas_per_iter("elapsed_time_ms", elapsed_time_ms);
this->add_meas_per_iter("learning_rate", base_lr);
this->add_meas_per_iter("train_loss", train_loss);
int64_t elapsed_it = it + 1;
Expand Down
18 changes: 11 additions & 7 deletions src/mllibstrategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,15 @@ namespace dd
* @param hist measure history vector
* @param npoints max number of output points
*/
std::vector<double> subsample_hist(const std::vector<double> &hist,
const int &npoints) const
int subsample_hist(const std::vector<double> &hist,
std::vector<double> &sub_hist, const int &npoints) const
{
std::vector<double> sub_hist;
sub_hist.clear();
sub_hist.reserve(npoints);
int rpoints = std::ceil(hist.size() / npoints) + 1;
int rpoints = static_cast<int>(std::ceil(hist.size() / (double)npoints));
for (size_t i = 0; i < hist.size(); i += rpoints)
sub_hist.push_back(hist.at(i));
return sub_hist;
return rpoints;
}

/**
Expand All @@ -249,8 +249,12 @@ namespace dd
while (hit != _meas_per_iter.end())
{
if (npoints > 0 && (int)(*hit).second.size() > npoints)
meas_hist.add((*hit).first + "_hist",
subsample_hist((*hit).second, npoints));
{
std::vector<double> sub_hist;
int sampling = subsample_hist((*hit).second, sub_hist, npoints);
meas_hist.add((*hit).first + "_hist", sub_hist);
meas_hist.add((*hit).first + "_sampling", sampling);
}
else
meas_hist.add((*hit).first + "_hist", (*hit).second);
++hit;
Expand Down

0 comments on commit fe5fc41

Please sign in to comment.