Skip to content

Commit

Permalink
Merge pull request #9301 from Xreki/core_inference_fix_run
Browse files Browse the repository at this point in the history
Expose a interface to create all variables and enable the test of not creating variables every time.
  • Loading branch information
luotao1 authored Apr 9, 2018
2 parents b1a5a3c + 93e9905 commit 46d6f4c
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 42 deletions.
72 changes: 42 additions & 30 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,43 @@ static void CheckTensorNANOrInf(const std::string& name,
"Tensor %s contains NAN", name);
}

void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
int block_id) {
auto& global_block = pdesc.Block(block_id);

const Scope* ancestor_scope = scope;
while (ancestor_scope->parent()) {
ancestor_scope = ancestor_scope->parent();
}

if (ancestor_scope != scope) {
for (auto& var : global_block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}

if (var->Persistable()) {
auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : global_block.AllVars()) {
auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
}
}

void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
platform::RecordBlock b(block_id);
Expand Down Expand Up @@ -184,8 +221,8 @@ static bool has_fetch_operators(
void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name,
const std::string& fetch_holder_name, bool create_vars) {
bool create_vars, const std::string& feed_holder_name,
const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId);
bool has_feed_ops =
has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
Expand Down Expand Up @@ -296,38 +333,13 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(

void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars) {
auto& block = ctx->prog_.Block(ctx->block_id_);

Scope* local_scope = scope;
if (create_vars) {
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}

if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
} // if (create_local_scope)
} // if (create_vars)
}
CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
}

for (auto& op : ctx->ops_) {
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,18 @@ class Executor {
void Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch",
bool create_vars = true);
const std::string& fetch_holder_name = "fetch");

static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id);

static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids);

void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id);

void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true,
bool create_vars = true);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Scope {
/// nullptr if cannot find.
Variable* FindVar(const std::string& name) const;

const Scope& parent() const { return *parent_; }
const Scope* parent() const { return parent_; }

/// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ TEST(inference, image_classification) {

// Run inference on CPU
LOG(INFO) << "--- CPU Runs: ---";
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1,
FLAGS_repeat);
TestInference<paddle::platform::CPUPlace, false>(dirname, cpu_feeds,
cpu_fetchs1, FLAGS_repeat);
LOG(INFO) << output1.dims();

#ifdef PADDLE_WITH_CUDA
Expand All @@ -57,8 +57,8 @@ TEST(inference, image_classification) {

// Run inference on CUDA GPU
LOG(INFO) << "--- GPU Runs: ---";
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2,
FLAGS_repeat);
TestInference<paddle::platform::CUDAPlace, false>(dirname, cpu_feeds,
cpu_fetchs2, FLAGS_repeat);
LOG(INFO) << output2.dims();

CheckError<float>(output1, output2);
Expand Down
15 changes: 12 additions & 3 deletions paddle/fluid/inference/tests/test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void CheckError(const paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
}

template <typename Place>
template <typename Place, bool CreateVars = true>
void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
Expand Down Expand Up @@ -166,8 +166,16 @@ void TestInference(const std::string& dirname,

// 6. Run the inference program
{
if (!CreateVars) {
// If users don't want to create and destroy variables every time they
// run, they need to set `create_vars` to false and manually call
// `CreateVariables` before running.
executor.CreateVariables(*inference_program, scope, 0);
}

// Ignore the profiling results of the first run
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
executor.Run(*inference_program, scope, feed_targets, fetch_targets,
CreateVars);

// Enable the profiler
paddle::platform::EnableProfiler(state);
Expand All @@ -178,7 +186,8 @@ void TestInference(const std::string& dirname,
"run_inference",
paddle::platform::DeviceContextPool::Instance().Get(place));

executor.Run(*inference_program, scope, feed_targets, fetch_targets);
executor.Run(*inference_program, scope, feed_targets, fetch_targets,
CreateVars);
}

// Disable the profiler and print the timing information
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/go_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class GoOp : public framework::OperatorBase {

// TODO(varunarora): Consider moving this root scope lookup to scope.h.
const framework::Scope *root_scope = &scope;
const framework::Scope *parent_scope = &(root_scope->parent());
const framework::Scope *parent_scope = root_scope->parent();

while (parent_scope != nullptr) {
root_scope = parent_scope;
parent_scope = &(parent_scope->parent());
parent_scope = parent_scope->parent();
}

framework::BlockDesc *block = Attr<framework::BlockDesc *>(kBlock);
Expand Down

0 comments on commit 46d6f4c

Please sign in to comment.