Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Support sparse_slice and sparse_sum in pt #64009

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ pir::OpInfo OpTranscriber::LookUpOpInfo(pir::IrContext* ctx,
std::map<std::string, std::vector<std::string>> inputs = op_desc.Inputs();
std::vector<std::string> input_types;
for (const auto& pair : inputs) {
if (op_desc.Type() == "sparse_sum" || op_desc.Type() == "sparse_slice") {
if (pair.first != "x") {
continue;
}
}
Comment on lines +308 to +312
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里下个PR可以改进下,通过表来配置,避免硬编码,最好复用现在已有的信息

VarDesc* var_desc = op_desc.Block()->FindVarRecursive(pair.second[0]);
PADDLE_ENFORCE_NE(
var_desc,
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3226,6 +3226,23 @@
attrs:
data_format: data_layout

- op : sparse_slice
int_array :
starts :
data_type : int
tensor_name : StartsTensor
tensors_name : StartsTensorList
ends :
data_type : int
tensor_name : EndsTensor
tensors_name : EndsTensorList

- op : sparse_sum
scalar :
axis :
data_type : int
tensor_name : AxisTensor

- op : sparse_sync_batch_norm
attrs:
data_format: data_layout
Expand Down
7 changes: 7 additions & 0 deletions test/deprecated/legacy_test/test_sparse_slice_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import numpy as np
from utils import compare_legacy_with_pt

import paddle

Expand Down Expand Up @@ -206,26 +207,32 @@ def check_result_with_list(self, x, axes, starts, ends, format='coo'):
if format == 'coo':
self._check_result_coo(np_x, axes, starts, ends)

@compare_legacy_with_pt
def test_coo_5d(self):
for item in data_5d:
self.check_result_with_shape(*item, format='coo')

@compare_legacy_with_pt
def test_coo_4d(self):
for item in data_4d:
self.check_result_with_shape(*item, format='coo')

@compare_legacy_with_pt
def test_coo_3d(self):
for item in data_3d:
self.check_result_with_shape(*item, format='coo')

@compare_legacy_with_pt
def test_coo_2d(self):
for item in data_2d:
self.check_result_with_shape(*item, format='coo')

@compare_legacy_with_pt
def test_coo_1d(self):
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
self.check_result_with_list(x, [0], [3], [5], format='coo')

@compare_legacy_with_pt
def test_coo_1d_zero(self):
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
self.check_result_with_list(x, [0], [-3], [-1], format='coo')
Expand Down
2 changes: 2 additions & 0 deletions test/deprecated/legacy_test/test_sparse_sum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import numpy as np
from utils import compare_legacy_with_pt

import paddle

Expand Down Expand Up @@ -172,6 +173,7 @@ def check_result_coo(self, x_shape, dims, keepdim, dtype=None):
)
paddle.disable_static()

@compare_legacy_with_pt
def test_sum(self):
# 1d
self.check_result_coo([5], None, False)
Expand Down