Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Solver to allow interactive stepping #1228

Merged
merged 2 commits into from
Jan 7, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Solver {
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
virtual ~Solver() {}
inline shared_ptr<Net<Dtype> > net() { return net_; }
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
Expand All @@ -34,9 +35,6 @@ class Solver {
int iter() { return iter_; }

protected:
// PreSolve is run before any solving iteration starts, allowing one to
// put up some scaffold.
virtual void PreSolve() {}
// Get the update value for the current iteration.
virtual void ComputeUpdateValue() = 0;
// The Solver::Snapshot function implements the basic snapshotting utility
Expand Down Expand Up @@ -73,14 +71,14 @@ template <typename Dtype>
class SGDSolver : public Solver<Dtype> {
public:
explicit SGDSolver(const SolverParameter& param)
: Solver<Dtype>(param) {}
: Solver<Dtype>(param) { PreSolve(); }
explicit SGDSolver(const string& param_file)
: Solver<Dtype>(param_file) {}
: Solver<Dtype>(param_file) { PreSolve(); }

const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }

protected:
virtual void PreSolve();
void PreSolve();
Dtype GetLearningRate();
virtual void ComputeUpdateValue();
virtual void SnapshotSolverState(SolverState * state);
Expand Down
3 changes: 2 additions & 1 deletion python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ BOOST_PYTHON_MODULE(_caffe) {
.add_property("test_nets", &PySGDSolver::test_nets)
.add_property("iter", &PySGDSolver::iter)
.def("solve", &PySGDSolver::Solve)
.def("solve", &PySGDSolver::SolveResume);
.def("solve", &PySGDSolver::SolveResume)
.def("step", &PySGDSolver::Step);

bp::class_<vector<shared_ptr<PyNet> > >("NetVec")
.def(bp::vector_indexing_suite<vector<shared_ptr<PyNet> >, true>());
Expand Down
1 change: 1 addition & 0 deletions python/caffe/_caffe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class PySGDSolver {
vector<shared_ptr<PyNet> > test_nets() { return test_nets_; }
int iter() { return solver_->iter(); }
void Solve() { return solver_->Solve(); }
void Step(int iters) { solver_->Step(iters); }
void SolveResume(const string& resume_file);

protected:
Expand Down
73 changes: 38 additions & 35 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
LOG(INFO) << "Initializing solver from parameters: " << std::endl
<< param.DebugString();
param_ = param;
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
if (param_.random_seed() >= 0) {
Caffe::set_random_seed(param_.random_seed());
}
// Scaffolding code
InitTrainNet();
InitTestNets();
LOG(INFO) << "Solver scaffolding done.";
iter_ = 0;
current_step_ = 0;
}

template <typename Dtype>
Expand Down Expand Up @@ -155,39 +158,15 @@ void Solver<Dtype>::InitTestNets() {
}

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
PreSolve();

iter_ = 0;
current_step_ = 0;
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
}
// Remember the initial iter_ value; will be non-zero if we loaded from a
// resume_file above.
void Solver<Dtype>::Step(int iters) {
vector<Blob<Dtype>*> bottom_vec;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be removed since you changed the call to ForwardPrefilled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changed call is in Solve, this is in Step, where the call is to ForwardBackward. An alternative is to remove bottom_vec but call ForwardPrefilled and Backward separately; that's what I've now implemented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think I'll revert that, because it makes #1663 awkward. We should probably just clean up the Net interface to avoid these dummy vectors at some later point.

const int start_iter = iter_;

const int stop_iter = iter_ + iters;
int average_loss = this->param_.average_loss();

CHECK_GE(average_loss, 1) << "average_loss should be non-negative.";

vector<Dtype> losses;
Dtype smoothed_loss = 0;

// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
vector<Blob<Dtype>*> bottom_vec;
for (; iter_ < param_.max_iter(); ++iter_) {
// Save a snapshot if needed.
if (param_.snapshot() && iter_ > start_iter &&
iter_ % param_.snapshot() == 0) {
Snapshot();
}

for (; iter_ < stop_iter; ++iter_) {
if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())) {
TestAll();
Expand Down Expand Up @@ -227,13 +206,36 @@ void Solver<Dtype>::Solve(const char* resume_file) {
}
}
}

ComputeUpdateValue();
net_->Update();

// Save a snapshot if needed.
if (param_.snapshot() && (iter_ + 1) % param_.snapshot() == 0) {
Snapshot();
}
}
}

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();

if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
}

// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
Step(param_.max_iter() - iter_);
// If we haven't already, save a snapshot after optimization, unless
// overridden by setting snapshot_after_train := false
if (param_.snapshot_after_train()
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
Snapshot();
}
// Always save a snapshot after optimization, unless overridden by setting
// snapshot_after_train := false.
if (param_.snapshot_after_train()) { Snapshot(); }
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
Expand All @@ -242,7 +244,7 @@ void Solver<Dtype>::Solve(const char* resume_file) {
// display the loss, which is computed in the forward pass.
if (param_.display() && iter_ % param_.display() == 0) {
Dtype loss;
net_->Forward(bottom_vec, &loss);
net_->ForwardPrefilled(&loss);
LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
}
if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
Expand Down Expand Up @@ -328,14 +330,15 @@ void Solver<Dtype>::Snapshot() {
string model_filename, snapshot_filename;
const int kBufferSize = 20;
char iter_str_buffer[kBufferSize];
snprintf(iter_str_buffer, kBufferSize, "_iter_%d", iter_);
// Add one to iter_ to get the number of iterations that have completed.
snprintf(iter_str_buffer, kBufferSize, "_iter_%d", iter_ + 1);
filename += iter_str_buffer;
model_filename = filename + ".caffemodel";
LOG(INFO) << "Snapshotting to " << model_filename;
WriteProtoToBinaryFile(net_param, model_filename.c_str());
SolverState state;
SnapshotSolverState(&state);
state.set_iter(iter_);
state.set_iter(iter_ + 1);
state.set_learned_net(model_filename);
state.set_current_step(current_step_);
snapshot_filename = filename + ".solverstate";
Expand Down