Skip to content

Commit

Permalink
Merge pull request #7636 from kexinzhao/save_inference_model
Browse files Browse the repository at this point in the history
Add feed and fetch op to ProgramDesc before saving for inference
  • Loading branch information
kexinzhao authored Jan 18, 2018
2 parents 7905e36 + 856f650 commit d77e6a6
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 47 deletions.
21 changes: 0 additions & 21 deletions paddle/inference/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,6 @@ cc_library(paddle_fluid_api
# Merge all modules into a simgle static library
cc_library(paddle_fluid DEPS paddle_fluid_api ${FLUID_CORE_MODULES})

# ptools
# just for testing, we may need to change the storing format for inference_model
# and move the dependent of pickle.
# download from http://www.picklingtools.com/
# build in the C++ sub-directory, using command
# make -f Makefile.Linux libptools.so
set(PTOOLS_LIB)
set(PTOOLS_ROOT $ENV{PTOOLS_ROOT} CACHE PATH "Folder contains PicklingTools")
find_path(PTOOLS_INC_DIR chooseser.h PATHS ${PTOOLS_ROOT}/C++)
find_library(PTOOLS_SHARED_LIB NAMES ptools PATHS ${PTOOLS_ROOT}/C++)
if(PTOOLS_INC_DIR AND PTOOLS_SHARED_LIB)
add_definitions(-DPADDLE_USE_PTOOLS)
set(PTOOLS_LIB ptools)
message(STATUS "Found PicklingTools: ${PTOOLS_SHARED_LIB}")
add_library(${PTOOLS_LIB} SHARED IMPORTED GLOBAL)
set_property(TARGET ${PTOOLS_LIB} PROPERTY IMPORTED_LOCATION ${PTOOLS_SHARED_LIB})
include_directories(${PTOOLS_ROOT}/C++)
include_directories(${PTOOLS_ROOT}/C++/opencontainers_1_8_5/include)
add_definitions(-DOC_NEW_STYLE_INCLUDES) # used in ptools
endif()

add_executable(example example.cc)
if(APPLE)
set(OPTIONAL_LINK_FLAGS)
Expand Down
18 changes: 3 additions & 15 deletions paddle/inference/example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,21 @@ limitations under the License. */
#include "paddle/inference/inference.h"

DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_string(feed_var_names, "", "Names of feeding variables");
DEFINE_string(fetch_var_names, "", "Names of fetching variables");

int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_dirname.empty() || FLAGS_feed_var_names.empty() ||
FLAGS_fetch_var_names.empty()) {
if (FLAGS_dirname.empty()) {
// Example:
// ./example --dirname=recognize_digits_mlp.inference.model
// --feed_var_names="x"
// --fetch_var_names="fc_2.tmp_2"
std::cout << "Usage: ./example --dirname=path/to/your/model "
"--feed_var_names=x --fetch_var_names=y"
<< std::endl;
std::cout << "Usage: ./example --dirname=path/to/your/model" << std::endl;
exit(1);
}

std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
std::cout << "FLAGS_feed_var_names: " << FLAGS_feed_var_names << std::endl;
std::cout << "FLAGS_fetch_var_names: " << FLAGS_fetch_var_names << std::endl;

std::string dirname = FLAGS_dirname;
std::vector<std::string> feed_var_names = {FLAGS_feed_var_names};
std::vector<std::string> fetch_var_names = {FLAGS_fetch_var_names};

paddle::InferenceEngine* engine = new paddle::InferenceEngine();
engine->LoadInferenceModel(dirname, feed_var_names, fetch_var_names);
engine->LoadInferenceModel(dirname);

paddle::framework::LoDTensor input;
srand(time(0));
Expand Down
40 changes: 29 additions & 11 deletions paddle/inference/inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,37 @@ limitations under the License. */

namespace paddle {

void InferenceEngine::LoadInferenceModel(const std::string& dirname) {
std::string model_filename = dirname + "/__model__.dat";
LOG(INFO) << "loading model from " << model_filename;
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
std::string program_desc_str;
inputfs.seekg(0, std::ios::end);
program_desc_str.resize(inputfs.tellg());
inputfs.seekg(0, std::ios::beg);
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
inputfs.read(&program_desc_str[0], program_desc_str.size());
inputfs.close();

program_ = new framework::ProgramDesc(program_desc_str);
GenerateLoadProgram(dirname);

framework::BlockDesc* global_block = program_->MutableBlock(0);
feed_var_names_.clear();
fetch_var_names_.clear();
for (auto* op : global_block->AllOps()) {
if (op->Type() == "feed") {
feed_var_names_.insert(feed_var_names_.begin(), op->Output("Out")[0]);
} else if (op->Type() == "fetch") {
fetch_var_names_.push_back(op->Input("X")[0]);
}
}
}

void InferenceEngine::LoadInferenceModel(
const std::string& dirname,
const std::vector<std::string>& feed_var_names,
const std::vector<std::string>& fetch_var_names) {
#ifdef PADDLE_USE_PTOOLS
std::string model_filename = dirname + "/__model__";
LOG(INFO) << "Using PicklingTools, loading model from " << model_filename;
Val v;
LoadValFromFile(model_filename.c_str(), v, SERIALIZE_P0);
std::string program_desc_str = v["program_desc_str"];
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
// PicklingTools cannot parse the vector of strings correctly.
#else
std::string model_filename = dirname + "/__model__.dat";
LOG(INFO) << "loading model from " << model_filename;
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
Expand All @@ -48,7 +66,7 @@ void InferenceEngine::LoadInferenceModel(
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
inputfs.read(&program_desc_str[0], program_desc_str.size());
inputfs.close();
#endif

program_ = new framework::ProgramDesc(program_desc_str);
GenerateLoadProgram(dirname);

Expand All @@ -62,7 +80,7 @@ void InferenceEngine::LoadInferenceModel(
}

bool InferenceEngine::IsParameter(const framework::VarDesc* var) {
if (var->Persistable()) {
if (var->Persistable() && var->Name() != "feed" && var->Name() != "fetch") {
// There are many unreachable variables in the program
for (size_t i = 0; i < program_->Size(); ++i) {
const framework::BlockDesc& block = program_->Block(i);
Expand Down
1 change: 1 addition & 0 deletions paddle/inference/inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class InferenceEngine {
delete load_program_;
}

void LoadInferenceModel(const std::string& dirname);
void LoadInferenceModel(const std::string& dirname,
const std::vector<std::string>& feed_var_names,
const std::vector<std::string>& fetch_var_names);
Expand Down
31 changes: 31 additions & 0 deletions python/paddle/v2/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import cPickle as pickle

from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
from . import core

__all__ = [
'save_vars',
Expand Down Expand Up @@ -191,6 +192,33 @@ def get_inference_program(target_vars, main_program=None):
return inference_program


def prepend_feed_ops(inference_program, feeded_var_names):
global_block = inference_program.global_block()
feed_var = global_block.create_var(
name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True)

for i, name in enumerate(feeded_var_names):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})


def append_fetch_ops(inference_program, fetch_var_names):
global_block = inference_program.global_block()
fetch_var = global_block.create_var(
name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True)

for i, name in enumerate(fetch_var_names):
global_block.append_op(
type='fetch',
inputs={'X': [name]},
outputs={'Out': [fetch_var]},
attrs={'col': i})


def save_inference_model(dirname,
feeded_var_names,
target_vars,
Expand Down Expand Up @@ -241,6 +269,9 @@ def save_inference_model(dirname,
"fetch_var_names": fetch_var_names
}, f, -1)

prepend_feed_ops(inference_program, feeded_var_names)
append_fetch_ops(inference_program, fetch_var_names)

# Save only programDesc of inference_program in binary format
# in another file: __model__.dat
with open(model_file_name + ".dat", "wb") as fp:
Expand Down

0 comments on commit d77e6a6

Please sign in to comment.