Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose a interface to create all variables and enable the test of not creating variables every time. #9301

Merged
merged 8 commits into from
Apr 9, 2018
72 changes: 42 additions & 30 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,43 @@ static void CheckTensorNANOrInf(const std::string& name,
"Tensor %s contains NAN", name);
}

void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
int block_id) {
auto& global_block = pdesc.Block(block_id);

const Scope* ancestor_scope = scope;
while (ancestor_scope->parent()) {
ancestor_scope = ancestor_scope->parent();
}

if (ancestor_scope != scope) {
for (auto& var : global_block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}

if (var->Persistable()) {
auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : global_block.AllVars()) {
auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
}
}

void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
platform::RecordBlock b(block_id);
Expand Down Expand Up @@ -184,8 +221,8 @@ static bool has_fetch_operators(
void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name,
const std::string& fetch_holder_name, bool create_vars) {
bool create_vars, const std::string& feed_holder_name,
const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId);
bool has_feed_ops =
has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
Expand Down Expand Up @@ -296,38 +333,13 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(

void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars) {
auto& block = ctx->prog_.Block(ctx->block_id_);

Scope* local_scope = scope;
if (create_vars) {
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}

if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
} // if (create_local_scope)
} // if (create_vars)
}
CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
}

for (auto& op : ctx->ops_) {
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,18 @@ class Executor {
void Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch",
bool create_vars = true);
const std::string& fetch_holder_name = "fetch");

static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id);

static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids);

void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id);

void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true,
bool create_vars = true);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Scope {
/// nullptr if cannot find.
Variable* FindVar(const std::string& name) const;

const Scope& parent() const { return *parent_; }
const Scope* parent() const { return parent_; }

/// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const;
Expand Down
11 changes: 9 additions & 2 deletions paddle/fluid/inference/tests/test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,14 @@ void TestInference(const std::string& dirname,

// 6. Run the inference program
{
const bool create_vars = false;
if (!create_vars) {
executor.CreateVariables(*inference_program, scope, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just some trivial issues, !create_vars -> executor.CreateVariables( seems a little wried.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_vars意思是在executor.Run函数里面是否创建Variables,默认值为true,意思是每次Run都会create和destroy所有的Variables。如果executor.Run时设置参数create_varsfalse,那么用户需要提前手动CreateVaribales,不然就会出现变量不存在的错误。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

能否增加一个单测说明create_varsfalse时,用户如何提前手动CreateVaribales

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done。在image classification里面加了create_vars为false的单测。

}

// Ignore the profiling results of the first run
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
executor.Run(
*inference_program, scope, feed_targets, fetch_targets, create_vars);

// Enable the profiler
paddle::platform::EnableProfiler(state);
Expand All @@ -181,7 +187,8 @@ void TestInference(const std::string& dirname,
"run_inference",
paddle::platform::DeviceContextPool::Instance().Get(place));

executor.Run(*inference_program, scope, feed_targets, fetch_targets);
executor.Run(
*inference_program, scope, feed_targets, fetch_targets, create_vars);
}

// Disable the profiler and print the timing information
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/go_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class GoOp : public framework::OperatorBase {

// TODO(varunarora): Consider moving this root scope lookup to scope.h.
const framework::Scope *root_scope = &scope;
const framework::Scope *parent_scope = &(root_scope->parent());
const framework::Scope *parent_scope = root_scope->parent();

while (parent_scope != nullptr) {
root_scope = parent_scope;
parent_scope = &(parent_scope->parent());
parent_scope = parent_scope->parent();
}

framework::BlockDesc *block = Attr<framework::BlockDesc *>(kBlock);
Expand Down