diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index 0dde78f30dabb..65e63f35c12fe 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -71,9 +71,12 @@ void AddDoubleGradImpl(const Context& dev_ctx, auto out_shape = dout.dims(); dev_ctx.template Alloc(ddout); if (ddx_tensor == nullptr && ddy_tensor == nullptr) { + VLOG(4) << "Special case when ddx and ddy are not needed\n"; ddout = nullptr; } else if (ddx_tensor == nullptr && ddy_tensor != nullptr) { if (ddy_tensor->dims() != out_shape) { + VLOG(4) << "Special case when ddx is not needed and ddy needs to " + "broadcast\n"; std::vector ins = {ddy_tensor}; std::vector outs = {ddout}; ExpandKernel(dev_ctx, @@ -81,10 +84,14 @@ void AddDoubleGradImpl(const Context& dev_ctx, IntArray{phi::vectorize(out_shape)}, ddout); } else { + VLOG(4) << "Special case when ddx is not needed and ddy doesn't need " + "to broadcast\n"; phi::Copy(dev_ctx, *ddy_tensor, dev_ctx.GetPlace(), false, ddout); } } else if (ddx_tensor != nullptr && ddy_tensor == nullptr) { if (ddx_tensor->dims() != out_shape) { + VLOG(4) << "Special case when ddy is not needed and ddx need to " + "broadcast\n"; std::vector ins = {ddx_tensor}; std::vector outs = {ddout}; ExpandKernel(dev_ctx, @@ -92,6 +99,8 @@ void AddDoubleGradImpl(const Context& dev_ctx, IntArray{phi::vectorize(out_shape)}, ddout); } else { + VLOG(4) << "Special case when ddx is not needed and ddy doesn't need " + "to broadcast\n"; phi::Copy(dev_ctx, *ddx_tensor, dev_ctx.GetPlace(), false, ddout); } } else {