From 3cd58db8af11597cc371ccf24e330b9ab7bece3e Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Fri, 18 Sep 2020 18:16:27 +0800 Subject: [PATCH] Add several op mapping in PyTorch frontend (#6472) * Add copy_ and clamp_ in PyTorch frontend * add true_divide in PyTorch frontend * more test cases for copy_ * fix format * remove copy_ * fix format * skip true_divide for torch < 1.5 --- python/tvm/relay/frontend/pytorch.py | 2 + tests/python/frontend/pytorch/test_forward.py | 47 +++++++++++++++++-- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c9320a9b2882..9ceb9fc66ec4 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2509,6 +2509,7 @@ def _get_convert_map(prelude, default_dtype): "aten::div": _elemwise("divide"), "aten::div_": _elemwise("divide"), "aten::floor_divide": _elemwise("floor_divide"), + "aten::true_divide": _elemwise("divide"), "aten::addcdiv": _addcdiv(), "aten::addcmul": _addcmul(), "aten::ones": _ones(default_dtype), @@ -2630,6 +2631,7 @@ def _get_convert_map(prelude, default_dtype): "aten::isinf": _unary("isinf"), "aten::isnan": _unary("isnan"), "aten::clamp": _clamp(), + "aten::clamp_": _clamp(), "aten::detach": _identity(), "aten::upsample_bilinear2d": _upsample("bilinear", prelude), "aten::upsample_nearest2d": _upsample("nearest_neighbor", prelude), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e8a8507158a3..83ba22b7c1d9 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -21,15 +21,14 @@ from scipy.stats import t as tdistr import numpy as np import torch +import torchvision from torch.nn import Module import tvm -import torchvision - from tvm import relay from tvm.contrib import graph_runtime from tvm.contrib.nvcc import have_fp16 import tvm.testing - +from packaging import version as package_version sys.setrecursionlimit(10000) @@ -2398,6 +2397,24 @@ def forward(self, *args): verify_model(Clamp3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_clamp_(): + torch.set_grad_enabled(False) + + class ClampInPlace(Module): + def __init__(self, min, max): + super(ClampInPlace, self).__init__() + self.min = min + self.max = max + + def forward(self, *args): + return torch.clamp_(args[0], self.min, self.max) + + for ishape, min, max in (([4, 8], 0.1, 0.9), ([7, 6], 0.2, 0.5)): + input_data = torch.rand(ishape).float() + verify_model(ClampInPlace(min, max).float().eval(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_ones(): torch.set_grad_enabled(False) @@ -2895,6 +2912,28 @@ def forward(self, *args): verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2]) +@tvm.testing.uses_gpu +def test_forward_true_divide(): + if package_version.parse(torch.__version__) < package_version.parse("1.5.0"): + return + torch.set_grad_enabled(False) + + class TrueDivide(Module): + def forward(self, *args): + return torch.true_divide(args[0], args[1]) + + dividend = torch.rand([5, 3]).float() + # divisor could be either tensor or scalar + divisor_tensor = torch.rand([5, 3]).float() + 0.5 + divisor_scalar = torch.tensor(1.0, dtype=torch.float32) + verify_model( + TrueDivide().float().eval(), input_data=[dividend, divisor_tensor], atol=1e-4, rtol=1e-4 + ) + verify_model( + TrueDivide().float().eval(), input_data=[dividend, divisor_scalar], atol=1e-4, rtol=1e-4 + ) + + @tvm.testing.uses_gpu def test_forward_traced_function(): def fn(t1, t2): @@ -3308,6 +3347,7 @@ def test_forward_pretrained_bert_base_uncased(): test_forward_where() test_forward_addcdiv() test_forward_addcmul() + test_forward_true_divide() test_forward_clone() test_forward_softplus() test_forward_softsign() @@ -3323,6 +3363,7 @@ def test_forward_pretrained_bert_base_uncased(): test_forward_pow() test_forward_unary() test_forward_clamp() + test_forward_clamp_() test_forward_logical_not() test_forward_bitwise_not() test_forward_bitwise_xor()