Skip to content

Commit

Permalink
Add several op mapping in PyTorch frontend (#6472)
Browse files Browse the repository at this point in the history
* Add copy_ and clamp_ in PyTorch frontend

* add true_divide in PyTorch frontend

* more test cases for copy_

* fix format

* remove copy_

* fix format

* skip true_divide for torch < 1.5
  • Loading branch information
yongwww authored Sep 18, 2020
1 parent 7aed468 commit 28ea54a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
2 changes: 2 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2509,6 +2509,7 @@ def _get_convert_map(prelude, default_dtype):
"aten::div": _elemwise("divide"),
"aten::div_": _elemwise("divide"),
"aten::floor_divide": _elemwise("floor_divide"),
"aten::true_divide": _elemwise("divide"),
"aten::addcdiv": _addcdiv(),
"aten::addcmul": _addcmul(),
"aten::ones": _ones(default_dtype),
Expand Down Expand Up @@ -2630,6 +2631,7 @@ def _get_convert_map(prelude, default_dtype):
"aten::isinf": _unary("isinf"),
"aten::isnan": _unary("isnan"),
"aten::clamp": _clamp(),
"aten::clamp_": _clamp(),
"aten::detach": _identity(),
"aten::upsample_bilinear2d": _upsample("bilinear", prelude),
"aten::upsample_nearest2d": _upsample("nearest_neighbor", prelude),
Expand Down
47 changes: 44 additions & 3 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
from scipy.stats import t as tdistr
import numpy as np
import torch
import torchvision
from torch.nn import Module
import tvm
import torchvision

from tvm import relay
from tvm.contrib import graph_runtime
from tvm.contrib.nvcc import have_fp16
import tvm.testing

from packaging import version as package_version

sys.setrecursionlimit(10000)

Expand Down Expand Up @@ -2398,6 +2397,24 @@ def forward(self, *args):
verify_model(Clamp3().float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_clamp_():
torch.set_grad_enabled(False)

class ClampInPlace(Module):
def __init__(self, min, max):
super(ClampInPlace, self).__init__()
self.min = min
self.max = max

def forward(self, *args):
return torch.clamp_(args[0], self.min, self.max)

for ishape, min, max in (([4, 8], 0.1, 0.9), ([7, 6], 0.2, 0.5)):
input_data = torch.rand(ishape).float()
verify_model(ClampInPlace(min, max).float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_ones():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -2895,6 +2912,28 @@ def forward(self, *args):
verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2])


@tvm.testing.uses_gpu
def test_forward_true_divide():
if package_version.parse(torch.__version__) < package_version.parse("1.5.0"):
return
torch.set_grad_enabled(False)

class TrueDivide(Module):
def forward(self, *args):
return torch.true_divide(args[0], args[1])

dividend = torch.rand([5, 3]).float()
# divisor could be either tensor or scalar
divisor_tensor = torch.rand([5, 3]).float() + 0.5
divisor_scalar = torch.tensor(1.0, dtype=torch.float32)
verify_model(
TrueDivide().float().eval(), input_data=[dividend, divisor_tensor], atol=1e-4, rtol=1e-4
)
verify_model(
TrueDivide().float().eval(), input_data=[dividend, divisor_scalar], atol=1e-4, rtol=1e-4
)


@tvm.testing.uses_gpu
def test_forward_traced_function():
def fn(t1, t2):
Expand Down Expand Up @@ -3308,6 +3347,7 @@ def test_forward_pretrained_bert_base_uncased():
test_forward_where()
test_forward_addcdiv()
test_forward_addcmul()
test_forward_true_divide()
test_forward_clone()
test_forward_softplus()
test_forward_softsign()
Expand All @@ -3323,6 +3363,7 @@ def test_forward_pretrained_bert_base_uncased():
test_forward_pow()
test_forward_unary()
test_forward_clamp()
test_forward_clamp_()
test_forward_logical_not()
test_forward_bitwise_not()
test_forward_bitwise_xor()
Expand Down

0 comments on commit 28ea54a

Please sign in to comment.