Skip to content

Commit

Permalink
upgrade torch.median and 3 apis (#382)
Browse files Browse the repository at this point in the history
* update torch.median

* update test_nanmedian

* update torch.nn.functional.group_norm

* update torch.nn.TransformerEncoderLayer

* update median

* update
  • Loading branch information
RedContritio authored Apr 10, 2024
1 parent c7bbf82 commit b6cda71
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 77 deletions.
34 changes: 24 additions & 10 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -8494,7 +8494,7 @@
}
},
"torch.median": {
"Matcher": "GenericMatcher",
"Matcher": "DoubleAssignMatcher",
"paddle_api": "paddle.median",
"args_list": [
"input",
Expand All @@ -8507,9 +8507,9 @@
"input": "x",
"dim": "axis"
},
"unsupport_args": [
"dim"
]
"paddle_default_kwargs": {
"mode": "'min'"
}
},
"torch.meshgrid": {
"Matcher": "MeshgridMatcher",
Expand Down Expand Up @@ -8714,7 +8714,7 @@
}
},
"torch.nanmedian": {
"Matcher": "GenericMatcher",
"Matcher": "DoubleAssignMatcher",
"paddle_api": "paddle.nanmedian",
"args_list": [
"input",
Expand All @@ -8727,9 +8727,9 @@
"input": "x",
"dim": "axis"
},
"unsupport_args": [
"dim"
]
"paddle_default_kwargs": {
"mode": "'min'"
}
},
"torch.nanquantile": {
"Matcher": "GenericMatcher",
Expand Down Expand Up @@ -10747,7 +10747,6 @@
"dtype": ""
},
"unsupport_args": [
"layer_norm_eps",
"batch_first"
],
"paddle_default_kwargs": {
Expand Down Expand Up @@ -10792,7 +10791,6 @@
"dtype": ""
},
"unsupport_args": [
"layer_norm_eps",
"batch_first"
],
"paddle_default_kwargs": {
Expand Down Expand Up @@ -11468,6 +11466,22 @@
},
"min_input_args": 2
},
"torch.nn.functional.group_norm": {
"Matcher": "GenericMatcher",
"min_input_args": 2,
"paddle_api": "paddle.nn.functional.group_norm",
"args_list": [
"input",
"num_groups",
"weight",
"bias",
"eps"
],
"kwargs_change": {
"input": "x",
"eps": "epsilon"
}
},
"torch.nn.functional.gumbel_softmax": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.gumbel_softmax",
Expand Down
14 changes: 7 additions & 7 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3297,13 +3297,15 @@ def generate_code(self, kwargs):

class DoubleAssignMatcher(BaseMatcher):
def generate_code(self, kwargs):
kwargs_change = {}
if "kwargs_change" in self.api_mapping:
kwargs_change = self.api_mapping["kwargs_change"]
kwargs = self.set_paddle_default_kwargs(kwargs)
kwargs_change = self.api_mapping.get("kwargs_change", {})

for k in kwargs_change:
if k in kwargs:
kwargs[kwargs_change[k]] = kwargs.pop(k)
if kwargs[k]:
kwargs[kwargs_change[k]] = kwargs.pop(k)
else:
kwargs.pop(k)

if "out" in kwargs:
out_v = kwargs.pop("out")
Expand All @@ -3325,9 +3327,7 @@ def generate_code(self, kwargs):
class TripleAssignMatcher(BaseMatcher):
def generate_code(self, kwargs):
kwargs = self.set_paddle_default_kwargs(kwargs)
kwargs_change = {}
if "kwargs_change" in self.api_mapping:
kwargs_change = self.api_mapping["kwargs_change"]
kwargs_change = self.api_mapping.get("kwargs_change", {})

for k in kwargs_change:
if k in kwargs:
Expand Down
28 changes: 13 additions & 15 deletions tests/test_median.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ def test_case_2():
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)


Expand All @@ -57,8 +55,6 @@ def test_case_3():
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)


Expand All @@ -73,8 +69,6 @@ def test_case_4():
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)


Expand All @@ -83,16 +77,11 @@ def test_case_5():
"""
import torch
input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
out = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
out = (torch.tensor([[1.1], [1.2]]), torch.tensor([[1], [2]]))
result = torch.median(input, dim=1, keepdim=True, out=out)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)
obj.run(pytorch_code, ["result", "out"])


def test_case_6():
Expand All @@ -106,6 +95,15 @@ def test_case_6():
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
result = torch.median(input, dim=1, keepdim=True)
"""
)
obj.run(pytorch_code, ["result"])
22 changes: 3 additions & 19 deletions tests/test_nanmedian.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ def test_case_2():
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)


Expand All @@ -54,12 +52,7 @@ def test_case_3():
result = torch.nanmedian(input, 1, keepdim=True)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)
obj.run(pytorch_code, ["result"])


def test_case_4():
Expand All @@ -73,8 +66,6 @@ def test_case_4():
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)


Expand All @@ -83,15 +74,13 @@ def test_case_5():
"""
import torch
input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
out = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
out = (torch.tensor([[1.1], [1.2]]), torch.tensor([[1], [2]]))
result = torch.nanmedian(input, dim=1, keepdim=True, out=out)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)


Expand All @@ -103,9 +92,4 @@ def test_case_6():
result = torch.nanmedian(input, 0)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not return index when dim is specified",
)
obj.run(pytorch_code, ["result"])
7 changes: 1 addition & 6 deletions tests/test_nn_TransformerEncoderLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,7 @@ def test_case_5():
result = model(tgt)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle unsupport layer_norm_eps args",
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_6():
Expand Down
93 changes: 73 additions & 20 deletions tests/test_nn_functional_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,10 @@ def test_case_1():
[-0.1920, 0.1826, 1.9217, -0.4359],
[ 1.1926, -0.0247, 0.4744, -1.0216],
[-0.0360, -1.1656, 0.3661, -1.8147]]]])
result = F.group_norm(x, 3)
result = F.group_norm(x, 2)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
)
obj.run(pytorch_code, ["result"], atol=1e-4)


def test_case_2():
Expand All @@ -58,15 +53,10 @@ def test_case_2():
[-0.1920, 0.1826, 1.9217, -0.4359],
[ 1.1926, -0.0247, 0.4744, -1.0216],
[-0.0360, -1.1656, 0.3661, -1.8147]]]])
result = F.group_norm(x, 3)
result = F.group_norm(x, 2)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
)
obj.run(pytorch_code, ["result"], atol=1e-4)


def test_case_3():
Expand All @@ -83,12 +73,75 @@ def test_case_3():
[-0.1920, 0.1826, 1.9217, -0.4359],
[ 1.1926, -0.0247, 0.4744, -1.0216],
[-0.0360, -1.1656, 0.3661, -1.8147]]]])
result = F.group_norm(x, 3, eps=1e-5)
result = F.group_norm(x, 2, eps=1e-5)
"""
)
obj.run(pytorch_code, ["result"], atol=1e-4)


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn.functional as F
x = torch.tensor([[[[-0.0878, 0.3378, 0.0547, 1.2068],
[0.4212, -1.6113, 0.7277, 0.0766],
[0.8189, 0.0958, 1.7780, 1.1192],
[0.7286, -0.1988, 1.0519, 0.9217]],
[[0.0088, -1.9815, -0.3543, 0.1712],
[-0.1830, 0.0325, -0.1784, 0.1072],
[1.1752, -0.0234, -1.0873, -0.5568],
[0.4471, 0.4073, -1.6031, -0.0310]]]])
weight = torch.tensor([1.3, 1.2])
bias = torch.tensor([0.1, 0.2])
result = F.group_norm(x, 2, weight, bias, 1e-5)
"""
)
obj.run(pytorch_code, ["result"], atol=1e-4)


# generated by validate_unittest autofix, based on test_case_4
def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn.functional as F
x = torch.tensor([[[[-0.0878, 0.3378, 0.0547, 1.2068],
[0.4212, -1.6113, 0.7277, 0.0766],
[0.8189, 0.0958, 1.7780, 1.1192],
[0.7286, -0.1988, 1.0519, 0.9217]],
[[0.0088, -1.9815, -0.3543, 0.1712],
[-0.1830, 0.0325, -0.1784, 0.1072],
[1.1752, -0.0234, -1.0873, -0.5568],
[0.4471, 0.4073, -1.6031, -0.0310]]]])
weight = torch.tensor([1.3, 1.2])
bias = torch.tensor([0.1, 0.2])
result = F.group_norm(input=x, num_groups=2, weight=weight, bias=bias, eps=1e-5)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
obj.run(pytorch_code, ["result"], atol=1e-4)


# generated by validate_unittest autofix, based on test_case_4
def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn.functional as F
x = torch.tensor([[[[-0.0878, 0.3378, 0.0547, 1.2068],
[0.4212, -1.6113, 0.7277, 0.0766],
[0.8189, 0.0958, 1.7780, 1.1192],
[0.7286, -0.1988, 1.0519, 0.9217]],
[[0.0088, -1.9815, -0.3543, 0.1712],
[-0.1830, 0.0325, -0.1784, 0.1072],
[1.1752, -0.0234, -1.0873, -0.5568],
[0.4471, 0.4073, -1.6031, -0.0310]]]])
weight = torch.tensor([1.3, 1.2])
bias = torch.tensor([0.1, 0.2])
result = F.group_norm(eps=1e-5, bias=bias, weight=weight, num_groups=2, input=x)
"""
)
obj.run(pytorch_code, ["result"], atol=1e-4)

0 comments on commit b6cda71

Please sign in to comment.