From 0bd1aa1cba2a8aeea2d70bba0c3404d1baad89c0 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 6 Apr 2023 21:24:27 +0800 Subject: [PATCH] Sync the pull request #51903. --- paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu | 4 +++- paddle/phi/kernels/gpu/gather_nd_kernel.cu | 2 ++ python/paddle/tensor/manipulation.py | 13 +++++++++++-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu b/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu index a78dc717b046b..da1045c27c58d 100644 --- a/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/gather_nd_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" @@ -63,4 +64,5 @@ PD_REGISTER_KERNEL(gather_nd_grad, double, int64_t, int, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/gather_nd_kernel.cu b/paddle/phi/kernels/gpu/gather_nd_kernel.cu index 7b2412958902d..dc642ffd58f1c 100644 --- a/paddle/phi/kernels/gpu/gather_nd_kernel.cu +++ b/paddle/phi/kernels/gpu/gather_nd_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/gather_nd_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" @@ -58,4 +59,5 @@ PD_REGISTER_KERNEL(gather_nd, int, int16_t, bool, + phi::dtype::float16, phi::dtype::float16) {} diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index f987e8b89cf25..422a11c7e8853 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -3791,8 +3791,17 @@ def gather_nd(x, index, name=None): check_variable_and_dtype( x, 'x', - ['bool', 'float32', 'float64', 'int16', 'int32', 'int64'], - 'gather_np', + [ + 'bool', + 'float16', + 'uint16', + 'float32', + 'float64', + 'int16', + 'int32', + 'int64', + ], + 'gather_nd', ) check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather_np') helper = LayerHelper('gather_nd', **locals())