Skip to content

Commit

Permalink
fix inference prepare data bug (#33305)
Browse files Browse the repository at this point in the history
* fix inference prepare data bug

* rename functions

* typo

* typo

* typo

* UT correct

* correct condition

* correct condition

* ci coverage

* morelines

* fix ci coverage
  • Loading branch information
b3602sss authored Jun 4, 2021
1 parent 6877b13 commit dd18123
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 2 deletions.
7 changes: 6 additions & 1 deletion paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,12 @@ Scope* OperatorWithKernel::PrepareData(
// the rest iterations to save the elapsed time.
// We do not support skipping PrepareData in while block, because the Op's
// input may be changed by subsequent Ops, which may cause an error.
if (pre_scope_ == &scope && new_scope == nullptr) {

// For inference, ops that behind conditional branch aren't supported well,
// so disable prepare optimization conservatively.
bool force_prepare_data = HasAttr("inference_force_prepare_data") &&
Attr<bool>("inference_force_prepare_data");
if (pre_scope_ == &scope && new_scope == nullptr && !force_prepare_data) {
need_prepare_data_ = false;
}

Expand Down
39 changes: 39 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,46 @@ bool AnalysisPredictor::CreateExecutor() {
executor_.reset(new paddle::framework::NaiveExecutor(place_));
return true;
}

static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) {
// here is prepare data optimization related bad cases:
// let's assume an op behind conditional_block and if conditional_block
// chooses branch 1, the op need to call prepare data. else the op don't need
// to call prepare data. In running, if predictor chooses branch 2, then
// optimization takes effect, later issue is followed if predictor chooses
// branch 1, because the op lost chance to prepare data.
std::vector<std::string> op_type = {"conditional_block_infer",
"select_input"};
for (const auto &type : op_type) {
if (op->Type() == type) {
return true;
}
}
return false;
}

static void DisablePrepareDataOpt(
std::shared_ptr<framework::ProgramDesc> inference_program, int block,
bool pre_disable_opt) {
bool disable_opt = false;
auto &infer_block = inference_program->Block(block);
for (auto *op : infer_block.AllOps()) {
if (disable_opt || pre_disable_opt) {
op->SetAttr("inference_force_prepare_data", true);
}
if (op->HasAttr("sub_block")) {
int blockID = op->GetBlockAttrId("sub_block");
DisablePrepareDataOpt(inference_program, blockID,
disable_opt || pre_disable_opt);
}
// disable prepare data if unfriendly op is found
disable_opt = IsPrepareDataOptTargetOp(op);
}
}

bool AnalysisPredictor::PrepareExecutor() {
DisablePrepareDataOpt(inference_program_, 0, false);

executor_->Prepare(sub_scope_, *inference_program_, 0,
config_.use_feed_fetch_ops_);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class ConditionalBlockInferOp : public ConditionalOp {

framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block");
VLOG(3) << "Conditional block.idx = " << block->ID()
<< ", scope = " << &cur_scope;
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
scope.DeleteScope(scopes->front());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def setUp(self):
}, {
"conv2d_0.tmp_0": [16, 6, 16, 16],
"data": [16, 6, 16, 16],
"depthwise_conv2d_0.tmp_0": [32, 6, 64, 64]
"depthwise_conv2d_0.tmp_0": [16, 6, 16, 16]
}, False)
self.fetch_list = [conv_out]

Expand Down

0 comments on commit dd18123

Please sign in to comment.