Skip to content

Commit

Permalink
add log
Browse files Browse the repository at this point in the history
  • Loading branch information
YibinLiu666 committed Mar 6, 2024
1 parent 1f05792 commit 53340cf
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,27 +71,36 @@ void AddDoubleGradImpl(const Context& dev_ctx,
auto out_shape = dout.dims();
dev_ctx.template Alloc<T>(ddout);
if (ddx_tensor == nullptr && ddy_tensor == nullptr) {
VLOG(4) << "Special case when ddx and ddy are not needed\n";
ddout = nullptr;
} else if (ddx_tensor == nullptr && ddy_tensor != nullptr) {
if (ddy_tensor->dims() != out_shape) {
VLOG(4) << "Special case when ddx is not needed and ddy needs to "
"broadcast\n";
std::vector<const DenseTensor*> ins = {ddy_tensor};
std::vector<DenseTensor*> outs = {ddout};
ExpandKernel<T, Context>(dev_ctx,
*ddy_tensor,
IntArray{phi::vectorize<int64_t>(out_shape)},
ddout);
} else {
VLOG(4) << "Special case when ddx is not needed and ddy doesn't need "
"to broadcast\n";
phi::Copy(dev_ctx, *ddy_tensor, dev_ctx.GetPlace(), false, ddout);
}
} else if (ddx_tensor != nullptr && ddy_tensor == nullptr) {
if (ddx_tensor->dims() != out_shape) {
VLOG(4) << "Special case when ddy is not needed and ddx need to "
"broadcast\n";
std::vector<const DenseTensor*> ins = {ddx_tensor};
std::vector<DenseTensor*> outs = {ddout};
ExpandKernel<T, Context>(dev_ctx,
*ddx_tensor,
IntArray{phi::vectorize<int64_t>(out_shape)},
ddout);
} else {
VLOG(4) << "Special case when ddx is not needed and ddy doesn't need "
"to broadcast\n";
phi::Copy(dev_ctx, *ddx_tensor, dev_ctx.GetPlace(), false, ddout);
}
} else {
Expand Down

0 comments on commit 53340cf

Please sign in to comment.