Skip to content

Commit

Permalink
Fix a bug in save_inference_model and prune when the program is inita…
Browse files Browse the repository at this point in the history
…ilized by load_inference_model (#10011)

* Fix bug in save_inference_model and prune when the program is initialized by load_inference_program.

* Save the transpiled program instead.
  • Loading branch information
Xreki authored and kexinzhao committed Apr 18, 2018
1 parent 9ca578d commit 598035f
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 22 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/op_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class OpDesc {

void InferVarType(BlockDesc *block) const;

void MarkAsTarget() { desc_.set_is_target(true); }
void SetIsTarget(bool is_target) { desc_.set_is_target(is_target); }

void Flush();

Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ void BindProgramDesc(pybind11::module *m) {
.def("block", &pd::ProgramDesc::MutableBlock,
pybind11::return_value_policy::reference)
.def("num_blocks", &pd::ProgramDesc::Size)
.def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames)
.def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames)
.def("serialize_to_string", SerializeMessage<pd::ProgramDesc>)
.def("parse_from_string",
[](pd::ProgramDesc &program_desc, const std::string &data) {
Expand Down Expand Up @@ -299,6 +301,7 @@ void BindOpDesc(pybind11::module *m) {
.def("check_attrs", &pd::OpDesc::CheckAttrs)
.def("infer_shape", &pd::OpDesc::InferShape)
.def("infer_var_type", &pd::OpDesc::InferVarType)
.def("set_is_target", &pd::OpDesc::SetIsTarget)
.def("serialize_to_string", SerializeMessage<pd::OpDesc>)
.def("block", &pd::OpDesc::Block,
pybind11::return_value_policy::reference);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ All parameter, weight, gradient are variables in Paddle.
const std::vector<std::array<size_t, 2>> &targets) {
ProgramDesc prog_with_targets(origin);
for (const auto &t : targets) {
prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget();
prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true);
}
proto::ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), &pruned_desc);
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,12 @@ def prune(self, targets):
for t in targets:
if not isinstance(t, Operator):
if isinstance(t, Variable):
if t.op is None:
global_block = self.global_block()
for op in global_block.ops:
if t.name in op.output_arg_names:
t.op = op
break
t = t.op
else:
raise ValueError(("All targets of prune() can only be "
Expand Down
29 changes: 9 additions & 20 deletions python/paddle/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,13 @@ def save_inference_model(dirname,
if not os.path.isdir(dirname):
os.makedirs(dirname)

# Clear the is_target information and remove the existed feed and fetch op
global_block = main_program.global_block()
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
global_block.remove_op(i)

pruned_program = main_program.prune(targets=target_vars)
inference_program = pruned_program.inference_optimize()
fetch_var_names = [v.name for v in target_vars]
Expand All @@ -362,24 +369,6 @@ def save_inference_model(dirname,
save_persistables(executor, dirname, inference_program, params_filename)


def get_feed_targets_names(program):
feed_targets_names = []
global_block = program.global_block()
for op in global_block.ops:
if op.desc.type() == 'feed':
feed_targets_names.insert(0, op.desc.output('Out')[0])
return feed_targets_names


def get_fetch_targets_names(program):
fetch_targets_names = []
global_block = program.global_block()
for op in global_block.ops:
if op.desc.type() == 'fetch':
fetch_targets_names.append(op.desc.input('X')[0])
return fetch_targets_names


def load_inference_model(dirname,
executor,
model_filename=None,
Expand Down Expand Up @@ -418,8 +407,8 @@ def load_inference_model(dirname,
program = Program.parse_from_string(program_desc_str)
load_persistables(executor, dirname, program, params_filename)

feed_target_names = get_feed_targets_names(program)
fetch_target_names = get_fetch_targets_names(program)
feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names()
fetch_targets = [
program.global_block().var(name) for name in fetch_target_names
]
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/fluid/tests/book/test_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ def infer(use_cuda, save_dirname=None):

print("infer results: ", results[0])

fluid.io.save_inference_model(save_dirname, feed_target_names,
fetch_targets, exe,
inference_transpiler_program)


def main(net_type, use_cuda, is_local=True):
if use_cuda and not fluid.core.is_compiled_with_cuda():
Expand Down

0 comments on commit 598035f

Please sign in to comment.