diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index fd66e3c1f367..b256faa5d6f9 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -19,7 +19,6 @@ # pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension """PT: PyTorch frontend.""" import itertools -from packaging import version import numpy as np @@ -31,6 +30,7 @@ from .. import op as _op from .common import get_relay_op from .common import infer_shape as _infer_shape +from .common import infer_value as _infer_value __all__ = ["from_pytorch"] @@ -614,6 +614,61 @@ def _impl(inputs, input_types): return _op.tensor.sqrt(data) return _impl +def _floor(): + def _impl(inputs, input_types): + data = inputs[0] + return _op.floor(data) + return _impl + +def _to(): + def _impl(inputs, input_types): + data = inputs[0] + if inputs[3] in ["cpu", "cuda"]: + return data + # special handling for aten::to(data, 6, _, _, _) case + # 6 means dtype = float + # this happens when converting upsampling with scale factor + cast_func = { + 6: float, + 3: int, + } + cast_func_expr = { + 6: lambda x: _op.cast(x, "float32"), + 3: lambda x: _op.cast(x, "int32"), + } + if inputs[1] in cast_func and not isinstance(data, _expr.Expr): + return cast_func[inputs[1]](data) + elif inputs[1] in cast_func and isinstance(data, _expr.Expr): + return cast_func_expr[inputs[1]](data) + return data + + return _impl + +def _upsample(method): + def _impl(inputs, input_types): + if isinstance(inputs[1], _expr.Var): + out_size = _infer_shape(inputs[1]) + elif isinstance(inputs[1], list): + infer_res = [_infer_value(size, {}) for size in inputs[1]] + out_size = [np.asscalar(res.asnumpy().astype(np.int)) + for res in infer_res] + + data = inputs[0] + + if len(inputs) > 2: + align_corners = inputs[2] + else: + align_corners = False + + if align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" + + return _op.image.resize(data, out_size, "NCHW", method, coord_trans) + + return _impl + # Helper functions for operator implementation def _convert_data_type(input_type): @@ -686,7 +741,7 @@ def _convert_elemwise_input(data, input_type): "aten::div_" : _elemwise("divide"), "aten::ones" : _ones(), "aten::zeros" : _zeros(), - "aten::to" : _identity(), + "aten::to" : _to(), "aten::unsqueeze" : _unsqueeze(), "aten::cat" : _concatenate(), "aten::slice" : _slice(), @@ -729,15 +784,18 @@ def _convert_elemwise_input(data, input_type): "aten::permute" : _transpose(), "aten::sum" : _reduce("sum"), "aten::prod" : _reduce("prod"), - "aten::sqrt" : _sqrt() + "aten::sqrt" : _sqrt(), + 'aten::floor' : _floor(), + "aten::detach" : _identity(), + "aten::upsample_bilinear2d" : _upsample("bilinear"), + "aten::upsample_nearest2d" : _upsample("nearest_neighbor"), } def _run_jit_passes(graph): """ The inline pass is necessary to unwrap prim::CallMethod """ import torch - if version.parse(torch.__version__) >= version.parse("1.4.0"): - torch._C._jit_pass_inline(graph) + torch._C._jit_pass_inline(graph) def _is_int_seq(seq): @@ -985,8 +1043,7 @@ def parse_operators(operators, outputs, output_index_map, ret_name): def get_all_op_names(graph): """ Return all operator names in the input graph """ - nodes = list(graph.nodes()) - return set(node.kind() for node in nodes) + return set(node.kind() for node in graph.nodes()) def get_graph_input_names(script_module): @@ -997,7 +1054,7 @@ def get_graph_input_names(script_module): return ir_inputs[1:] # remove self at the 0th arg -def from_pytorch(script_module, input_shapes): +def from_pytorch(script_module, input_shapes, custom_convert_map=None): """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -1011,6 +1068,9 @@ def from_pytorch(script_module, input_shapes): Graph level input shape dictionary The keys should be the same one returned by get_graph_input_names(...) above + custom_convert_map: Dictionary of str to Relay op + A custom op conversion map in the same format as _convert_map above + Returns ------- mod : tvm.relay.Module @@ -1021,6 +1081,10 @@ def from_pytorch(script_module, input_shapes): """ graph = script_module.graph.copy() _run_jit_passes(graph) + + if custom_convert_map: + _convert_map.update(custom_convert_map) + op_names = get_all_op_names(graph) _report_missing_conversion(op_names) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 831389b7ebf5..c2ff94de546f 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -17,15 +17,12 @@ # pylint: disable=import-self, invalid-name, unused-argument """Unit tests for various models and operators""" from time import time -import os import sys -from tempfile import TemporaryDirectory from scipy.stats import t as tdistr import numpy as np import torch from torch.nn import Module import tvm -from tvm import te import torchvision from tvm import relay @@ -36,22 +33,6 @@ sys.setrecursionlimit(10000) -def _vectorize(ten): - return ten.reshape(-1) - -def atol(tru, est): - def _atol_elt(tru, est): - return abs(tru - est) - tru = _vectorize(tru) - est = _vectorize(est) - return max([_atol_elt(x, y) for x, y in zip(tru, est)]) - -def rtol(tru, est): - def _rtol_elt(tru, est): - return abs(tru - est) / min(abs(tru), abs(est)) - tru = _vectorize(tru) - est = _vectorize(est) - return max([_rtol_elt(x, y) for x, y in zip(tru, est)]) def assert_shapes_match(tru, est): if tru.shape != est.shape: @@ -77,7 +58,7 @@ def load_torchvision(model_name): input_data[:, channel] /= std[channel] model = getattr(torchvision.models, model_name)(pretrained=True) model = model.float().eval() - return model, input_data + return model, [input_data] def load_pretrainedmodels(model_name): """Given a model name, returns a pretrainedmodels.pytorch model in eval @@ -89,7 +70,7 @@ def load_pretrainedmodels(model_name): for channel in range(3): input_data[:, channel] -= model.mean[channel] input_data[:, channel] /= model.std[channel] - return model, input_data + return model, [input_data] def load_model(model_name): """Given a model name, returns a model as well as an example input.""" @@ -116,7 +97,7 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40): latencies = [] count = 0 while True: - if isinstance(model, torch.nn.Module): + if isinstance(model, Module): input_data = [torch.rand(shape).float() for shape in input_shapes] if torch.cuda.is_available(): input_data = list(map(lambda x: x.cuda(), input_data)) @@ -153,23 +134,34 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40): if err < thresh: return est -def verify_model(model_name, input_data=[]): +def verify_model(model_name, input_data=[], + custom_convert_map={}, + ctx_list=ctx_list()): """Assert that the output of a compiled model matches with that of its baseline.""" - if len(input_data) == 0: + if isinstance(model_name, str): baseline_model, baseline_input = load_model(model_name) - else: + elif isinstance(input_data, list): baseline_model = model_name baseline_input = input_data + elif isinstance(input_data, torch.Tensor) or len(input_data.shape) == 0: + baseline_model = model_name + baseline_input = [input_data] + else: + assert False, "Unexpected input format" + if torch.cuda.is_available(): baseline_model = baseline_model.cuda() - baseline_input = baseline_input.cuda() + baseline_input = [inp.cuda() for inp in baseline_input] + with torch.no_grad(): - baseline_outputs = baseline_model(baseline_input) + baseline_outputs = baseline_model(*baseline_input) + if isinstance(baseline_outputs, tuple): baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) else: baseline_outputs = (baseline_outputs.float().cpu().numpy(),) + trace = torch.jit.trace(baseline_model, baseline_input).float().eval() if torch.cuda.is_available(): @@ -177,17 +169,21 @@ def verify_model(model_name, input_data=[]): else: trace = trace.cpu() - input_name = get_graph_input_names(trace)[0] # only one input - input_shapes = {input_name: list(baseline_input.shape)} - mod, params = relay.frontend.from_pytorch(trace, input_shapes) - compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())} + input_names = get_graph_input_names(trace) + input_shapes = dict(zip(input_names, + [inp.shape for inp in baseline_input])) + mod, params = relay.frontend.from_pytorch(trace, input_shapes, + custom_convert_map) + compiled_input = dict(zip(input_names, + [inp.cpu().numpy() for inp in baseline_input])) with relay.build_config(opt_level=3): - for target, ctx in ctx_list(): + for target, ctx in ctx_list: relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params) relay_model = graph_runtime.create(relay_graph, relay_lib, ctx) relay_model.set_input(**relay_params) - relay_model.set_input(**compiled_input) + for name, inp in compiled_input.items(): + relay_model.set_input(name, inp) relay_model.run() for i, baseline_output in enumerate(baseline_outputs): @@ -228,12 +224,11 @@ def forward(self, *args): ones = ones.cuda() return args[0] + ones - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Add1().float().eval(), input_data=input_data) - verify_model(Add2().float().eval(), input_data=input_data) - verify_model(Add3().float().eval(), input_data=input_data) - verify_model(Add4().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Add1().float().eval(), input_data=input_data) + verify_model(Add2().float().eval(), input_data=input_data) + verify_model(Add3().float().eval(), input_data=input_data) + verify_model(Add4().float().eval(), input_data=input_data) def test_forward_subtract(): torch.set_grad_enabled(False) @@ -261,12 +256,11 @@ def forward(self, *args): ones = ones.cuda() return args[0] - ones - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Subtract1().float().eval(), input_data=input_data) - verify_model(Subtract2().float().eval(), input_data=input_data) - verify_model(Subtract3().float().eval(), input_data=input_data) - verify_model(Subtract4().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Subtract1().float().eval(), input_data=input_data) + verify_model(Subtract2().float().eval(), input_data=input_data) + verify_model(Subtract3().float().eval(), input_data=input_data) + verify_model(Subtract4().float().eval(), input_data=input_data) def test_forward_multiply(): torch.set_grad_enabled(False) @@ -294,12 +288,11 @@ def forward(self, *args): ones = ones.cuda() return args[0] * ones - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Multiply1().float().eval(), input_data=input_data) - verify_model(Multiply2().float().eval(), input_data=input_data) - verify_model(Multiply3().float().eval(), input_data=input_data) - verify_model(Multiply4().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Multiply1().float().eval(), input_data=input_data) + verify_model(Multiply2().float().eval(), input_data=input_data) + verify_model(Multiply3().float().eval(), input_data=input_data) + verify_model(Multiply4().float().eval(), input_data=input_data) def test_forward_unsqueeze(): torch.set_grad_enabled(False) @@ -327,10 +320,9 @@ def forward(self, *args): c = (args[0][:, :, 2] + 5) * 13 return torch.cat([t.unsqueeze(2) for t in [a, b, c]], 2) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Concatenate1().float().eval(), input_data=input_data) - verify_model(Concatenate2().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Concatenate1().float().eval(), input_data=input_data) + verify_model(Concatenate2().float().eval(), input_data=input_data) def test_forward_relu(): torch.set_grad_enabled(False) @@ -340,9 +332,8 @@ class ReLU1(Module): def forward(self, *args): return torch.nn.ReLU()(args[0]) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(ReLU1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(ReLU1().float().eval(), input_data=input_data) def test_forward_adaptiveavgpool(): torch.set_grad_enabled(False) @@ -356,10 +347,9 @@ class AdaptiveAvgPool2D2(Module): def forward(self, *args): return torch.nn.AdaptiveAvgPool2d([10, 10])(args[0]) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(AdaptiveAvgPool2D1().float().eval(), input_data=input_data) - verify_model(AdaptiveAvgPool2D2().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(AdaptiveAvgPool2D1().float().eval(), input_data=input_data) + verify_model(AdaptiveAvgPool2D2().float().eval(), input_data=input_data) def test_forward_maxpool(): torch.set_grad_enabled(False) @@ -373,10 +363,9 @@ class MaxPool2D2(Module): def forward(self, *args): return torch.nn.MaxPool2d(kernel_size=[10, 10])(args[0]) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(MaxPool2D1().float().eval(), input_data=input_data) - verify_model(MaxPool2D2().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(MaxPool2D1().float().eval(), input_data=input_data) + verify_model(MaxPool2D2().float().eval(), input_data=input_data) def test_forward_avgpool(): torch.set_grad_enabled(False) @@ -386,9 +375,8 @@ class AvgPool2D1(Module): def forward(self, *args): return torch.nn.AvgPool2d(kernel_size=[10, 10])(args[0]) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(AvgPool2D1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(AvgPool2D1().float().eval(), input_data=input_data) def test_forward_hardtanh(): torch.set_grad_enabled(False) @@ -398,9 +386,8 @@ class HardTanh1(Module): def forward(self, *args): return torch.nn.Hardtanh()(args[0]) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(HardTanh1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(HardTanh1().float().eval(), input_data=input_data) def test_forward_conv(): torch.set_grad_enabled(False) @@ -433,11 +420,10 @@ def __init__(self): def forward(self, *args): return self.softmax(self.conv(args[0])) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Conv2D1().float().eval(), input_data=input_data) - verify_model(Conv2D2().float().eval(), input_data=input_data) - verify_model(Conv2D3().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Conv2D1().float().eval(), input_data=input_data) + verify_model(Conv2D2().float().eval(), input_data=input_data) + verify_model(Conv2D3().float().eval(), input_data=input_data) def test_forward_threshold(): torch.set_grad_enabled(False) @@ -447,9 +433,8 @@ class Threshold1(Module): def forward(self, *args): return torch.nn.Threshold(0, 0)(args[0]) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Threshold1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Threshold1().float().eval(), input_data=input_data) def test_forward_contiguous(): torch.set_grad_enabled(False) @@ -459,9 +444,8 @@ class Contiguous1(Module): def forward(self, *args): return args[0].contiguous() - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Contiguous1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Contiguous1().float().eval(), input_data=input_data) def test_forward_batchnorm(): torch.set_grad_enabled(False) @@ -481,10 +465,9 @@ def __init__(self): def forward(self, *args): return self.batch_norm(args[0]) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(BatchNorm1().float().eval(), input_data=input_data) - verify_model(BatchNorm2().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(BatchNorm1().float().eval(), input_data=input_data) + verify_model(BatchNorm2().float().eval(), input_data=input_data) def test_forward_transpose(): torch.set_grad_enabled(False) @@ -498,10 +481,9 @@ class Transpose2(Module): def forward(self, *args): return args[0].transpose(-2, -1) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Transpose1().float().eval(), input_data=input_data) - verify_model(Transpose2().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Transpose1().float().eval(), input_data=input_data) + verify_model(Transpose2().float().eval(), input_data=input_data) def test_forward_size(): torch.set_grad_enabled(False) @@ -511,9 +493,8 @@ class Size1(Module): def forward(self, *args): return float(args[0].size(0)) * args[0] - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Size1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Size1().float().eval(), input_data=input_data) def test_forward_view(): torch.set_grad_enabled(False) @@ -527,10 +508,9 @@ class View2(Module): def forward(self, *args): return args[0].view(args[0].shape[0], -1) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(View1().float().eval(), input_data=input_data) - verify_model(View2().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(View1().float().eval(), input_data=input_data) + verify_model(View2().float().eval(), input_data=input_data) def test_forward_select(): torch.set_grad_enabled(False) @@ -540,9 +520,8 @@ class Select1(Module): def forward(self, *args): return args[0].select(1, 1) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Select1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Select1().float().eval(), input_data=input_data) def test_forward_clone(): torch.set_grad_enabled(False) @@ -552,9 +531,8 @@ class Clone1(Module): def forward(self, *args): return args[0].clone() - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Clone1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Clone1().float().eval(), input_data=input_data) def test_forward_logsoftmax(): torch.set_grad_enabled(False) @@ -564,9 +542,8 @@ class LogSoftmax1(Module): def forward(self, *args): return torch.nn.LogSoftmax(dim=1)(args[0][0, 0]) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(LogSoftmax1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(LogSoftmax1().float().eval(), input_data=input_data) def test_forward_sigmoid(): torch.set_grad_enabled(False) @@ -576,9 +553,8 @@ class Sigmoid1(Module): def forward(self, *args): return torch.nn.Sigmoid()(args[0]) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Sigmoid1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Sigmoid1().float().eval(), input_data=input_data) def test_forward_dense(): torch.set_grad_enabled(False) @@ -598,10 +574,9 @@ def __init__(self): def forward(self, *args): return self.linear(args[0][0, 0]) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Dense1().float().eval(), input_data=input_data) - verify_model(Dense2().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Dense1().float().eval(), input_data=input_data) + verify_model(Dense2().float().eval(), input_data=input_data) def test_forward_dropout(): torch.set_grad_enabled(False) @@ -611,9 +586,8 @@ class Dropout1(Module): def forward(self, *args): return torch.nn.functional.dropout(args[0][0, 0], 0.5, False) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Dropout1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Dropout1().float().eval(), input_data=input_data) def test_forward_slice(): torch.set_grad_enabled(False) @@ -627,10 +601,9 @@ class Slice2(Module): def forward(self, *args): return args[0][0, :, :, :] - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Slice1().float().eval(), input_data=input_data) - verify_model(Slice2().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Slice1().float().eval(), input_data=input_data) + verify_model(Slice2().float().eval(), input_data=input_data) def test_forward_mean(): torch.set_grad_enabled(False) @@ -640,9 +613,8 @@ class Mean1(Module): def forward(self, *args): return args[0].mean(2) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Mean1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Mean1().float().eval(), input_data=input_data) def test_forward_expand(): torch.set_grad_enabled(False) @@ -652,9 +624,8 @@ class Expand1(Module): def forward(self, *args): return args[0].expand((3, -1, -1, -1)) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Expand1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Expand1().float().eval(), input_data=input_data) def test_forward_pow(): torch.set_grad_enabled(False) @@ -664,9 +635,8 @@ class Pow1(Module): def forward(self, *args): return args[0] ** 2 - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Pow1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Pow1().float().eval(), input_data=input_data) def test_forward_chunk(): torch.set_grad_enabled(False) @@ -677,9 +647,61 @@ def forward(self, *args): chunks = args[0].chunk(7, 2) return torch.cat(chunks, 2) - with torch.no_grad(): - input_data = torch.rand(input_shape).float() - verify_model(Chunk1().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).float() + verify_model(Chunk1().float().eval(), input_data=input_data) + +def test_upsample(): + class Upsample(Module): + def __init__(self, size=None, scale=None, + mode="nearest", align_corners=None): + super().__init__() + self.size = size + self.scale = scale + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + return torch.nn.functional.interpolate(x, size=self.size, + scale_factor=self.scale, + mode=self.mode, + align_corners=self.align_corners) + inp = torch.rand((1, 3, 32, 32)) + verify_model(Upsample(size=(64, 64), mode="nearest"), inp) + verify_model(Upsample(scale=2, mode="nearest"), inp) + verify_model(Upsample(size=(50, 50), mode="nearest"), inp) + verify_model(Upsample(size=(64, 64), mode="bilinear", align_corners=True), inp) + verify_model(Upsample(scale=2, mode="bilinear", align_corners=True), inp) + verify_model(Upsample(size=(50, 50), mode="bilinear", align_corners=True), inp) + +def test_to(): + """ test for aten::to(...) """ + class ToCPU(Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.to("cpu") + + class ToFloat(Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.float() + + class ToInt(Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.int() + + verify_model(ToCPU().eval(), torch.rand((1, 3, 32, 32))) + verify_model(ToFloat().eval(), torch.zeros((1, 3, 32, 32), dtype=torch.int)) + verify_model(ToFloat().eval(), torch.tensor(2, dtype=torch.int)) + verify_model(ToInt().eval(), torch.zeros((1, 3, 32, 32))) + verify_model(ToInt().eval(), torch.tensor(2.0)) + # Model tests def test_resnet18(): @@ -730,6 +752,57 @@ def test_vgg11_bn(): """ +def test_custom_conversion_map(): + def get_roi_align(): + pool_size = 5 + n_channels = 2 * (pool_size ** 2) + x = torch.rand(2, n_channels, 10, 10) + rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9], + [1, 0, 0, 9, 9]], dtype=torch.float) + roi_align = torchvision.ops.RoIAlign(pool_size, spatial_scale=1, + sampling_ratio=-1) + return roi_align.eval(), [x, rois] + + def convert_roi_align(): + def _impl(inputs, input_types): + spatial_scale = inputs[2] + pooled_size = (inputs[3], inputs[4]) + sampling_ratio = inputs[5] + return relay.op.vision.roi_align(inputs[0], inputs[1], + pooled_size, spatial_scale, + sampling_ratio) + return _impl + + custom_map = {'torchvision::roi_align': convert_roi_align()} + model, inputs = get_roi_align() + + verify_model(model, inputs, custom_map) + + +def test_segmentaton_models(): + class SegmentationModelWrapper(Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return out["out"] + + fcn = torchvision.models.segmentation.fcn_resnet101(pretrained=True) + deeplab = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True) + + inp = [torch.rand((1, 3, 300, 300), dtype=torch.float)] + + for model in [fcn, deeplab]: + # depthwise + dilated covolution not supported on x86 + # see https://github.com/apache/incubator-tvm/issues/4962 + verify_model(SegmentationModelWrapper(model.eval()), inp, + ctx_list=[("cuda", tvm.gpu(0))]) + + if __name__ == "__main__": # Single operator tests test_forward_add() @@ -760,6 +833,8 @@ def test_vgg11_bn(): test_forward_expand() test_forward_pow() test_forward_chunk() + test_upsample() + test_to() # Model tests test_resnet18() @@ -770,3 +845,7 @@ def test_vgg11_bn(): test_googlenet() test_mnasnet0_5() test_mobilenet_v2() + + test_custom_conversion_map() + + test_segmentaton_models() diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index 503f64a4e7d9..1c568ceb3ef5 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -37,7 +37,7 @@ PyTorch versions should be backwards compatible but should be used with the proper TorchVision version. -Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may +Currently, TVM supports PyTorch 1.4 and 1.3. Other versions may be unstable. """