From a683f9ebd78ca1ef693de69f689c5e9a44eadb5b Mon Sep 17 00:00:00 2001 From: philkr Date: Thu, 3 Sep 2015 14:28:55 -0700 Subject: [PATCH] Exposing solver callbacks to python --- python/caffe/_caffe.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index ccd5776ac40..9c66d6ec20c 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -205,6 +205,27 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) { return bp::object(); } +template +class PythonCallback: public Solver::Callback { + protected: + bp::object on_start_, on_gradients_ready_; + + public: + PythonCallback(bp::object on_start, bp::object on_gradients_ready) + : on_start_(on_start), on_gradients_ready_(on_gradients_ready) { } + virtual void on_gradients_ready() { + on_gradients_ready_(); + } + virtual void on_start() { + on_start_(); + } +}; +template +void Solver_addCallback(Solver * solver, bp::object on_start, + bp::object on_gradients_ready) { + solver->add_callback(new PythonCallback(on_start, on_gradients_ready)); +} + BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1); BOOST_PYTHON_MODULE(_caffe) { @@ -283,6 +304,7 @@ BOOST_PYTHON_MODULE(_caffe) { .add_property("test_nets", bp::make_function(&Solver::test_nets, bp::return_internal_reference<>())) .add_property("iter", &Solver::iter) + .def("add_callback", &Solver_addCallback) .def("solve", static_cast::*)(const char*)>( &Solver::Solve), SolveOverloads()) .def("step", &Solver::Step)