Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add topk prim backward #50679

Merged
merged 100 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
abba3f7
tmp gather vjp
JiabinYang Feb 3, 2023
86f7cc7
merge develop
JiabinYang Feb 3, 2023
ff51755
support gather
JiabinYang Feb 7, 2023
a37ce9f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiabinYang Feb 7, 2023
acea6c4
remove useless code
JiabinYang Feb 7, 2023
422e93e
fix compiling error
JiabinYang Feb 8, 2023
49b8dd7
fix ut
JiabinYang Feb 8, 2023
6cb7e01
add eager test
JiabinYang Feb 9, 2023
41ebe1b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiabinYang Feb 9, 2023
66602a4
add eager test
JiabinYang Feb 9, 2023
73f2822
add seed
JiabinYang Feb 9, 2023
238a0fb
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiabinYang Feb 9, 2023
7502b31
small change
zhengqiwen1997 Feb 10, 2023
95fab71
fix cpu error
JiabinYang Feb 13, 2023
af37274
merge develop
JiabinYang Feb 13, 2023
3106480
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhengqiwen1997 Feb 14, 2023
d297072
fix transpose op compat
JiabinYang Feb 15, 2023
147436f
remove tensor index case
JiabinYang Feb 15, 2023
f82f1d6
fix prim_cinn
JiabinYang Feb 15, 2023
cdd3dbb
small commit
zhengqiwen1997 Feb 15, 2023
9e50329
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhengqiwen1997 Feb 15, 2023
3fd552b
add cumsum prim backward
GGBond8488 Feb 16, 2023
d4907e8
small commit
zhengqiwen1997 Feb 16, 2023
91e4152
skip aixs=None test case
GGBond8488 Feb 16, 2023
830ea51
fix op generante eror
GGBond8488 Feb 16, 2023
1f87a33
fix static test error
GGBond8488 Feb 16, 2023
b035a91
remove unused code
GGBond8488 Feb 17, 2023
3507a6b
fix static test error
GGBond8488 Feb 17, 2023
ac97fac
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhengqiwen1997 Feb 17, 2023
29a4801
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GGBond8488 Feb 17, 2023
5f9a53d
small commit
zhengqiwen1997 Feb 17, 2023
4a37f43
skip cpu float16 test case
GGBond8488 Feb 17, 2023
8f23cbe
skip eager cpu cumsum float16 test case
GGBond8488 Feb 20, 2023
6da723e
fix conflicts
GGBond8488 Feb 20, 2023
ee5f0ce
add eager and static UT
zhengqiwen1997 Feb 20, 2023
981309f
Merge commit 'refs/pull/50305/head' of https://github.com/PaddlePaddl…
zhengqiwen1997 Feb 20, 2023
028005b
fix ut
JiabinYang Feb 20, 2023
564bdc6
merge develop
JiabinYang Feb 20, 2023
39dcbc7
add composite backward rule
zhengqiwen1997 Feb 20, 2023
66cb700
conflict resolved
zhengqiwen1997 Feb 20, 2023
5cbb91e
Merge commit 'refs/pull/50305/head' of https://github.com/PaddlePaddl…
zhengqiwen1997 Feb 20, 2023
a244bcd
fix error
zhengqiwen1997 Feb 20, 2023
50b256c
fix type error and format error
zhengqiwen1997 Feb 21, 2023
4c511ee
add try cpu+float16 test
zhengqiwen1997 Feb 21, 2023
a3fb0fc
fix test bugs
zhengqiwen1997 Feb 21, 2023
51decaf
remove test for cpu+float16 and make y[0] be the grad arg
zhengqiwen1997 Feb 21, 2023
67b1293
add cinn test
GGBond8488 Feb 21, 2023
9c3fd8f
fix UT
zhengqiwen1997 Feb 21, 2023
8f7b029
fix the wrong dim of v in test cases
zhengqiwen1997 Feb 21, 2023
467e883
change y[0] to y[1] for grad in UT
zhengqiwen1997 Feb 21, 2023
51ff339
reshape flatten out
GGBond8488 Feb 21, 2023
635beef
Disable cinn single test
GGBond8488 Feb 22, 2023
a1725b9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GGBond8488 Feb 22, 2023
e2265ef
use scatter_nd_add
zhengqiwen1997 Feb 22, 2023
1a9b642
conflict resolved
zhengqiwen1997 Feb 22, 2023
cb6d3ca
modify the reshape part of topk_grad
zhengqiwen1997 Feb 22, 2023
6923d34
delete useless build file
zhengqiwen1997 Feb 22, 2023
f559336
to make the syntax right
zhengqiwen1997 Feb 22, 2023
06d8555
modify bug
zhengqiwen1997 Feb 22, 2023
af2cd5e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GGBond8488 Feb 23, 2023
ef57763
try use of put_along_axis
zhengqiwen1997 Feb 23, 2023
2313a12
remove cinn test
GGBond8488 Feb 23, 2023
35de0a0
fix conflicts
GGBond8488 Feb 23, 2023
cbcc61c
reformat todo
GGBond8488 Feb 23, 2023
fc94cc5
add silu composite rule
zxcd Feb 23, 2023
a91d063
fix code style.
zxcd Feb 23, 2023
583addb
add cinn test
zhengqiwen1997 Feb 23, 2023
ccdca72
conflict resolved
zhengqiwen1997 Feb 23, 2023
3a23ea6
fix composite grad maker code gen
Charles-hit Feb 24, 2023
dd1b94b
Merge commit 'refs/pull/50854/head' of https://github.com/PaddlePaddl…
zhengqiwen1997 Feb 24, 2023
abdbcf5
fix conflcits
GGBond8488 Feb 24, 2023
0de9ebe
add prim in cumsum op test
GGBond8488 Feb 24, 2023
69d6df2
remove old test
GGBond8488 Feb 24, 2023
8cb8968
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GGBond8488 Feb 24, 2023
532fdb8
fix typro
GGBond8488 Feb 24, 2023
910d41c
pass the static test
zhengqiwen1997 Feb 24, 2023
e3a9eef
conflict resolved
zhengqiwen1997 Feb 24, 2023
ffe8930
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhengqiwen1997 Feb 24, 2023
ca54bd3
fix typro
GGBond8488 Feb 24, 2023
bd3310f
modify optest and delete old test files
zhengqiwen1997 Feb 24, 2023
556a824
remove normal test_top_k_op test
zhengqiwen1997 Feb 24, 2023
596895f
fix typro
GGBond8488 Feb 24, 2023
2631fd4
pass axis=None test case
GGBond8488 Feb 24, 2023
42dcfc2
buffer comment
zhengqiwen1997 Feb 27, 2023
04e5819
for debug
zhengqiwen1997 Feb 27, 2023
a4ccc9f
add silu fp16 unit test.
zxcd Feb 27, 2023
1683995
add static guard
Charles-hit Feb 27, 2023
0d90712
Merge commit 'refs/pull/50971/head' of https://github.com/PaddlePaddl…
zhengqiwen1997 Feb 27, 2023
47835ca
merge
zhengqiwen1997 Feb 27, 2023
d478f09
remove forward prim test
GGBond8488 Feb 27, 2023
dc24092
remove same name axis
GGBond8488 Feb 27, 2023
e002a82
modify the test_top_v2_op.py to pass all local tests
zhengqiwen1997 Feb 27, 2023
d75c5bd
delete the useless testcase
zhengqiwen1997 Feb 28, 2023
5af4155
merge conflict
zhengqiwen1997 Feb 28, 2023
e20bb2a
Merge commit 'refs/pull/50838/head' of https://github.com/PaddlePaddl…
zhengqiwen1997 Feb 28, 2023
bc485d1
fix mistake
zhengqiwen1997 Feb 28, 2023
c87ad2a
Merge branch 'develop' into topk_grad_comp
zhengqiwen1997 Feb 28, 2023
6fc520a
add more testcases to test dtype16 and dtype32
zhengqiwen1997 Mar 1, 2023
6ae7aee
Merge branch 'topk_grad_comp' of https://github.com/zhengqiwen1997/Pa…
zhengqiwen1997 Mar 1, 2023
4d7ab72
Merge branch 'develop' into topk_grad_comp
zhengqiwen1997 Mar 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
- transpose
- pad
- cumsum
- put_along_axis
Original file line number Diff line number Diff line change
Expand Up @@ -744,5 +744,22 @@ void cumsum_grad(const Tensor& x,
}
}

