From ee7937d6803ea3bdc55f7bef0f11b8e27f42b36b Mon Sep 17 00:00:00 2001 From: TR666 Date: Tue, 25 Jun 2024 10:39:56 +0800 Subject: [PATCH] [XPU] Add bool type for op:concat; test=develop --- lite/kernels/xpu/concat_compute.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lite/kernels/xpu/concat_compute.cc b/lite/kernels/xpu/concat_compute.cc index 00cf5bb1012..8b4d4965338 100644 --- a/lite/kernels/xpu/concat_compute.cc +++ b/lite/kernels/xpu/concat_compute.cc @@ -104,6 +104,8 @@ using concati64 = paddle::lite::kernels::xpu::ConcatCompute; using concati8 = paddle::lite::kernels::xpu::ConcatCompute; +using concatbool = + paddle::lite::kernels::xpu::ConcatCompute; REGISTER_LITE_KERNEL(concat, kXPU, kFloat, kNCHW, concatfp32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFloat))}) @@ -147,3 +149,9 @@ REGISTER_LITE_KERNEL(concat, kXPU, kInt8, kNCHW, concati8, concat_INT8) {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt8))}) .Finalize(); +REGISTER_LITE_KERNEL(concat, kXPU, kFloat, kNCHW, concatbool, concat_BOOL) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kBool))}) + .BindInput("AxisTensor", + {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kBool))}) + .Finalize();