Skip to content

Commit

Permalink
Cache scope in While Op during Inference.
Browse files Browse the repository at this point in the history
  • Loading branch information
carryyu committed Dec 15, 2022
1 parent fd3169d commit da4cb8e
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions paddle/fluid/operators/controlflow/while_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,18 +271,27 @@ class WhileOp : public framework::OperatorBase {
scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>());
}
} else {
auto &current_scope = scope.NewScope();
bool need_create_variable = false;
if (inference_cache_scope_ == nullptr) {
inference_cache_scope_ = &scope.NewScope();
need_create_variable = true;
VLOG(4) << "Inference cache scope in While Op: "
<< inference_cache_scope_;
}

if (FLAGS_control_flow_use_new_executor) {
BuildScopeForControlFlowOp(*core_, *block, &current_scope);
core_->reset_scope(&current_scope);
BuildScopeForControlFlowOp(*core_, *block, inference_cache_scope_);
core_->reset_scope(inference_cache_scope_);
} else {
executor_->CreateVariables(*program, &current_scope, block->ID());
if (need_create_variable) {
executor_->CreateVariables(
*program, inference_cache_scope_, block->ID());
}
}

while (cond_data) {
for (auto &name : current_scope.LocalVarNames()) {
auto *var = current_scope.Var(name);
for (auto &name : inference_cache_scope_->LocalVarNames()) {
auto *var = inference_cache_scope_->Var(name);
if (var->IsType<phi::DenseTensor>()) {
// Clear all lod information for all lod_tensors.
auto *t = var->GetMutable<phi::DenseTensor>();
Expand All @@ -299,20 +308,20 @@ class WhileOp : public framework::OperatorBase {
core_->Run({}, false);
} else {
executor_->RunPreparedContext(
ctx_.get(), &current_scope, false, false, false);
ctx_.get(), inference_cache_scope_, false, false, false);
}

cond_data = GetCondData(
scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>());
}
scope.DeleteScope(&current_scope);
}
}

private:
mutable std::shared_ptr<framework::Executor> executor_{nullptr};
mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx_{nullptr};
mutable std::shared_ptr<framework::InterpreterCore> core_{nullptr};
mutable framework::Scope *inference_cache_scope_ = nullptr;
};

class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down

0 comments on commit da4cb8e

Please sign in to comment.