From ff19d5f5c010dd8d6bfcf768b4fe27d0458f17df Mon Sep 17 00:00:00 2001 From: J Yegerlehner Date: Fri, 3 Apr 2015 16:11:23 -0500 Subject: [PATCH] Add signal handler and early exit/snapshot to Solver. Add signal handler and early exit/snapshot to Solver. Add signal handler and early exit/snapshot to Solver. Also check for exit and snapshot when testing. Skip running test after early exit. Fix more lint. Rebase on master. Finish rebase on master. Fixups per review comments. Redress review comments. Lint. Correct error message wording. --- include/caffe/solver.hpp | 37 ++++++++- include/caffe/util/signal_handler.h | 24 ++++++ src/caffe/solver.cpp | 70 +++++++++++++++-- src/caffe/util/signal_handler.cpp | 115 ++++++++++++++++++++++++++++ tools/caffe.cpp | 32 +++++++- 5 files changed, 268 insertions(+), 10 deletions(-) create mode 100644 include/caffe/util/signal_handler.h create mode 100644 src/caffe/util/signal_handler.cpp diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index ab12ef1b1bd..aba3e036004 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -1,6 +1,6 @@ #ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_ #define CAFFE_OPTIMIZATION_SOLVER_HPP_ - +#include #include #include @@ -8,6 +8,28 @@ namespace caffe { +/** + * @brief Enumeration of actions that a client of the Solver may request by + * implementing the Solver's action request function, which a + * a client may optionally provide in order to request early termination + * or saving a snapshot without exiting. In the executable caffe, this + * mechanism is used to allow the snapshot to be saved when stopping + * execution with a SIGINT (Ctrl-C). + */ + namespace SolverAction { + enum Enum { + NONE = 0, // Take no special action. + STOP = 1, // Stop training. snapshot_after_train controls whether a + // snapshot is created. + SNAPSHOT = 2 // Take a snapshot, and keep training. + }; + } + +/** + * @brief Type of a function that returns a Solver Action enumeration. + */ +typedef boost::function ActionCallback; + /** * @brief An interface for classes that perform optimization on Net%s. * @@ -23,6 +45,12 @@ class Solver { void Init(const SolverParameter& param); void InitTrainNet(); void InitTestNets(); + + // Client of the Solver optionally may call this in order to set the function + // that the solver uses to see what action it should take (e.g. snapshot or + // exit training early). + void SetActionFunction(ActionCallback func); + SolverAction::Enum GetRequestedAction(); // The main entry of the solver function. In default, iter will be zero. Pass // in a non-zero iter number to resume training for a pre-trained net. virtual void Solve(const char* resume_file = NULL); @@ -84,6 +112,13 @@ class Solver { // in data parallelism const Solver* const root_solver_; + // A function that can be set by a client of the Solver to provide indication + // that it wants a snapshot saved and/or to exit early. + ActionCallback action_request_function_; + + // True iff a request to stop early was received. + bool requested_early_exit_; + DISABLE_COPY_AND_ASSIGN(Solver); }; diff --git a/include/caffe/util/signal_handler.h b/include/caffe/util/signal_handler.h new file mode 100644 index 00000000000..fb84c65bd2e --- /dev/null +++ b/include/caffe/util/signal_handler.h @@ -0,0 +1,24 @@ +#ifndef INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_ +#define INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_ + +#include "caffe/proto/caffe.pb.h" +#include "caffe/solver.hpp" + +namespace caffe { + +class SignalHandler { + public: + // Contructor. Specify what action to take when a signal is received. + SignalHandler(SolverAction::Enum SIGINT_action, + SolverAction::Enum SIGHUP_action); + ~SignalHandler(); + ActionCallback GetActionFunction(); + private: + SolverAction::Enum CheckForSignals() const; + SolverAction::Enum SIGINT_action_; + SolverAction::Enum SIGHUP_action_; +}; + +} // namespace caffe + +#endif // INCLUDE_CAFFE_UTIL_SIGNAL_HANDLER_H_ diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 9348e11c249..394ec3b3ad7 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -17,15 +17,31 @@ namespace caffe { +template +void Solver::SetActionFunction(ActionCallback func) { + action_request_function_ = func; +} + +template +SolverAction::Enum Solver::GetRequestedAction() { + if (action_request_function_) { + // If the external request function has been set, call it. + return action_request_function_(); + } + return SolverAction::NONE; +} + template Solver::Solver(const SolverParameter& param, const Solver* root_solver) - : net_(), callbacks_(), root_solver_(root_solver) { + : net_(), callbacks_(), root_solver_(root_solver), + requested_early_exit_(false) { Init(param); } template Solver::Solver(const string& param_file, const Solver* root_solver) - : net_(), callbacks_(), root_solver_(root_solver) { + : net_(), callbacks_(), root_solver_(root_solver), + requested_early_exit_(false) { SolverParameter param; ReadProtoFromTextFileOrDie(param_file, ¶m); Init(param); @@ -195,6 +211,10 @@ void Solver::Step(int iters) { && (iter_ > 0 || param_.test_initialization()) && Caffe::root_solver()) { TestAll(); + if (requested_early_exit_) { + // Break out of the while loop because stop was requested while testing. + break; + } } for (int i = 0; i < callbacks_.size(); ++i) { @@ -250,12 +270,20 @@ void Solver::Step(int iters) { // the number of times the weights have been updated. ++iter_; + SolverAction::Enum request = GetRequestedAction(); + // Save a snapshot if needed. - if (param_.snapshot() - && iter_ % param_.snapshot() == 0 - && Caffe::root_solver()) { + if ((param_.snapshot() + && iter_ % param_.snapshot() == 0 + && Caffe::root_solver()) || + (request == SolverAction::SNAPSHOT)) { Snapshot(); } + if (SolverAction::STOP == request) { + requested_early_exit_ = true; + // Break out of training loop. + break; + } } } @@ -265,6 +293,9 @@ void Solver::Solve(const char* resume_file) { LOG(INFO) << "Solving " << net_->name(); LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); + // Initialize to false every time we start solving. + requested_early_exit_ = false; + if (resume_file) { LOG(INFO) << "Restoring previous solver status from " << resume_file; Restore(resume_file); @@ -279,6 +310,10 @@ void Solver::Solve(const char* resume_file) { && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) { Snapshot(); } + if (requested_early_exit_) { + LOG(INFO) << "Optimization stopped early."; + return; + } // After the optimization is done, run an additional train and test pass to // display the train and test loss/outputs if appropriate (based on the // display and test_interval settings, respectively). Unlike in the rest of @@ -296,10 +331,11 @@ void Solver::Solve(const char* resume_file) { LOG(INFO) << "Optimization Done."; } - template void Solver::TestAll() { - for (int test_net_id = 0; test_net_id < test_nets_.size(); ++test_net_id) { + for (int test_net_id = 0; + test_net_id < test_nets_.size() && !requested_early_exit_; + ++test_net_id) { Test(test_net_id); } } @@ -317,6 +353,21 @@ void Solver::Test(const int test_net_id) { const shared_ptr >& test_net = test_nets_[test_net_id]; Dtype loss = 0; for (int i = 0; i < param_.test_iter(test_net_id); ++i) { + SolverAction::Enum request = GetRequestedAction(); + // Check to see if stoppage of testing/training has been requested. + while (request != SolverAction::NONE) { + if (SolverAction::SNAPSHOT == request) { + Snapshot(); + } else if (SolverAction::STOP == request) { + requested_early_exit_ = true; + } + request = GetRequestedAction(); + } + if (requested_early_exit_) { + // break out of test loop. + break; + } + Dtype iter_loss; const vector*>& result = test_net->Forward(bottom_vec, &iter_loss); @@ -341,6 +392,10 @@ void Solver::Test(const int test_net_id) { } } } + if (requested_early_exit_) { + LOG(INFO) << "Test interrupted."; + return; + } if (param_.test_compute_loss()) { loss /= param_.test_iter(test_net_id); LOG(INFO) << "Test loss: " << loss; @@ -361,7 +416,6 @@ void Solver::Test(const int test_net_id) { } } - template void Solver::Snapshot() { CHECK(Caffe::root_solver()); diff --git a/src/caffe/util/signal_handler.cpp b/src/caffe/util/signal_handler.cpp new file mode 100644 index 00000000000..5d764ec524f --- /dev/null +++ b/src/caffe/util/signal_handler.cpp @@ -0,0 +1,115 @@ +#include +#include + +#include +#include + +#include "caffe/util/signal_handler.h" + +namespace { + static volatile sig_atomic_t got_sigint = false; + static volatile sig_atomic_t got_sighup = false; + static bool already_hooked_up = false; + + void handle_signal(int signal) { + switch (signal) { + case SIGHUP: + got_sighup = true; + break; + case SIGINT: + got_sigint = true; + break; + } + } + + void HookupHandler() { + if (already_hooked_up) { + LOG(FATAL) << "Tried to hookup signal handlers more than once."; + } + already_hooked_up = true; + + struct sigaction sa; + // Setup the handler + sa.sa_handler = &handle_signal; + // Restart the system call, if at all possible + sa.sa_flags = SA_RESTART; + // Block every signal during the handler + sigfillset(&sa.sa_mask); + // Intercept SIGHUP and SIGINT + if (sigaction(SIGHUP, &sa, NULL) == -1) { + LOG(FATAL) << "Cannot install SIGHUP handler."; + } + if (sigaction(SIGINT, &sa, NULL) == -1) { + LOG(FATAL) << "Cannot install SIGINT handler."; + } + } + + // Set the signal handlers to the default. + void UnhookHandler() { + if (already_hooked_up) { + struct sigaction sa; + // Setup the sighub handler + sa.sa_handler = SIG_DFL; + // Restart the system call, if at all possible + sa.sa_flags = SA_RESTART; + // Block every signal during the handler + sigfillset(&sa.sa_mask); + // Intercept SIGHUP and SIGINT + if (sigaction(SIGHUP, &sa, NULL) == -1) { + LOG(FATAL) << "Cannot uninstall SIGHUP handler."; + } + if (sigaction(SIGINT, &sa, NULL) == -1) { + LOG(FATAL) << "Cannot uninstall SIGINT handler."; + } + + already_hooked_up = false; + } + } + + // Return true iff a SIGINT has been received since the last time this + // function was called. + bool GotSIGINT() { + bool result = got_sigint; + got_sigint = false; + return result; + } + + // Return true iff a SIGHUP has been received since the last time this + // function was called. + bool GotSIGHUP() { + bool result = got_sighup; + got_sighup = false; + return result; + } +} // namespace + +namespace caffe { + +SignalHandler::SignalHandler(SolverAction::Enum SIGINT_action, + SolverAction::Enum SIGHUP_action): + SIGINT_action_(SIGINT_action), + SIGHUP_action_(SIGHUP_action) { + HookupHandler(); +} + +SignalHandler::~SignalHandler() { + UnhookHandler(); +} + +SolverAction::Enum SignalHandler::CheckForSignals() const { + if (GotSIGHUP()) { + return SIGHUP_action_; + } + if (GotSIGINT()) { + return SIGINT_action_; + } + return SolverAction::NONE; +} + +// Return the function that the solver can use to find out if a snapshot or +// early exit is being requested. +ActionCallback SignalHandler::GetActionFunction() { + return boost::bind(&SignalHandler::CheckForSignals, this); +} + +} // namespace caffe diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 9f31b37ac2b..ff63860a3c1 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -12,6 +12,7 @@ namespace bp = boost::python; #include "boost/algorithm/string.hpp" #include "caffe/caffe.hpp" +#include "caffe/util/signal_handler.h" using caffe::Blob; using caffe::Caffe; @@ -39,6 +40,12 @@ DEFINE_string(weights, "", "separated by ','. Cannot be set simultaneously with snapshot."); DEFINE_int32(iterations, 50, "The number of iterations to run."); +DEFINE_string(sigint_effect, "stop", + "Optional; action to take when a SIGINT signal is received: " + "snapshot, stop or none."); +DEFINE_string(sighup_effect, "snapshot", + "Optional; action to take when a SIGHUP signal is received: " + "snapshot, stop or none."); // A simple registry for caffe commands. typedef int (*BrewFunction)(); @@ -126,6 +133,22 @@ void CopyLayers(caffe::Solver* solver, const std::string& model_list) { } } +// Translate the signal effect the user specified on the command-line to the +// corresponding enumeration. +caffe::SolverAction::Enum GetRequestedAction( + const std::string& flag_value) { + if (flag_value == "stop") { + return caffe::SolverAction::STOP; + } + if (flag_value == "snapshot") { + return caffe::SolverAction::SNAPSHOT; + } + if (flag_value == "none") { + return caffe::SolverAction::NONE; + } + LOG(FATAL) << "Invalid signal effect \""<< flag_value << "\" was specified"; +} + // Train / Finetune a model. int train() { CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train."; @@ -165,7 +188,14 @@ int train() { Caffe::set_solver_count(gpus.size()); } - shared_ptr > solver(caffe::GetSolver(solver_param)); + caffe::SignalHandler signal_handler( + GetRequestedAction(FLAGS_sigint_effect), + GetRequestedAction(FLAGS_sighup_effect)); + + shared_ptr > + solver(caffe::GetSolver(solver_param)); + + solver->SetActionFunction(signal_handler.GetActionFunction()); if (FLAGS_snapshot.size()) { LOG(INFO) << "Resuming from " << FLAGS_snapshot;