Skip to content

Commit

Permalink
feat(pir): reg global_scatter and limit_by_capacity
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyewww committed Mar 11, 2024
1 parent 03bc4c7 commit 203da1a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,7 @@
kernel :
func : global_scatter
data_type : dtype
backward : global_scatter_grad

- op : greater_equal
args : (Tensor x, Tensor y)
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,15 @@
func : gelu_grad
composite: gelu_grad(x, out_grad, approximate, x_grad)

- backward_op : global_scatter_grad
forward : (Tensor x, Tensor local_count, Tensor global_count, int ring_id=0, bool use_calc_stream=false) -> Tensor(out)
args : (Tensor x, Tensor local_count, Tensor global_count, int ring_id=0, bool use_calc_stream=false)
output : Tensor(x_grad), Tensor(local_count), Tensor(global_count)
infer_meta :
func : UnchangedInferMeta
kernel :
func : global_scatter_grad

- backward_op : grid_sample_grad
forward : grid_sample (Tensor x, Tensor grid, str mode, str padding_mode, bool align_corners) -> Tensor(out)
args : (Tensor x, Tensor grid, Tensor out_grad, str mode, str padding_mode, bool align_corners)
Expand Down

0 comments on commit 203da1a

Please sign in to comment.