Skip to content

Commit

Permalink
Exposing solver callbacks to python
Browse files Browse the repository at this point in the history
  • Loading branch information
philkr committed Apr 29, 2016
1 parent f623d04 commit 35e91e5
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,27 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
return bp::object();
}

template<typename Dtype>
class PythonCallback: public Solver<Dtype>::Callback {
protected:
bp::object on_start_, on_gradients_ready_;

public:
PythonCallback(bp::object on_start, bp::object on_gradients_ready)
: on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
virtual void on_gradients_ready() {
on_gradients_ready_();
}
virtual void on_start() {
on_start_();
}
};
template<typename Dtype>
void Solver_addCallback(Solver<Dtype> * solver, bp::object on_start,
bp::object on_gradients_ready) {
solver->add_callback(new PythonCallback<Dtype>(on_start, on_gradients_ready));
}

BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);

BOOST_PYTHON_MODULE(_caffe) {
Expand Down Expand Up @@ -307,6 +328,7 @@ BOOST_PYTHON_MODULE(_caffe) {
.add_property("test_nets", bp::make_function(&Solver<Dtype>::test_nets,
bp::return_internal_reference<>()))
.add_property("iter", &Solver<Dtype>::iter)
.def("add_callback", &Solver_addCallback<Dtype>)
.def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
&Solver<Dtype>::Solve), SolveOverloads())
.def("step", &Solver<Dtype>::Step)
Expand Down

0 comments on commit 35e91e5

Please sign in to comment.