From b886c4d804cb492fc9a1eacf7c08f40114c9a013 Mon Sep 17 00:00:00 2001 From: risemeup1 <62429225+risemeup1@users.noreply.github.com> Date: Mon, 6 May 2024 17:30:45 +0800 Subject: [PATCH] [PIR] Support sparse_slice and sparse_sum in pt (#64009) * support sparse_slice and sparse_sum in pt * support sparse_slice and sparse_sum in pt * support sparse_slice and sparse_sum in pt --- .../ir_adaptor/translator/op_translator.cc | 5 +++++ paddle/phi/api/yaml/op_compat.yaml | 17 +++++++++++++++++ .../legacy_test/test_sparse_slice_op.py | 7 +++++++ .../legacy_test/test_sparse_sum_op.py | 2 ++ 4 files changed, 31 insertions(+) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index a6d29fe6288bc8..adc1f38d5117d4 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -305,6 +305,11 @@ pir::OpInfo OpTranscriber::LookUpOpInfo(pir::IrContext* ctx, std::map> inputs = op_desc.Inputs(); std::vector input_types; for (const auto& pair : inputs) { + if (op_desc.Type() == "sparse_sum" || op_desc.Type() == "sparse_slice") { + if (pair.first != "x") { + continue; + } + } VarDesc* var_desc = op_desc.Block()->FindVarRecursive(pair.second[0]); PADDLE_ENFORCE_NE( var_desc, diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 3e226e63e82b0c..9fdee62e604a25 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3250,6 +3250,23 @@ attrs: data_format: data_layout +- op : sparse_slice + int_array : + starts : + data_type : int + tensor_name : StartsTensor + tensors_name : StartsTensorList + ends : + data_type : int + tensor_name : EndsTensor + tensors_name : EndsTensorList + +- op : sparse_sum + scalar : + axis : + data_type : int + tensor_name : AxisTensor + - op : sparse_sync_batch_norm attrs: data_format: data_layout diff --git a/test/deprecated/legacy_test/test_sparse_slice_op.py b/test/deprecated/legacy_test/test_sparse_slice_op.py index 483720b0663a26..714a55d24f21c9 100644 --- a/test/deprecated/legacy_test/test_sparse_slice_op.py +++ b/test/deprecated/legacy_test/test_sparse_slice_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from utils import compare_legacy_with_pt import paddle @@ -206,26 +207,32 @@ def check_result_with_list(self, x, axes, starts, ends, format='coo'): if format == 'coo': self._check_result_coo(np_x, axes, starts, ends) + @compare_legacy_with_pt def test_coo_5d(self): for item in data_5d: self.check_result_with_shape(*item, format='coo') + @compare_legacy_with_pt def test_coo_4d(self): for item in data_4d: self.check_result_with_shape(*item, format='coo') + @compare_legacy_with_pt def test_coo_3d(self): for item in data_3d: self.check_result_with_shape(*item, format='coo') + @compare_legacy_with_pt def test_coo_2d(self): for item in data_2d: self.check_result_with_shape(*item, format='coo') + @compare_legacy_with_pt def test_coo_1d(self): x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0] self.check_result_with_list(x, [0], [3], [5], format='coo') + @compare_legacy_with_pt def test_coo_1d_zero(self): x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0] self.check_result_with_list(x, [0], [-3], [-1], format='coo') diff --git a/test/deprecated/legacy_test/test_sparse_sum_op.py b/test/deprecated/legacy_test/test_sparse_sum_op.py index 3690341c51dc0d..8d245508b3d3ef 100644 --- a/test/deprecated/legacy_test/test_sparse_sum_op.py +++ b/test/deprecated/legacy_test/test_sparse_sum_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from utils import compare_legacy_with_pt import paddle @@ -172,6 +173,7 @@ def check_result_coo(self, x_shape, dims, keepdim, dtype=None): ) paddle.disable_static() + @compare_legacy_with_pt def test_sum(self): # 1d self.check_result_coo([5], None, False)