diff --git a/paddle/inference/CMakeLists.txt b/paddle/inference/CMakeLists.txt index 8437b2b21942ea..02ca8a45a851d2 100644 --- a/paddle/inference/CMakeLists.txt +++ b/paddle/inference/CMakeLists.txt @@ -8,27 +8,6 @@ cc_library(paddle_fluid_api # Merge all modules into a simgle static library cc_library(paddle_fluid DEPS paddle_fluid_api ${FLUID_CORE_MODULES}) -# ptools -# just for testing, we may need to change the storing format for inference_model -# and move the dependent of pickle. -# download from http://www.picklingtools.com/ -# build in the C++ sub-directory, using command -# make -f Makefile.Linux libptools.so -set(PTOOLS_LIB) -set(PTOOLS_ROOT $ENV{PTOOLS_ROOT} CACHE PATH "Folder contains PicklingTools") -find_path(PTOOLS_INC_DIR chooseser.h PATHS ${PTOOLS_ROOT}/C++) -find_library(PTOOLS_SHARED_LIB NAMES ptools PATHS ${PTOOLS_ROOT}/C++) -if(PTOOLS_INC_DIR AND PTOOLS_SHARED_LIB) - add_definitions(-DPADDLE_USE_PTOOLS) - set(PTOOLS_LIB ptools) - message(STATUS "Found PicklingTools: ${PTOOLS_SHARED_LIB}") - add_library(${PTOOLS_LIB} SHARED IMPORTED GLOBAL) - set_property(TARGET ${PTOOLS_LIB} PROPERTY IMPORTED_LOCATION ${PTOOLS_SHARED_LIB}) - include_directories(${PTOOLS_ROOT}/C++) - include_directories(${PTOOLS_ROOT}/C++/opencontainers_1_8_5/include) - add_definitions(-DOC_NEW_STYLE_INCLUDES) # used in ptools -endif() - add_executable(example example.cc) if(APPLE) set(OPTIONAL_LINK_FLAGS) diff --git a/paddle/inference/example.cc b/paddle/inference/example.cc index 9711b20e6fb409..0c18b45624dedc 100644 --- a/paddle/inference/example.cc +++ b/paddle/inference/example.cc @@ -18,33 +18,21 @@ limitations under the License. */ #include "paddle/inference/inference.h" DEFINE_string(dirname, "", "Directory of the inference model."); -DEFINE_string(feed_var_names, "", "Names of feeding variables"); -DEFINE_string(fetch_var_names, "", "Names of fetching variables"); int main(int argc, char** argv) { google::ParseCommandLineFlags(&argc, &argv, true); - if (FLAGS_dirname.empty() || FLAGS_feed_var_names.empty() || - FLAGS_fetch_var_names.empty()) { + if (FLAGS_dirname.empty()) { // Example: // ./example --dirname=recognize_digits_mlp.inference.model - // --feed_var_names="x" - // --fetch_var_names="fc_2.tmp_2" - std::cout << "Usage: ./example --dirname=path/to/your/model " - "--feed_var_names=x --fetch_var_names=y" - << std::endl; + std::cout << "Usage: ./example --dirname=path/to/your/model" << std::endl; exit(1); } std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl; - std::cout << "FLAGS_feed_var_names: " << FLAGS_feed_var_names << std::endl; - std::cout << "FLAGS_fetch_var_names: " << FLAGS_fetch_var_names << std::endl; - std::string dirname = FLAGS_dirname; - std::vector feed_var_names = {FLAGS_feed_var_names}; - std::vector fetch_var_names = {FLAGS_fetch_var_names}; paddle::InferenceEngine* engine = new paddle::InferenceEngine(); - engine->LoadInferenceModel(dirname, feed_var_names, fetch_var_names); + engine->LoadInferenceModel(dirname); paddle::framework::LoDTensor input; srand(time(0)); diff --git a/paddle/inference/inference.cc b/paddle/inference/inference.cc index 37b8b20ddfcf25..49001778808173 100644 --- a/paddle/inference/inference.cc +++ b/paddle/inference/inference.cc @@ -25,19 +25,37 @@ limitations under the License. */ namespace paddle { +void InferenceEngine::LoadInferenceModel(const std::string& dirname) { + std::string model_filename = dirname + "/__model__.dat"; + LOG(INFO) << "loading model from " << model_filename; + std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); + std::string program_desc_str; + inputfs.seekg(0, std::ios::end); + program_desc_str.resize(inputfs.tellg()); + inputfs.seekg(0, std::ios::beg); + LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); + inputfs.read(&program_desc_str[0], program_desc_str.size()); + inputfs.close(); + + program_ = new framework::ProgramDesc(program_desc_str); + GenerateLoadProgram(dirname); + + framework::BlockDesc* global_block = program_->MutableBlock(0); + feed_var_names_.clear(); + fetch_var_names_.clear(); + for (auto* op : global_block->AllOps()) { + if (op->Type() == "feed") { + feed_var_names_.insert(feed_var_names_.begin(), op->Output("Out")[0]); + } else if (op->Type() == "fetch") { + fetch_var_names_.push_back(op->Input("X")[0]); + } + } +} + void InferenceEngine::LoadInferenceModel( const std::string& dirname, const std::vector& feed_var_names, const std::vector& fetch_var_names) { -#ifdef PADDLE_USE_PTOOLS - std::string model_filename = dirname + "/__model__"; - LOG(INFO) << "Using PicklingTools, loading model from " << model_filename; - Val v; - LoadValFromFile(model_filename.c_str(), v, SERIALIZE_P0); - std::string program_desc_str = v["program_desc_str"]; - LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); -// PicklingTools cannot parse the vector of strings correctly. -#else std::string model_filename = dirname + "/__model__.dat"; LOG(INFO) << "loading model from " << model_filename; std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); @@ -48,7 +66,7 @@ void InferenceEngine::LoadInferenceModel( LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); inputfs.read(&program_desc_str[0], program_desc_str.size()); inputfs.close(); -#endif + program_ = new framework::ProgramDesc(program_desc_str); GenerateLoadProgram(dirname); @@ -62,7 +80,7 @@ void InferenceEngine::LoadInferenceModel( } bool InferenceEngine::IsParameter(const framework::VarDesc* var) { - if (var->Persistable()) { + if (var->Persistable() && var->Name() != "feed" && var->Name() != "fetch") { // There are many unreachable variables in the program for (size_t i = 0; i < program_->Size(); ++i) { const framework::BlockDesc& block = program_->Block(i); diff --git a/paddle/inference/inference.h b/paddle/inference/inference.h index a3f3ef4b440036..7fc09cb9e539a6 100644 --- a/paddle/inference/inference.h +++ b/paddle/inference/inference.h @@ -28,6 +28,7 @@ class InferenceEngine { delete load_program_; } + void LoadInferenceModel(const std::string& dirname); void LoadInferenceModel(const std::string& dirname, const std::vector& feed_var_names, const std::vector& fetch_var_names); diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index 499df05e592855..e7a06a07145758 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -15,6 +15,7 @@ import cPickle as pickle from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable +from . import core __all__ = [ 'save_vars', @@ -191,6 +192,33 @@ def get_inference_program(target_vars, main_program=None): return inference_program +def prepend_feed_ops(inference_program, feeded_var_names): + global_block = inference_program.global_block() + feed_var = global_block.create_var( + name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True) + + for i, name in enumerate(feeded_var_names): + out = global_block.var(name) + global_block.prepend_op( + type='feed', + inputs={'X': [feed_var]}, + outputs={'Out': [out]}, + attrs={'col': i}) + + +def append_fetch_ops(inference_program, fetch_var_names): + global_block = inference_program.global_block() + fetch_var = global_block.create_var( + name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True) + + for i, name in enumerate(fetch_var_names): + global_block.append_op( + type='fetch', + inputs={'X': [name]}, + outputs={'Out': [fetch_var]}, + attrs={'col': i}) + + def save_inference_model(dirname, feeded_var_names, target_vars, @@ -241,6 +269,9 @@ def save_inference_model(dirname, "fetch_var_names": fetch_var_names }, f, -1) + prepend_feed_ops(inference_program, feeded_var_names) + append_fetch_ops(inference_program, fetch_var_names) + # Save only programDesc of inference_program in binary format # in another file: __model__.dat with open(model_file_name + ".dat", "wb") as fp: