Skip to content

Commit

Permalink
update median
Browse files Browse the repository at this point in the history
  • Loading branch information
RedContritio committed Apr 3, 2024
1 parent c7887ca commit 0c58f78
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 18 deletions.
10 changes: 8 additions & 2 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -8494,7 +8494,7 @@
}
},
"torch.median": {
"Matcher": "MedianMatcher",
"Matcher": "DoubleAssignMatcher",
"paddle_api": "paddle.median",
"args_list": [
"input",
Expand All @@ -8506,6 +8506,9 @@
"kwargs_change": {
"input": "x",
"dim": "axis"
},
"paddle_default_kwargs": {
"mode": "'min'"
}
},
"torch.meshgrid": {
Expand Down Expand Up @@ -8711,7 +8714,7 @@
}
},
"torch.nanmedian": {
"Matcher": "MedianMatcher",
"Matcher": "DoubleAssignMatcher",
"paddle_api": "paddle.nanmedian",
"args_list": [
"input",
Expand All @@ -8723,6 +8726,9 @@
"kwargs_change": {
"input": "x",
"dim": "axis"
},
"paddle_default_kwargs": {
"mode": "'min'"
}
},
"torch.nanquantile": {
Expand Down
20 changes: 12 additions & 8 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3305,6 +3305,12 @@ def generate_code(self, kwargs):
if k in kwargs:
kwargs[kwargs_change[k]] = kwargs.pop(k)

paddle_default_kwargs = self.api_mapping.get("paddle_default_kwargs", {})

for k in paddle_default_kwargs:
if k not in kwargs:
kwargs[k] = paddle_default_kwargs[k]

if "out" in kwargs:
out_v = kwargs.pop("out")
API_TEMPLATE = textwrap.dedent(
Expand Down Expand Up @@ -3336,6 +3342,12 @@ def generate_code(self, kwargs):
else:
kwargs.pop(k)

paddle_default_kwargs = self.api_mapping.get("paddle_default_kwargs", {})

for k in paddle_default_kwargs:
if k not in kwargs:
kwargs[k] = paddle_default_kwargs[k]

if "out" in kwargs:
out_v = kwargs.pop("out")
API_TEMPLATE = textwrap.dedent(
Expand Down Expand Up @@ -4328,11 +4340,3 @@ def generate_code(self, kwargs):
else:
code = "misidentify"
return code


class MedianMatcher(BaseMatcher):
def generate_code(self, kwargs):
if kwargs.get("dim", -1) != "-1":
kwargs["mode"] = "'min'"

return GenericMatcher.generate_code(self, kwargs)
15 changes: 13 additions & 2 deletions tests/test_median.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_case_4():
)


def _test_case_5():
def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
Expand All @@ -81,7 +81,7 @@ def _test_case_5():
result = torch.median(input, dim=1, keepdim=True, out=out)
"""
)
obj.run(pytorch_code, ["result"])
obj.run(pytorch_code, ["result", "out"])


def test_case_6():
Expand All @@ -96,3 +96,14 @@ def test_case_6():
pytorch_code,
["result"],
)


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
result = torch.median(input, dim=1, keepdim=True)
"""
)
obj.run(pytorch_code, ["result"])
20 changes: 14 additions & 6 deletions tests/test_nanmedian.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,18 @@ def test_case_2():
)


# 会引发段错误,先屏蔽
def _test_case_4():
def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
result = torch.nanmedian(input, 1, keepdim=True)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
Expand All @@ -59,8 +69,7 @@ def _test_case_4():
)


# 会引发段错误,先屏蔽
def _test_case_5():
def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
Expand All @@ -75,8 +84,7 @@ def _test_case_5():
)


# 会引发段错误,先屏蔽
def _test_case_6():
def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
Expand Down

0 comments on commit 0c58f78

Please sign in to comment.