Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix inference prepare data bug #33305

Merged
merged 22 commits into from
Jun 4, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,10 @@ Scope* OperatorWithKernel::PrepareData(
// the rest iterations to save the elapsed time.
// We do not support skipping PrepareData in while block, because the Op's
// input may be changed by subsequent Ops, which may cause an error.
if (pre_scope_ == &scope && new_scope == nullptr) {

// For inference, ops that behind conditional branch aren't supported well,
// so disable prepare optimization conservatively.
if (pre_scope_ == &scope && new_scope == nullptr && !force_prepare_data_) {
need_prepare_data_ = false;
}

Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,10 @@ class OperatorWithKernel : public OperatorBase {

OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
: OperatorBase(type, inputs, outputs, attrs) {
force_prepare_data_ = HasAttr("inference_force_prepare") &&
Attr<bool>("inference_force_prepare");
}

static std::unordered_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() {
Expand Down Expand Up @@ -572,6 +575,7 @@ class OperatorWithKernel : public OperatorBase {
mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
mutable const Scope* pre_scope_ = nullptr;
mutable bool need_prepare_data_ = true;
mutable bool force_prepare_data_ = false;
mutable bool enable_cache_runtime_context_ = false;
mutable bool all_kernels_must_compute_runtime_shape_ = false;
mutable std::mutex cache_update_mutex_;
Expand Down
40 changes: 40 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,47 @@ bool AnalysisPredictor::CreateExecutor() {
executor_.reset(new paddle::framework::NaiveExecutor(place_));
return true;
}

static bool Is_PrePareDate_Opt_Target_Op(framework::OpDesc *op) {
// here is prepare data optimization related bad cases:
// let's assume an op behind conditional_block and if conditional_block
// chooses branch 1, the op need to call prepare data. else the op don't need
// to
// to call prepare data. In running, if predictor chooses branch 2, then
// optimization
// takes effect, later issue is followed if predictor chooses branch 1,
// because
// the op lost chance to prepare data.
std::vector<std::string> op_type = {"conditional_block_infer",
"select_input"};
for (const auto &type : op_type) {
if (op->Type() == type) {
return true;
}
}
return false;
}

static void Disable_PrePareDate_Opt(
b3602sss marked this conversation as resolved.
Show resolved Hide resolved
std::shared_ptr<framework::ProgramDesc> inference_program, int block) {
bool disable_opt = false;
for (auto *op : inference_program->Block(block).AllOps()) {
if (disable_opt) {
op->SetAttr("inference_force_prepare", true);
}
if (op->HasAttr("sub_block")) {
int blockID = op->GetBlockAttrId("sub_block");
Disable_PrePareDate_Opt(inference_program, blockID);
}
if (Is_PrePareDate_Opt_Target_Op(op)) {
disable_opt = true;
}
}
}

bool AnalysisPredictor::PrepareExecutor() {
Disable_PrePareDate_Opt(inference_program_, 0);

executor_->Prepare(sub_scope_, *inference_program_, 0,
config_.use_feed_fetch_ops_);

Expand Down