Skip to content

Commit

Permalink
Merge pull request BVLC#4227 from philkr/save_hdf5
Browse files Browse the repository at this point in the history
[pycaffe] expose saving/loading nets as hdf5 to python
  • Loading branch information
shelhamer committed Jun 3, 2016
2 parents f45c3a6 + 742c93f commit df412ac
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
12 changes: 11 additions & 1 deletion python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ void Net_Save(const Net<Dtype>& net, string filename) {
WriteProtoToBinaryFile(net_param, filename.c_str());
}

void Net_SaveHDF5(const Net<Dtype>& net, string filename) {
net.ToHDF5(filename);
}

void Net_LoadHDF5(Net<Dtype>* net, string filename) {
net->CopyTrainedLayersFromHDF5(filename.c_str());
}

void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj,
bp::object labels_obj) {
// check that this network has an input MemoryDataLayer
Expand Down Expand Up @@ -267,7 +275,9 @@ BOOST_PYTHON_MODULE(_caffe) {
bp::return_value_policy<bp::copy_const_reference>()))
.def("_set_input_arrays", &Net_SetInputArrays,
bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >())
.def("save", &Net_Save);
.def("save", &Net_Save)
.def("save_hdf5", &Net_SaveHDF5)
.def("load_hdf5", &Net_LoadHDF5);
BP_REGISTER_SHARED_PTR_TO_PYTHON(Net<Dtype>);

bp::class_<Blob<Dtype>, shared_ptr<Blob<Dtype> >, boost::noncopyable>(
Expand Down
14 changes: 14 additions & 0 deletions python/caffe/test/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,17 @@ def test_save_and_read(self):
for i in range(len(self.net.params[name])):
self.assertEqual(abs(self.net.params[name][i].data
- net2.params[name][i].data).sum(), 0)

def test_save_hdf5(self):
f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
f.close()
self.net.save_hdf5(f.name)
net_file = simple_net_file(self.num_output)
net2 = caffe.Net(net_file, caffe.TRAIN)
net2.load_hdf5(f.name)
os.remove(net_file)
os.remove(f.name)
for name in self.net.params:
for i in range(len(self.net.params[name])):
self.assertEqual(abs(self.net.params[name][i].data
- net2.params[name][i].data).sum(), 0)

0 comments on commit df412ac

Please sign in to comment.