From 7bd43095a0807915b996b4584567d72041b093af Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Wed, 27 May 2020 20:26:40 +0530 Subject: [PATCH] [PYTORCH]Minor bug fixes --- python/tvm/relay/frontend/pytorch.py | 58 +++++++++++++++++++++------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index cc7cd4830cd4..f68affd82726 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -34,6 +34,7 @@ from .common import get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value +from .common import infer_value_simulated as _infer_value_simulated from .common import infer_type as _infer_type from ..prelude import Prelude, StaticTensorArrayOps @@ -152,19 +153,33 @@ def _impl(inputs, input_types): def _arange(): def _impl(inputs, input_types): + def _get_value(val, dtype): + if isinstance(val, _expr.Expr): + return _op.cast(val, _convert_data_type(dtype)) + return _create_typed_const(val, dtype) + + def _get_type(val, inp_type): + if isinstance(val, _expr.Expr): + dtype = str(_infer_type(val).checked_type) + return dtype if dtype != "float32" else "float" + return inp_type + if len(inputs) == 5: - dtype = "float" if "float" in input_types[0:1] else _convert_dtype_value(inputs[1]) - start = _create_typed_const(0, dtype) - stop = _create_typed_const(inputs[0], dtype) - step = _create_typed_const(1, dtype) + dtype0 = _get_type(inputs[0], input_types[0]) + dtype = "float" if dtype0 == "float" else _convert_dtype_value(inputs[1]) + start = _get_value(0, dtype) + stop = _get_value(inputs[0], dtype) + step = _get_value(1, dtype) elif len(inputs) == 7: - dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3]) - start = _create_typed_const(inputs[0], dtype) - stop = _create_typed_const(inputs[1], dtype) - step = _create_typed_const(inputs[2], dtype) + types = [_get_type(inputs[i], input_types[i]) for i in range(3)] + dtype = "float" if "float" in types else _convert_dtype_value(inputs[3]) + start = _get_value(inputs[0], dtype) + stop = _get_value(inputs[1], dtype) + step = _get_value(inputs[2], dtype) else: msg = "Unknown number of arguments (%d) to parse." % (len(inputs)) raise AssertionError(msg) + return _op.transform.arange(start=start, stop=stop, step=step, @@ -235,12 +250,18 @@ def _impl(inputs, input_types): begin = [0] * len(end) dim = int(inputs[1]) - begin[dim] = int(inputs[2]) + if isinstance(inputs[2], _expr.Call): + begin[dim] = np.asscalar(_infer_value(inputs[2], {}).asnumpy().astype(np.int)) + else: + begin[dim] = int(inputs[2]) if isinstance(inputs[3], str) and inputs[3].isdigit(): end[dim] = min(end[dim], int(inputs[3])) else: - end[dim] = inputs[3] + if isinstance(inputs[3], _expr.Call): + end[dim] = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) + else: + end[dim] = inputs[3] strides.append(int(inputs[4])) return _op.transform.strided_slice(data, begin, end, strides) @@ -997,7 +1018,10 @@ def _impl(inputs, input_types): def _numtotensor(): def _impl(inputs, input_types): val = inputs[0] - dtype = type(val) + dtype = input_types[0] + + if isinstance(val, _expr.Expr): + return val if isinstance(val, tvm.tir.IntImm): val = val.__int__() @@ -1019,16 +1043,22 @@ def _impl(inputs, input_types): data = inputs[0] if len(inputs) == 3: - new_shape = [inputs[1], _infer_shape(inputs[2])[0]] + shape_inp = [inputs[1], _infer_shape(inputs[2])[0]] else: if isinstance(inputs[1], list): - new_shape = inputs[1] + shape_inp = inputs[1] else: - new_shape = _infer_shape(inputs[1]) + shape_inp = _infer_shape(inputs[1]) + new_shape = shape_inp + for i, shape in enumerate(shape_inp): + if isinstance(shape, _expr.Expr): + val = _infer_value_simulated(shape, {}) + new_shape[i] = np.asscalar(val.asnumpy()) return _op.transform.reshape(data, new_shape) return _impl + def _reshape(): def _impl(inputs, input_types): data = inputs[0]