From f3651c846a7b7e04bf981d088e9edb6aaae3e750 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Thu, 10 Sep 2020 11:58:36 -0700 Subject: [PATCH 01/23] Improve Pytorch Frontend --- python/tvm/relay/frontend/pytorch.py | 501 +++++++++++++----- tests/python/frontend/pytorch/test_forward.py | 2 + 2 files changed, 381 insertions(+), 122 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 2eff4153592d..0f5350b759c3 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -25,6 +25,7 @@ import numpy as np import tvm +from tvm.topi.util import get_const_tuple from .. import analysis as _analysis from .. import expr as _expr @@ -129,7 +130,6 @@ def _elemwise(name): def _impl(inputs, input_types): data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) return get_relay_op(name)(data0, data1) - return _impl @@ -184,8 +184,14 @@ def _impl(inputs, input_types): def _get_value(val, dtype): # dtype is a tvm dtype if isinstance(val, _expr.Expr): - return _op.cast(val, dtype) - return _create_typed_const(val, dtype) + try: + ret = _infer_value(_op.cast(val, dtype), {}).asnumpy() + ret = _expr.const(ret, dtype) + except Exception: + ret = _op.cast(val, dtype) + else: + ret = _create_typed_const(val, dtype) + return ret def _get_type(val, inp_type): if isinstance(val, _expr.Expr): @@ -282,38 +288,92 @@ def _impl(inputs, input_types): def _slice(): def _impl(inputs, input_types): + index_size_limit = 2**63 - 1 data = inputs[0] - strides = [] - - if isinstance(data, _expr.Expr): - inferred_shape = _infer_shape(data) - end = [] - for infer in inferred_shape: - end.append(int(infer)) - if isinstance(data, _expr.Var): - end = inferred_shape - end = list(end) - else: - end = data.shape + dshape = _infer_shape(data) + ndim = len(dshape) + end = [] + for dim in dshape: + if isinstance(dim, tvm.tir.Any): + end = _op.shape_of(data) + break + else: + end.append(int(dim)) - begin = [0] * len(end) + begin = [0] * ndim dim = int(inputs[1]) + stride = int(inputs[4]) if isinstance(inputs[2], _expr.Call): - begin[dim] = np.asscalar(_infer_value(inputs[2], {}).asnumpy().astype(np.int)) + try: + begin[dim] = np.asscalar(_infer_value(inputs[2], {}).asnumpy().astype(np.int)) + except Exception: + begin[dim] = inputs[2] else: begin[dim] = int(inputs[2]) + # Process begin + if not isinstance(begin[dim], int): + tmp = [] + for b in begin: + if isinstance(b, int): + tmp.append(_op.expand_dims(_expr.const(b, "int64"), axis=0)) + else: + tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64")) + begin = _op.concatenate(tmp, axis=0) + btype = _infer_type(begin).checked_type.dtype + if str(btype) != "int32": + begin = _op.cast(begin, "int32") + if isinstance(inputs[3], str) and inputs[3].isdigit(): - end[dim] = min(end[dim], int(inputs[3])) + target_end = int(inputs[3]) else: - if isinstance(inputs[3], _expr.Call): - target_end = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) + if isinstance(inputs[3], _expr.Expr): + try: + target_end = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) + except Exception: + target_end = inputs[3] else: target_end = inputs[3] - end[dim] = min(end[dim], target_end) - - strides = [1] * len(end) + if isinstance(target_end, int) and target_end >= index_size_limit: + # Quick path for original data. + if isinstance(begin, _expr.Constant) and begin.data.asnumpy().tolist()[dim] == 0 \ + and stride == 1: + return data + target_end = dshape[dim] + + # Process end + if isinstance(target_end, int): + if isinstance(end, list): + end[dim] = target_end + else: + all_static = True + for i in range(len(dshape)): + if i != dim and isinstance(dshape[i], tvm.tir.Any): + all_static = False + + if all_static: + end = list(get_const_tuple(dshape)) + end[dim] = target_end + else: + target_end = _expr.const(target_end) + end = _op.scatter(end, _op.expand_dims(_expr.const(dim), axis=0), + _op.expand_dims(target_end, axis=0), axis=0) + else: + end = _op.shape_of(data) + if not isinstance(target_end, tvm.tir.Any): + ttype = _infer_type(target_end).checked_type.dtype + if str(ttype) != "int32": + target_end = _op.cast(target_end, 'int32') + end = _op.scatter(end, _op.expand_dims(_expr.const(dim), axis=0), + _op.expand_dims(target_end, axis=0), axis=0) + + if not isinstance(end, list): + etype = _infer_type(end).checked_type.dtype + if str(etype) != "int32": + end = _op.cast(end, "int32") + + strides = [1] * ndim strides[dim] = int(inputs[4]) return _op.transform.strided_slice( @@ -380,7 +440,11 @@ def _impl(inputs, input_types): def _topk(): def _impl(inputs, input_types): data = inputs[0] - k = int(inputs[1]) + try: + k = int(_infer_value(inputs[1], {}).asnumpy().tolist()) + k = _expr.const(k) + except Exception: + k = inputs[1] axis = int(inputs[2]) is_ascend = not bool(inputs[3]) sort = bool(inputs[4]) @@ -389,7 +453,7 @@ def _impl(inputs, input_types): msg = "Currently supports only sorted output for topk operator." raise AssertionError(msg) - outs = _op.topk(data, k=k, axis=axis, is_ascend=is_ascend, ret_type="both") + outs = _op.topk(data, k=k, axis=axis, is_ascend=is_ascend, ret_type="both", dtype="int64") return outs[0], outs[1] @@ -407,7 +471,7 @@ def _impl(inputs, input_types): def _repeat(): def _impl(inputs, input_types): data = inputs[0] - reps = _get_dims(inputs[1]) + reps = inputs[1] return _op.transform.tile(data, reps=reps) return _impl @@ -454,27 +518,55 @@ def _impl(inputs, input_types): return _impl +def _full_impl(data, fill_value, dtype): + size = [] + need_reshape = False + new_shape = [] + for dim in data: + if isinstance(dim, _expr.Expr): + if isinstance(dim, _expr.Constant): + dim = int(dim.data.asnumpy()) + if isinstance(size, list): + size.append(dim) + new_shape.append(dim) + else: + try: + dim = int(_infer_value(dim, {}).asnumpy()) + if isinstance(size, list): + size.append(dim) + new_shape.append(dim) + except Exception: + size = None + need_reshape = True + new_shape.append(0) + else: + if isinstance(size, list): + size.append(dim) + new_shape.append(dim) + + + if size is None: + tmp = [] + for dim in data: + tmp.append(_op.cast(_op.expand_dims(dim, axis=0), "int64")) + size = _op.concatenate(tmp, axis=0) + + out = _op.full(_expr.const(fill_value), size, dtype=dtype) + if need_reshape: + out = _op.reshape(out, new_shape) + return out def _ones(): def _impl(inputs, input_types): data = inputs[0] import torch - - if isinstance(data, _expr.Expr): - shape = _infer_shape(data) - elif isinstance(data, list): - shape = data - elif isinstance(data, (torch.Tensor, np.ndarray)): - shape = data.shape - else: + if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)): msg = "Data type %s could not be parsed in ones op" % (type(data)) raise AssertionError(msg) dtype = _convert_dtype_value(inputs[1]) - - return _op.full(_expr.const(1), shape, dtype=dtype) - + return _full_impl(data, 1, dtype) return _impl @@ -498,21 +590,12 @@ def _impl(inputs, input_types): data = inputs[0] import torch - - if isinstance(data, _expr.Expr): - shape = _infer_shape(data) - elif isinstance(data, list): - shape = data - elif isinstance(data, (torch.Tensor, np.ndarray)): - shape = data.shape - else: + if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)): msg = "Data type %s could not be parsed in zeros op" % (type(data)) raise AssertionError(msg) dtype = _convert_dtype_value(inputs[1]) - - return _op.full(_expr.const(0), shape, dtype=dtype) - + return _full_impl(data, 0, dtype) return _impl @@ -534,18 +617,11 @@ def _impl(inputs, input_types): def _full(default_dtype): def _impl(inputs, input_types): data = inputs[0] - fill_value = inputs[1] - import torch - if isinstance(data, _expr.Expr): - shape = _infer_shape(data) - elif isinstance(data, list): - shape = data - elif isinstance(data, (torch.Tensor, np.ndarray)): - shape = data.shape - else: - msg = "Data type %s could not be parsed in zeros op" % (type(data)) + import torch + if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)): + msg = "Data type %s could not be parsed in full op" % (type(data)) raise AssertionError(msg) if inputs[2] is not None: # dtype given @@ -554,8 +630,7 @@ def _impl(inputs, input_types): # if dtype is None, torch uses a global default set by torch.set_default_tensor_type() dtype = default_dtype - return _op.full(_expr.const(fill_value), shape, dtype=dtype) - + return _full_impl(data, fill_value, dtype) return _impl @@ -1100,16 +1175,25 @@ def _impl(inputs, input_types): def _flatten(): def _impl(inputs, input_types): data = inputs[0] - start_dim = inputs[1] if len(inputs) > 0 else 0 - end_dim = inputs[2] if len(inputs) > 1 else -1 - - if start_dim == 0 and end_dim == -1: - return _op.transform.reshape(data, (-1,)) - if start_dim == 1 and end_dim == -1: - return _op.nn.batch_flatten(data) - - raise NotImplementedError("Only support 1d flatten or batch flatten") - + start = int(inputs[1]) + end = int(inputs[2]) + dshape = get_const_tuple(_infer_shape(data)) + ndim = len(dshape) + if end < 0: + end += ndim + new_shape = [0] * start + + new_shape.append(-1) + squeeze_axes = [] + for i in range(start + 1, end + 1): + new_shape.append(1) + squeeze_axes.append(i) + for _ in range(end + 1, ndim): + new_shape.append(0) + out = _op.reshape(data, new_shape) + if squeeze_axes: + out = _op.squeeze(out, axis=squeeze_axes) + return out return _impl @@ -1141,14 +1225,13 @@ def _impl(inputs, input_types): bias = inputs[0] return _op.nn.bias_add(dense_out, bias) else: - return dense_out - + return dense_out + _expr.const(inputs[0]) return _impl def _size(prelude): def _impl_dynamic(inp, axis): - shape_dynamic = _op.shape_of(inp) + shape_dynamic = _op.shape_of(inp, dtype="int32") if axis is not None: return _op.take(shape_dynamic, _expr.const(axis), 0) return shape_dynamic @@ -1164,9 +1247,8 @@ def _impl(inputs, input_types): return _impl_dynamic(inputs[0], axis) if axis is not None: - return shape[axis] - return shape - + return _expr.const(shape[axis]) + return _expr.const(shape) return _impl @@ -1220,12 +1302,34 @@ def _impl(inputs, input_types): def _reshape(): def _impl(inputs, input_types): data = inputs[0] - if _is_int_seq(inputs[1]): - new_shape = inputs[1] + new_shape = inputs[1] + + tmp_shape = [] + is_dyn = False + for s in new_shape: + if isinstance(s, _expr.Constant): + tmp_shape.append(int(s.data.asnumpy())) + elif isinstance(s, _expr.Expr): + try: + dim = int(_infer_value(s, {}).asnumpy()) + tmp_shape.append(dim) + except Exception: + is_dyn = True + tmp_shape.append(s) + else: + tmp_shape.append(s) + + if is_dyn: + new_shape = [] + for i, s in enumerate(tmp_shape): + if not isinstance(s, _expr.Expr): + s = _expr.const(s, "int64") + else: + s = _op.cast(s, "int64") + new_shape.append(_op.expand_dims(s, axis=0)) + new_shape = _op.concatenate(new_shape, axis=0) else: - assert isinstance(inputs[1], list) - infer_res = [_infer_value(_wrap_const(size), {}) for size in inputs[1]] - new_shape = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res] + new_shape = tmp_shape return _op.transform.reshape(data, new_shape) return _impl @@ -1573,12 +1677,11 @@ def _impl(inputs, input_types): def _expand(): def _impl(inputs, input_types): data_in = inputs[0] - if isinstance(data_in, _expr.Expr): - shape = list(_infer_shape(data_in)) + shape = list(_infer_shape(data_in)) ndims = len(shape) sizes = inputs[1] - out = inputs[0] + out = data_in out_dims = len(sizes) if ndims < out_dims: @@ -1586,14 +1689,11 @@ def _impl(inputs, input_types): out = _op.expand_dims(out, axis=0, num_newaxis=num_newaxis) shape = [1] * num_newaxis + shape - for i in range(ndims): - if sizes[i] == -1 or sizes[i] == shape[i]: - continue - data = list() - for temp in range(sizes[i]): - data.append(out) - - out = _op.tensor.concatenate(data, i) + for i in range(out_dims): + if sizes[i] != -1 and shape[i] == 1: + if not isinstance(sizes[i], int): + sizes[i] = int(_infer_value(sizes[i], {}).asnumpy()) + out = _op.repeat(out, sizes[i], axis=i) return out @@ -1648,10 +1748,18 @@ def _impl(inputs, input_types): # group into tuple of 2 ints paddings = [paddings[i : i + 2] for i in range(0, len(paddings), 2)] + const_paddings = [] + for pad in paddings: + const_paddings.append([]) + for p in pad: + if not isinstance(p, int): + p = int(_infer_value(p, {}).asnumpy()) + const_paddings[-1].append(p) + if mode == "constant": - return _op.nn.pad(data, paddings, pad_value=inputs[2], pad_mode=mode) + return _op.nn.pad(data, const_paddings, pad_value=inputs[2], pad_mode=mode) else: - return _op.nn.pad(data, paddings, pad_mode=mode) + return _op.nn.pad(data, const_paddings, pad_mode=mode) return _impl @@ -1669,36 +1777,45 @@ def _impl(inputs, input_types): def _to(): def _impl(inputs, input_types): data = inputs[0] - if inputs[3] in ["cpu", "cuda"]: - return data + dtype = inputs[1] if inputs[1] is not None and not isinstance(inputs[1], str) \ + else inputs[2] # special handling for aten::to(data, 6, _, _, _) case # 6 means dtype = float # this happens when converting upsampling with scale factor - cast_func = {6: float, 7: float, 3: int, 4: int} - cast_func_expr = { - 6: lambda x: _op.cast(x, "float32"), - 7: lambda x: _op.cast(x, "float64"), - 3: lambda x: _op.cast(x, "int32"), - 4: lambda x: _op.cast(x, "int64"), + cast_map = { + 6: "float32", + 7: "float64", + 3: "int32", + 4: "int64", } - if inputs[1] in cast_func and not isinstance(data, _expr.Expr): - return cast_func[inputs[1]](data) - elif inputs[1] in cast_func_expr and isinstance(data, _expr.Expr): - return cast_func_expr[inputs[1]](data) - return data + cast_func = { + 6: float, + 7: float, + 3: int, + 4: int + } + + ret = data + if isinstance(data, _expr.Expr): + actual_dtype = str(_infer_type(data).checked_type.dtype) + if dtype in cast_map and cast_map[dtype] != actual_dtype: + ret = _op.cast(data, cast_map[dtype]) + elif dtype in cast_map: + ret = cast_func[dtype](data) + + return ret return _impl def _upsample(method, prelude): def _impl(inputs, input_types): - if isinstance(inputs[1], _expr.Var): - out_size = _infer_shape(inputs[1]) - elif _is_int_seq(inputs[1]): - out_size = inputs[1] - elif isinstance(inputs[1], list): - infer_res = [_infer_value(size, {}) for size in inputs[1]] - out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res] + out_size = [] + for size in inputs[1]: + if not isinstance(size, int): + out_size.append(int(_infer_value(size, {}).asnumpy())) + else: + out_size.append(size) data = inputs[0] @@ -1768,12 +1885,12 @@ def _impl(inputs, input_types): def _expand_as(): def _impl(inputs, input_types): - # TODO: maybe fix this - # This assumes expand_as can be removed because TVM has broadcast op - msg = "aten::expand_as(...) found, assume it is part of broadcast op" - logging.warning(msg) - return inputs[0] - + target = inputs[1] + t0 = _infer_type(inputs[0]).checked_type.dtype + t1 = _infer_type(inputs[1]).checked_type.dtype + if str(t0) != str(t1): + target = _op.cast(target, t0) + return _op.broadcast_to_like(inputs[0], target) return _impl @@ -2042,6 +2159,128 @@ def _impl(inputs, input_types): return _impl +def _roi_align(prelude): + def _impl(inputs, input_types): + data = inputs[0] + boxes = inputs[1] + + output_size = (inputs[3], inputs[4]) + spatial_scale = inputs[2] + + return _op.vision.roi_align(data, boxes, output_size, spatial_scale) + return _impl + +def _unbind(): + def _impl(inputs, input_types): + data = inputs[0] + dim = int(inputs[1]) + ishapes = _infer_shape(data) + if dim >= len(ishapes): + msg = "Please check input dim, it shouldn't" \ + "be greater than or equal to rank." + raise AttributeError(msg) + + selections = ishapes[dim] + res_split = _op.split(data, selections, dim) + # squeeze each split piece to get same shape as aten::unbind + # TODO (yongwww): add new op to avoid the squeeze overhead + ret = [] + for i in range(selections): + ret.append(_op.transform.squeeze(res_split[i], axis=[dim])) + ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) + return ret + return _impl + +def _shape_as_tensor(prelude): + def _impl(inputs, input_types): + is_symbolic_shape = False + input_shape = _infer_shape(inputs[0], prelude.mod) + for axis in input_shape: + if not isinstance(axis, (int, tvm.tir.IntImm)): + is_symbolic_shape = True + break + + if is_symbolic_shape: + ret = _op.shape_of(inputs[0], dtype='int64') + else: + ret = _expr.const(np.array(input_shape), dtype="int64") + + return ret + return _impl + +def _logical_and(): + def _impl(inputs, input_types): + lhs = _op.cast(inputs[0], "bool") + rhs = _op.cast(inputs[1], "bool") + + return _op.logical_and(lhs, rhs) + return _impl + +def _nonzero(): + def _impl(inputs, input_types): + data = inputs[0] + ret = _op.transform.argwhere(data) + if len(inputs) > 1 and inputs[1]: + ret = _unbind()([ret, 0], None) + return ret + return _impl + +def _scatter(): + def _impl(inputs, input_types): + data = inputs[0] + axis = int(inputs[1]) + index = inputs[2] + src = inputs[3] + return _op.transform.scatter(data, index, src, axis) + return _impl + +def _scalar_tensor(): + def _impl(inputs, input_types): + data = inputs[0] + cast_map = { + 6: "float32", + 7: "float64", + 3: "int32", + 4: "int64", + } + type_key = inputs[1] + if isinstance(data, _expr.Constant): + data = data.data.asnumpy().tolist() + return _expr.const(data, cast_map[type_key]) + return _impl + +def _interpolate(): + def _impl(inputs, input_types): + if isinstance(inputs[1], _expr.Expr): + out_size = inputs[1] + elif isinstance(inputs[1], list): + try: + infer_res = [_infer_value(size, {}) for size in inputs[1]] + out_size = [np.asscalar(res.asnumpy().astype(np.int)) + for res in infer_res] + except Exception: + h = _op.expand_dims(inputs[1][0], axis=0) + w = _op.expand_dims(inputs[1][1], axis=0) + out_size = _op.concatenate([h, w], axis=0) + + data = inputs[0] + align_corners = inputs[4] + method = inputs[3] + if method.startswith("nearest"): + method = "nearest_neighbor" + + if method == "nearest_neighbor": + coord_trans = "asymmetric" + elif align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" + + def func(x): + return _op.image.resize(x, out_size, "NCHW", method, coord_trans) + + return func(data) + return _impl def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" @@ -2360,6 +2599,14 @@ def _get_convert_map(prelude, default_dtype): "aten::index": _index(), "torchvision::nms": _nms(prelude), "aten::logsumexp": _logsumexp(), + "torchvision::roi_align" : _roi_align(prelude), + "aten::unbind" : _unbind(), + "aten::__and__": _logical_and(), + "aten::_shape_as_tensor" : _shape_as_tensor(prelude), + "aten::nonzero" : _nonzero(), + "aten::scatter" : _scatter(), + "aten::scalar_tensor" : _scalar_tensor(), + "aten::__interpolate" : _interpolate(), } return convert_map @@ -2795,7 +3042,17 @@ def get_input(index): def get_var(name, val): if val: checked_type = _infer_type_with_prelude(val, prelude) - return _expr.var(name, type_annotation=checked_type) + if hasattr(checked_type, "shape"): + shape = get_const_tuple(checked_type.shape) + actual_shape = [] + for dim in shape: + if isinstance(dim, int) and dim == 0: + actual_shape.append(Any()) + else: + actual_shape.append(dim) + return _expr.var(name, shape=actual_shape, dtype=checked_type.dtype) + else: + return _expr.var(name, type_annotation=checked_type) return _expr.var(name) loop_iter_var = _expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype) @@ -2812,7 +3069,7 @@ def get_var(name, val): var for var in _get_free_vars_from_block(body_block) if var in outputs - and not isinstance(outputs[var], (_expr.Constant, int, float)) + and not isinstance(outputs[var], (_expr.Constant, int, float, str)) and outputs[var] ] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 4192cf45737d..8f5b31915edb 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3167,6 +3167,7 @@ def test_forward_pretrained_bert_base_uncased(): if __name__ == "__main__": + """ # some structural tests test_forward_traced_function() test_forward_dtypes() @@ -3312,6 +3313,7 @@ def test_forward_pretrained_bert_base_uncased(): # Test simple conditionals and loop test_control_flow() test_simple_rnn() + """ # More complex recurrent models from lstm_test import test_custom_lstm From 35bcf6597cedc5d8b7d73bf34f0eff4608dac5d6 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Thu, 10 Sep 2020 18:09:56 -0700 Subject: [PATCH 02/23] Add tests --- python/tvm/relay/frontend/pytorch.py | 16 +++- tests/python/frontend/pytorch/test_forward.py | 84 ++++++++++++++++++- 2 files changed, 92 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 0f5350b759c3..4fe72c201096 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2166,8 +2166,13 @@ def _impl(inputs, input_types): output_size = (inputs[3], inputs[4]) spatial_scale = inputs[2] + sample_ratio = inputs[5] + aligned = inputs[6] - return _op.vision.roi_align(data, boxes, output_size, spatial_scale) + if aligned: + data -= _expr.const(0.5) + + return _op.vision.roi_align(data, boxes, output_size, spatial_scale, sample_ratio) return _impl def _unbind(): @@ -2216,12 +2221,15 @@ def _impl(inputs, input_types): return _op.logical_and(lhs, rhs) return _impl -def _nonzero(): +def _nonzero(is_numpy_style): def _impl(inputs, input_types): data = inputs[0] ret = _op.transform.argwhere(data) - if len(inputs) > 1 and inputs[1]: - ret = _unbind()([ret, 0], None) + + if is_numpy_style or (len(inputs) > 1 and inputs[1]): + # TODO(kevinthesun): Support this by adding unbind op + # ret = _unbind()([ret, 0], None) + raise RuntimeError("as_tuple is not supported yet for nonzero.") return ret return _impl diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 8f5b31915edb..801abd6365ee 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1679,6 +1679,33 @@ def _gen_rand_inputs(num_boxes): verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores], targets) +def test_forward_roi_align(): + """ROI align""" + torch.set_grad_enabled(False) + class ROIAlgin(Module): + def __init__(self, output_sizes, spatial_scale=1.0, sampling_ratio=-1, aligned=False): + super().__init__() + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + self.aligned = aligned + self.output_sizes = output_sizes + + def forward(self, *args): + return torchvision.ops.roi_align(args[0], args[1], self.output_sizes, + self.spatial_scale, self.sampling_ratio, + self.aligned) + + in_data = torch.Tensor(np.random.uniform(size=(1, 8, 100, 100))) + in_boxes = torch.Tensor(np.random.uniform(0.0, 100.0, size=(35, 4))) + in_batch = torch.zeros((35, 1), dtype=torch.float) + in_boxes = torch.cat([in_batch, in_boxes], dim=1) + + + verify_model(ROIAlgin(7), [in_data, in_boxes]) + verify_model(ROIAlgin((10, 10), 0.7, 5), [in_data, in_boxes]) + verify_model(ROIAlgin(15, 0.9, 3, False), [in_data, in_boxes]) + + @tvm.testing.uses_gpu def test_conv3d(): for ishape in [(1, 32, 16, 16, 16), (1, 32, 9, 15, 15), (1, 32, 13, 7, 7)]: @@ -1718,7 +1745,6 @@ def test_conv3d_transpose(): inp, ) - # Model tests @tvm.testing.uses_gpu def test_resnet18(): @@ -3025,6 +3051,54 @@ def forward(self, x): verify_script_model(Stack(), [(8, 8, 8)], _get_default_vm_targets()) +def test_forward_unbind(): + class Unbind(torch.nn.Module): + def __init__(self, axis=0): + super().__init__() + self.axis = axis + + def forward(self, x): + return torch.unbind(x, self.axis) + + inp = torch.randn(8, 8, 8) + verify_model(Unbind(0), input_data=inp) + verify_model(Unbind(1), input_data=inp) + verify_model(Unbind(2), input_data=inp) + + +def test_forward_nonzero(): + class Nonzero(Module): + def __init__(self, as_tuple=False): + super().__init__() + self.as_tuple = as_tuple + + def forward(self, data): + return torch.nonzero(data, as_tuple=self.as_tuple) + + inp = torch.Tensor(np.array([[0, 1, 0], [2, 0, 9], [-1, -1, 0]]).astype("float32")) + verify_trace_model(Nonzero(), [inp], ['llvm']) + + +def test_forward_scatter(): + class Scatter(Module): + def __init__(self, dim=0): + super().__init__() + self.dim = dim + + def forward(self, data, index, src): + return torch.scatter(data, dim=self.dim, index=index, src=src) + + in_data = torch.zeros(3, 5) + in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) + in_src = torch.rand(2, 5) + verify_model(Scatter(), input_data=[in_data, in_index, in_src]) + + in_data = torch.zeros(2, 4) + in_index = torch.tensor([[2], [3]]) + in_src = torch.rand(2, 1) + verify_model(Scatter(1), input_data=[in_data, in_index, in_src]) + + def test_forward_pretrained_bert_base_uncased(): ###################################################################### # This is an example how to run BERT models using TVM @@ -3167,7 +3241,6 @@ def test_forward_pretrained_bert_base_uncased(): if __name__ == "__main__": - """ # some structural tests test_forward_traced_function() test_forward_dtypes() @@ -3265,6 +3338,7 @@ def test_forward_pretrained_bert_base_uncased(): test_upsample() test_forward_upsample3d() test_forward_nms() + test_forward_roi_align() test_to() test_flatten() test_type_as() @@ -3286,6 +3360,9 @@ def test_forward_pretrained_bert_base_uncased(): test_logsumexp() test_stack() test_stack_dynamic() + test_forward_unbind() + test_forward_nonzero() + test_forward_scatter() # Model tests test_resnet18() @@ -3313,10 +3390,9 @@ def test_forward_pretrained_bert_base_uncased(): # Test simple conditionals and loop test_control_flow() test_simple_rnn() - """ # More complex recurrent models - from lstm_test import test_custom_lstm + from test_lstm import test_custom_lstm test_custom_lstm() From fb881dd1c2a0d7a71fb2c89911317af95920654c Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Thu, 10 Sep 2020 19:04:29 -0700 Subject: [PATCH 03/23] Fix pylint --- python/tvm/relay/frontend/pytorch.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 4fe72c201096..0fa89b9c50ba 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -297,8 +297,7 @@ def _impl(inputs, input_types): if isinstance(dim, tvm.tir.Any): end = _op.shape_of(data) break - else: - end.append(int(dim)) + end.append(int(dim)) begin = [0] * ndim dim = int(inputs[1]) @@ -348,8 +347,8 @@ def _impl(inputs, input_types): end[dim] = target_end else: all_static = True - for i in range(len(dshape)): - if i != dim and isinstance(dshape[i], tvm.tir.Any): + for i, shape_dim in enumerate(dshape): + if i != dim and isinstance(shape_dim, tvm.tir.Any): all_static = False if all_static: @@ -1192,7 +1191,7 @@ def _impl(inputs, input_types): new_shape.append(0) out = _op.reshape(data, new_shape) if squeeze_axes: - out = _op.squeeze(out, axis=squeeze_axes) + out = _op.squeeze(out, axis=squeeze_axes) return out return _impl From 989fcbd0f5c7f44debd0ee7dd64168b00652178b Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Fri, 11 Sep 2020 14:43:44 -0700 Subject: [PATCH 04/23] Improve data cast --- python/tvm/relay/frontend/pytorch.py | 73 +++++++++++++++++++--------- 1 file changed, 51 insertions(+), 22 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 0fa89b9c50ba..902de656c58a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -185,7 +185,7 @@ def _get_value(val, dtype): # dtype is a tvm dtype if isinstance(val, _expr.Expr): try: - ret = _infer_value(_op.cast(val, dtype), {}).asnumpy() + ret = _infer_value(val, {}).asnumpy() ret = _expr.const(ret, dtype) except Exception: ret = _op.cast(val, dtype) @@ -439,10 +439,13 @@ def _impl(inputs, input_types): def _topk(): def _impl(inputs, input_types): data = inputs[0] - try: - k = int(_infer_value(inputs[1], {}).asnumpy().tolist()) - k = _expr.const(k) - except Exception: + if isinstance(k, _expr.Expr): + try: + k = _infer_value(inputs[1], {}).asnumpy().tolist() + k = _expr.const(k) + except Exception: + k = inputs[1] + else: k = inputs[1] axis = int(inputs[2]) is_ascend = not bool(inputs[3]) @@ -555,7 +558,7 @@ def _full_impl(data, fill_value, dtype): out = _op.reshape(out, new_shape) return out -def _ones(): +def _ones(default_dtype): def _impl(inputs, input_types): data = inputs[0] @@ -564,18 +567,24 @@ def _impl(inputs, input_types): msg = "Data type %s could not be parsed in ones op" % (type(data)) raise AssertionError(msg) - dtype = _convert_dtype_value(inputs[1]) + if inputs[1] is not None: + dtype = _convert_dtype_value(inputs[1]) + else: + dtype = default_dtype return _full_impl(data, 1, dtype) return _impl -def _ones_like(): +def _ones_like(default_dtype): def _impl(inputs, input_types): data = inputs[0] out = _op.ones_like(data) # If the input and the output datatype is different, do a cast - dtype = _convert_dtype_value(inputs[1]) + if inputs[1] is not None: + dtype = _convert_dtype_value(inputs[1]) + else: + dtype = default_dtype if input_types[0] != dtype: out = _op.cast(out, dtype) @@ -584,7 +593,7 @@ def _impl(inputs, input_types): return _impl -def _zeros(): +def _zeros(default_dtype): def _impl(inputs, input_types): data = inputs[0] @@ -593,18 +602,24 @@ def _impl(inputs, input_types): msg = "Data type %s could not be parsed in zeros op" % (type(data)) raise AssertionError(msg) - dtype = _convert_dtype_value(inputs[1]) + if inputs[1] is not None: + dtype = _convert_dtype_value(inputs[1]) + else: + dtype = default_dtype return _full_impl(data, 0, dtype) return _impl -def _zeros_like(): +def _zeros_like(default_dtype): def _impl(inputs, input_types): data = inputs[0] out = _op.zeros_like(data) # If the input and the output datatype is different, do a cast - dtype = _convert_dtype_value(inputs[1]) + if inputs[1] is not None: + dtype = _convert_dtype_value(inputs[1]) + else: + dtype = default_dtype if input_types[0] not in dtype: out = _op.cast(out, dtype) @@ -633,7 +648,7 @@ def _impl(inputs, input_types): return _impl -def _full_like(): +def _full_like(default_dtype): def _impl(inputs, input_types): data = inputs[0] fill_value = inputs[1] @@ -641,7 +656,11 @@ def _impl(inputs, input_types): out = _op.full_like(data, _expr.const(fill_value)) # If the input and the output datatype is different, do a cast - dtype = _convert_dtype_value(inputs[2]) + if inputs[2] is not None: # dtype given + dtype = _convert_dtype_value(inputs[2]) + else: + # if dtype is None, torch uses a global default set by torch.set_default_tensor_type() + dtype = default_dtype if input_types[0] not in dtype: out = _op.cast(out, dtype) @@ -2166,7 +2185,7 @@ def _impl(inputs, input_types): output_size = (inputs[3], inputs[4]) spatial_scale = inputs[2] sample_ratio = inputs[5] - aligned = inputs[6] + aligned = False if len(inputs) < 7 else inputs[6] if aligned: data -= _expr.const(0.5) @@ -2329,6 +2348,14 @@ def _pytorch_result_type(dtypes, non_tensor_inputs): def _pytorch_promote_types(inputs, dtypes): """This promotes TVM inputs with TVM dtypes passed like PyTorch would""" + actual_dtypes = [] + for i, inp in enumerate(inputs): + if isinstance(inp, _expr.Expr): + idt = _infer_type(inp).checked_type.dtype + actual_dtypes.append(idt) + else: + actual_dtypes.append(dtypes[i]) + dtypes = actual_dtypes tensor_dtypes = [dt for inp, dt in zip(inputs, dtypes) if not np.isscalar(inp)] non_tensor_inputs = [inp for inp in inputs if np.isscalar(inp)] result_type = _pytorch_result_type(tensor_dtypes, non_tensor_inputs) @@ -2396,6 +2423,8 @@ def _convert_data_type(input_type, default_dtype=None): return "qint32" elif input_type in ["bool", "torch.bool"]: return "bool" + elif input_type in ["str"]: + return "str" else: raise NotImplementedError("input_type {} is not handled yet".format(input_type)) return "float32" # Never reached @@ -2452,12 +2481,12 @@ def _get_convert_map(prelude, default_dtype): "aten::floor_divide": _elemwise("floor_divide"), "aten::addcdiv": _addcdiv(), "aten::addcmul": _addcmul(), - "aten::ones": _ones(), - "aten::ones_like": _ones_like(), - "aten::zeros": _zeros(), - "aten::zeros_like": _zeros_like(), + "aten::ones": _ones(default_dtype), + "aten::ones_like": _ones_like(default_dtype), + "aten::zeros": _zeros(default_dtype), + "aten::zeros_like": _zeros_like(default_dtype), "aten::full": _full(default_dtype), - "aten::full_like": _full_like(), + "aten::full_like": _full_like(default_dtype), "aten::linspace": _linspace(), "aten::reciprocal": _reciprocal(), "aten::repeat": _repeat(), @@ -2762,7 +2791,7 @@ def _get_constant(node): # TODO(t-vi): When is this needed? return tensor.item() return _wrap_const(tensor.numpy()) - elif ty == "DeviceObjType": + elif ty == "DeviceObjType" or ty == "StringType": return node.s(attr_name) elif ty == "FunctionType": return None From 8104c16eeb2fe887d0c4b84e0ac1a8535cf929ef Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Fri, 11 Sep 2020 15:48:41 -0700 Subject: [PATCH 05/23] Use int64 for slice axis --- python/tvm/relay/frontend/pytorch.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 902de656c58a..270a1ce3e391 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -288,6 +288,7 @@ def _impl(inputs, input_types): def _slice(): def _impl(inputs, input_types): + axis_dtype = "int64" index_size_limit = 2**63 - 1 data = inputs[0] dshape = _infer_shape(data) @@ -315,13 +316,13 @@ def _impl(inputs, input_types): tmp = [] for b in begin: if isinstance(b, int): - tmp.append(_op.expand_dims(_expr.const(b, "int64"), axis=0)) + tmp.append(_op.expand_dims(_expr.const(b, axis_dtype), axis=0)) else: - tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64")) + tmp.append(_op.cast(_op.expand_dims(b, axis=0), axis_dtype)) begin = _op.concatenate(tmp, axis=0) btype = _infer_type(begin).checked_type.dtype - if str(btype) != "int32": - begin = _op.cast(begin, "int32") + if str(btype) != axis_dtype: + begin = _op.cast(begin, axis_dtype) if isinstance(inputs[3], str) and inputs[3].isdigit(): target_end = int(inputs[3]) @@ -362,15 +363,15 @@ def _impl(inputs, input_types): end = _op.shape_of(data) if not isinstance(target_end, tvm.tir.Any): ttype = _infer_type(target_end).checked_type.dtype - if str(ttype) != "int32": - target_end = _op.cast(target_end, 'int32') + if str(ttype) != axis_dtype: + target_end = _op.cast(target_end, axis_dtype) end = _op.scatter(end, _op.expand_dims(_expr.const(dim), axis=0), _op.expand_dims(target_end, axis=0), axis=0) if not isinstance(end, list): etype = _infer_type(end).checked_type.dtype - if str(etype) != "int32": - end = _op.cast(end, "int32") + if str(etype) != axis_dtype: + end = _op.cast(end, axis_dtype) strides = [1] * ndim strides[dim] = int(inputs[4]) From 1daf122c431d5df610090d0ad0aa4fec792c9884 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Fri, 11 Sep 2020 15:53:07 -0700 Subject: [PATCH 06/23] Fix lint --- python/tvm/relay/frontend/pytorch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 270a1ce3e391..cd76f4d87bff 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -440,7 +440,7 @@ def _impl(inputs, input_types): def _topk(): def _impl(inputs, input_types): data = inputs[0] - if isinstance(k, _expr.Expr): + if isinstance(inputs[0], _expr.Expr): try: k = _infer_value(inputs[1], {}).asnumpy().tolist() k = _expr.const(k) @@ -2640,7 +2640,8 @@ def _get_convert_map(prelude, default_dtype): "aten::unbind" : _unbind(), "aten::__and__": _logical_and(), "aten::_shape_as_tensor" : _shape_as_tensor(prelude), - "aten::nonzero" : _nonzero(), + "aten::nonzero" : _nonzero(False), + "aten::nonzero_numpy" : _nonzero(True), "aten::scatter" : _scatter(), "aten::scalar_tensor" : _scalar_tensor(), "aten::__interpolate" : _interpolate(), @@ -2792,7 +2793,7 @@ def _get_constant(node): # TODO(t-vi): When is this needed? return tensor.item() return _wrap_const(tensor.numpy()) - elif ty == "DeviceObjType" or ty == "StringType": + elif ty in ["DeviceObjType", "StringType"]: return node.s(attr_name) elif ty == "FunctionType": return None From f7fc5a5dc9e27cd7923b325ebd4405d8066d7b90 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 14 Sep 2020 11:37:22 +0800 Subject: [PATCH 07/23] fix roi_align(..., aligned=True) --- python/tvm/relay/frontend/pytorch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index cd76f4d87bff..a1a1381c9520 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2189,7 +2189,8 @@ def _impl(inputs, input_types): aligned = False if len(inputs) < 7 else inputs[6] if aligned: - data -= _expr.const(0.5) + # boxes[:,1:] -= 0.5/spatial_scale + boxes-=_expr.const([0]+[0.5/spatial_scale]*4) return _op.vision.roi_align(data, boxes, output_size, spatial_scale, sample_ratio) return _impl From 7a14ed5e4f1e4e724fcb804785a9a82354761d76 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Mon, 14 Sep 2020 14:51:57 -0700 Subject: [PATCH 08/23] Minor fix --- python/tvm/relay/frontend/pytorch.py | 29 ++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a1a1381c9520..62ca83bffcbf 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -185,7 +185,7 @@ def _get_value(val, dtype): # dtype is a tvm dtype if isinstance(val, _expr.Expr): try: - ret = _infer_value(val, {}).asnumpy() + ret = _infer_value(_op.cast(val, dtype), {}).asnumpy() ret = _expr.const(ret, dtype) except Exception: ret = _op.cast(val, dtype) @@ -211,9 +211,14 @@ def _get_type(val, inp_type): dtype = "float32" else: dtype = "int64" - start = _get_value(0, dtype) - stop = _get_value(inputs[0], dtype) - step = _get_value(1, dtype) + if inputs[1] is not None: + start = _get_value(inputs[0], dtype) + stop = _get_value(inputs[1], dtype) + step = _get_value(inputs[2], dtype) + else: + start = _expr.const(0, dtype) + stop = _get_value(inputs[0], dtype) + step = _expr.const(1, dtype) elif len(inputs) == 7: types = [_get_type(inputs[i], input_types[i]) for i in range(3)] if inputs[3] is not None: @@ -222,9 +227,14 @@ def _get_type(val, inp_type): dtype = "float32" else: dtype = "int64" - start = _get_value(inputs[0], dtype) - stop = _get_value(inputs[1], dtype) - step = _get_value(inputs[2], dtype) + if inputs[1] is not None: + start = _get_value(inputs[0], dtype) + stop = _get_value(inputs[1], dtype) + step = _get_value(inputs[2], dtype) + else: + start = _expr.const(0, dtype) + stop = _get_value(inputs[0], dtype) + step = _expr.const(1, dtype) else: msg = "Unknown number of arguments (%d) to parse." % (len(inputs)) raise AssertionError(msg) @@ -360,7 +370,7 @@ def _impl(inputs, input_types): end = _op.scatter(end, _op.expand_dims(_expr.const(dim), axis=0), _op.expand_dims(target_end, axis=0), axis=0) else: - end = _op.shape_of(data) + end = _op.cast(_op.shape_of(data), axis_dtype) if not isinstance(target_end, tvm.tir.Any): ttype = _infer_type(target_end).checked_type.dtype if str(ttype) != axis_dtype: @@ -2189,8 +2199,7 @@ def _impl(inputs, input_types): aligned = False if len(inputs) < 7 else inputs[6] if aligned: - # boxes[:,1:] -= 0.5/spatial_scale - boxes-=_expr.const([0]+[0.5/spatial_scale]*4) + boxes -= _expr.const(0.5 / spatial_scale) return _op.vision.roi_align(data, boxes, output_size, spatial_scale, sample_ratio) return _impl From 57e12d9f11275f474e8b197ab4d2ced6a6d65bf9 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Mon, 14 Sep 2020 16:07:17 -0700 Subject: [PATCH 09/23] Add e2e test --- tests/python/frontend/pytorch/test_forward.py | 6 + .../frontend/pytorch/test_object_detection.py | 113 ++++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 tests/python/frontend/pytorch/test_object_detection.py diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 801abd6365ee..6d90e70c1474 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3398,3 +3398,9 @@ def test_forward_pretrained_bert_base_uncased(): # Test bert model test_forward_pretrained_bert_base_uncased() + + # Test object detection models + from test_object_detection import test_detection_models + + test_detection_models(0) + test_detection_models(1) diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py new file mode 100644 index 000000000000..879124e1afbd --- /dev/null +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -0,0 +1,113 @@ +import torch +import torchvision +import cv2 + +import tvm + +from tvm import relay +from tvm.runtime.vm import VirtualMachine +from tvm.contrib.download import download + + +in_size = 512 + +def process_image(img): + img = cv2.imread(img).astype("float32") + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img/255.).permute(2,0,1).float() + img = torch.unsqueeze(img, axis=0) + + return img + + +def do_trace(model, inp, in_size=in_size): + model_trace = torch.jit.trace(model, inp) + model_trace.eval() + return model_trace + + +def dict_to_tuple(out_dict): + if "masks" in out_dict.keys(): + return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"] + return out_dict["boxes"], out_dict["scores"], out_dict["labels"] + + +class TraceWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return dict_to_tuple(out[0]) + + +def generate_jit_model(index, img): + model_funcs = [torchvision.models.detection.fasterrcnn_resnet50_fpn, + torchvision.models.detection.maskrcnn_resnet50_fpn] + + model_func = model_funcs[index] + model = TraceWrapper(model_func(pretrained=True)) + + model.eval() + inp = process_image(img) + + with torch.no_grad(): + out = model(inp) + + script_module = do_trace(model, inp) + script_out = script_module(inp) + + assert len(out[0]) > 0 and len(script_out[0]) > 0 + torch._C._jit_pass_inline(script_module.graph) + return script_module + + +def test_detection_models(model_index, score_threshold=0.9): + img = "test_street_small.jpg" + img_url = "https://raw.githubusercontent.com/dmlc/web-data/" \ + "master/gluoncv/detection/street_small.jpg" + download(img_url, img) + + input_shape = (1, 3, in_size, in_size) + target = "llvm" + input_name = 'input0' + shape_list = [(input_name, input_shape)] + + scripted_model = generate_jit_model(model_index, img) + mod, params = relay.frontend.from_pytorch(scripted_model, + shape_list) + + with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): + vm_exec = relay.vm.compile(mod, target=target, params=params) + + ctx = tvm.cpu() + vm = VirtualMachine(vm_exec, ctx) + data = process_image(img) + pt_res = scripted_model(data) + data = data.detach().numpy() + vm.set_input("main", **{input_name: data}) + tvm_res = vm.run() + + # Note: due to accumulated numerical error, we can't directly compare results + # with pytorch output. Some boxes might have a quite tiny difference in score + # and the order can become different. We just measure how many valid boxes + # there are for input image. + pt_scores = pt_res[1].detach().numpy().tolist() + tvm_scores = tvm_res[1].asnumpy().tolist() + num_pt_valid_scores = num_tvm_valid_scores = 0 + + for score in pt_scores: + if score >= score_threshold: + num_pt_valid_scores += 1 + else: + break + + for score in tvm_scores: + if score >= score_threshold: + num_tvm_valid_scores += 1 + + assert num_pt_valid_scores == num_tvm_valid_scores, \ + "Output mismatch: Under score threshold {}, Pytorch has {} valid " \ + "boxes while TVM has {}.".format(score_threshold, num_pt_valid_scores, + num_tvm_valid_scores) From e7eb64a2536a77a00747fef65c185ec9e7fb3649 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Mon, 14 Sep 2020 16:09:27 -0700 Subject: [PATCH 10/23] Add asf header --- .../frontend/pytorch/test_object_detection.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index 879124e1afbd..7c35e973f9fe 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -1,3 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name, unused-argument +"""Test torch vision fasterrcnn and maskrcnn models""" import torch import torchvision import cv2 From 0385a1dddf7ae66c4f37d8553469aa1060c80ce7 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Mon, 14 Sep 2020 16:25:51 -0700 Subject: [PATCH 11/23] Minor change --- python/tvm/relay/frontend/pytorch.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 62ca83bffcbf..db13c28b5a3f 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -211,14 +211,9 @@ def _get_type(val, inp_type): dtype = "float32" else: dtype = "int64" - if inputs[1] is not None: - start = _get_value(inputs[0], dtype) - stop = _get_value(inputs[1], dtype) - step = _get_value(inputs[2], dtype) - else: - start = _expr.const(0, dtype) - stop = _get_value(inputs[0], dtype) - step = _expr.const(1, dtype) + start = _expr.const(0, dtype) + stop = _get_value(inputs[0], dtype) + step = _expr.const(1, dtype) elif len(inputs) == 7: types = [_get_type(inputs[i], input_types[i]) for i in range(3)] if inputs[3] is not None: @@ -227,14 +222,9 @@ def _get_type(val, inp_type): dtype = "float32" else: dtype = "int64" - if inputs[1] is not None: - start = _get_value(inputs[0], dtype) - stop = _get_value(inputs[1], dtype) - step = _get_value(inputs[2], dtype) - else: - start = _expr.const(0, dtype) - stop = _get_value(inputs[0], dtype) - step = _expr.const(1, dtype) + start = _get_value(inputs[0], dtype) + stop = _get_value(inputs[1], dtype) + step = _get_value(inputs[2], dtype) else: msg = "Unknown number of arguments (%d) to parse." % (len(inputs)) raise AssertionError(msg) From 648344612eb17d73e30ab4236cf06ab634483b4b Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Mon, 14 Sep 2020 18:09:29 -0700 Subject: [PATCH 12/23] Use dynamic topk --- python/tvm/relay/frontend/pytorch.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index db13c28b5a3f..708c9a34c3c6 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -440,14 +440,7 @@ def _impl(inputs, input_types): def _topk(): def _impl(inputs, input_types): data = inputs[0] - if isinstance(inputs[0], _expr.Expr): - try: - k = _infer_value(inputs[1], {}).asnumpy().tolist() - k = _expr.const(k) - except Exception: - k = inputs[1] - else: - k = inputs[1] + k = inputs[1] axis = int(inputs[2]) is_ascend = not bool(inputs[3]) sort = bool(inputs[4]) From 82940b18a698dd8be42285dd562a72cdcb2d69f5 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Tue, 15 Sep 2020 10:51:28 -0700 Subject: [PATCH 13/23] Improve test --- tests/python/frontend/pytorch/test_forward.py | 2 ++ tests/python/frontend/pytorch/test_object_detection.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6d90e70c1474..8fe8ce2137d7 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3241,6 +3241,7 @@ def test_forward_pretrained_bert_base_uncased(): if __name__ == "__main__": + """ # some structural tests test_forward_traced_function() test_forward_dtypes() @@ -3398,6 +3399,7 @@ def test_forward_pretrained_bert_base_uncased(): # Test bert model test_forward_pretrained_bert_base_uncased() + """ # Test object detection models from test_object_detection import test_detection_models diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index 7c35e973f9fe..938cd5a44eab 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=import-self, invalid-name, unused-argument """Test torch vision fasterrcnn and maskrcnn models""" +import numpy as np import torch import torchvision import cv2 @@ -31,6 +32,7 @@ def process_image(img): img = cv2.imread(img).astype("float32") + img = cv2.resize(img, (in_size, in_size)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = torch.from_numpy(img/255.).permute(2,0,1).float() img = torch.unsqueeze(img, axis=0) @@ -68,7 +70,7 @@ def generate_jit_model(index, img): model = TraceWrapper(model_func(pretrained=True)) model.eval() - inp = process_image(img) + inp = torch.Tensor(np.random.uniform(0.0, 250.0,size=(1, 3, in_size, in_size))) with torch.no_grad(): out = model(inp) From 7beac1b5c54d3defa9da1e07e0f735d53274cd89 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Tue, 15 Sep 2020 14:41:58 -0700 Subject: [PATCH 14/23] Rollback topk --- python/tvm/relay/frontend/pytorch.py | 10 +++++++++- tests/python/frontend/pytorch/test_forward.py | 3 --- tests/python/frontend/pytorch/test_object_detection.py | 6 +++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 708c9a34c3c6..0a351163f489 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -440,11 +440,18 @@ def _impl(inputs, input_types): def _topk(): def _impl(inputs, input_types): data = inputs[0] - k = inputs[1] axis = int(inputs[2]) is_ascend = not bool(inputs[3]) sort = bool(inputs[4]) + if isinstance(inputs[1], _expr.Expr): + try: + k = _infer_value(inputs[1], {}).asnumpy().tolist() + except Exception: + k = inputs[1] + else: + k = inputs[1] + if not sort: msg = "Currently supports only sorted output for topk operator." raise AssertionError(msg) @@ -3249,6 +3256,7 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None, default_d graph = script_module.graph.copy() _run_jit_passes(graph) + print(graph) if custom_convert_map: convert_map.update(custom_convert_map) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 8fe8ce2137d7..9ea1a8a0a282 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3241,7 +3241,6 @@ def test_forward_pretrained_bert_base_uncased(): if __name__ == "__main__": - """ # some structural tests test_forward_traced_function() test_forward_dtypes() @@ -3399,10 +3398,8 @@ def test_forward_pretrained_bert_base_uncased(): # Test bert model test_forward_pretrained_bert_base_uncased() - """ # Test object detection models from test_object_detection import test_detection_models - test_detection_models(0) test_detection_models(1) diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index 938cd5a44eab..aa7428d1dcfc 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -62,7 +62,7 @@ def forward(self, inp): return dict_to_tuple(out[0]) -def generate_jit_model(index, img): +def generate_jit_model(index): model_funcs = [torchvision.models.detection.fasterrcnn_resnet50_fpn, torchvision.models.detection.maskrcnn_resnet50_fpn] @@ -79,7 +79,6 @@ def generate_jit_model(index, img): script_out = script_module(inp) assert len(out[0]) > 0 and len(script_out[0]) > 0 - torch._C._jit_pass_inline(script_module.graph) return script_module @@ -94,9 +93,10 @@ def test_detection_models(model_index, score_threshold=0.9): input_name = 'input0' shape_list = [(input_name, input_shape)] - scripted_model = generate_jit_model(model_index, img) + scripted_model = generate_jit_model(model_index) mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + print(mod["main"]) with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): vm_exec = relay.vm.compile(mod, target=target, params=params) From 2346b4ec2e1b46354a52b154fa5fb5b1128d62fd Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Tue, 15 Sep 2020 14:48:26 -0700 Subject: [PATCH 15/23] py format --- python/tvm/relay/frontend/pytorch.py | 99 +++++++++++++------ tests/python/frontend/pytorch/test_forward.py | 18 ++-- .../frontend/pytorch/test_object_detection.py | 30 +++--- 3 files changed, 97 insertions(+), 50 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 0a351163f489..364ee71e0f96 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -130,6 +130,7 @@ def _elemwise(name): def _impl(inputs, input_types): data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) return get_relay_op(name)(data0, data1) + return _impl @@ -289,7 +290,7 @@ def _impl(inputs, input_types): def _slice(): def _impl(inputs, input_types): axis_dtype = "int64" - index_size_limit = 2**63 - 1 + index_size_limit = 2 ** 63 - 1 data = inputs[0] dshape = _infer_shape(data) ndim = len(dshape) @@ -337,8 +338,11 @@ def _impl(inputs, input_types): if isinstance(target_end, int) and target_end >= index_size_limit: # Quick path for original data. - if isinstance(begin, _expr.Constant) and begin.data.asnumpy().tolist()[dim] == 0 \ - and stride == 1: + if ( + isinstance(begin, _expr.Constant) + and begin.data.asnumpy().tolist()[dim] == 0 + and stride == 1 + ): return data target_end = dshape[dim] @@ -357,16 +361,24 @@ def _impl(inputs, input_types): end[dim] = target_end else: target_end = _expr.const(target_end) - end = _op.scatter(end, _op.expand_dims(_expr.const(dim), axis=0), - _op.expand_dims(target_end, axis=0), axis=0) + end = _op.scatter( + end, + _op.expand_dims(_expr.const(dim), axis=0), + _op.expand_dims(target_end, axis=0), + axis=0, + ) else: end = _op.cast(_op.shape_of(data), axis_dtype) if not isinstance(target_end, tvm.tir.Any): ttype = _infer_type(target_end).checked_type.dtype if str(ttype) != axis_dtype: target_end = _op.cast(target_end, axis_dtype) - end = _op.scatter(end, _op.expand_dims(_expr.const(dim), axis=0), - _op.expand_dims(target_end, axis=0), axis=0) + end = _op.scatter( + end, + _op.expand_dims(_expr.const(dim), axis=0), + _op.expand_dims(target_end, axis=0), + axis=0, + ) if not isinstance(end, list): etype = _infer_type(end).checked_type.dtype @@ -521,6 +533,7 @@ def _impl(inputs, input_types): return _impl + def _full_impl(data, fill_value, dtype): size = [] need_reshape = False @@ -547,7 +560,6 @@ def _full_impl(data, fill_value, dtype): size.append(dim) new_shape.append(dim) - if size is None: tmp = [] for dim in data: @@ -559,11 +571,13 @@ def _full_impl(data, fill_value, dtype): out = _op.reshape(out, new_shape) return out + def _ones(default_dtype): def _impl(inputs, input_types): data = inputs[0] import torch + if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)): msg = "Data type %s could not be parsed in ones op" % (type(data)) raise AssertionError(msg) @@ -573,6 +587,7 @@ def _impl(inputs, input_types): else: dtype = default_dtype return _full_impl(data, 1, dtype) + return _impl @@ -599,6 +614,7 @@ def _impl(inputs, input_types): data = inputs[0] import torch + if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)): msg = "Data type %s could not be parsed in zeros op" % (type(data)) raise AssertionError(msg) @@ -608,6 +624,7 @@ def _impl(inputs, input_types): else: dtype = default_dtype return _full_impl(data, 0, dtype) + return _impl @@ -635,6 +652,7 @@ def _impl(inputs, input_types): fill_value = inputs[1] import torch + if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)): msg = "Data type %s could not be parsed in full op" % (type(data)) raise AssertionError(msg) @@ -646,6 +664,7 @@ def _impl(inputs, input_types): dtype = default_dtype return _full_impl(data, fill_value, dtype) + return _impl @@ -1213,6 +1232,7 @@ def _impl(inputs, input_types): if squeeze_axes: out = _op.squeeze(out, axis=squeeze_axes) return out + return _impl @@ -1245,6 +1265,7 @@ def _impl(inputs, input_types): return _op.nn.bias_add(dense_out, bias) else: return dense_out + _expr.const(inputs[0]) + return _impl @@ -1268,6 +1289,7 @@ def _impl(inputs, input_types): if axis is not None: return _expr.const(shape[axis]) return _expr.const(shape) + return _impl @@ -1353,14 +1375,16 @@ def _impl(inputs, input_types): return _impl + def _pixel_shuffle(prelude): def _impl(inputs, input_types): data = inputs[0] upscale_factor = inputs[1] upscale_squared = upscale_factor * upscale_factor b, c, h, w = _infer_shape(data) - assert c % upscale_squared == 0, \ - "input channel should be divisible by square of upscale_factor" + assert ( + c % upscale_squared == 0 + ), "input channel should be divisible by square of upscale_factor" ndims = len(_infer_shape(data, prelude.mod)) axes = list(range(ndims)) @@ -1379,8 +1403,10 @@ def _impl(inputs, input_types): axes = [0, 1, 4, 2, 5, 3] data = _op.transform.transpose(data, axes) return _op.transform.reshape(data, out_shape) + return _impl + def _clone(): def _impl(inputs, input_types): data = inputs[0] @@ -1796,8 +1822,7 @@ def _impl(inputs, input_types): def _to(): def _impl(inputs, input_types): data = inputs[0] - dtype = inputs[1] if inputs[1] is not None and not isinstance(inputs[1], str) \ - else inputs[2] + dtype = inputs[1] if inputs[1] is not None and not isinstance(inputs[1], str) else inputs[2] # special handling for aten::to(data, 6, _, _, _) case # 6 means dtype = float # this happens when converting upsampling with scale factor @@ -1808,12 +1833,7 @@ def _impl(inputs, input_types): 4: "int64", } - cast_func = { - 6: float, - 7: float, - 3: int, - 4: int - } + cast_func = {6: float, 7: float, 3: int, 4: int} ret = data if isinstance(data, _expr.Expr): @@ -1824,6 +1844,7 @@ def _impl(inputs, input_types): ret = cast_func[dtype](data) return ret + return _impl @@ -1910,6 +1931,7 @@ def _impl(inputs, input_types): if str(t0) != str(t1): target = _op.cast(target, t0) return _op.broadcast_to_like(inputs[0], target) + return _impl @@ -2178,6 +2200,7 @@ def _impl(inputs, input_types): return _impl + def _roi_align(prelude): def _impl(inputs, input_types): data = inputs[0] @@ -2192,16 +2215,17 @@ def _impl(inputs, input_types): boxes -= _expr.const(0.5 / spatial_scale) return _op.vision.roi_align(data, boxes, output_size, spatial_scale, sample_ratio) + return _impl + def _unbind(): def _impl(inputs, input_types): data = inputs[0] dim = int(inputs[1]) ishapes = _infer_shape(data) if dim >= len(ishapes): - msg = "Please check input dim, it shouldn't" \ - "be greater than or equal to rank." + msg = "Please check input dim, it shouldn't" "be greater than or equal to rank." raise AttributeError(msg) selections = ishapes[dim] @@ -2213,8 +2237,10 @@ def _impl(inputs, input_types): ret.append(_op.transform.squeeze(res_split[i], axis=[dim])) ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) return ret + return _impl + def _shape_as_tensor(prelude): def _impl(inputs, input_types): is_symbolic_shape = False @@ -2225,21 +2251,25 @@ def _impl(inputs, input_types): break if is_symbolic_shape: - ret = _op.shape_of(inputs[0], dtype='int64') + ret = _op.shape_of(inputs[0], dtype="int64") else: ret = _expr.const(np.array(input_shape), dtype="int64") return ret + return _impl + def _logical_and(): def _impl(inputs, input_types): lhs = _op.cast(inputs[0], "bool") rhs = _op.cast(inputs[1], "bool") return _op.logical_and(lhs, rhs) + return _impl + def _nonzero(is_numpy_style): def _impl(inputs, input_types): data = inputs[0] @@ -2250,8 +2280,10 @@ def _impl(inputs, input_types): # ret = _unbind()([ret, 0], None) raise RuntimeError("as_tuple is not supported yet for nonzero.") return ret + return _impl + def _scatter(): def _impl(inputs, input_types): data = inputs[0] @@ -2259,8 +2291,10 @@ def _impl(inputs, input_types): index = inputs[2] src = inputs[3] return _op.transform.scatter(data, index, src, axis) + return _impl + def _scalar_tensor(): def _impl(inputs, input_types): data = inputs[0] @@ -2274,8 +2308,10 @@ def _impl(inputs, input_types): if isinstance(data, _expr.Constant): data = data.data.asnumpy().tolist() return _expr.const(data, cast_map[type_key]) + return _impl + def _interpolate(): def _impl(inputs, input_types): if isinstance(inputs[1], _expr.Expr): @@ -2283,8 +2319,7 @@ def _impl(inputs, input_types): elif isinstance(inputs[1], list): try: infer_res = [_infer_value(size, {}) for size in inputs[1]] - out_size = [np.asscalar(res.asnumpy().astype(np.int)) - for res in infer_res] + out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res] except Exception: h = _op.expand_dims(inputs[1][0], axis=0) w = _op.expand_dims(inputs[1][1], axis=0) @@ -2307,8 +2342,10 @@ def func(x): return _op.image.resize(x, out_size, "NCHW", method, coord_trans) return func(data) + return _impl + def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" import torch @@ -2636,15 +2673,15 @@ def _get_convert_map(prelude, default_dtype): "aten::index": _index(), "torchvision::nms": _nms(prelude), "aten::logsumexp": _logsumexp(), - "torchvision::roi_align" : _roi_align(prelude), - "aten::unbind" : _unbind(), + "torchvision::roi_align": _roi_align(prelude), + "aten::unbind": _unbind(), "aten::__and__": _logical_and(), - "aten::_shape_as_tensor" : _shape_as_tensor(prelude), - "aten::nonzero" : _nonzero(False), - "aten::nonzero_numpy" : _nonzero(True), - "aten::scatter" : _scatter(), - "aten::scalar_tensor" : _scalar_tensor(), - "aten::__interpolate" : _interpolate(), + "aten::_shape_as_tensor": _shape_as_tensor(prelude), + "aten::nonzero": _nonzero(False), + "aten::nonzero_numpy": _nonzero(True), + "aten::scatter": _scatter(), + "aten::scalar_tensor": _scalar_tensor(), + "aten::__interpolate": _interpolate(), } return convert_map diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9ea1a8a0a282..8488da640a7c 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -182,7 +182,7 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at with torch.no_grad(): baseline_outputs = baseline_model(*baseline_input) - + if isinstance(baseline_outputs, tuple): baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) else: @@ -1682,6 +1682,7 @@ def _gen_rand_inputs(num_boxes): def test_forward_roi_align(): """ROI align""" torch.set_grad_enabled(False) + class ROIAlgin(Module): def __init__(self, output_sizes, spatial_scale=1.0, sampling_ratio=-1, aligned=False): super().__init__() @@ -1691,16 +1692,20 @@ def __init__(self, output_sizes, spatial_scale=1.0, sampling_ratio=-1, aligned=F self.output_sizes = output_sizes def forward(self, *args): - return torchvision.ops.roi_align(args[0], args[1], self.output_sizes, - self.spatial_scale, self.sampling_ratio, - self.aligned) + return torchvision.ops.roi_align( + args[0], + args[1], + self.output_sizes, + self.spatial_scale, + self.sampling_ratio, + self.aligned, + ) in_data = torch.Tensor(np.random.uniform(size=(1, 8, 100, 100))) in_boxes = torch.Tensor(np.random.uniform(0.0, 100.0, size=(35, 4))) in_batch = torch.zeros((35, 1), dtype=torch.float) in_boxes = torch.cat([in_batch, in_boxes], dim=1) - verify_model(ROIAlgin(7), [in_data, in_boxes]) verify_model(ROIAlgin((10, 10), 0.7, 5), [in_data, in_boxes]) verify_model(ROIAlgin(15, 0.9, 3, False), [in_data, in_boxes]) @@ -1745,6 +1750,7 @@ def test_conv3d_transpose(): inp, ) + # Model tests @tvm.testing.uses_gpu def test_resnet18(): @@ -3076,7 +3082,7 @@ def forward(self, data): return torch.nonzero(data, as_tuple=self.as_tuple) inp = torch.Tensor(np.array([[0, 1, 0], [2, 0, 9], [-1, -1, 0]]).astype("float32")) - verify_trace_model(Nonzero(), [inp], ['llvm']) + verify_trace_model(Nonzero(), [inp], ["llvm"]) def test_forward_scatter(): diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index aa7428d1dcfc..16c49019fa84 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -30,11 +30,12 @@ in_size = 512 + def process_image(img): img = cv2.imread(img).astype("float32") img = cv2.resize(img, (in_size, in_size)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = torch.from_numpy(img/255.).permute(2,0,1).float() + img = torch.from_numpy(img / 255.0).permute(2, 0, 1).float() img = torch.unsqueeze(img, axis=0) return img @@ -63,14 +64,16 @@ def forward(self, inp): def generate_jit_model(index): - model_funcs = [torchvision.models.detection.fasterrcnn_resnet50_fpn, - torchvision.models.detection.maskrcnn_resnet50_fpn] + model_funcs = [ + torchvision.models.detection.fasterrcnn_resnet50_fpn, + torchvision.models.detection.maskrcnn_resnet50_fpn, + ] model_func = model_funcs[index] model = TraceWrapper(model_func(pretrained=True)) model.eval() - inp = torch.Tensor(np.random.uniform(0.0, 250.0,size=(1, 3, in_size, in_size))) + inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size))) with torch.no_grad(): out = model(inp) @@ -84,18 +87,19 @@ def generate_jit_model(index): def test_detection_models(model_index, score_threshold=0.9): img = "test_street_small.jpg" - img_url = "https://raw.githubusercontent.com/dmlc/web-data/" \ - "master/gluoncv/detection/street_small.jpg" + img_url = ( + "https://raw.githubusercontent.com/dmlc/web-data/" + "master/gluoncv/detection/street_small.jpg" + ) download(img_url, img) input_shape = (1, 3, in_size, in_size) target = "llvm" - input_name = 'input0' + input_name = "input0" shape_list = [(input_name, input_shape)] scripted_model = generate_jit_model(model_index) - mod, params = relay.frontend.from_pytorch(scripted_model, - shape_list) + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) print(mod["main"]) with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): @@ -127,7 +131,7 @@ def test_detection_models(model_index, score_threshold=0.9): if score >= score_threshold: num_tvm_valid_scores += 1 - assert num_pt_valid_scores == num_tvm_valid_scores, \ - "Output mismatch: Under score threshold {}, Pytorch has {} valid " \ - "boxes while TVM has {}.".format(score_threshold, num_pt_valid_scores, - num_tvm_valid_scores) + assert num_pt_valid_scores == num_tvm_valid_scores, ( + "Output mismatch: Under score threshold {}, Pytorch has {} valid " + "boxes while TVM has {}.".format(score_threshold, num_pt_valid_scores, num_tvm_valid_scores) + ) From 006512b6b9f336887be70621486f41d17dc49a16 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Tue, 15 Sep 2020 15:46:26 -0700 Subject: [PATCH 16/23] remove print --- python/tvm/relay/frontend/pytorch.py | 1 - tests/python/frontend/pytorch/test_object_detection.py | 1 - 2 files changed, 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 364ee71e0f96..75757a230f8d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3293,7 +3293,6 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None, default_d graph = script_module.graph.copy() _run_jit_passes(graph) - print(graph) if custom_convert_map: convert_map.update(custom_convert_map) diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index 16c49019fa84..ec5fe24d9789 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -100,7 +100,6 @@ def test_detection_models(model_index, score_threshold=0.9): scripted_model = generate_jit_model(model_index) mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) - print(mod["main"]) with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): vm_exec = relay.vm.compile(mod, target=target, params=params) From 819ef5c15b5b51a86c16c6eb8a5a69a38f5a2435 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Tue, 15 Sep 2020 18:17:22 -0700 Subject: [PATCH 17/23] More improve --- python/tvm/relay/frontend/pytorch.py | 9 ++------- tests/python/frontend/pytorch/test_object_detection.py | 4 +++- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 75757a230f8d..7817d3bcde89 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1238,8 +1238,6 @@ def _impl(inputs, input_types): def _dense(): def _impl(inputs, input_types): - use_bias = isinstance(inputs[0], _expr.Expr) - data = inputs[1] data_type = input_types[1] weight = inputs[2] @@ -1260,7 +1258,7 @@ def _impl(inputs, input_types): units = _infer_shape(weight_out)[0] dense_out = _op.nn.dense(data, weight_out, units=units) - if use_bias: + if isinstance(inputs[0], _expr.Expr): bias = inputs[0] return _op.nn.bias_add(dense_out, bias) else: @@ -2338,10 +2336,7 @@ def _impl(inputs, input_types): else: coord_trans = "half_pixel" - def func(x): - return _op.image.resize(x, out_size, "NCHW", method, coord_trans) - - return func(data) + return _op.image.resize(data, out_size, "NCHW", method, coord_trans) return _impl diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index ec5fe24d9789..d5043ee16e1f 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -28,7 +28,7 @@ from tvm.contrib.download import download -in_size = 512 +in_size = 300 def process_image(img): @@ -129,6 +129,8 @@ def test_detection_models(model_index, score_threshold=0.9): for score in tvm_scores: if score >= score_threshold: num_tvm_valid_scores += 1 + else: + break assert num_pt_valid_scores == num_tvm_valid_scores, ( "Output mismatch: Under score threshold {}, Pytorch has {} valid " From 8c34780213ed145ab2e8f164f57143067e052881 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Tue, 15 Sep 2020 18:23:23 -0700 Subject: [PATCH 18/23] Fix test --- tests/python/frontend/pytorch/test_forward.py | 5 ----- tests/python/frontend/pytorch/test_object_detection.py | 3 +++ 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 8488da640a7c..a925a1db9508 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3404,8 +3404,3 @@ def test_forward_pretrained_bert_base_uncased(): # Test bert model test_forward_pretrained_bert_base_uncased() - - # Test object detection models - from test_object_detection import test_detection_models - - test_detection_models(1) diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index d5043ee16e1f..5d4652c8115b 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -136,3 +136,6 @@ def test_detection_models(model_index, score_threshold=0.9): "Output mismatch: Under score threshold {}, Pytorch has {} valid " "boxes while TVM has {}.".format(score_threshold, num_pt_valid_scores, num_tvm_valid_scores) ) + +def run_test(): + test_detection_models(1) From afe7398cabec5d9482f45e3d542ef8d4394124e5 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Tue, 15 Sep 2020 18:31:56 -0700 Subject: [PATCH 19/23] Improve addmm --- python/tvm/relay/frontend/pytorch.py | 25 ++++++++----------- .../frontend/pytorch/test_object_detection.py | 1 + 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 7817d3bcde89..c9320a9b2882 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1236,33 +1236,30 @@ def _impl(inputs, input_types): return _impl -def _dense(): +def _addmm(): def _impl(inputs, input_types): - data = inputs[1] + input_mat = inputs[0] + mat1 = inputs[1] data_type = input_types[1] - weight = inputs[2] + mat2 = inputs[2] beta = inputs[3] alpha = inputs[4] if not isinstance(alpha, _expr.Expr) and alpha != 1: alpha = _create_typed_const(alpha, data_type) - data *= alpha + mat1 *= alpha if not isinstance(beta, _expr.Expr) and beta != 1: beta = _create_typed_const(beta, data_type) - weight *= beta + mat2 *= beta - weight_out = _op.transform.transpose(weight, axes=[1, 0]) + transposed_mat2 = _op.transform.transpose(mat2, axes=[1, 0]) - units = _infer_shape(weight_out)[0] - dense_out = _op.nn.dense(data, weight_out, units=units) + units = _infer_shape(transposed_mat2)[0] + dense_out = _op.nn.dense(mat1, transposed_mat2, units=units) - if isinstance(inputs[0], _expr.Expr): - bias = inputs[0] - return _op.nn.bias_add(dense_out, bias) - else: - return dense_out + _expr.const(inputs[0]) + return dense_out + input_mat return _impl @@ -2567,7 +2564,7 @@ def _get_convert_map(prelude, default_dtype): "aten::transpose_": _transpose(prelude), "aten::t": _transpose(prelude), "aten::flatten": _flatten(), - "aten::addmm": _dense(), + "aten::addmm": _addmm(), "aten::size": _size(prelude), "aten::view": _view(), "aten::reshape": _reshape(), diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index 5d4652c8115b..07dadfe2835f 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -137,5 +137,6 @@ def test_detection_models(model_index, score_threshold=0.9): "boxes while TVM has {}.".format(score_threshold, num_pt_valid_scores, num_tvm_valid_scores) ) + def run_test(): test_detection_models(1) From 5e69bbfc50eef23bcc68e2b095336b81ae7140b0 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Wed, 16 Sep 2020 11:07:34 -0700 Subject: [PATCH 20/23] Fix test --- tests/python/frontend/pytorch/test_forward.py | 10 +++++----- tests/python/frontend/pytorch/test_object_detection.py | 9 +++------ 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index a925a1db9508..da874e16aa37 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1684,11 +1684,10 @@ def test_forward_roi_align(): torch.set_grad_enabled(False) class ROIAlgin(Module): - def __init__(self, output_sizes, spatial_scale=1.0, sampling_ratio=-1, aligned=False): + def __init__(self, output_sizes, spatial_scale=1.0, sampling_ratio=-1): super().__init__() self.spatial_scale = spatial_scale self.sampling_ratio = sampling_ratio - self.aligned = aligned self.output_sizes = output_sizes def forward(self, *args): @@ -1698,7 +1697,6 @@ def forward(self, *args): self.output_sizes, self.spatial_scale, self.sampling_ratio, - self.aligned, ) in_data = torch.Tensor(np.random.uniform(size=(1, 8, 100, 100))) @@ -1708,7 +1706,7 @@ def forward(self, *args): verify_model(ROIAlgin(7), [in_data, in_boxes]) verify_model(ROIAlgin((10, 10), 0.7, 5), [in_data, in_boxes]) - verify_model(ROIAlgin(15, 0.9, 3, False), [in_data, in_boxes]) + verify_model(ROIAlgin(15, 0.9, 3), [in_data, in_boxes]) @tvm.testing.uses_gpu @@ -3102,7 +3100,9 @@ def forward(self, data, index, src): in_data = torch.zeros(2, 4) in_index = torch.tensor([[2], [3]]) in_src = torch.rand(2, 1) - verify_model(Scatter(1), input_data=[in_data, in_index, in_src]) + + #TODO: add scatter gpu schedule to enable gpu test. + verify_trace_model(Scatter(1), [in_data, in_index, in_src], ["llvm"]) def test_forward_pretrained_bert_base_uncased(): diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index 07dadfe2835f..a2aeeb8bed99 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -85,7 +85,7 @@ def generate_jit_model(index): return script_module -def test_detection_models(model_index, score_threshold=0.9): +def test_detection_models(): img = "test_street_small.jpg" img_url = ( "https://raw.githubusercontent.com/dmlc/web-data/" @@ -97,8 +97,9 @@ def test_detection_models(model_index, score_threshold=0.9): target = "llvm" input_name = "input0" shape_list = [(input_name, input_shape)] + score_threshold=0.9 - scripted_model = generate_jit_model(model_index) + scripted_model = generate_jit_model(1) mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): @@ -136,7 +137,3 @@ def test_detection_models(model_index, score_threshold=0.9): "Output mismatch: Under score threshold {}, Pytorch has {} valid " "boxes while TVM has {}.".format(score_threshold, num_pt_valid_scores, num_tvm_valid_scores) ) - - -def run_test(): - test_detection_models(1) From 2880b40b60369552926bb098ac6676326604a6ae Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Wed, 16 Sep 2020 11:16:59 -0700 Subject: [PATCH 21/23] Fix format --- tests/python/frontend/pytorch/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index da874e16aa37..8e7fbba46474 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3101,7 +3101,7 @@ def forward(self, data, index, src): in_index = torch.tensor([[2], [3]]) in_src = torch.rand(2, 1) - #TODO: add scatter gpu schedule to enable gpu test. + # TODO: add scatter gpu schedule to enable gpu test. verify_trace_model(Scatter(1), [in_data, in_index, in_src], ["llvm"]) From 98f7386eb092b53db843c9740c4808f6a12bf621 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Wed, 16 Sep 2020 11:18:53 -0700 Subject: [PATCH 22/23] Fix format --- tests/python/frontend/pytorch/test_object_detection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index a2aeeb8bed99..f5197494a345 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -97,7 +97,7 @@ def test_detection_models(): target = "llvm" input_name = "input0" shape_list = [(input_name, input_shape)] - score_threshold=0.9 + score_threshold = 0.9 scripted_model = generate_jit_model(1) mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) From 27c2e97008e3193ad4102335cf28bbd6eb7d55e7 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Wed, 16 Sep 2020 15:10:05 -0700 Subject: [PATCH 23/23] Fix test scatter --- tests/python/frontend/pytorch/test_forward.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 8e7fbba46474..e8a8507158a3 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3095,7 +3095,8 @@ def forward(self, data, index, src): in_data = torch.zeros(3, 5) in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) in_src = torch.rand(2, 5) - verify_model(Scatter(), input_data=[in_data, in_index, in_src]) + # TODO: add scatter gpu schedule to enable gpu test. + verify_trace_model(Scatter(), [in_data, in_index, in_src], ["llvm"]) in_data = torch.zeros(2, 4) in_index = torch.tensor([[2], [3]])