From 742c93f31be4c874aa5fd0103f25f8a2f8d4d63d Mon Sep 17 00:00:00 2001 From: philkr Date: Mon, 23 May 2016 20:09:45 -0700 Subject: [PATCH] Exposing load_hdf5 and save_hdf5 to python --- python/caffe/_caffe.cpp | 12 +++++++++++- python/caffe/test/test_net.py | 14 ++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index 32b5d921094..48a0c8f2e95 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -114,6 +114,14 @@ void Net_Save(const Net& net, string filename) { WriteProtoToBinaryFile(net_param, filename.c_str()); } +void Net_SaveHDF5(const Net& net, string filename) { + net.ToHDF5(filename); +} + +void Net_LoadHDF5(Net* net, string filename) { + net->CopyTrainedLayersFromHDF5(filename.c_str()); +} + void Net_SetInputArrays(Net* net, bp::object data_obj, bp::object labels_obj) { // check that this network has an input MemoryDataLayer @@ -267,7 +275,9 @@ BOOST_PYTHON_MODULE(_caffe) { bp::return_value_policy())) .def("_set_input_arrays", &Net_SetInputArrays, bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >()) - .def("save", &Net_Save); + .def("save", &Net_Save) + .def("save_hdf5", &Net_SaveHDF5) + .def("load_hdf5", &Net_LoadHDF5); BP_REGISTER_SHARED_PTR_TO_PYTHON(Net); bp::class_, shared_ptr >, boost::noncopyable>( diff --git a/python/caffe/test/test_net.py b/python/caffe/test/test_net.py index aad828aa8aa..4cacfcd05bb 100644 --- a/python/caffe/test/test_net.py +++ b/python/caffe/test/test_net.py @@ -79,3 +79,17 @@ def test_save_and_read(self): for i in range(len(self.net.params[name])): self.assertEqual(abs(self.net.params[name][i].data - net2.params[name][i].data).sum(), 0) + + def test_save_hdf5(self): + f = tempfile.NamedTemporaryFile(mode='w+', delete=False) + f.close() + self.net.save_hdf5(f.name) + net_file = simple_net_file(self.num_output) + net2 = caffe.Net(net_file, caffe.TRAIN) + net2.load_hdf5(f.name) + os.remove(net_file) + os.remove(f.name) + for name in self.net.params: + for i in range(len(self.net.params[name])): + self.assertEqual(abs(self.net.params[name][i].data + - net2.params[name][i].data).sum(), 0)