Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade torch.median and 3 apis #382

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -8494,7 +8494,7 @@
}
},
"torch.median": {
"Matcher": "GenericMatcher",
"Matcher": "DoubleAssignMatcher",
"paddle_api": "paddle.median",
"args_list": [
"input",
Expand All @@ -8507,9 +8507,9 @@
"input": "x",
"dim": "axis"
},
"unsupport_args": [
"dim"
]
"paddle_default_kwargs": {
"mode": "'min'"
}
RedContritio marked this conversation as resolved.
Show resolved Hide resolved
},
"torch.meshgrid": {
"Matcher": "MeshgridMatcher",
Expand Down Expand Up @@ -8714,7 +8714,7 @@
}
},
"torch.nanmedian": {
"Matcher": "GenericMatcher",
"Matcher": "DoubleAssignMatcher",
"paddle_api": "paddle.nanmedian",
"args_list": [
"input",
Expand All @@ -8727,9 +8727,9 @@
"input": "x",
"dim": "axis"
},
"unsupport_args": [
"dim"
]
"paddle_default_kwargs": {
"mode": "'min'"
}
},
"torch.nanquantile": {
"Matcher": "GenericMatcher",
Expand Down Expand Up @@ -10747,7 +10747,6 @@
"dtype": ""
},
"unsupport_args": [
"layer_norm_eps",
"batch_first"
],
"paddle_default_kwargs": {
Expand Down Expand Up @@ -10792,7 +10791,6 @@
"dtype": ""
},
"unsupport_args": [
"layer_norm_eps",
"batch_first"
],
"paddle_default_kwargs": {
Expand Down Expand Up @@ -11468,6 +11466,22 @@
},
"min_input_args": 2
},
"torch.nn.functional.group_norm": {
"Matcher": "GenericMatcher",
"min_input_args": 2,
"paddle_api": "paddle.nn.functional.group_norm",
"args_list": [
"input",
"num_groups",
"weight",
"bias",
"eps"
],
"kwargs_change": {
"input": "x",
"eps": "epsilon"
}
},
"torch.nn.functional.gumbel_softmax": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.gumbel_softmax",
Expand Down
14 changes: 7 additions & 7 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3297,13 +3297,15 @@ def generate_code(self, kwargs):

class DoubleAssignMatcher(BaseMatcher):
def generate_code(self, kwargs):
kwargs_change = {}
if "kwargs_change" in self.api_mapping:
kwargs_change = self.api_mapping["kwargs_change"]
kwargs = self.set_paddle_default_kwargs(kwargs)
kwargs_change = self.api_mapping.get("kwargs_change", {})

for k in kwargs_change:
if k in kwargs:
kwargs[kwargs_change[k]] = kwargs.pop(k)
if kwargs[k]:
kwargs[kwargs_change[k]] = kwargs.pop(k)
else:
kwargs.pop(k)

if "out" in kwargs:
out_v = kwargs.pop("out")
Expand All @@ -3325,9 +3327,7 @@ def generate_code(self, kwargs):
class TripleAssignMatcher(BaseMatcher):
def generate_code(self, kwargs):
kwargs = self.set_paddle_default_kwargs(kwargs)
kwargs_change = {}
if "kwargs_change" in self.api_mapping:
kwargs_change = self.api_mapping["kwargs_change"]
kwargs_change = self.api_mapping.get("kwargs_change", {})

for k in kwargs_change:
if k in kwargs:
Expand Down
28 changes: 13 additions & 15 deletions tests/test_median.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ def test_case_2():
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)


Expand All @@ -57,8 +55,6 @@ def test_case_3():
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)


Expand All @@ -73,8 +69,6 @@ def test_case_4():
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)


Expand All @@ -83,16 +77,11 @@ def test_case_5():
"""
import torch
input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
out = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
out = (torch.tensor([[1.1], [1.2]]), torch.tensor([[1], [2]]))
result = torch.median(input, dim=1, keepdim=True, out=out)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)
obj.run(pytorch_code, ["result", "out"])


def test_case_6():
Expand All @@ -106,6 +95,15 @@ def test_case_6():
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
result = torch.median(input, dim=1, keepdim=True)
"""
)
obj.run(pytorch_code, ["result"])
22 changes: 3 additions & 19 deletions tests/test_nanmedian.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ def test_case_2():
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)


Expand All @@ -54,12 +52,7 @@ def test_case_3():
result = torch.nanmedian(input, 1, keepdim=True)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)
obj.run(pytorch_code, ["result"])


def test_case_4():
Expand All @@ -73,8 +66,6 @@ def test_case_4():
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)


