diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index a6d29fe6288bc..adc1f38d5117d 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -305,6 +305,11 @@ pir::OpInfo OpTranscriber::LookUpOpInfo(pir::IrContext* ctx, std::map> inputs = op_desc.Inputs(); std::vector input_types; for (const auto& pair : inputs) { + if (op_desc.Type() == "sparse_sum" || op_desc.Type() == "sparse_slice") { + if (pair.first != "x") { + continue; + } + } VarDesc* var_desc = op_desc.Block()->FindVarRecursive(pair.second[0]); PADDLE_ENFORCE_NE( var_desc, diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 883f6ec122f43..38aadc92fee60 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3226,6 +3226,23 @@ attrs: data_format: data_layout +- op : sparse_slice + int_array : + starts : + data_type : int + tensor_name : StartsTensor + tensors_name : StartsTensorList + ends : + data_type : int + tensor_name : EndsTensor + tensors_name : EndsTensorList + +- op : sparse_sum + scalar : + axis : + data_type : int + tensor_name : AxisTensor + - op : sparse_sync_batch_norm attrs: data_format: data_layout diff --git a/test/deprecated/legacy_test/test_sparse_slice_op.py b/test/deprecated/legacy_test/test_sparse_slice_op.py index 483720b0663a2..714a55d24f21c 100644 --- a/test/deprecated/legacy_test/test_sparse_slice_op.py +++ b/test/deprecated/legacy_test/test_sparse_slice_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from utils import compare_legacy_with_pt import paddle @@ -206,26 +207,32 @@ def check_result_with_list(self, x, axes, starts, ends, format='coo'): if format == 'coo': self._check_result_coo(np_x, axes, starts, ends) + @compare_legacy_with_pt def test_coo_5d(self): for item in data_5d: self.check_result_with_shape(*item, format='coo') + @compare_legacy_with_pt def test_coo_4d(self): for item in data_4d: self.check_result_with_shape(*item, format='coo') + @compare_legacy_with_pt def test_coo_3d(self): for item in data_3d: self.check_result_with_shape(*item, format='coo') + @compare_legacy_with_pt def test_coo_2d(self): for item in data_2d: self.check_result_with_shape(*item, format='coo') + @compare_legacy_with_pt def test_coo_1d(self): x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0] self.check_result_with_list(x, [0], [3], [5], format='coo') + @compare_legacy_with_pt def test_coo_1d_zero(self): x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0] self.check_result_with_list(x, [0], [-3], [-1], format='coo') diff --git a/test/deprecated/legacy_test/test_sparse_sum_op.py b/test/deprecated/legacy_test/test_sparse_sum_op.py index 3690341c51dc0..8d245508b3d3e 100644 --- a/test/deprecated/legacy_test/test_sparse_sum_op.py +++ b/test/deprecated/legacy_test/test_sparse_sum_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from utils import compare_legacy_with_pt import paddle @@ -172,6 +173,7 @@ def check_result_coo(self, x_shape, dims, keepdim, dtype=None): ) paddle.disable_static() + @compare_legacy_with_pt def test_sum(self): # 1d self.check_result_coo([5], None, False)