Skip to content

Commit

Permalink
fix 3 bug of new_executor (#37142)
Browse files Browse the repository at this point in the history
* fix 3 bug, test=develop

* refine, test=develop
  • Loading branch information
wanghuancoder authored Nov 15, 2021
1 parent b628c31 commit 8358d61
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 5 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ void InterpreterCore::Convert() {

for (auto& item : op_func_node.input_index) {
for (auto id : item.second) {
if (id == kEmptyVarIndex) {
continue;
}
input_var2op_info_.at(id).push_back(op_idx);
// var can be gc-ed
if (!info.IsBuilt()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ void InterpreterCoreGarbageCollector::Add(
void InterpreterCoreGarbageCollector::Add(paddle::framework::Variable* var,
paddle::platform::DeviceEvent& event,
const platform::DeviceContext* ctx) {
if (!var) {
return;
}

if (var->IsType<LoDTensor>()) {
Add(var->GetMutable<LoDTensor>()->MoveMemoryHolder(), event, ctx);
} else if (var->IsType<
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/framework/new_executor/interpretercore_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,13 @@ void build_op_func_list(const platform::Place& place,
VariableValueMap ins_map;
VariableIdMap ins_name2id;
bool enforce_exist = true;
if (op->Type() == "recurrent_grad") enforce_exist = false;
if (op->Type() == "recurrent_grad" || op->Type() == "rnn_memory_helper" ||
op->Type() == "rnn_memory_helper_grad" ||
op->Type() == "conditional_block" ||
op->Type() == "conditional_block_grad" || op->Type() == "while" ||
op->Type() == "while_grad") {
enforce_exist = false;
}
std::tie(ins_map, ins_name2id) =
build_variable_map(inputs_names, var_scope, enforce_exist);

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/new_executor/new_executor_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ const std::vector<Variable*>& InterpretercoreInferShapeContext::OutputVars(
VariableScope::VariableScope(Scope* scope) {
// for @EMPTY@ variable
var_list_.push_back(nullptr);
name2id_[kEmptyVarName] = 0;
name2id_[kEmptyVarName] = kEmptyVarIndex;
vec_meta_info_.emplace_back(0, nullptr);
scope_ = scope;
PADDLE_ENFORCE_NE(
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/new_executor/new_executor_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
using OpKernelMap =
std::unordered_map<OpKernelType, OpKernelComputeFunc, OpKernelType::Hash>;

constexpr int kEmptyVarIndex = 0;

class InterpretercoreInferShapeContext : public InferShapeContext {
public:
InterpretercoreInferShapeContext(const OperatorBase& op,
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/fluid/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,13 +598,13 @@ def _get_exe_from_cache(self, program, scope):
assert isinstance(
program, Program), "Required type(Program), but received {}".format(
type(program).__name__)
if program not in self._cached_executors:
if str(program) not in self._cached_executors:
new_program = program.clone()
_prune_feed_ops(new_program)
new_exe = _StandaloneExecutor(self._place, new_program, scope)
self._cached_executors[program] = new_exe
self._cached_executors[str(program)] = new_exe

return self._cached_executors[program]
return self._cached_executors[str(program)]


class Executor(object):
Expand Down

0 comments on commit 8358d61

Please sign in to comment.