Skip to content

Commit

Permalink
Add signal handler and early exit/snapshot to Solver.
Browse files Browse the repository at this point in the history
Add signal handler and early exit/snapshot to Solver.

Add signal handler and early exit/snapshot to Solver.

Also check for exit and snapshot when testing.

Skip running test after early exit.

Fix more lint.

Rebase on master.
  • Loading branch information
jyegerlehner committed Aug 14, 2015
1 parent c6b9f58 commit c946a00
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 12 deletions.
22 changes: 21 additions & 1 deletion include/caffe/solver.hpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
#define CAFFE_OPTIMIZATION_SOLVER_HPP_

#include <boost/function.hpp>
#include <string>
#include <vector>

#include "caffe/net.hpp"

namespace caffe {

/**
* @brief Type of a function that returns a Solver Action enumeration.
*/
typedef boost::function<SolverParameter_Action()> ActionCallback;

/**
* @brief An interface for classes that perform optimization on Net%s.
*
Expand All @@ -23,6 +28,12 @@ class Solver {
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();

// Client of the Solver optionally may call this in order to set the function
// that the solver uses to see what action it should take (e.g. snapshot or
// exit training early).
void SetActionFunction(ActionCallback func);
SolverParameter_Action GetRequestedAction();
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
Expand Down Expand Up @@ -66,6 +77,8 @@ class Solver {
string SnapshotToBinaryProto();
string SnapshotToHDF5();
// The test routine
// stop_was_requested will be set to true iff a request to stop training
// was received whilst testing.
void TestAll();
void Test(const int test_net_id = 0);
virtual void SnapshotSolverState(const string& model_filename) = 0;
Expand All @@ -84,6 +97,13 @@ class Solver {
// in data parallelism
const Solver* const root_solver_;

// A function that can be set by a client of the Solver to provide indication
// that it wants a snapshot saved and/or to exit early.
ActionCallback action_request_function_;

// True iff a request to stop early was received.
bool requested_early_exit_;

DISABLE_COPY_AND_ASSIGN(Solver);
};

Expand Down
24 changes: 24 additions & 0 deletions include/caffe/util/signal_handler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
#define INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_

#include "caffe/proto/caffe.pb.h"
#include "caffe/solver.hpp"

namespace caffe {

class SignalHandler {
public:
// Contructor. Specify what action to take when a signal is received.
SignalHandler(SolverParameter_Action SIGINT_action,
SolverParameter_Action SIGHUP_action);
ActionCallback GetActionFunction();
private:
SignalHandler(); // Not implemented.
SolverParameter_Action CheckForSignals() const;
SolverParameter_Action SIGINT_action_;
SolverParameter_Action SIGHUP_action_;
};

} // namespace caffe

#endif // INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
13 changes: 13 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,19 @@ message SolverParameter {

// If false, don't save a snapshot after training finishes.
optional bool snapshot_after_train = 28 [default = true];

// Enumeration of actions that a client of the Solver may request by
// implementing the Solver's action request function, which a
// a client may optionally provide in order to request early termination
// or saving a snapshot without exiting. In the executable caffe, this
// mechanism is used to allow the snapshot to be saved when stopping
// execution with a SIGINT (Ctrl-C).
enum Action {
NONE = 0; // Take no special action.
STOP = 1; // Stop training. snapshot_after_train controls whether a snapshot
// is created.
SNAPSHOT = 2; // Take a snapshot, and keep training.
}
}

// A message that stores the solver snapshots
Expand Down
74 changes: 64 additions & 10 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,31 @@

namespace caffe {

template<typename Dtype>
void Solver<Dtype>::SetActionFunction(ActionCallback func) {
action_request_function_ = func;
}

template<typename Dtype>
SolverParameter_Action Solver<Dtype>::GetRequestedAction() {
if (action_request_function_) {
// If the external request function has been set, call it.
return action_request_function_();
}
return SolverParameter_Action_NONE;
}

template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
: net_(), callbacks_(), root_solver_(root_solver) {
Solver<Dtype>::Solver(const SolverParameter& param)
: net_(), callbacks_(), root_solver_(root_solver),
requested_early_exit_(false) {
Init(param);
}

template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
: net_(), callbacks_(), root_solver_(root_solver) {
Solver<Dtype>::Solver(const string& param_file)
: net_(), callbacks_(), root_solver_(root_solver),
requested_early_exit_(false) {
SolverParameter param;
ReadProtoFromTextFileOrDie(param_file, &param);
Init(param);
Expand Down Expand Up @@ -195,6 +211,10 @@ void Solver<Dtype>::Step(int iters) {
&& (iter_ > 0 || param_.test_initialization())
&& Caffe::root_solver()) {
TestAll();
if (requested_early_exit_) {
// Break out of the while loop because stop was requested while testing.
break;
}
}

for (int i = 0; i < callbacks_.size(); ++i) {
Expand Down Expand Up @@ -250,12 +270,20 @@ void Solver<Dtype>::Step(int iters) {
// the number of times the weights have been updated.
++iter_;

SolverParameter_Action request = GetRequestedAction();

// Save a snapshot if needed.
if (param_.snapshot()
&& iter_ % param_.snapshot() == 0
&& Caffe::root_solver()) {
if ((param_.snapshot()
&& iter_ % param_.snapshot() == 0
&& Caffe::root_solver()) ||
(request == SolverParameter_Action_SNAPSHOT)) {
Snapshot();
}
if (SolverParameter_Action_STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}

Expand All @@ -265,6 +293,9 @@ void Solver<Dtype>::Solve(const char* resume_file) {
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();

// Initialize to false every time we start solving.
requested_early_exit_ = false;

if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
Expand All @@ -279,6 +310,10 @@ void Solver<Dtype>::Solve(const char* resume_file) {
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
Snapshot();
}
if (requested_early_exit_) {
LOG(INFO) << "Optimization stopped early.";
return;
}
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
Expand All @@ -296,10 +331,11 @@ void Solver<Dtype>::Solve(const char* resume_file) {
LOG(INFO) << "Optimization Done.";
}


template <typename Dtype>
void Solver<Dtype>::TestAll() {
for (int test_net_id = 0; test_net_id < test_nets_.size(); ++test_net_id) {
for (int test_net_id = 0;
test_net_id < test_nets_.size() && !requested_early_exit_;
++test_net_id) {
Test(test_net_id);
}
}
Expand All @@ -317,6 +353,21 @@ void Solver<Dtype>::Test(const int test_net_id) {
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
Dtype loss = 0;
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
SolverParameter_Action request = GetRequestedAction();
// Check to see if stoppage of testing/training has been requested.
while (request != SolverParameter_Action_NONE) {
if (SolverParameter_Action_SNAPSHOT == request) {
Snapshot();
} else if (SolverParameter_Action_STOP == request) {
requested_early_exit_ = true;
}
request = GetRequestedAction();
}
if (requested_early_exit_) {
// break out of test loop.
break;
}

Dtype iter_loss;
const vector<Blob<Dtype>*>& result =
test_net->Forward(bottom_vec, &iter_loss);
Expand All @@ -341,6 +392,10 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}
}
if (requested_early_exit_) {
LOG(INFO) << "Test interrupted.";
return;
}
if (param_.test_compute_loss()) {
loss /= param_.test_iter(test_net_id);
LOG(INFO) << "Test loss: " << loss;
Expand All @@ -361,7 +416,6 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}


template <typename Dtype>
void Solver<Dtype>::Snapshot() {
CHECK(Caffe::root_solver());
Expand Down
91 changes: 91 additions & 0 deletions src/caffe/util/signal_handler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include <boost/bind.hpp>
#include <boost/thread/mutex.hpp>
#include <boost/thread/thread.hpp>
#include <glog/logging.h>

#include <signal.h>
#include <csignal>

#include "caffe/util/signal_handler.h"

namespace {
static volatile sig_atomic_t got_sigint = false;
static volatile sig_atomic_t got_sighup = false;
static bool already_hooked_up = false;

void handle_signal(int signal) {
switch (signal) {
case SIGHUP:
got_sighup = true;
break;
case SIGINT:
got_sigint = true;
break;
}
}

void HookupHandler() {
if (already_hooked_up) {
LOG(FATAL) << "Tried to hookup signal handlers more than once.";
}
already_hooked_up = true;

struct sigaction sa;
// Setup the sighub handler
sa.sa_handler = &handle_signal;
// Restart the system call, if at all possible
sa.sa_flags = SA_RESTART;
// Block every signal during the handler
sigfillset(&sa.sa_mask);
// Intercept SIGHUP and SIGINT
if (sigaction(SIGHUP, &sa, NULL) == -1) {
LOG(FATAL) << "Cannot install SIGHUP handler.";
}
if (sigaction(SIGINT, &sa, NULL) == -1) {
LOG(FATAL) << "Cannot install SIGINT handler.";
}
}

// Return true iff a SIGINT has been received since the last time this
// function was called.
bool GotSIGINT() {
bool result = got_sigint;
got_sigint = false;
return result;
}

// Return true iff a SIGHUP has been received since the last time this
// function was called.
bool GotSIGHUP() {
bool result = got_sighup;
got_sighup = false;
return result;
}
} // namespace

namespace caffe {

SignalHandler::SignalHandler(SolverParameter_Action SIGINT_action,
SolverParameter_Action SIGHUP_action):
SIGINT_action_(SIGINT_action),
SIGHUP_action_(SIGHUP_action) {
HookupHandler();
}

SolverParameter_Action SignalHandler::CheckForSignals() const {
if (GotSIGHUP()) {
return SIGHUP_action_;
}
if (GotSIGINT()) {
return SIGINT_action_;
}
return SolverParameter_Action_NONE;
}

// Return the function that the solver can use to find out if a snapshot or
// early exit is being requested.
ActionCallback SignalHandler::GetActionFunction() {
return boost::bind(&SignalHandler::CheckForSignals, this);
}

} // namespace caffe
30 changes: 29 additions & 1 deletion tools/caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace bp = boost::python;

#include "boost/algorithm/string.hpp"
#include "caffe/caffe.hpp"
#include "caffe/util/signal_handler.h"

using caffe::Blob;
using caffe::Caffe;
Expand Down Expand Up @@ -39,6 +40,12 @@ DEFINE_string(weights, "",
"separated by ','. Cannot be set simultaneously with snapshot.");
DEFINE_int32(iterations, 50,
"The number of iterations to run.");
DEFINE_string(sigint_effect, "stop",
"Optional; action to take when a SIGINT signal is received: "
"snapshot, stop or none.");
DEFINE_string(sighup_effect, "snapshot",
"Optional; action to take when a SIGHUP signal is received: "
"snapshot, stop or none.");

// A simple registry for caffe commands.
typedef int (*BrewFunction)();
Expand Down Expand Up @@ -126,6 +133,20 @@ void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
}
}

caffe::SolverParameter_Action GetRequestedAction(
const std::string& flag_value) {
if (flag_value == "stop") {
return caffe::SolverParameter_Action_STOP;
}
if (flag_value == "snapshot") {
return caffe::SolverParameter_Action_SNAPSHOT;
}
if (flag_value == "none") {
return caffe::SolverParameter_Action_NONE;
}
LOG(FATAL) << "Invalid signal effect \""<< flag_value << "\" was specified";
}

// Train / Finetune a model.
int train() {
CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train.";
Expand Down Expand Up @@ -165,7 +186,14 @@ int train() {
Caffe::set_solver_count(gpus.size());
}

shared_ptr<Solver<float> > solver(caffe::GetSolver<float>(solver_param));
caffe::SignalHandler signal_handler(
GetRequestedAction(FLAGS_sigint_effect),
GetRequestedAction(FLAGS_sighup_effect));

shared_ptr<caffe::Solver<float> >
solver(caffe::GetSolver<float>(solver_param));

solver->SetActionFunction(signal_handler.GetActionFunction());

if (FLAGS_snapshot.size()) {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
Expand Down

0 comments on commit c946a00

Please sign in to comment.