Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snapshot on signal #2253

Merged
merged 1 commit into from
Aug 22, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion include/caffe/solver.hpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
#define CAFFE_OPTIMIZATION_SOLVER_HPP_

#include <boost/function.hpp>
#include <string>
#include <vector>

#include "caffe/net.hpp"

namespace caffe {

/**
* @brief Enumeration of actions that a client of the Solver may request by
* implementing the Solver's action request function, which a
* a client may optionally provide in order to request early termination
* or saving a snapshot without exiting. In the executable caffe, this
* mechanism is used to allow the snapshot to be saved when stopping
* execution with a SIGINT (Ctrl-C).
*/
namespace SolverAction {
enum Enum {
NONE = 0, // Take no special action.
STOP = 1, // Stop training. snapshot_after_train controls whether a
// snapshot is created.
SNAPSHOT = 2 // Take a snapshot, and keep training.
};
}

/**
* @brief Type of a function that returns a Solver Action enumeration.
*/
typedef boost::function<SolverAction::Enum()> ActionCallback;

/**
* @brief An interface for classes that perform optimization on Net%s.
*
Expand All @@ -23,6 +45,12 @@ class Solver {
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();

// Client of the Solver optionally may call this in order to set the function
// that the solver uses to see what action it should take (e.g. snapshot or
// exit training early).
void SetActionFunction(ActionCallback func);
SolverAction::Enum GetRequestedAction();
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
Expand Down Expand Up @@ -84,6 +112,13 @@ class Solver {
// in data parallelism
const Solver* const root_solver_;

// A function that can be set by a client of the Solver to provide indication
// that it wants a snapshot saved and/or to exit early.
ActionCallback action_request_function_;

// True iff a request to stop early was received.
bool requested_early_exit_;

DISABLE_COPY_AND_ASSIGN(Solver);
};

Expand Down
24 changes: 24 additions & 0 deletions include/caffe/util/signal_handler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
#define INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_

#include "caffe/proto/caffe.pb.h"
#include "caffe/solver.hpp"

namespace caffe {

class SignalHandler {
public:
// Contructor. Specify what action to take when a signal is received.
SignalHandler(SolverAction::Enum SIGINT_action,
SolverAction::Enum SIGHUP_action);
~SignalHandler();
ActionCallback GetActionFunction();
private:
SolverAction::Enum CheckForSignals() const;
SolverAction::Enum SIGINT_action_;
SolverAction::Enum SIGHUP_action_;
};

} // namespace caffe

#endif // INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_
70 changes: 62 additions & 8 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,31 @@

namespace caffe {

template<typename Dtype>
void Solver<Dtype>::SetActionFunction(ActionCallback func) {
action_request_function_ = func;
}

template<typename Dtype>
SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
if (action_request_function_) {
// If the external request function has been set, call it.
return action_request_function_();
}
return SolverAction::NONE;
}

template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
: net_(), callbacks_(), root_solver_(root_solver) {
: net_(), callbacks_(), root_solver_(root_solver),
requested_early_exit_(false) {
Init(param);
}

template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
: net_(), callbacks_(), root_solver_(root_solver) {
: net_(), callbacks_(), root_solver_(root_solver),
requested_early_exit_(false) {
SolverParameter param;
ReadProtoFromTextFileOrDie(param_file, &param);
Init(param);
Expand Down Expand Up @@ -195,6 +211,10 @@ void Solver<Dtype>::Step(int iters) {
&& (iter_ > 0 || param_.test_initialization())
&& Caffe::root_solver()) {
TestAll();
if (requested_early_exit_) {
// Break out of the while loop because stop was requested while testing.
break;
}
}

for (int i = 0; i < callbacks_.size(); ++i) {
Expand Down Expand Up @@ -250,12 +270,20 @@ void Solver<Dtype>::Step(int iters) {
// the number of times the weights have been updated.
++iter_;

SolverAction::Enum request = GetRequestedAction();

// Save a snapshot if needed.
if (param_.snapshot()
&& iter_ % param_.snapshot() == 0
&& Caffe::root_solver()) {
if ((param_.snapshot()
&& iter_ % param_.snapshot() == 0
&& Caffe::root_solver()) ||
(request == SolverAction::SNAPSHOT)) {
Snapshot();
}
if (SolverAction::STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}

Expand All @@ -265,6 +293,9 @@ void Solver<Dtype>::Solve(const char* resume_file) {
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();

// Initialize to false every time we start solving.
requested_early_exit_ = false;

if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
Expand All @@ -279,6 +310,10 @@ void Solver<Dtype>::Solve(const char* resume_file) {
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
Snapshot();
}
if (requested_early_exit_) {
LOG(INFO) << "Optimization stopped early.";
return;
}
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
Expand All @@ -296,10 +331,11 @@ void Solver<Dtype>::Solve(const char* resume_file) {
LOG(INFO) << "Optimization Done.";
}


template <typename Dtype>
void Solver<Dtype>::TestAll() {
for (int test_net_id = 0; test_net_id < test_nets_.size(); ++test_net_id) {
for (int test_net_id = 0;
test_net_id < test_nets_.size() && !requested_early_exit_;
++test_net_id) {
Test(test_net_id);
}
}
Expand All @@ -317,6 +353,21 @@ void Solver<Dtype>::Test(const int test_net_id) {
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
Dtype loss = 0;
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
SolverAction::Enum request = GetRequestedAction();
// Check to see if stoppage of testing/training has been requested.
while (request != SolverAction::NONE) {
if (SolverAction::SNAPSHOT == request) {
Snapshot();
} else if (SolverAction::STOP == request) {
requested_early_exit_ = true;
}
request = GetRequestedAction();
}
if (requested_early_exit_) {
// break out of test loop.
break;
}

Dtype iter_loss;
const vector<Blob<Dtype>*>& result =
test_net->Forward(bottom_vec, &iter_loss);
Expand All @@ -341,6 +392,10 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}
}
if (requested_early_exit_) {
LOG(INFO) << "Test interrupted.";
return;
}
if (param_.test_compute_loss()) {
loss /= param_.test_iter(test_net_id);
LOG(INFO) << "Test loss: " << loss;
Expand All @@ -361,7 +416,6 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}


template <typename Dtype>
void Solver<Dtype>::Snapshot() {
CHECK(Caffe::root_solver());
Expand Down
115 changes: 115 additions & 0 deletions src/caffe/util/signal_handler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#include <boost/bind.hpp>
#include <glog/logging.h>

#include <signal.h>
#include <csignal>

#include "caffe/util/signal_handler.h"

namespace {
static volatile sig_atomic_t got_sigint = false;
static volatile sig_atomic_t got_sighup = false;
static bool already_hooked_up = false;

void handle_signal(int signal) {
switch (signal) {
case SIGHUP:
got_sighup = true;
break;
case SIGINT:
got_sigint = true;
break;
}
}

void HookupHandler() {
if (already_hooked_up) {
LOG(FATAL) << "Tried to hookup signal handlers more than once.";
}
already_hooked_up = true;

struct sigaction sa;
// Setup the handler
sa.sa_handler = &handle_signal;
// Restart the system call, if at all possible
sa.sa_flags = SA_RESTART;
// Block every signal during the handler
sigfillset(&sa.sa_mask);
// Intercept SIGHUP and SIGINT
if (sigaction(SIGHUP, &sa, NULL) == -1) {
LOG(FATAL) << "Cannot install SIGHUP handler.";
}
if (sigaction(SIGINT, &sa, NULL) == -1) {
LOG(FATAL) << "Cannot install SIGINT handler.";
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe "unhook" these signals in SignalHandler::~SignalHandler()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK will do.


// Set the signal handlers to the default.
void UnhookHandler() {
if (already_hooked_up) {
struct sigaction sa;
// Setup the sighub handler
sa.sa_handler = SIG_DFL;
// Restart the system call, if at all possible
sa.sa_flags = SA_RESTART;
// Block every signal during the handler
sigfillset(&sa.sa_mask);
// Intercept SIGHUP and SIGINT
if (sigaction(SIGHUP, &sa, NULL) == -1) {
LOG(FATAL) << "Cannot uninstall SIGHUP handler.";
}
if (sigaction(SIGINT, &sa, NULL) == -1) {
LOG(FATAL) << "Cannot uninstall SIGINT handler.";
}

already_hooked_up = false;
}
}

// Return true iff a SIGINT has been received since the last time this
// function was called.
bool GotSIGINT() {
bool result = got_sigint;
got_sigint = false;
return result;
}

// Return true iff a SIGHUP has been received since the last time this
// function was called.
bool GotSIGHUP() {
bool result = got_sighup;
got_sighup = false;
return result;
}
} // namespace

namespace caffe {

SignalHandler::SignalHandler(SolverAction::Enum SIGINT_action,
SolverAction::Enum SIGHUP_action):
SIGINT_action_(SIGINT_action),
SIGHUP_action_(SIGHUP_action) {
HookupHandler();
}

SignalHandler::~SignalHandler() {
UnhookHandler();
}

SolverAction::Enum SignalHandler::CheckForSignals() const {
if (GotSIGHUP()) {
return SIGHUP_action_;
}
if (GotSIGINT()) {
return SIGINT_action_;
}
return SolverAction::NONE;
}

// Return the function that the solver can use to find out if a snapshot or
// early exit is being requested.
ActionCallback SignalHandler::GetActionFunction() {
return boost::bind(&SignalHandler::CheckForSignals, this);
}

} // namespace caffe
Loading