Skip to content

Commit

Permalink
[PIR] Support sparse_slice and sparse_sum in pt (PaddlePaddle#64009)
Browse files Browse the repository at this point in the history
* support sparse_slice and sparse_sum in pt

* support sparse_slice and sparse_sum in pt

* support sparse_slice and sparse_sum in pt
  • Loading branch information
risemeup1 authored and yinfan98 committed May 7, 2024
1 parent 1472564 commit b886c4d
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 0 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ pir::OpInfo OpTranscriber::LookUpOpInfo(pir::IrContext* ctx,
std::map<std::string, std::vector<std::string>> inputs = op_desc.Inputs();
std::vector<std::string> input_types;
for (const auto& pair : inputs) {
if (op_desc.Type() == "sparse_sum" || op_desc.Type() == "sparse_slice") {
if (pair.first != "x") {
continue;
}
}
VarDesc* var_desc = op_desc.Block()->FindVarRecursive(pair.second[0]);
PADDLE_ENFORCE_NE(
var_desc,
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3250,6 +3250,23 @@
attrs:
data_format: data_layout

- op : sparse_slice
int_array :
starts :
data_type : int
tensor_name : StartsTensor
tensors_name : StartsTensorList
ends :
data_type : int
tensor_name : EndsTensor
tensors_name : EndsTensorList

- op : sparse_sum
scalar :
axis :
data_type : int
tensor_name : AxisTensor

- op : sparse_sync_batch_norm
attrs:
data_format: data_layout
Expand Down
7 changes: 7 additions & 0 deletions test/deprecated/legacy_test/test_sparse_slice_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import numpy as np
from utils import compare_legacy_with_pt

import paddle

Expand Down Expand Up @@ -206,26 +207,32 @@ def check_result_with_list(self, x, axes, starts, ends, format='coo'):
if format == 'coo':
self._check_result_coo(np_x, axes, starts, ends)

@compare_legacy_with_pt
def test_coo_5d(self):
for item in data_5d:
self.check_result_with_shape(*item, format='coo')

@compare_legacy_with_pt
def test_coo_4d(self):
for item in data_4d:
self.check_result_with_shape(*item, format='coo')

@compare_legacy_with_pt
def test_coo_3d(self):
for item in data_3d:
self.check_result_with_shape(*item, format='coo')

@compare_legacy_with_pt
def test_coo_2d(self):
for item in data_2d:
self.check_result_with_shape(*item, format='coo')

@compare_legacy_with_pt
def test_coo_1d(self):
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
self.check_result_with_list(x, [0], [3], [5], format='coo')

@compare_legacy_with_pt
def test_coo_1d_zero(self):
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
self.check_result_with_list(x, [0], [-3], [-1], format='coo')
Expand Down
2 changes: 2 additions & 0 deletions test/deprecated/legacy_test/test_sparse_sum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import numpy as np
from utils import compare_legacy_with_pt

import paddle

Expand Down Expand Up @@ -172,6 +173,7 @@ def check_result_coo(self, x_shape, dims, keepdim, dtype=None):
)
paddle.disable_static()

@compare_legacy_with_pt
def test_sum(self):
# 1d
self.check_result_coo([5], None, False)
Expand Down

0 comments on commit b886c4d

Please sign in to comment.