Skip to content

Commit

Permalink
add topk prim backward (#50679)
Browse files Browse the repository at this point in the history
* tmp gather vjp

* support gather

* remove useless code

* fix compiling error

* fix ut

* add eager test

* add eager test

* add seed

* small change

* fix cpu error

* fix transpose op compat

* remove tensor index case

* fix prim_cinn

* small commit

* add cumsum prim backward

* small commit

* skip aixs=None test case

* fix op generante eror

* fix static test error

* remove unused code

* fix static test error

* small commit

* skip cpu float16 test case

* skip eager cpu cumsum float16 test case

* add eager and static UT

* fix ut

* add composite backward rule

* fix error

* fix type error and format error

* add try cpu+float16 test

* fix test bugs

* remove test for cpu+float16 and make y[0] be the grad arg

* add cinn test

* fix UT

* fix the wrong dim of v in test cases

* change y[0] to y[1] for grad in UT

* reshape flatten out

* Disable cinn single test

* use scatter_nd_add

* modify the reshape part of topk_grad

* delete useless build file

* to make the syntax right

* modify bug

* try use of put_along_axis

* remove cinn test

* reformat todo

* add silu composite rule

* fix code style.

* add cinn test

* fix composite grad maker code gen

* add prim in cumsum op test

* remove old test

* fix typro

* pass the static test

* fix typro

* modify optest and delete old test files

* remove normal test_top_k_op test

* fix typro

* pass axis=None test case

* buffer comment

* for debug

* add silu fp16 unit test.

* add static guard

* remove forward prim test

* remove same name axis

* modify the test_top_v2_op.py to pass all local tests

* delete the useless testcase

* fix mistake

* add more testcases to test dtype16 and dtype32

---------

Co-authored-by: JiabinYang <[email protected]>
Co-authored-by: GGBond8488 <[email protected]>
Co-authored-by: zxcd <[email protected]>
Co-authored-by: Charles-hit <[email protected]>
  • Loading branch information
5 people authored Mar 1, 2023
1 parent e152e89 commit 296b3ff
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 7 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
- transpose
- pad
- cumsum
- put_along_axis
17 changes: 17 additions & 0 deletions paddle/fluid/prim/api/composite_backward/composite_backward_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -744,5 +744,22 @@ void cumsum_grad(const Tensor& x,
}
}

template <typename T>
void topk_grad(const Tensor& x,
const Tensor& indices,
const Tensor& out_grad,
const Scalar& k,
const int& axis,
const bool& largest,
const bool& sorted,
Tensor* x_grad) {
if (x_grad) {
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
auto x_grad_tmp = put_along_axis<T>(zero_tensor, indices, out_grad, axis);

set_output<T>(x_grad_tmp, x_grad);
}
}

} // namespace prim
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,7 @@
kernel :
func : topk_grad
data_type : out_grad
composite : topk_grad(x, indices, out_grad, k, axis, largest, sorted, x_grad)

- backward_op : trace_grad
forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,8 @@ set(TEST_CINN_OPS
test_slice_op
test_activation_op
test_full_like_op
test_fill_any_like_op)
test_fill_any_like_op
test_top_k_v2_op)

foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
Expand Down
38 changes: 32 additions & 6 deletions python/paddle/fluid/tests/unittests/test_top_k_v2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def init_args(self):

def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(10, 20)
Expand All @@ -60,7 +61,7 @@ def test_check_output(self):
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(set(['X']), 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)


class TestTopkOp1(TestTopkOp):
Expand All @@ -77,14 +78,15 @@ def init_args(self):
self.largest = False


class TestTopkOp3(OpTest):
class TestTopkOp3(TestTopkOp):
def init_args(self):
self.k = 6
self.axis = 1
self.largest = True

def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(16, 100)
Expand All @@ -105,6 +107,7 @@ def init_args(self):

def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(10, 10, 5)
Expand All @@ -125,6 +128,7 @@ def init_args(self):

def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(10, 10, 5)
Expand All @@ -137,17 +141,39 @@ def setUp(self):
self.outputs = {'Out': output, 'Indices': indices}


class TestTopkOp6(OpTest):
class TestTopkOp6(TestTopkOp):
def init_args(self):
self.k = 100
self.k = 3
self.axis = 1
self.largest = True

def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(80, 16384)
self.dtype = np.float32
self.input_data = np.random.rand(10, 10, 5)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
output, indices = numpy_topk(
self.input_data, axis=self.axis, k=self.k, largest=self.largest
)
self.outputs = {'Out': output, 'Indices': indices}


class TestTopkOp7(TestTopkOp):
def init_args(self):
self.k = 10
self.axis = 1
self.largest = True

def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float16
self.input_data = np.random.rand(10, 20, 10)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
Expand Down

0 comments on commit 296b3ff

Please sign in to comment.