Skip to content

Commit

Permalink
[Torch] Upsampling op support and enable registering a user defined o…
Browse files Browse the repository at this point in the history
…p conversion map (apache#4961)

* add custom conversion map

* add roi align test using custom convert map

* refactor test

* add support for upsampling op and test on segmentation models

* remove redundant no_grad

* add upsampling test case

* make the default custom map None, instead of empty dict

* updated tests, remove packaging and drop PT 1.2 support

* add better support for aten::to and tests

* add a note on dilation in x86
  • Loading branch information
masahi authored and Trevor Morris committed Apr 16, 2020
1 parent 411dcf2 commit 51f3d5b
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 142 deletions.
80 changes: 72 additions & 8 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
"""PT: PyTorch frontend."""
import itertools
from packaging import version

import numpy as np

Expand All @@ -31,6 +30,7 @@
from .. import op as _op
from .common import get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value

__all__ = ["from_pytorch"]

Expand Down Expand Up @@ -614,6 +614,61 @@ def _impl(inputs, input_types):
return _op.tensor.sqrt(data)
return _impl

def _floor():
def _impl(inputs, input_types):
data = inputs[0]
return _op.floor(data)
return _impl

def _to():
def _impl(inputs, input_types):
data = inputs[0]
if inputs[3] in ["cpu", "cuda"]:
return data
# special handling for aten::to(data, 6, _, _, _) case
# 6 means dtype = float
# this happens when converting upsampling with scale factor
cast_func = {
6: float,
3: int,
}
cast_func_expr = {
6: lambda x: _op.cast(x, "float32"),
3: lambda x: _op.cast(x, "int32"),
}
if inputs[1] in cast_func and not isinstance(data, _expr.Expr):
return cast_func[inputs[1]](data)
elif inputs[1] in cast_func and isinstance(data, _expr.Expr):
return cast_func_expr[inputs[1]](data)
return data

return _impl

def _upsample(method):
def _impl(inputs, input_types):
if isinstance(inputs[1], _expr.Var):
out_size = _infer_shape(inputs[1])
elif isinstance(inputs[1], list):
infer_res = [_infer_value(size, {}) for size in inputs[1]]
out_size = [np.asscalar(res.asnumpy().astype(np.int))
for res in infer_res]

data = inputs[0]

if len(inputs) > 2:
align_corners = inputs[2]
else:
align_corners = False

if align_corners:
coord_trans = "align_corners"
else:
coord_trans = "half_pixel"

return _op.image.resize(data, out_size, "NCHW", method, coord_trans)

return _impl

# Helper functions for operator implementation

def _convert_data_type(input_type):
Expand Down Expand Up @@ -686,7 +741,7 @@ def _convert_elemwise_input(data, input_type):
"aten::div_" : _elemwise("divide"),
"aten::ones" : _ones(),
"aten::zeros" : _zeros(),
"aten::to" : _identity(),
"aten::to" : _to(),
"aten::unsqueeze" : _unsqueeze(),
"aten::cat" : _concatenate(),
"aten::slice" : _slice(),
Expand Down Expand Up @@ -729,15 +784,18 @@ def _convert_elemwise_input(data, input_type):
"aten::permute" : _transpose(),
"aten::sum" : _reduce("sum"),
"aten::prod" : _reduce("prod"),
"aten::sqrt" : _sqrt()
"aten::sqrt" : _sqrt(),
'aten::floor' : _floor(),
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
}


def _run_jit_passes(graph):
""" The inline pass is necessary to unwrap prim::CallMethod """
import torch
if version.parse(torch.__version__) >= version.parse("1.4.0"):
torch._C._jit_pass_inline(graph)
torch._C._jit_pass_inline(graph)


def _is_int_seq(seq):
Expand Down Expand Up @@ -985,8 +1043,7 @@ def parse_operators(operators, outputs, output_index_map, ret_name):

def get_all_op_names(graph):
""" Return all operator names in the input graph """
nodes = list(graph.nodes())
return set(node.kind() for node in nodes)
return set(node.kind() for node in graph.nodes())


def get_graph_input_names(script_module):
Expand All @@ -997,7 +1054,7 @@ def get_graph_input_names(script_module):
return ir_inputs[1:] # remove self at the 0th arg


def from_pytorch(script_module, input_shapes):
def from_pytorch(script_module, input_shapes, custom_convert_map=None):
""" Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.
Expand All @@ -1011,6 +1068,9 @@ def from_pytorch(script_module, input_shapes):
Graph level input shape dictionary
The keys should be the same one returned by get_graph_input_names(...) above
custom_convert_map: Dictionary of str to Relay op
A custom op conversion map in the same format as _convert_map above
Returns
-------
mod : tvm.relay.Module
Expand All @@ -1021,6 +1081,10 @@ def from_pytorch(script_module, input_shapes):
"""
graph = script_module.graph.copy()
_run_jit_passes(graph)

if custom_convert_map:
_convert_map.update(custom_convert_map)

op_names = get_all_op_names(graph)
_report_missing_conversion(op_names)

Expand Down
Loading

0 comments on commit 51f3d5b

Please sign in to comment.