diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index e4d605aa4560a..027d6bd76141a 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -563,6 +563,23 @@ def infer_value_simulated(input_val, params): return output_value +def try_infer_value(val, on_success=None, on_failure=None): + """Try running infer_value on the input val, and if successful, return the inferred value or + pass it to on_success callback if provided. Otherwise, run on_failure callback if it is + provided, or return the input val as output. In each case, the second return value + indicates whether infer_value has succeeded or not. + """ + try: + ret = infer_value(val, {}).asnumpy() + if on_success: + return on_success(ret), True + return ret, True + except Exception: + if on_failure: + return on_failure(), False + return val, False + + def new_var(name_hint, type_annotation=None, shape=None, dtype="float32"): return _expr.var(name_hint, type_annotation, shape, dtype) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9ceb9fc66ec4c..c667b0430f006 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks # pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except -# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension +# pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda """PT: PyTorch frontend.""" import itertools import logging @@ -36,6 +36,7 @@ from .common import AttrCvt, get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value +from .common import try_infer_value from .common import infer_value_simulated as _infer_value_simulated from .common import infer_type as _infer_type from ..prelude import Prelude, StaticTensorArrayOps @@ -185,11 +186,8 @@ def _impl(inputs, input_types): def _get_value(val, dtype): # dtype is a tvm dtype if isinstance(val, _expr.Expr): - try: - ret = _infer_value(_op.cast(val, dtype), {}).asnumpy() - ret = _expr.const(ret, dtype) - except Exception: - ret = _op.cast(val, dtype) + inp = _op.cast(val, dtype) + ret, _ = try_infer_value(inp, lambda ret: _expr.const(ret, dtype)) else: ret = _create_typed_const(val, dtype) return ret @@ -305,10 +303,7 @@ def _impl(inputs, input_types): dim = int(inputs[1]) stride = int(inputs[4]) if isinstance(inputs[2], _expr.Call): - try: - begin[dim] = np.asscalar(_infer_value(inputs[2], {}).asnumpy().astype(np.int)) - except Exception: - begin[dim] = inputs[2] + begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int))) else: begin[dim] = int(inputs[2]) @@ -329,10 +324,9 @@ def _impl(inputs, input_types): target_end = int(inputs[3]) else: if isinstance(inputs[3], _expr.Expr): - try: - target_end = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) - except Exception: - target_end = inputs[3] + target_end, _ = try_infer_value( + inputs[3], lambda ret: np.asscalar(ret.astype(np.int)) + ) else: target_end = inputs[3] @@ -457,10 +451,7 @@ def _impl(inputs, input_types): sort = bool(inputs[4]) if isinstance(inputs[1], _expr.Expr): - try: - k = _infer_value(inputs[1], {}).asnumpy().tolist() - except Exception: - k = inputs[1] + k, _ = try_infer_value(inputs[1], lambda ret: ret.tolist()) else: k = inputs[1] @@ -546,15 +537,15 @@ def _full_impl(data, fill_value, dtype): size.append(dim) new_shape.append(dim) else: - try: - dim = int(_infer_value(dim, {}).asnumpy()) + dim, success = try_infer_value(dim, lambda ret: int(ret), lambda: 0) + new_shape.append(dim) + + if success: if isinstance(size, list): size.append(dim) - new_shape.append(dim) - except Exception: + else: size = None need_reshape = True - new_shape.append(0) else: if isinstance(size, list): size.append(dim) @@ -1346,12 +1337,11 @@ def _impl(inputs, input_types): if isinstance(s, _expr.Constant): tmp_shape.append(int(s.data.asnumpy())) elif isinstance(s, _expr.Expr): - try: - dim = int(_infer_value(s, {}).asnumpy()) - tmp_shape.append(dim) - except Exception: + dim, success = try_infer_value(s, lambda ret: int(ret)) + tmp_shape.append(dim) + + if not success: is_dyn = True - tmp_shape.append(s) else: tmp_shape.append(s) @@ -2312,13 +2302,15 @@ def _impl(inputs, input_types): if isinstance(inputs[1], _expr.Expr): out_size = inputs[1] elif isinstance(inputs[1], list): - try: - infer_res = [_infer_value(size, {}) for size in inputs[1]] - out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res] - except Exception: - h = _op.expand_dims(inputs[1][0], axis=0) - w = _op.expand_dims(inputs[1][1], axis=0) - out_size = _op.concatenate([h, w], axis=0) + out_size = [] + for i in [0, 1]: + size, _ = try_infer_value( + inputs[1][i], + lambda ret: ret.astype(np.int), + lambda: _op.expand_dims(inputs[1][i], axis=0), + ) + out_size.append(size) + out_size = _op.concatenate(out_size, axis=0) data = inputs[0] align_corners = inputs[4]