template <typename T>
void topk_grad(const Tensor& x,
const Tensor& indices,
const Tensor& out_grad,
const Scalar& k,
const int& axis,
const bool& largest,
const bool& sorted,
Tensor* x_grad) {
if (x_grad) {
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
auto x_grad_tmp = put_along_axis<T>(zero_tensor, indices, out_grad, axis);

set_output<T>(x_grad_tmp, x_grad);
}
}

} // namespace prim
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,7 @@
kernel :
func : topk_grad
data_type : out_grad
composite : topk_grad(x, indices, out_grad, k, axis, largest, sorted, x_grad)

- backward_op : trace_grad
forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,8 @@ set(TEST_CINN_OPS
test_slice_op
test_activation_op
test_full_like_op
test_fill_any_like_op)
test_fill_any_like_op
test_top_k_v2_op)

foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
Expand Down
38 changes: 32 additions & 6 deletions python/paddle/fluid/tests/unittests/test_top_k_v2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def init_args(self):

def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(10, 20)
Expand All @@ -60,7 +61,7 @@ def test_check_output(self):
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(set(['X']), 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)


class TestTopkOp1(TestTopkOp):
Expand All @@ -77,14 +78,15 @@ def init_args(self):
self.largest = False


class TestTopkOp3(OpTest):
class TestTopkOp3(TestTopkOp):
def init_args(self):
self.k = 6
self.axis = 1
self.largest = True

def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(16, 100)
Expand All @@ -105,6 +107,7 @@ def init_args(self):

def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(10, 10, 5)
Expand All @@ -125,6 +128,7 @@ def init_args(self):

def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(10, 10, 5)
Expand All @@ -137,17 +141,39 @@ def setUp(self):
self.outputs = {'Out': output, 'Indices': indices}


class TestTopkOp6(OpTest):
class TestTopkOp6(TestTopkOp):
def init_args(self):
self.k = 100
self.k = 3
self.axis = 1
self.largest = True

def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float64
self.input_data = np.random.rand(80, 16384)
self.dtype = np.float32
self.input_data = np.random.rand(10, 10, 5)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
output, indices = numpy_topk(
self.input_data, axis=self.axis, k=self.k, largest=self.largest
)
self.outputs = {'Out': output, 'Indices': indices}


class TestTopkOp7(TestTopkOp):
def init_args(self):
self.k = 10
self.axis = 1
self.largest = True

def setUp(self):
self.op_type = "top_k_v2"
self.prim_op_type = "prim"
self.python_api = paddle.topk
self.dtype = np.float16
self.input_data = np.random.rand(10, 20, 10)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
Expand Down