From 296b3ff04562691467f053845e6febfc8277309b Mon Sep 17 00:00:00 2001 From: zqw_1997 <118182234+zhengqiwen1997@users.noreply.github.com> Date: Wed, 1 Mar 2023 20:25:53 +0800 Subject: [PATCH] add topk prim backward (#50679) * tmp gather vjp * support gather * remove useless code * fix compiling error * fix ut * add eager test * add eager test * add seed * small change * fix cpu error * fix transpose op compat * remove tensor index case * fix prim_cinn * small commit * add cumsum prim backward * small commit * skip aixs=None test case * fix op generante eror * fix static test error * remove unused code * fix static test error * small commit * skip cpu float16 test case * skip eager cpu cumsum float16 test case * add eager and static UT * fix ut * add composite backward rule * fix error * fix type error and format error * add try cpu+float16 test * fix test bugs * remove test for cpu+float16 and make y[0] be the grad arg * add cinn test * fix UT * fix the wrong dim of v in test cases * change y[0] to y[1] for grad in UT * reshape flatten out * Disable cinn single test * use scatter_nd_add * modify the reshape part of topk_grad * delete useless build file * to make the syntax right * modify bug * try use of put_along_axis * remove cinn test * reformat todo * add silu composite rule * fix code style. * add cinn test * fix composite grad maker code gen * add prim in cumsum op test * remove old test * fix typro * pass the static test * fix typro * modify optest and delete old test files * remove normal test_top_k_op test * fix typro * pass axis=None test case * buffer comment * for debug * add silu fp16 unit test. * add static guard * remove forward prim test * remove same name axis * modify the test_top_v2_op.py to pass all local tests * delete the useless testcase * fix mistake * add more testcases to test dtype16 and dtype32 --------- Co-authored-by: JiabinYang <360788950@qq.com> Co-authored-by: GGBond8488 <857631483@qq.com> Co-authored-by: zxcd <228587199@qq.com> Co-authored-by: Charles-hit --- paddle/fluid/prim/api/api.yaml | 1 + .../composite_backward_api.h | 17 +++++++++ paddle/phi/api/yaml/backward.yaml | 1 + .../fluid/tests/unittests/CMakeLists.txt | 3 +- .../fluid/tests/unittests/test_top_k_v2_op.py | 38 ++++++++++++++++--- 5 files changed, 53 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index e47c7a45713dc..55831ca02d082 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -26,3 +26,4 @@ - transpose - pad - cumsum +- put_along_axis diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index b18a33d582436..83e8975c7afe8 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -744,5 +744,22 @@ void cumsum_grad(const Tensor& x, } } +template +void topk_grad(const Tensor& x, + const Tensor& indices, + const Tensor& out_grad, + const Scalar& k, + const int& axis, + const bool& largest, + const bool& sorted, + Tensor* x_grad) { + if (x_grad) { + auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); + auto x_grad_tmp = put_along_axis(zero_tensor, indices, out_grad, axis); + + set_output(x_grad_tmp, x_grad); + } +} + } // namespace prim } // namespace paddle diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index ee75d281b97da..8492da75eb251 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1524,6 +1524,7 @@ kernel : func : topk_grad data_type : out_grad + composite : topk_grad(x, indices, out_grad, k, axis, largest, sorted, x_grad) - backward_op : trace_grad forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 271dd98250c40..2d4db9df69e81 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1209,7 +1209,8 @@ set(TEST_CINN_OPS test_slice_op test_activation_op test_full_like_op - test_fill_any_like_op) + test_fill_any_like_op + test_top_k_v2_op) foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) if(WITH_CINN) diff --git a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py index 2a8af4d4ad9a6..9f6e9ad9d736b 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py @@ -45,6 +45,7 @@ def init_args(self): def setUp(self): self.op_type = "top_k_v2" + self.prim_op_type = "prim" self.python_api = paddle.topk self.dtype = np.float64 self.input_data = np.random.rand(10, 20) @@ -60,7 +61,7 @@ def test_check_output(self): self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(set(['X']), 'Out', check_eager=True) + self.check_grad(['X'], 'Out', check_eager=True, check_prim=True) class TestTopkOp1(TestTopkOp): @@ -77,7 +78,7 @@ def init_args(self): self.largest = False -class TestTopkOp3(OpTest): +class TestTopkOp3(TestTopkOp): def init_args(self): self.k = 6 self.axis = 1 @@ -85,6 +86,7 @@ def init_args(self): def setUp(self): self.op_type = "top_k_v2" + self.prim_op_type = "prim" self.python_api = paddle.topk self.dtype = np.float64 self.input_data = np.random.rand(16, 100) @@ -105,6 +107,7 @@ def init_args(self): def setUp(self): self.op_type = "top_k_v2" + self.prim_op_type = "prim" self.python_api = paddle.topk self.dtype = np.float64 self.input_data = np.random.rand(10, 10, 5) @@ -125,6 +128,7 @@ def init_args(self): def setUp(self): self.op_type = "top_k_v2" + self.prim_op_type = "prim" self.python_api = paddle.topk self.dtype = np.float64 self.input_data = np.random.rand(10, 10, 5) @@ -137,17 +141,39 @@ def setUp(self): self.outputs = {'Out': output, 'Indices': indices} -class TestTopkOp6(OpTest): +class TestTopkOp6(TestTopkOp): def init_args(self): - self.k = 100 + self.k = 3 self.axis = 1 self.largest = True def setUp(self): self.op_type = "top_k_v2" + self.prim_op_type = "prim" self.python_api = paddle.topk - self.dtype = np.float64 - self.input_data = np.random.rand(80, 16384) + self.dtype = np.float32 + self.input_data = np.random.rand(10, 10, 5) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest + ) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp7(TestTopkOp): + def init_args(self): + self.k = 10 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.prim_op_type = "prim" + self.python_api = paddle.topk + self.dtype = np.float16 + self.input_data = np.random.rand(10, 20, 10) self.init_args() self.inputs = {'X': self.input_data} self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}