Skip to content

Commit

Permalink
Implement the composition of minimum_double_grad (#62342)
Browse files Browse the repository at this point in the history
* Implement the composition of minimum_double_grad

* add test
  • Loading branch information
YibinLiu666 authored Mar 21, 2024
1 parent 55550bf commit 714ddbe
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"add_triple_grad",
"silu_double_grad",
"tanh_triple_grad",
"minimum_double_grad",
]

# white ops list whose kernel can automaically do type promotion.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,32 @@ void cos_double_grad(const Tensor& x,
}
}

template <typename T>
void minimum_double_grad(const Tensor& x,
const Tensor& y,
const paddle::optional<Tensor>& grad_x_grad,
const paddle::optional<Tensor>& grad_y_grad,
Tensor* grad_out_grad) {
if (grad_out_grad) {
if (grad_x_grad && grad_y_grad) {
auto x_mask = cast<T>(less_than<T>(x, y), grad_x_grad.get().dtype());
auto ddout =
grad_x_grad.get() * x_mask + grad_y_grad.get() * (1 - x_mask);
set_output<T>(ddout, grad_out_grad);
} else if (grad_x_grad) {
auto x_mask = cast<T>(less_than<T>(x, y), grad_x_grad.get().dtype());
auto ddout = grad_x_grad.get() * x_mask;
set_output<T>(ddout, grad_out_grad);
} else if (grad_y_grad) {
auto y_mask = cast<T>(greater_equal<T>(x, y), grad_y_grad.get().dtype());
auto ddout = grad_y_grad.get() * y_mask;
set_output<T>(ddout, grad_out_grad);
} else {
grad_out_grad = nullptr;
}
}
}

template <typename T>
void tanh_triple_grad(const Tensor& out,
const Tensor& grad_out_forward,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@
kernel :
func : minimum_grad
composite : minimum_grad(x, y, out_grad, axis, x_grad, y_grad)
backward : minimum_double_grad

- backward_op : mish_grad
forward : mish (Tensor x, float lambda) -> Tensor(out)
Expand Down Expand Up @@ -876,6 +877,13 @@
func : fused_gemm_epilogue_grad
optional : reserve_space

- backward_op: minimum_double_grad
forward: minimum_grad(Tensor x, Tensor y, Tensor grad_out) -> Tensor(grad_x), Tensor(grad_y)
args: (Tensor x, Tensor y, Tensor grad_x_grad, Tensor grad_y_grad)
output: Tensor(grad_out_grad)
composite: minimum_double_grad(x, y, grad_x_grad, grad_y_grad, grad_out_grad)
optional : grad_x_grad, grad_y_grad

- backward_op: unpool_grad
forward: unpool (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) -> Tensor(out)
args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format)
Expand Down
74 changes: 74 additions & 0 deletions test/prim/prim/vjp/test_comp_high_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,5 +411,79 @@ def test_high_grad(self):
self.func_triple(p)


@param.parameterized_class(
('shape1', 'shape2'),
[
(
[2, 3, 4],
[2, 3, 4],
),
(
[2, 3, 3, 4],
[3, 1, 4],
),
(
[2, 3, 3, 4],
[3, 1, 1],
),
(
[2, 3, 3, 4],
[2, 3, 1, 4],
),
(
[2, 3, 3, 4],
[2, 3, 1, 1],
),
],
)
class TestMinimumHighGradCheck(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.shape1 = cls.shape1
cls.shape2 = cls.shape2

def minimum_wrapper(self, x):
return paddle.minimum(x[0], x[1])

@prog_scope()
def func_double(self, place):
shape1 = self.shape1
shape2 = self.shape2
eps = 0.0005
dtype = np.float64
x = paddle.static.data('x', shape1, dtype=dtype)
y = paddle.static.data('y', shape2, dtype=dtype)
x.persistable = True
y.persistable = True
out = paddle.minimum(x, y)
x_arr = np.random.uniform(-1, 1, shape1).astype(dtype)
y_arr = np.random.uniform(-2, 2, shape2).astype(dtype)
x_arr[np.abs(x_arr) < 0.005] = 0.002
y_arr[np.abs(y_arr) < 0.005] = 0.002
from paddle.base import core

core._set_prim_backward_enabled(True)
core._set_prim_backward_blacklist("minimum_grad")
gradient_checker.double_grad_check(
[x, y], y=out, x_init=[x_arr, y_arr], place=place, eps=eps
)
gradient_checker.double_grad_check_for_dygraph(
self.minimum_wrapper,
[x, y],
y=out,
x_init=[x_arr, y_arr],
place=place,
)
core._set_prim_backward_enabled(False)

def test_high_grad(self):
paddle.enable_static()
places = [base.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(base.CUDAPlace(0))
for p in places:
self.func_double(p)


if __name__ == '__main__':
unittest.main()

0 comments on commit 714ddbe

Please sign in to comment.