diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index f7a10ada809e6..66f07b6757fe1 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -376,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) { paddle::framework::UseCPU(); op->Run(scope, cpu_place); - EXPECT_EQ(op_test_value, -20); + EXPECT_EQ(op_test_value, -9); // add cuda kernels paddle::framework::UseCUDA(); op->Run(scope, cuda_place); - EXPECT_EQ(op_test_value, -30); + EXPECT_EQ(op_test_value, -10); // use cudnn kernel paddle::framework::UseCUDNN(); op->Run(scope, cuda_place); - EXPECT_EQ(op_test_value, -40); + EXPECT_EQ(op_test_value, -20); } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index adc85b1049f98..3744eae696896 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -474,6 +474,20 @@ void OperatorWithKernel::Run(const Scope& scope, ExecutionContext ctx(*this, scope, *dev_ctx); auto expected_kernel_key = this->GetExpectedKernelType(ctx); + OpKernelMap& kernels = kernels_iter->second; + + for (auto& candidate : kKernelPriority) { + auto candidate_key = + OpKernelType(expected_kernel_key.data_type_, std::get<0>(candidate), + expected_kernel_key.data_layout_, std::get<1>(candidate)); + + if ((candidate_key == expected_kernel_key) || + (kernels.count(candidate_key))) { + expected_kernel_key = candidate_key; + break; + } + } + Scope& new_scope = scope.NewScope(); for (auto& var_name_item : this->Inputs()) { @@ -504,7 +518,6 @@ void OperatorWithKernel::Run(const Scope& scope, } } - OpKernelMap& kernels = kernels_iter->second; auto kernel_iter = kernels.find(expected_kernel_key); kernel_iter->second->Compute(ExecutionContext(*this, new_scope, *dev_ctx));