Skip to content

Commit

Permalink
fix test_activation_op (#59618)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingmingyyj authored Dec 6, 2023
1 parent 5a79827 commit b5ebcae
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 1 deletion.
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
'send_v2',
'shadow_feed',
'sparse_momentum',
'soft_relu',
'uniform_random_batch_size_like',
]

Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,16 @@
func : slice
backward : slice_grad

- op : soft_relu
args : (Tensor x, float threshold = 20.0f)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : soft_relu
backward : soft_relu_grad

- op : softmax
args : (Tensor x, int axis)
output : Tensor(out)
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,16 @@
backward : slice_double_grad
no_need_buffer : input

- backward_op : soft_relu_grad
forward : soft_relu (Tensor x, float threshold) -> Tensor(out)
args : (Tensor out, Tensor out_grad, float threshold)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out]
kernel :
func : soft_relu_grad

- backward_op : softmax_grad
forward : softmax (Tensor x, int axis) -> Tensor(out)
args : (Tensor out, Tensor out_grad, int axis)
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ const std::unordered_set<std::string> LegacyOpList = {
SeedOp::name(),
ShareDataOp::name(),
SparseMomentumOp::name(),
GetTensorFromSelectedRowsOp::name()};
GetTensorFromSelectedRowsOp::name(),
SoftReluOp::name(),
SoftReluGradOp::name()};

enum class AttrType {
UNDEFINED = 0,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2828,6 +2828,13 @@
outputs :
out : Out

- op : soft_relu
backward : soft_relu_grad
inputs :
x : X
outputs :
out : Out

- op : softmax
backward : softmax_grad
inputs :
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
test_accuracy_op
test_activation_bf16_mkldnn_op
test_activation_mkldnn_op
test_activation_op
test_adadelta_op
test_adagrad_op
test_adagrad_op_static_build
Expand Down

0 comments on commit b5ebcae

Please sign in to comment.