From ca90356b0e81133a06816d1348208b9aba87302b Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 8 Jan 2018 15:56:26 +0800 Subject: [PATCH 1/3] add back priority --- paddle/framework/op_registry_test.cc | 6 +++--- paddle/framework/operator.cc | 15 ++++++++++++++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index f7a10ada809e6..66f07b6757fe1 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -376,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) { paddle::framework::UseCPU(); op->Run(scope, cpu_place); - EXPECT_EQ(op_test_value, -20); + EXPECT_EQ(op_test_value, -9); // add cuda kernels paddle::framework::UseCUDA(); op->Run(scope, cuda_place); - EXPECT_EQ(op_test_value, -30); + EXPECT_EQ(op_test_value, -10); // use cudnn kernel paddle::framework::UseCUDNN(); op->Run(scope, cuda_place); - EXPECT_EQ(op_test_value, -40); + EXPECT_EQ(op_test_value, -20); } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index adc85b1049f98..3744eae696896 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -474,6 +474,20 @@ void OperatorWithKernel::Run(const Scope& scope, ExecutionContext ctx(*this, scope, *dev_ctx); auto expected_kernel_key = this->GetExpectedKernelType(ctx); + OpKernelMap& kernels = kernels_iter->second; + + for (auto& candidate : kKernelPriority) { + auto candidate_key = + OpKernelType(expected_kernel_key.data_type_, std::get<0>(candidate), + expected_kernel_key.data_layout_, std::get<1>(candidate)); + + if ((candidate_key == expected_kernel_key) || + (kernels.count(candidate_key))) { + expected_kernel_key = candidate_key; + break; + } + } + Scope& new_scope = scope.NewScope(); for (auto& var_name_item : this->Inputs()) { @@ -504,7 +518,6 @@ void OperatorWithKernel::Run(const Scope& scope, } } - OpKernelMap& kernels = kernels_iter->second; auto kernel_iter = kernels.find(expected_kernel_key); kernel_iter->second->Compute(ExecutionContext(*this, new_scope, *dev_ctx)); From 0b52cc886f2fbc0e491c9a73ff5ee3e856915b55 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 8 Jan 2018 13:06:56 +0000 Subject: [PATCH 2/3] fix priority --- paddle/framework/operator.cc | 5 ++++- paddle/operators/fetch_op.cc | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 3744eae696896..febad37b42b1a 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -488,6 +488,8 @@ void OperatorWithKernel::Run(const Scope& scope, } } + VLOG(3) << "expected_kernel_key:" << expected_kernel_key; + Scope& new_scope = scope.NewScope(); for (auto& var_name_item : this->Inputs()) { @@ -520,7 +522,8 @@ void OperatorWithKernel::Run(const Scope& scope, auto kernel_iter = kernels.find(expected_kernel_key); - kernel_iter->second->Compute(ExecutionContext(*this, new_scope, *dev_ctx)); + kernel_iter->second->Compute(ExecutionContext( + *this, new_scope, *pool.Get(expected_kernel_key.place_))); } proto::DataType OperatorWithKernel::IndicateDataType( diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index 387d1e0a747f7..48c01f984f825 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -53,7 +53,7 @@ class FetchOp : public framework::OperatorBase { // FIXME(yuyang18): Should we assume the fetch operator always generate // CPU outputs? platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); + auto &dev_ctx = *pool.Get(src_item.place()); CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item); dev_ctx.Wait(); From 5b94948b320dd7499f8b847bd535278f89db2a8e Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 8 Jan 2018 14:56:16 +0000 Subject: [PATCH 3/3] disable UseAll when init --- paddle/framework/init.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/framework/init.cc b/paddle/framework/init.cc index 7ec8d18b0e886..e7087e063cbe8 100644 --- a/paddle/framework/init.cc +++ b/paddle/framework/init.cc @@ -72,7 +72,7 @@ bool InitDevices(const std::vector &devices) { LOG(WARNING) << "Not specified CPU device, create CPU by Default."; } platform::DeviceContextPool::Init(places); - framework::UseALL(); + // framework::UseALL(); return true; }