Expand All @@ -83,15 +74,13 @@ def test_case_5():
"""
import torch
input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
out = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
out = (torch.tensor([[1.1], [1.2]]), torch.tensor([[1], [2]]))
result = torch.nanmedian(input, dim=1, keepdim=True, out=out)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)


Expand All @@ -103,9 +92,4 @@ def test_case_6():
result = torch.nanmedian(input, 0)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)
obj.run(pytorch_code, ["result"])
7 changes: 1 addition & 6 deletions tests/test_nn_TransformerEncoderLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,7 @@ def test_case_5():
result = model(tgt)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle unsupport layer_norm_eps args",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TransformDecodeLayer这个的单测是不是没改?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,下个 pr 改

)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_6():
Expand Down
93 changes: 73 additions & 20 deletions tests/test_nn_functional_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,10 @@ def test_case_1():
[-0.1920, 0.1826, 1.9217, -0.4359],
[ 1.1926, -0.0247, 0.4744, -1.0216],
[-0.0360, -1.1656, 0.3661, -1.8147]]]])
result = F.group_norm(x, 3)
result = F.group_norm(x, 2)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
)
obj.run(pytorch_code, ["result"], atol=1e-4)


def test_case_2():
Expand All @@ -58,15 +53,10 @@ def test_case_2():
[-0.1920, 0.1826, 1.9217, -0.4359],
[ 1.1926, -0.0247, 0.4744, -1.0216],
[-0.0360, -1.1656, 0.3661, -1.8147]]]])
result = F.group_norm(x, 3)
result = F.group_norm(x, 2)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
)
obj.run(pytorch_code, ["result"], atol=1e-4)


def test_case_3():
Expand All @@ -83,12 +73,75 @@ def test_case_3():
[-0.1920, 0.1826, 1.9217, -0.4359],
[ 1.1926, -0.0247, 0.4744, -1.0216],
[-0.0360, -1.1656, 0.3661, -1.8147]]]])
result = F.group_norm(x, 3, eps=1e-5)
result = F.group_norm(x, 2, eps=1e-5)
"""
)
obj.run(pytorch_code, ["result"], atol=1e-4)


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn.functional as F
x = torch.tensor([[[[-0.0878, 0.3378, 0.0547, 1.2068],
[0.4212, -1.6113, 0.7277, 0.0766],
[0.8189, 0.0958, 1.7780, 1.1192],
[0.7286, -0.1988, 1.0519, 0.9217]],

[[0.0088, -1.9815, -0.3543, 0.1712],
[-0.1830, 0.0325, -0.1784, 0.1072],
[1.1752, -0.0234, -1.0873, -0.5568],
[0.4471, 0.4073, -1.6031, -0.0310]]]])
weight = torch.tensor([1.3, 1.2])
bias = torch.tensor([0.1, 0.2])
result = F.group_norm(x, 2, weight, bias, 1e-5)
"""
)
obj.run(pytorch_code, ["result"], atol=1e-4)


# generated by validate_unittest autofix, based on test_case_4
def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn.functional as F
x = torch.tensor([[[[-0.0878, 0.3378, 0.0547, 1.2068],
[0.4212, -1.6113, 0.7277, 0.0766],
[0.8189, 0.0958, 1.7780, 1.1192],
[0.7286, -0.1988, 1.0519, 0.9217]],

[[0.0088, -1.9815, -0.3543, 0.1712],
[-0.1830, 0.0325, -0.1784, 0.1072],
[1.1752, -0.0234, -1.0873, -0.5568],
[0.4471, 0.4073, -1.6031, -0.0310]]]])
weight = torch.tensor([1.3, 1.2])
bias = torch.tensor([0.1, 0.2])
result = F.group_norm(input=x, num_groups=2, weight=weight, bias=bias, eps=1e-5)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
obj.run(pytorch_code, ["result"], atol=1e-4)


# generated by validate_unittest autofix, based on test_case_4
def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn.functional as F
x = torch.tensor([[[[-0.0878, 0.3378, 0.0547, 1.2068],
[0.4212, -1.6113, 0.7277, 0.0766],
[0.8189, 0.0958, 1.7780, 1.1192],
[0.7286, -0.1988, 1.0519, 0.9217]],

[[0.0088, -1.9815, -0.3543, 0.1712],
[-0.1830, 0.0325, -0.1784, 0.1072],
[1.1752, -0.0234, -1.0873, -0.5568],
[0.4471, 0.4073, -1.6031, -0.0310]]]])
weight = torch.tensor([1.3, 1.2])
bias = torch.tensor([0.1, 0.2])
result = F.group_norm(eps=1e-5, bias=bias, weight=weight, num_groups=2, input=x)
"""
)
obj.run(pytorch_code, ["result"], atol=1e-4)