From 018e8453aa38455aa1561224b2e039b1e3f647a8 Mon Sep 17 00:00:00 2001 From: masahi Date: Fri, 21 Aug 2020 11:29:45 +0900 Subject: [PATCH] [Torch] Fix dtype handling for modules with integer parameters (#6311) * return the correct type for GetAttr node * keep _get_pytorch_value_type intact * add test and handle quantized param --- python/tvm/relay/frontend/pytorch.py | 30 +++++++++++++------ tests/python/frontend/pytorch/test_forward.py | 13 ++++++++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 85dd5f4ce48ba..8725a64a689c3 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2130,6 +2130,7 @@ def _report_missing_conversion(op_names, convert_map): msg = "The following operators are not implemented: {}".format(missing) raise NotImplementedError(msg) + def _getattr_attr_name(node): attribute_names = node.attributeNames() assert len(attribute_names) == 1 @@ -2140,6 +2141,7 @@ def _getattr_attr_name(node): def _getattr_full_name(getattrs): return ".".join([_getattr_attr_name(node) for node in getattrs]) + def _get_pytorch_value_type(typ, default_dtype="float32"): kind = typ.kind() if kind == 'TensorType': @@ -2162,16 +2164,25 @@ def _get_pytorch_value_type(typ, default_dtype="float32"): return 'UnsupportedType' -def _get_input_types(op_node, default_dtype="float32"): +def _get_input_types(op_node, outputs, default_dtype="float32"): """Returns a TVM dtype for each input nodes derived from the torch type""" - return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype) - for i in op_node.inputs()] - + in_types = [] + for inp in op_node.inputs(): + if inp.node().kind() == "prim::GetAttr": + # GetAttr nodes always return None when we call scalarType() on it + name = inp.debugName() + assert name in outputs + if isinstance(outputs[name], _expr.Var): + in_types.append(outputs[name].type_annotation.dtype) + else: + # For quantized modules with parameters, here we would get + # "prim::GetAttr[name="_packed_params"]". Since the dtype corresponding to + # _packed_params is not needed by quantized ops, we return an arbitrary type. + in_types.append(default_dtype) + else: + in_types.append(_get_pytorch_value_type(inp.type(), default_dtype=default_dtype)) -def _get_output_types(op_node, default_dtype="float32"): - """Returns a TVM dtype for each input nodes derived from the torch type""" - return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype) - for i in op_node.outputs()] + return in_types def _get_constant(node): @@ -2575,7 +2586,8 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude, defau outputs.update(zip(unpacked_names, loop_out)) else: relay_op = convert_map[operator] - relay_out = relay_op(inputs, _get_input_types(op_node, default_dtype=default_dtype)) + relay_out = relay_op(inputs, _get_input_types(op_node, outputs, + default_dtype=default_dtype)) if isinstance(relay_out, tuple): # This is for torch operators that return multiple outputs diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d5b4ed2fc9c88..e5c9634544500 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2550,6 +2550,19 @@ def fn(t1, t2): tensor2 = torch.randn(3, 4).to(dtype=dt) verify_model(fn, input_data=[tensor1, tensor2]) + class ModuleWithIntParameters(Module): + def __init__(self, arr): + super().__init__() + self.param = torch.nn.Parameter(torch.LongTensor(arr), requires_grad=False) + + def forward(self, x): + return x.long() + self.param + + shape = (10, 10) + param = torch.ones(shape, dtype=torch.long) + inp = torch.ones(shape, dtype=torch.int) + verify_model(ModuleWithIntParameters(param), input_data=inp) + def test_weight_names(): tm = torch.jit.trace(torch.nn.Linear(3, 4), [torch.randn(2, 3)])