From b3eb56bc91b8c2a9de0c890ef74ea2eaa31e0027 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Tue, 2 Apr 2024 17:09:03 +0800 Subject: [PATCH 1/2] update quantile --- paconvert/api_mapping.json | 19 +++++----------- tests/test_Tensor_nanquantile.py | 17 ++++++++++++-- tests/test_Tensor_quantile.py | 7 +++--- tests/test_nanquantile.py | 21 ++++++++++++----- tests/test_quantile.py | 39 ++++++++++++++++++++++++++++---- 5 files changed, 74 insertions(+), 29 deletions(-) diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index e50d95797..51803b12d 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -2466,10 +2466,7 @@ ], "kwargs_change": { "dim": "axis" - }, - "unsupport_args": [ - "interpolation" - ] + } }, "torch.Tensor.nansum": { "Matcher": "GenericMatcher", @@ -2785,7 +2782,9 @@ "args_list": [ "q", "dim", - "keepdim" + "keepdim", + "*", + "interpolation" ], "kwargs_change": { "dim": "axis" @@ -8747,10 +8746,7 @@ "kwargs_change": { "input": "x", "dim": "axis" - }, - "unsupport_args": [ - "interpolation " - ] + } }, "torch.nansum": { "Matcher": "GenericMatcher", @@ -13500,10 +13496,7 @@ "kwargs_change": { "input": "x", "dim": "axis" - }, - "unsupport_args": [ - "interpolation " - ] + } }, "torch.rad2deg": { "Matcher": "GenericMatcher", diff --git a/tests/test_Tensor_nanquantile.py b/tests/test_Tensor_nanquantile.py index a8e024ca5..56d82033c 100644 --- a/tests/test_Tensor_nanquantile.py +++ b/tests/test_Tensor_nanquantile.py @@ -97,6 +97,19 @@ def test_case_7(): obj.run( pytorch_code, ["result"], - unsupport=True, - reason="Paddle not support this parameter", + ) + + +# generated by validate_unittest autofix, based on test_case_7 +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0]], dtype=torch.float64) + result = x.nanquantile(interpolation='higher', keepdim=True, dim=1, q=0.3) + """ + ) + obj.run( + pytorch_code, + ["result"], ) diff --git a/tests/test_Tensor_quantile.py b/tests/test_Tensor_quantile.py index b38443650..636b48ae0 100644 --- a/tests/test_Tensor_quantile.py +++ b/tests/test_Tensor_quantile.py @@ -132,13 +132,12 @@ def test_case_10(): obj.run(pytorch_code, ["result"]) -# torch.Tensor.quantile not support interpolation -def _test_case_11(): +def test_case_11(): pytorch_code = textwrap.dedent( """ import torch - x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]],dtype=torch.float64) - result=x.quantile(q=0.6, dim=1, keepdim=False, interpolation='linear') + x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]], dtype=torch.float64) + result = x.quantile(q=0.6, dim=1, keepdim=False, interpolation='linear') """ ) obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nanquantile.py b/tests/test_nanquantile.py index 0c1c30c13..2e3142475 100644 --- a/tests/test_nanquantile.py +++ b/tests/test_nanquantile.py @@ -114,12 +114,21 @@ def test_case_9(): """ import torch x = torch.tensor([[0]], dtype=torch.float64) - result = torch.nanquantile(x=x, q=0.3, dim=1, keepdim=True, interpolation='higher') + out = torch.tensor([[1]], dtype=torch.float64) + result = torch.nanquantile(input=x, q=0.3, dim=1, keepdim=True, interpolation='higher', out=out) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="Paddle not support this parameter", + obj.run(pytorch_code, ["result", "out"]) + + +# generated by validate_unittest autofix, based on test_case_9 +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0]], dtype=torch.float64) + out = torch.tensor([[1]], dtype=torch.float64) + result = torch.nanquantile(out=out, interpolation='higher', keepdim=True, dim=1, q=0.3, input=x) + """ ) + obj.run(pytorch_code, ["result", "out"]) diff --git a/tests/test_quantile.py b/tests/test_quantile.py index d334f2981..b742d8ed6 100644 --- a/tests/test_quantile.py +++ b/tests/test_quantile.py @@ -86,12 +86,43 @@ def test_case_7(): """ import torch x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]], dtype=torch.float64) - result = torch.quantile(x=x, q=0.3, dim=1, keepdim=True, interpolation='higher') + out = torch.tensor([], dtype=torch.float64) + result = torch.quantile(input=x, q=0.3, dim=1, keepdim=True, interpolation='higher', out=out) """ ) obj.run( pytorch_code, - ["result"], - unsupport=True, - reason="Paddle not support this parameter", + ["result", "out"], + ) + + +# generated by validate_unittest autofix, based on test_case_7 +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]], dtype=torch.float64) + out = torch.tensor([], dtype=torch.float64) + result = torch.quantile(x, 0.3, 1, True) + """ + ) + obj.run( + pytorch_code, + ["result", "out"], + ) + + +# generated by validate_unittest autofix, based on test_case_7 +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]], dtype=torch.float64) + out = torch.tensor([], dtype=torch.float64) + result = torch.quantile(out=out, interpolation='higher', keepdim=True, dim=1, q=0.3, input=x) + """ + ) + obj.run( + pytorch_code, + ["result", "out"], ) From 441f868ce2aab9bd4839eb9918119cb4626ef984 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Wed, 3 Apr 2024 14:58:25 +0800 Subject: [PATCH 2/2] update tests about interpolation arg --- tests/test_Tensor_nanquantile.py | 56 +++++++++++++++++++++++++++++ tests/test_Tensor_quantile.py | 44 +++++++++++++++++++++++ tests/test_nanquantile.py | 48 +++++++++++++++++++++++++ tests/test_quantile.py | 60 ++++++++++++++++++++++++++++++++ 4 files changed, 208 insertions(+) diff --git a/tests/test_Tensor_nanquantile.py b/tests/test_Tensor_nanquantile.py index 56d82033c..ba7e7a389 100644 --- a/tests/test_Tensor_nanquantile.py +++ b/tests/test_Tensor_nanquantile.py @@ -113,3 +113,59 @@ def test_case_8(): pytorch_code, ["result"], ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0]], dtype=torch.float64) + result = x.nanquantile(q=0.3, dim=1, keepdim=True, interpolation='lower') + """ + ) + obj.run( + pytorch_code, + ["result"], + ) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0]], dtype=torch.float64) + result = x.nanquantile(q=0.3, dim=1, keepdim=True, interpolation='nearest') + """ + ) + obj.run( + pytorch_code, + ["result"], + ) + + +def test_case_11(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0]], dtype=torch.float64) + result = x.nanquantile(q=0.3, dim=1, keepdim=True, interpolation='midpoint') + """ + ) + obj.run( + pytorch_code, + ["result"], + ) + + +def test_case_12(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0]], dtype=torch.float64) + result = x.nanquantile(q=0.3, dim=1, keepdim=True, interpolation='linear') + """ + ) + obj.run( + pytorch_code, + ["result"], + ) diff --git a/tests/test_Tensor_quantile.py b/tests/test_Tensor_quantile.py index 636b48ae0..16075fe63 100644 --- a/tests/test_Tensor_quantile.py +++ b/tests/test_Tensor_quantile.py @@ -141,3 +141,47 @@ def test_case_11(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_12(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]], dtype=torch.float64) + result = x.quantile(q=0.6, dim=1, keepdim=False, interpolation='lower') + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_13(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]], dtype=torch.float64) + result = x.quantile(q=0.6, dim=1, keepdim=False, interpolation='higher') + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_14(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]], dtype=torch.float64) + result = x.quantile(q=0.6, dim=1, keepdim=False, interpolation='nearest') + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_15(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]], dtype=torch.float64) + result = x.quantile(q=0.6, dim=1, keepdim=False, interpolation='midpoint') + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nanquantile.py b/tests/test_nanquantile.py index 2e3142475..19926870e 100644 --- a/tests/test_nanquantile.py +++ b/tests/test_nanquantile.py @@ -132,3 +132,51 @@ def test_case_10(): """ ) obj.run(pytorch_code, ["result", "out"]) + + +def test_case_11(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0]], dtype=torch.float64) + out = torch.tensor([[1]], dtype=torch.float64) + result = torch.nanquantile(input=x, q=0.3, dim=1, keepdim=True, interpolation='linear', out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_12(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0]], dtype=torch.float64) + out = torch.tensor([[1]], dtype=torch.float64) + result = torch.nanquantile(input=x, q=0.3, dim=1, keepdim=True, interpolation='lower', out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_13(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0]], dtype=torch.float64) + out = torch.tensor([[1]], dtype=torch.float64) + result = torch.nanquantile(input=x, q=0.3, dim=1, keepdim=True, interpolation='nearest', out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_14(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0]], dtype=torch.float64) + out = torch.tensor([[1]], dtype=torch.float64) + result = torch.nanquantile(input=x, q=0.3, dim=1, keepdim=True, interpolation='midpoint', out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) diff --git a/tests/test_quantile.py b/tests/test_quantile.py index b742d8ed6..5af5a320d 100644 --- a/tests/test_quantile.py +++ b/tests/test_quantile.py @@ -126,3 +126,63 @@ def test_case_9(): pytorch_code, ["result", "out"], ) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]], dtype=torch.float64) + out = torch.tensor([], dtype=torch.float64) + result = torch.quantile(input=x, q=0.3, dim=1, keepdim=True, interpolation='linear', out=out) + """ + ) + obj.run( + pytorch_code, + ["result", "out"], + ) + + +def test_case_11(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]], dtype=torch.float64) + out = torch.tensor([], dtype=torch.float64) + result = torch.quantile(input=x, q=0.3, dim=1, keepdim=True, interpolation='lower', out=out) + """ + ) + obj.run( + pytorch_code, + ["result", "out"], + ) + + +def test_case_12(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]], dtype=torch.float64) + out = torch.tensor([], dtype=torch.float64) + result = torch.quantile(input=x, q=0.3, dim=1, keepdim=True, interpolation='nearest', out=out) + """ + ) + obj.run( + pytorch_code, + ["result", "out"], + ) + + +def test_case_13(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]], dtype=torch.float64) + out = torch.tensor([], dtype=torch.float64) + result = torch.quantile(input=x, q=0.3, dim=1, keepdim=True, interpolation='midpoint', out=out) + """ + ) + obj.run( + pytorch_code, + ["result", "out"], + )