diff --git a/paddle/phi/kernels/gpu/expand_grad_kernel.cu b/paddle/phi/kernels/gpu/expand_grad_kernel.cu index 224e435e58c851..92bf07492b4916 100644 --- a/paddle/phi/kernels/gpu/expand_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/expand_grad_kernel.cu @@ -39,6 +39,14 @@ void ExpandGradKernel(const Context& ctx, } } +template +void LegacyExpandGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const IntArray& shape, + DenseTensor* x_grad) { + ExpandGradKernel(ctx, x, out_grad, shape, x_grad); +} } // namespace phi PD_REGISTER_KERNEL(expand_grad, @@ -57,3 +65,13 @@ PD_REGISTER_KERNEL(expand_grad, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(legacy_expand_grad, + GPU, + ALL_LAYOUT, + phi::LegacyExpandGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/expand_kernel.cu b/paddle/phi/kernels/gpu/expand_kernel.cu index ef5643737f4007..1d08f54dfe6d4a 100644 --- a/paddle/phi/kernels/gpu/expand_kernel.cu +++ b/paddle/phi/kernels/gpu/expand_kernel.cu @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/expand_kernel.h" +#include "glog/logging.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" @@ -31,6 +32,8 @@ void ExpandKernel(const Context& ctx, auto diff = expand_shape.size() - x.dims().size(); auto out_shape = common::vectorize(x.dims()); out_shape.insert(out_shape.begin(), diff, 1); + VLOG(2) << "expand_shape: " << expand_shape; + VLOG(2) << "out_shape: " << out_shape; for (size_t i = 0; i < out_shape.size(); ++i) { PADDLE_ENFORCE_NE( expand_shape[i], @@ -76,6 +79,13 @@ void ExpandKernel(const Context& ctx, phi::funcs::BroadcastKernel(ctx, ins, &outs, kps::IdentityFunctor()); } +template +void LegacyExpandKernel(const Context& ctx, + const DenseTensor& x, + const IntArray& shape, + DenseTensor* out) { + ExpandKernel(ctx, x, shape, out); +} } // namespace phi PD_REGISTER_KERNEL(expand, @@ -94,3 +104,14 @@ PD_REGISTER_KERNEL(expand, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(legacy_expand, + GPU, + ALL_LAYOUT, + phi::LegacyExpandKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::float16) {}