Skip to content

Commit

Permalink
Merge pull request #7294 from jacquesqiao/add-back-priority
Browse files Browse the repository at this point in the history
add back priority
  • Loading branch information
jacquesqiao authored Jan 8, 2018
2 parents 43dab72 + 5b94948 commit d762e07
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 7 deletions.
2 changes: 1 addition & 1 deletion paddle/framework/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ bool InitDevices(const std::vector<std::string> &devices) {
LOG(WARNING) << "Not specified CPU device, create CPU by Default.";
}
platform::DeviceContextPool::Init(places);
framework::UseALL();
// framework::UseALL();
return true;
}

Expand Down
6 changes: 3 additions & 3 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
paddle::framework::UseCPU();
op->Run(scope, cpu_place);

EXPECT_EQ(op_test_value, -20);
EXPECT_EQ(op_test_value, -9);

// add cuda kernels
paddle::framework::UseCUDA();
op->Run(scope, cuda_place);

EXPECT_EQ(op_test_value, -30);
EXPECT_EQ(op_test_value, -10);

// use cudnn kernel
paddle::framework::UseCUDNN();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -40);
EXPECT_EQ(op_test_value, -20);
}
20 changes: 18 additions & 2 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,22 @@ void OperatorWithKernel::Run(const Scope& scope,
ExecutionContext ctx(*this, scope, *dev_ctx);
auto expected_kernel_key = this->GetExpectedKernelType(ctx);

OpKernelMap& kernels = kernels_iter->second;

for (auto& candidate : kKernelPriority) {
auto candidate_key =
OpKernelType(expected_kernel_key.data_type_, std::get<0>(candidate),
expected_kernel_key.data_layout_, std::get<1>(candidate));

if ((candidate_key == expected_kernel_key) ||
(kernels.count(candidate_key))) {
expected_kernel_key = candidate_key;
break;
}
}

VLOG(3) << "expected_kernel_key:" << expected_kernel_key;

Scope& new_scope = scope.NewScope();

for (auto& var_name_item : this->Inputs()) {
Expand Down Expand Up @@ -525,10 +541,10 @@ void OperatorWithKernel::Run(const Scope& scope,
}
}

OpKernelMap& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key);

kernel_iter->second->Compute(ExecutionContext(*this, new_scope, *dev_ctx));
kernel_iter->second->Compute(ExecutionContext(
*this, new_scope, *pool.Get(expected_kernel_key.place_)));
}

proto::DataType OperatorWithKernel::IndicateDataType(
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/fetch_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
auto &dev_ctx = *pool.Get(src_item.place());

CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item);
dev_ctx.Wait();
Expand Down

0 comments on commit d762e07

Please sign in to comment.