Skip to content

Commit

Permalink
add tanh_triple_grad composite logic (#56072) (#58657)
Browse files Browse the repository at this point in the history
* decompose tanh_triple_grad and add it into prim_white_list test=develop

* fix TanhTripleGradKernel bugs test=develop

* decompose tanh_triple_grad test=develop
  • Loading branch information
lxd-cumt authored Nov 6, 2023
1 parent fdd0689 commit 1383a2f
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"subtract_double_grad",
"add_triple_grad",
"silu_double_grad",
"tanh_triple_grad",
]

# dict of special api that forward api's output will affect bacward api's output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,76 @@ void tanh_double_grad(const Tensor& out,
}
}

template <typename T>
void tanh_triple_grad(const Tensor& out,
const Tensor& grad_out_forward,
const Tensor& grad_x_grad_forward,
const paddle::optional<Tensor>& grad_out_new_grad,
const paddle::optional<Tensor>& grad_out_grad_grad,
Tensor* out_grad,
Tensor* grad_out_forward_grad,
Tensor* grad_x_grad_forward_grad) {
if (out_grad) {
if (grad_out_grad_grad) {
if (grad_out_new_grad) {
auto out_grad_tmp =
(-2 * out * grad_x_grad_forward * grad_out_grad_grad.get()) -
(2 * grad_out_forward * grad_x_grad_forward *
grad_out_new_grad.get());
set_output<T>(out_grad_tmp, out_grad);
} else {
auto out_grad_tmp =
-2 * out * grad_x_grad_forward * grad_out_grad_grad.get();
set_output<T>(out_grad_tmp, out_grad);
}
} else {
if (grad_out_new_grad) {
auto out_grad_tmp = -(2 * grad_out_forward * grad_x_grad_forward *
grad_out_new_grad.get());
set_output<T>(out_grad_tmp, out_grad);
} else {
auto out_grad_tmp = 0 * out;
set_output<T>(out_grad_tmp, out_grad);
}
}
}

if (grad_out_forward_grad) {
if (grad_out_new_grad) {
auto grad_out_forward_grad_tmp =
-2 * out * grad_x_grad_forward * grad_out_new_grad.get();
set_output<T>(grad_out_forward_grad_tmp, grad_out_forward_grad);
} else {
auto grad_out_forward_grad_tmp = 0 * out;
set_output<T>(grad_out_forward_grad_tmp, grad_out_forward_grad);
}
}

if (grad_x_grad_forward_grad) {
if (grad_out_grad_grad) {
if (grad_out_new_grad) {
auto grad_x_grad_forward_grad_tmp =
(1 - (out * out)) * grad_out_grad_grad.get() -
2 * out * grad_out_forward * grad_out_new_grad.get();
set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
} else {
auto grad_x_grad_forward_grad_tmp =
(1 - (out * out)) * grad_out_grad_grad.get();
set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
}
} else {
if (grad_out_new_grad) {
auto grad_x_grad_forward_grad_tmp =
-(2 * out * grad_out_forward * grad_out_new_grad.get());
set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
} else {
auto grad_x_grad_forward_grad_tmp = 0 * grad_x_grad_forward;
set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
}
}
}
}

template <typename T>
void matmul_double_grad(const Tensor& x,
const Tensor& y,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,7 @@
param : [out, out, grad_x_grad_forward]
kernel :
func : tanh_triple_grad
composite : tanh_triple_grad(out, grad_out_forward, grad_x_grad_forward, grad_out_new_grad, grad_out_grad_grad, out_grad, grad_out_forward_grad, grad_x_grad_forward_grad)
inplace : (grad_x_grad_forward -> grad_out_forward_grad)
optional : grad_out_new_grad, grad_out_grad_grad

Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/impl/activation_grad_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,11 @@ void TanhTripleGradKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(d_dout);
}
if (d_out_new) {
d_dout->Resize(out.dims());
d_out_new->Resize(out.dims());
dev_ctx.template Alloc<T>(d_out_new);
}
if (d_ddx) {
d_dout->Resize(ddx.dims());
d_ddx->Resize(ddx.dims());
dev_ctx.template Alloc<T>(d_ddx);
}
funcs::TanhTripleGradFunctor<T> functor;
Expand Down

0 comments on commit 1383a2f

Please sign in to comment.