Skip to content

Commit

Permalink
updated tests, remove packaging and drop PT 1.2 support
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 28, 2020
1 parent 0455c87 commit d610392
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 31 deletions.
4 changes: 1 addition & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
"""PT: PyTorch frontend."""
import itertools
from packaging import version

import numpy as np

Expand Down Expand Up @@ -786,8 +785,7 @@ def _convert_elemwise_input(data, input_type):
def _run_jit_passes(graph):
""" The inline pass is necessary to unwrap prim::CallMethod """
import torch
if version.parse(torch.__version__) >= version.parse("1.4.0"):
torch._C._jit_pass_inline(graph)
torch._C._jit_pass_inline(graph)


def _is_int_seq(seq):
Expand Down
36 changes: 9 additions & 27 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,12 @@
# pylint: disable=import-self, invalid-name, unused-argument
"""Unit tests for various models and operators"""
from time import time
import os
import sys
from tempfile import TemporaryDirectory
from scipy.stats import t as tdistr
import numpy as np
import torch
from torch import nn
from torch.nn import Module
import tvm
from tvm import te
import torchvision

from tvm import relay
Expand All @@ -37,22 +33,6 @@

sys.setrecursionlimit(10000)

def _vectorize(ten):
return ten.reshape(-1)

def atol(tru, est):
def _atol_elt(tru, est):
return abs(tru - est)
tru = _vectorize(tru)
est = _vectorize(est)
return max([_atol_elt(x, y) for x, y in zip(tru, est)])

def rtol(tru, est):
def _rtol_elt(tru, est):
return abs(tru - est) / min(abs(tru), abs(est))
tru = _vectorize(tru)
est = _vectorize(est)
return max([_rtol_elt(x, y) for x, y in zip(tru, est)])

def assert_shapes_match(tru, est):
if tru.shape != est.shape:
Expand Down Expand Up @@ -117,7 +97,7 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
latencies = []
count = 0
while True:
if isinstance(model, torch.nn.Module):
if isinstance(model, Module):
input_data = [torch.rand(shape).float() for shape in input_shapes]
if torch.cuda.is_available():
input_data = list(map(lambda x: x.cuda(), input_data))
Expand Down Expand Up @@ -670,7 +650,7 @@ def forward(self, *args):
verify_model(Chunk1().float().eval(), input_data=input_data)

def test_upsample():
class Upsample(nn.Module):
class Upsample(Module):
def __init__(self, size=None, scale=None,
mode="nearest", align_corners=None):
super().__init__()
Expand All @@ -680,15 +660,17 @@ def __init__(self, size=None, scale=None,
self.align_corners = align_corners

def forward(self, x):
return nn.functional.interpolate(x, size=self.size,
scale_factor=self.scale,
mode=self.mode,
align_corners=self.align_corners)
return torch.nn.functional.interpolate(x, size=self.size,
scale_factor=self.scale,
mode=self.mode,
align_corners=self.align_corners)
inp = torch.rand((1, 3, 32, 32))
verify_model(Upsample(size=(64, 64), mode="nearest"), inp)
verify_model(Upsample(scale=2, mode="nearest"), inp)
verify_model(Upsample(size=(50, 50), mode="nearest"), inp)
verify_model(Upsample(size=(64, 64), mode="bilinear", align_corners=True), inp)
verify_model(Upsample(scale=2, mode="bilinear", align_corners=True), inp)
verify_model(Upsample(size=(50, 50), mode="bilinear", align_corners=True), inp)

# Model tests
def test_resnet18():
Expand Down Expand Up @@ -769,7 +751,7 @@ def _impl(inputs, input_types):


def test_segmentaton_models():
class SegmentationModelWrapper(torch.nn.Module):
class SegmentationModelWrapper(Module):
def __init__(self, model):
super().__init__()
self.model = model
Expand Down
2 changes: 1 addition & 1 deletion tutorials/frontend/from_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
PyTorch versions should be backwards compatible but should be used
with the proper TorchVision version.
Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may
Currently, TVM supports PyTorch 1.4 and 1.3. Other versions may
be unstable.
"""

Expand Down

0 comments on commit d610392

Please sign in to comment.