From b4e5f58cd3e7ab7721046bc30d9c3e24ffedfcee Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Mon, 22 Jun 2020 15:33:04 +0200 Subject: [PATCH] Improve type handling in PyTorch frontend (#5834) * Improve type handling in PyTorch frontend - Use type information from graph for inputs if available. Check against shape information from graph if available. - Allow user to set default dtype (default to float32 for sanity and compatibility). - Implement type promotion to follow PyTorch mechanism. This includes fixing the handling of many "Scalar" overloads in PyTorch binary ops. - Fix arange/linspace type semantics. - Added support for traced functions. (Because it really is about the "self" input handling.) Aside from adding an optional default_dtype keyword argument, this does not change the signature/requirements of from_pytorch. * Fix scalar detection using numpy.isscalar and address other review comments. Thank you @siju-samuel * refine test criteron on qnn_test::test_serialized_modules, fix bool conversion of const --- python/tvm/relay/frontend/pytorch.py | 435 +++++++++++------- tests/python/frontend/pytorch/qnn_test.py | 8 +- tests/python/frontend/pytorch/test_forward.py | 34 +- 3 files changed, 290 insertions(+), 187 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index d3b65102ea2c..f70a64a6c93c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -126,18 +126,7 @@ def _is_quantized_tensor(data, prelude): # operator implementation def _elemwise(name): def _impl(inputs, input_types): - # TODO: Figure out a better way to get typing to work for tensor + scalar - type0 = input_types[0] - if isinstance(inputs[1], _expr.Expr): - type0 = input_types[1] - - type1 = input_types[1] - if isinstance(inputs[0], _expr.Expr): - type1 = input_types[0] - - data0 = _convert_elemwise_input(inputs[0], type0) - data1 = _convert_elemwise_input(inputs[1], type1) - + data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) return get_relay_op(name)(data0, data1) return _impl @@ -145,8 +134,8 @@ def _impl(inputs, input_types): def _unary(name): def _impl(inputs, input_types): input_type = input_types[0] - data = _convert_elemwise_input(inputs[0], input_type) - + # this is just to ensure tensor input + data, = _pytorch_promote_types(inputs[:1], input_types[:1]) return get_relay_op(name)(data) return _impl @@ -154,7 +143,8 @@ def _impl(inputs, input_types): def _log1p(): def _impl(inputs, input_types): # 1_plus_log x = log(x + 1) - one = _expr.const(1, dtype="float32") + dtype, = input_types + one = _expr.const(1, dtype=dtype) return _op.log(inputs[0] + one) return _impl @@ -162,25 +152,40 @@ def _impl(inputs, input_types): def _arange(): def _impl(inputs, input_types): def _get_value(val, dtype): + # dtype is a tvm dtype if isinstance(val, _expr.Expr): - return _op.cast(val, _convert_data_type(dtype)) + return _op.cast(val, dtype) return _create_typed_const(val, dtype) def _get_type(val, inp_type): if isinstance(val, _expr.Expr): dtype = str(_infer_type(val).checked_type) - return dtype if dtype != "float32" else "float" + return dtype return inp_type + # PyTorch arange uses the following type semantics: + # - if a dtype is given, start, stop, step are converted to that dtype + # - if no dtype is given and all args are integral, dtype is int64 + # - if no dtype is given and there is a float arg, dtype is float32 if len(inputs) == 5: dtype0 = _get_type(inputs[0], input_types[0]) - dtype = "float" if dtype0 == "float" else _convert_dtype_value(inputs[1]) + if inputs[1] is not None: + dtype = _convert_dtype_value(inputs[1]) + elif dtype0.startswith("float"): + dtype = "float32" + else: + dtype = "int64" start = _get_value(0, dtype) stop = _get_value(inputs[0], dtype) step = _get_value(1, dtype) elif len(inputs) == 7: types = [_get_type(inputs[i], input_types[i]) for i in range(3)] - dtype = "float" if "float" in types else _convert_dtype_value(inputs[3]) + if inputs[3] is not None: + dtype = _convert_dtype_value(inputs[3]) + elif any([t.startswith("float") for t in types]): + dtype = "float32" + else: + dtype = "int64" start = _get_value(inputs[0], dtype) stop = _get_value(inputs[1], dtype) step = _get_value(inputs[2], dtype) @@ -191,7 +196,7 @@ def _get_type(val, inp_type): return _op.transform.arange(start=start, stop=stop, step=step, - dtype=_convert_data_type(dtype)) + dtype=dtype) return _impl def _squeeze(): @@ -200,6 +205,7 @@ def _impl(inputs, input_types): if len(inputs) == 1: axis = None else: + # TODO (t-vi): why is the cast to int needed? similarly elsewhere axis = [int(inputs[1])] return _op.transform.squeeze(data, axis) @@ -295,7 +301,7 @@ def _impl(inputs, input_types): return _impl def _split_with_sizes(): - def _impl(inputs, inputs_types): + def _impl(inputs, input_types): data = inputs[0] dim = int(inputs[2]) @@ -345,7 +351,7 @@ def _impl(inputs, input_types): def _reciprocal(): def _impl(inputs, input_types): data = inputs[0] - return _expr.const(1.0) / data + return _expr.const(1.0, dtype=input_types[0]) / data return _impl def _repeat(): @@ -373,22 +379,14 @@ def _impl(inputs, input_types): def _addcdiv(): def _impl(inputs, input_types): - data = inputs[0] - c = _expr.const(inputs[3]) - t1 = inputs[1] - t2 = inputs[2] - + data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4]) return data + (c * (t1 / t2)) return _impl def _addcmul(): def _impl(inputs, input_types): - data = inputs[0] - c = _expr.const(inputs[3]) - t1 = inputs[1] - t2 = inputs[2] - + data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4]) return data + (c * (t1 * t2)) return _impl @@ -396,9 +394,7 @@ def _impl(inputs, input_types): def _where(): def _impl(inputs, input_types): cond = inputs[0] - x = inputs[1] - y = inputs[2] - + x, y = _pytorch_promote_types(inputs[1:3], input_types[1:3]) return _op.where(cond, x, y) return _impl @@ -419,7 +415,7 @@ def _impl(inputs, input_types): msg = "Data type %s could not be parsed in ones op" % (type(data)) raise AssertionError(msg) - dtype = _convert_data_type(_convert_dtype_value(inputs[1])) + dtype = _convert_dtype_value(inputs[1]) return _op.full(_expr.const(1), shape, dtype=dtype) return _impl @@ -430,8 +426,8 @@ def _impl(inputs, input_types): out = _op.ones_like(data) # If the input and the output datatype is different, do a cast - dtype = _convert_data_type(_convert_dtype_value(inputs[1])) - if input_types[0] not in dtype: + dtype = _convert_dtype_value(inputs[1]) + if input_types[0] != dtype: out = _op.cast(out, dtype) return out @@ -453,7 +449,7 @@ def _impl(inputs, input_types): msg = "Data type %s could not be parsed in zeros op" % (type(data)) raise AssertionError(msg) - dtype = _convert_data_type(_convert_dtype_value(inputs[1])) + dtype = _convert_dtype_value(inputs[1]) return _op.full(_expr.const(0), shape, dtype=dtype) return _impl @@ -465,7 +461,7 @@ def _impl(inputs, input_types): out = _op.zeros_like(data) # If the input and the output datatype is different, do a cast - dtype = _convert_data_type(_convert_dtype_value(inputs[1])) + dtype = _convert_dtype_value(inputs[1]) if input_types[0] not in dtype: out = _op.cast(out, dtype) @@ -490,7 +486,7 @@ def _impl(inputs, input_types): raise AssertionError(msg) if inputs[2] is not None: # dtype given - dtype = _convert_data_type(_convert_dtype_value(inputs[2])) + dtype = _convert_dtype_value(inputs[2]) else: dtype = data.type_annotation.dtype @@ -505,7 +501,7 @@ def _impl(inputs, input_types): out = _op.full_like(data, _expr.const(fill_value)) # If the input and the output datatype is different, do a cast - dtype = _convert_data_type(_convert_dtype_value(inputs[2])) + dtype = _convert_dtype_value(inputs[2]) if input_types[0] not in dtype: out = _op.cast(out, dtype) @@ -526,7 +522,8 @@ def _impl(inputs, input_types): else: stop = start + step - dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3]) + dtype = ("float32" if inputs[3] is not None + else _convert_dtype_value(inputs[3])) start = _create_typed_const(start, dtype) stop = _create_typed_const(stop, dtype) step = _create_typed_const(step, dtype) @@ -534,7 +531,7 @@ def _impl(inputs, input_types): return _op.transform.arange(start=start, stop=stop, step=step, - dtype=_convert_data_type(dtype)) + dtype=dtype) return _impl @@ -565,35 +562,41 @@ def _impl(inputs, input_types): def _elu(): def _impl(inputs, input_types): data = inputs[0] - alpha = _expr.const(float(inputs[1])) - return alpha * _op.nn.relu(_expr.const(1.0) - _op.exp(data)) + _op.nn.relu(data) + dtype = input_types[0] + alpha = _expr.const(float(inputs[1]), dtype=dtype) + return alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data) return _impl def _celu(): def _impl(inputs, input_types): data = inputs[0] - alpha = _expr.const(float(inputs[1])) - return alpha * _op.nn.relu(_expr.const(1.0) - _op.exp(data / alpha)) + _op.nn.relu(data) + dtype = input_types[0] + alpha = _expr.const(float(inputs[1]), dtype=dtype) + return alpha * _op.nn.relu(_expr.const(1, dtype=dtype) + - _op.exp(data / alpha)) + _op.nn.relu(data) return _impl def _gelu(): def _impl(inputs, input_types): data = inputs[0] + dtype = input_types[0] # gelu is data * normcdf(data) # normcdf expressed as erf because we don't currently have that intrinsic # note that there is also a fastgelu variant approximating normcdf # with tanh and third order polynomials, but this is "true" gelu - return data * (_expr.const(0.5) + - _op.erf(data * _expr.const(0.5**0.5)) * _expr.const(0.5)) + return data * (_expr.const(0.5, dtype=dtype) + + _op.erf(data * _expr.const(0.5**0.5, dtype=dtype)) + * _expr.const(0.5, dtype=dtype)) return _impl def _selu(): def _impl(inputs, input_types): data = inputs[0] # https://pytorch.org/docs/stable/nn.html#selu - alpha = _expr.const(-1.6732632423543772848170429916717) - gamma = _expr.const(1.0507009873554804934193349852946) - return gamma * (alpha * _op.nn.relu(_expr.const(1.0) + dtype = input_types[0] + alpha = _expr.const(-1.6732632423543772848170429916717, dtype=dtype) + gamma = _expr.const(1.0507009873554804934193349852946, dtype=dtype) + return gamma * (alpha * _op.nn.relu(_expr.const(1.0, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data)) return _impl @@ -1112,8 +1115,9 @@ def _impl(inputs, input_types): def _softplus(): def _impl(inputs, input_types): data = inputs[0] - beta = _expr.const(float(inputs[1])) - return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.)) / beta + dtype = input_types[0] + beta = _expr.const(float(inputs[1]), dtype=dtype) + return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1., dtype=dtype)) / beta return _impl def _avg_pool2d(prelude): @@ -1195,6 +1199,7 @@ def _impl(inputs, input_types): def _norm(): def _impl(inputs, input_types): data = inputs[0] + dtype = input_types[0] axis = None keepdims = False if len(inputs) > 3: @@ -1207,7 +1212,7 @@ def _impl(inputs, input_types): elif order == np.NINF: return _op.reduce.min(_op.abs(data), axis=axis, keepdims=keepdims) else: - reci_order = _expr.const(1.0 / order) + reci_order = _expr.const(1.0 / order, dtype=dtype) order = _expr.const(order) return _op.power(_op.reduce.sum(_op.power(_op.abs(data), order), axis=axis, @@ -1239,7 +1244,7 @@ def _impl(inputs, input_types): if unbiased: msg = "Currently only supports standard-deviation calculated via the biased "\ - "estimator. Pytorch's Bessel's correction is not supported." + "estimator. PyTorch's Bessel's correction is not supported." raise NotImplementedError(msg) return _op.reduce.std(data, axis=axis, keepdims=keepdims) @@ -1255,7 +1260,7 @@ def _impl(inputs, input_types): if unbiased: msg = "Currently only supports standard-deviation calculated via the biased "\ - "estimator. Pytorch's Bessel's correction is not supported." + "estimator. PyTorch's Bessel's correction is not supported." raise NotImplementedError(msg) return _op.reduce.variance(data, axis=axis, keepdims=keepdims) @@ -1657,7 +1662,7 @@ def _type_as(): def _impl(inputs, input_types): assert len(inputs) == 2 assert len(input_types) == 2 - return _op.cast(inputs[0], _convert_data_type(input_types[1])) + return _op.cast(inputs[0], input_types[1]) return _impl @@ -1687,20 +1692,13 @@ def _impl(inputs, input_types): def _rsub(): def _impl(inputs, input_types): - # TODO: Figure out a better way to get typing to work for tensor + scalar - type0 = input_types[0] - if isinstance(inputs[1], _expr.Expr): - type0 = input_types[1] + data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) - type1 = input_types[1] - if isinstance(inputs[0], _expr.Expr): - type1 = input_types[0] - - data1 = _convert_elemwise_input(inputs[0], type0) - data0 = _convert_elemwise_input(inputs[1], type1) + # TODO (t-vi): should this also be part of the type promotion? alpha = _expr.const(float(inputs[2])) - return get_relay_op("subtract")(data0, alpha * data1) + # note: rsub means data0 and data1 swap places + return get_relay_op("subtract")(data1, alpha * data0) return _impl @@ -1729,8 +1727,55 @@ def _impl(inputs, input_types): return _impl +def _pytorch_result_type(dtypes, non_tensor_inputs): + """This promotes TVM dtypes like PyTorch would""" + import torch + dtype_map = { + "float64": torch.float64, + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "int64": torch.int64, + "int32": torch.int32, + "int16": torch.int16, + "int8": torch.int8, + "uint8": torch.uint8, + "bool": torch.bool + } + if len(dtypes) > 0: + result_type = dtypes[0] + for dt in dtypes[1:]: + if dt != result_type: # we don't want to work with same types as we + # don't do quantized here (which cannot be promoted?) + result_type = _convert_data_type(str(torch.result_type( + torch.zeros((), dtype=dtype_map[result_type]), + torch.zeros((), dtype=dtype_map[dt])))) + else: + result_type = "bool" # this is the smallest type... + for inp in non_tensor_inputs: + result_type = _convert_data_type( + str(torch.result_type(torch.zeros((), dtype=dtype_map[result_type]), + inp))) + return result_type + +def _pytorch_promote_types(inputs, dtypes): + """This promotes TVM inputs with TVM dtypes passed like PyTorch would""" + tensor_dtypes = [dt for inp, dt in zip(inputs, dtypes) if not np.isscalar(inp)] + non_tensor_inputs = [inp for inp in inputs if np.isscalar(inp)] + result_type = _pytorch_result_type(tensor_dtypes, non_tensor_inputs) + results = [] + for inp, dt in zip(inputs, dtypes): + if np.isscalar(inp): + results.append(_expr.const(inp, dtype=result_type)) + elif dt == result_type: + results.append(inp) + else: + results.append(_op.cast(inp, result_type)) + return results + # Helper functions for operator implementation def _convert_dtype_value(val): + """converts a PyTorch the PyTorch numeric type id to a torch scalar type.""" convert_torch_dtype_map = {7:"torch.float64", 6:"torch.float32", 5:"torch.float16", @@ -1741,12 +1786,19 @@ def _convert_dtype_value(val): 0:"torch.unit8", None:"torch.int64"} # Default is torch.int64 if val in convert_torch_dtype_map: - return convert_torch_dtype_map[val] + return _convert_data_type(convert_torch_dtype_map[val]) else: msg = "Torch data type value %d is not handled yet." % (val) raise NotImplementedError(msg) -def _convert_data_type(input_type): +def _convert_data_type(input_type, default_dtype=None): + """converts the PyTorch scalar type input_type to a TVM dtype. + optionally, default_dtype can be a TVM dtype that is used + if input_type is None (but not when it is unknown)""" + if input_type is None and default_dtype is not None: + return default_dtype + + input_type = input_type.lower() if input_type in ["double", "torch.float64"]: return "float64" elif input_type in ["float", "torch.float32"]: @@ -1763,12 +1815,21 @@ def _convert_data_type(input_type): return "int8" elif input_type in ["byte", "torch.uint8"]: return "uint8" + elif input_type in ["quint8", "torch.quint8"]: + return "quint8" + elif input_type in ["qint8", "torch.qint8"]: + return "qint8" + elif input_type in ["qint32", "torch.qint32"]: + return "qint32" + elif input_type in ["bool", "torch.bool"]: + return "bool" else: - raise NotImplementedError("input_type {} is not handled yet" % (input_type)) - return "float32" + raise NotImplementedError("input_type {} is not handled yet".format(input_type)) + return "float32" # Never reached -def _create_typed_const(data, data_type): - dtype = _convert_data_type(data_type) +def _create_typed_const(data, dtype): + """create a (scalar) constant of given value and dtype. + dtype should be a TVM dtype""" if dtype == "float64": typed_data = _expr.const(np.float64(data), dtype=dtype) @@ -1787,18 +1848,9 @@ def _create_typed_const(data, data_type): elif dtype == "uint8": typed_data = _expr.const(np.uint8(data), dtype=dtype) else: - raise NotImplementedError("input_type {} is not handled yet" % (data_type)) + raise NotImplementedError("input_type {} is not handled yet".format(dtype)) return typed_data -def _convert_elemwise_input(data, input_type): - import torch - if isinstance(data, torch.Tensor): - return _expr.const(data.item(), dtype=_convert_data_type(input_type)) - elif not isinstance(data, _expr.Expr): - return _expr.const(data, dtype=_convert_data_type(input_type)) - else: - return data - def _wrap_const(c): if not isinstance(c, (_expr.Expr, list, tvm.tir.expr.Any)): return _expr.const(c) @@ -1891,6 +1943,7 @@ def _get_convert_map(prelude): "aten::mean" : _mean(prelude), "aten::chunk" : _chunk(prelude), "aten::matmul" : _matmul(prelude), + "aten::bmm" : _matmul(prelude), "aten::expand" : _expand(), "aten::Int" : _int(), "prim::NumToTensor" : _numtotensor(), @@ -1981,12 +2034,13 @@ def _run_jit_passes(graph): def _is_int_seq(seq): + # TODO (t-vi): handle non-int constants? (like numpy.intXX) return len(seq) > 0 and all([isinstance(i, int) for i in seq]) def _get_tensor_and_var(torch_tensor, name): tensor = tvm.nd.array(torch_tensor.cpu().numpy()) - var = _expr.var(name, shape=tensor.shape) + var = _expr.var(name, shape=tensor.shape, dtype=tensor.dtype) return tensor, var @@ -2039,35 +2093,6 @@ def _report_missing_conversion(op_names, convert_map): msg = "The following operators are not implemented: {}".format(missing) raise NotImplementedError(msg) - -def _check_inputs(graph, input_shapes): - """ - Check the graph inputs match the expected number of inputs - and are in the correct format - """ - ir_inputs = _get_graph_input_names(graph) - - if not isinstance(input_shapes, list): - msg = "Graph inputs input_shapes should be list" - raise RuntimeError(msg) - missing_inputs = len(ir_inputs) - len(input_shapes) - if missing_inputs > 0: - msg = "Missing {} graph input(s) in input_shapes".format(missing_inputs) - raise RuntimeError(msg) - - for num, inp in enumerate(input_shapes): - if num < len(ir_inputs): - if not isinstance(inp, tuple): - msg = "Graph input {} is not a tuple".format(num) - raise RuntimeError(msg) - if (len(inp) != 2 or not isinstance(inp[0], str)): - msg = "Graph input {} is not valid, expected ('name', shape)".format(inp) - raise RuntimeError(msg) - else: - msg = "Unused graph input {} in input_shapes".format(inp) - logging.warning(msg) - - def _getattr_attr_name(node): attribute_names = node.attributeNames() assert len(attribute_names) == 1 @@ -2078,37 +2103,38 @@ def _getattr_attr_name(node): def _getattr_full_name(getattrs): return ".".join([_getattr_attr_name(node) for node in getattrs]) +def _get_pytorch_value_type(typ, default_dtype="float32"): + kind = typ.kind() + if kind == 'TensorType': + if typ.scalarType() is None: + # Tensor's type can be unknown if we use torch.jit.script(...) + # Defaults can be passed in, if not it is float32 + logging.warning("Untyped Tensor found, assume it is %s", default_dtype) + return default_dtype + else: + return _convert_data_type(typ.scalarType()) + + elif kind == 'ListType': + return "ListType" + elif kind in ['IntType', 'FloatType', 'BoolType', + 'StringType', 'OptionalType']: + pt_dtype = str(typ).lower() + dtype = pt_dtype if pt_dtype == 'OptionalType' else _convert_data_type(pt_dtype) + return dtype + else: + return 'UnsupportedType' -def _get_input_types(op_node): - """ Returns a torch type for each input nodes """ - input_list_types = [] - for input_node in op_node.inputs(): - in_ty = input_node.type() - input_node_kind = in_ty.kind() - if input_node_kind == 'TensorType': - if in_ty.scalarType() is None: - # Tensor's type can be unknown if we use torch.jit.script(...) - # Defaults to float for now - logging.warning("Untyped Tensor found, assume it is float") - input_list_types.append("float") - else: - input_list_types.append(in_ty.scalarType().lower()) - elif input_node_kind == 'ListType': - input_list_types.append("ListType") - elif input_node_kind in ['IntType', 'FloatType', 'BoolType', - 'StringType', 'OptionalType']: - input_list_types.append(str(in_ty).lower()) - else: - input_list_types.append('UnsupportedType') +def _get_input_types(op_node, default_dtype="float32"): + """Returns a TVM dtype for each input nodes derived from the torch type""" + return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype) + for i in op_node.inputs()] - if op_node.kind() in ['aten::ones', 'aten::zeros']: - node_type = op_node.output().type() - scalar_type = node_type.scalarType() - if scalar_type: - input_list_types[0] = scalar_type.lower() - return input_list_types +def _get_output_types(op_node, default_dtype="float32"): + """Returns a TVM dtype for each input nodes derived from the torch type""" + return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype) + for i in op_node.outputs()] def _get_constant(node): @@ -2120,14 +2146,17 @@ def _get_constant(node): attr_name = attribute_names[0] ty = node.output().type().kind() - if ty in ["IntType", "BoolType"]: + if ty == "IntType": return node.i(attr_name) + elif ty == "BoolType": + return bool(node.i(attr_name)) elif ty in ["FloatType", "LongType"]: return node.f(attr_name) elif ty in ["TensorType", "CompleteTensorType"]: tensor = node.t(attr_name) if len(tensor.shape) == 0: # tensor(0.1) - return float(tensor) + # TODO(t-vi): When is this needed? + return tensor.item() return _wrap_const(tensor.numpy()) elif ty == "DeviceObjType": return node.s(attr_name) @@ -2156,35 +2185,75 @@ def _get_operator_nodes(nodes): return ops -def _get_graph_input_names(graph): - """ Get the graph input names (use after graph copy and run jit passes) """ - # Variable names could change the first time a copy is made and after - # _run_jit_passes is called, expected that those functions already invoked - ir_inputs = _get_input_names(graph) - return ir_inputs[1:] # remove self at the 0th arg - - -def _get_relay_input_vars(graph, input_shapes, prelude): +def _get_relay_input_vars(graph, input_shapes, prelude, is_module=True, default_dtype="float32"): """ Return Relay vars from input shapes and create entries based on expected graph inputs - to allow translation """ - def get_relay_ty(ishape): - if _is_int_seq(ishape) or len(ishape) == 0: - return TensorType(ishape) - elif isinstance(ishape, tuple): - return TupleType([get_relay_ty(elem) for elem in ishape]) - elif isinstance(ishape, list): - assert len(ishape) > 0 - elem_tys = [get_relay_ty(s) for s in ishape] - msg = "List elements should have identical types" - assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg + + graph_inputs = list(graph.inputs()) + if is_module: + # a module has "self" as first input, which we do not need/want + graph_inputs = graph_inputs[1:] + + if not isinstance(input_shapes, list): + msg = "Graph inputs input_shapes should be a list" + raise RuntimeError(msg) + + if len(graph_inputs) != len(input_shapes): + msg = "PyTorch has {} inputs and input_shapes lists {}.".format( + len(graph_inputs), len(input_shapes)) + raise RuntimeError(msg) + + def get_relay_ty(ishape, pt_type): + if pt_type.kind() == 'TensorType': + if not (_is_int_seq(ishape) or len(ishape) == 0): + msg = "Shape for Tensors must be lists of ints" + raise RuntimeError(msg) + if ((pt_type.dim() is not None and pt_type.dim() != len(ishape)) or + (pt_type.sizes() is not None + and any([s1 != s2 for s1, s2 in zip(pt_type.sizes(), ishape)]))): + msg = "Shapes of input list and information in the graph do not match" + raise RuntimeError(msg) + pt_dtype = pt_type.scalarType() + dtype = _convert_data_type(pt_dtype, default_dtype=default_dtype) + return TensorType(ishape, dtype) + elif pt_type.kind() == 'TupleType': + if not isinstance(ishape, tuple): + msg = "Shapes for tuples must be tuples" + raise RuntimeError(msg) + return TupleType([get_relay_ty(elem, pt_t) + for elem, pt_t in zip(ishape, pt_type.elements())]) + elif pt_type.kind() == 'ListType': + if not isinstance(ishape, list): + msg = "Shapes for lists must be lists" + raise RuntimeError(msg) + pt_elemtype = pt_type.getElementType() + elem_tys = [get_relay_ty(s, pt_elemtype) for s in ishape] + if len(elem_tys) > 0 and not all(map(lambda ty: ty == elem_tys[0], elem_tys)): + msg = "List elements need have identical types" + raise RuntimeError(msg) return prelude.l(elem_tys[0]) + elif pt_type.kind() == 'OptionalType': + # we do not support None yet, so we fill in the type + return get_relay_ty(ishape, pt_type.getElementType()) + # TODO: scalar inputs raise NotImplementedError("unsupported input type") - input_types = [(tup[0], get_relay_ty(tup[1])) for tup in input_shapes] input_vars = {} - ir_inputs = _get_graph_input_names(graph) + + for num, inp in enumerate(input_shapes): + if not isinstance(inp, tuple): + msg = "Graph input {} is not a tuple".format(num) + raise RuntimeError(msg) + if (len(inp) != 2 or not isinstance(inp[0], str)): + msg = "Graph input {} is not valid, expected ('name', shape)".format(inp) + raise RuntimeError(msg) + + input_types = [(name, get_relay_ty(shape, gi.type())) + for (name, shape), gi in zip(input_shapes, graph_inputs)] + + ir_inputs = [i.debugName() for i in graph_inputs] for ir_input, (name, itype) in zip(ir_inputs, input_types): inp = _expr.var(name, type_annotation=itype) # Translate from graph input to user input name @@ -2292,19 +2361,22 @@ def convert_params(graph, state_dict): return params, param_tensors, packed_param_map -def convert_block(block, outputs, convert_map, prelude): +def convert_block(block, outputs, convert_map, prelude, default_dtype="float32"): """ Translate Torch "Block", used for prim::If and prim::Loop """ ops = _get_operator_nodes(block.nodes()) ret_names = _get_input_names(block.returnNode()) - return convert_operators(ops, outputs, ret_names, convert_map, prelude) + return convert_operators(ops, outputs, ret_names, convert_map, prelude, + default_dtype=default_dtype) -def convert_if(if_node, outputs, convert_map, prelude): +def convert_if(if_node, outputs, convert_map, prelude, default_dtype="float32"): """ Translate Torch prim::If to Relay If """ cond = outputs[if_node.inputsAt(0).debugName()] blocks = list(if_node.blocks()) - true_branch = convert_block(blocks[0], outputs, convert_map, prelude) - false_branch = convert_block(blocks[1], outputs, convert_map, prelude) + true_branch = convert_block(blocks[0], outputs, convert_map, prelude, + default_dtype=default_dtype) + false_branch = convert_block(blocks[1], outputs, convert_map, prelude, + default_dtype=default_dtype) assert len(true_branch) == 1 and len(false_branch) == 1 return _expr.If(cond, true_branch[0], false_branch[0]) @@ -2424,7 +2496,7 @@ def body(*current_vals): return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)] -def convert_operators(operators, outputs, ret_names, convert_map, prelude): +def convert_operators(operators, outputs, ret_names, convert_map, prelude, default_dtype="float32"): """ Convert each Torch IR operators to Relay equivalent """ for node_name, op_node in operators: operator = op_node.kind() @@ -2450,7 +2522,7 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude): unpacked = _unpack_tuple(inputs[0]) outputs.update(zip(_get_output_names(op_node), unpacked)) elif operator == "prim::If": - if_out = convert_if(op_node, outputs, convert_map, prelude) + if_out = convert_if(op_node, outputs, convert_map, prelude, default_dtype=default_dtype) outputs[node_name] = if_out elif operator == "prim::Loop": loop_out = convert_loop(op_node, outputs, convert_map, prelude) @@ -2459,7 +2531,7 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude): outputs.update(zip(unpacked_names, loop_out)) else: relay_op = convert_map[operator] - relay_out = relay_op(inputs, _get_input_types(op_node)) + relay_out = relay_op(inputs, _get_input_types(op_node, default_dtype=default_dtype)) if isinstance(relay_out, tuple): # This is for torch operators that return multiple outputs @@ -2486,7 +2558,7 @@ def get_all_op_names(graph): return set(node.kind() for node in nodes) -def from_pytorch(script_module, input_shapes, custom_convert_map=None): +def from_pytorch(script_module, input_shapes, custom_convert_map=None, default_dtype="float32"): """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -2512,6 +2584,8 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): params : dict of str to tvm.runtime.NDArray Dict of converted parameters stored in tvm.runtime.ndarray format """ + import torch + mod = tvm.IRModule() prelude = Prelude(mod) @@ -2525,10 +2599,12 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): op_names = get_all_op_names(graph) _report_missing_conversion(op_names, convert_map) - _check_inputs(graph, input_shapes) - params = script_module.state_dict() - outputs = _get_relay_input_vars(graph, input_shapes, prelude) + is_module = isinstance(script_module, torch.jit.ScriptModule) + params = script_module.state_dict() if is_module else {} + outputs = _get_relay_input_vars(graph, input_shapes, prelude, + default_dtype=default_dtype, + is_module=is_module) param_vars, tensors, packed_param_map = convert_params(graph, params) tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} @@ -2546,7 +2622,8 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): convert_map.update(qnn_torch.convert_map) ret = convert_operators(_get_operator_nodes(graph.nodes()), - outputs, ret_name, convert_map, prelude) + outputs, ret_name, convert_map, prelude, + default_dtype=default_dtype) mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0]) diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 8c6c248b0af4..f6c7280d17cf 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -501,6 +501,10 @@ def test_serialized_modules(): runtime.run() tvm_result = runtime.get_output(0).asnumpy() - num_identical = np.sum(tvm_result == pt_result) + # with 0.5ish results, 1e-2 is relative accuracy close to 2**-6. + # for simple layers like here this should be achievable + # with 8 bit quantization + # we only require 90% match just to be sure + num_identical = np.sum(np.abs(tvm_result - pt_result) < 1e-2) match_ratio = num_identical / float(np.prod(tvm_result.shape)) - assert match_ratio > 0.2 + assert match_ratio > 0.90 diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 96e9144e03cc..6ec3110dcf69 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -152,7 +152,8 @@ def verify_model(model_name, input_data=[], assert False, "Unexpected input format" if torch.cuda.is_available(): - baseline_model = baseline_model.cuda() + if isinstance(baseline_model, torch.nn.Module): + baseline_model = baseline_model.cuda() baseline_input = [inp.cuda() for inp in baseline_input] with torch.no_grad(): @@ -163,12 +164,14 @@ def verify_model(model_name, input_data=[], else: baseline_outputs = (baseline_outputs.cpu().numpy(),) - trace = torch.jit.trace(baseline_model, baseline_input).float().eval() + trace = torch.jit.trace(baseline_model, baseline_input) + if isinstance(baseline_model, torch.nn.Module): + trace = trace.float().eval() - if torch.cuda.is_available(): - trace = trace.cuda() - else: - trace = trace.cpu() + if torch.cuda.is_available(): + trace = trace.cuda() + else: + trace = trace.cpu() input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] input_shapes = list(zip(input_names, @@ -2363,6 +2366,23 @@ def forward(self, *args): t2 = torch.rand([1, 3]).float() verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2]) +def test_forward_traced_function(): + def fn(t1, t2): + return t1 + t2 + + tensor1 = torch.randn(3, 4) + tensor2 = torch.randn(3, 4) + verify_model(fn, input_data=[tensor1, tensor2]) + +def test_forward_dtypes(): + def fn(t1, t2): + return 2.5 * t1 + t2 + + for dt in [torch.int32, torch.int64, torch.double]: + tensor1 = torch.randn(3, 4).to(dtype=dt) + tensor2 = torch.randn(3, 4).to(dtype=dt) + verify_model(fn, input_data=[tensor1, tensor2]) + def test_forward_matmul(): torch.set_grad_enabled(False) @@ -2526,6 +2546,8 @@ def test_forward_pretrained_bert_base_uncased(): if __name__ == "__main__": + test_forward_traced_function() + test_forward_dtypes() # Single operator tests test_forward_add() test_forward_subtract()