Skip to content

Commit

Permalink
[Torch] Fix dtype handling for modules with integer parameters (#6311)
Browse files Browse the repository at this point in the history
* return the correct type for GetAttr node

* keep _get_pytorch_value_type intact

* add test and handle quantized param
  • Loading branch information
masahi authored Aug 21, 2020
1 parent 91ea9bc commit 470dfc3
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
30 changes: 21 additions & 9 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2130,6 +2130,7 @@ def _report_missing_conversion(op_names, convert_map):
msg = "The following operators are not implemented: {}".format(missing)
raise NotImplementedError(msg)


def _getattr_attr_name(node):
attribute_names = node.attributeNames()
assert len(attribute_names) == 1
Expand All @@ -2140,6 +2141,7 @@ def _getattr_attr_name(node):
def _getattr_full_name(getattrs):
return ".".join([_getattr_attr_name(node) for node in getattrs])


def _get_pytorch_value_type(typ, default_dtype="float32"):
kind = typ.kind()
if kind == 'TensorType':
Expand All @@ -2162,16 +2164,25 @@ def _get_pytorch_value_type(typ, default_dtype="float32"):
return 'UnsupportedType'


def _get_input_types(op_node, default_dtype="float32"):
def _get_input_types(op_node, outputs, default_dtype="float32"):
"""Returns a TVM dtype for each input nodes derived from the torch type"""
return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype)
for i in op_node.inputs()]

in_types = []
for inp in op_node.inputs():
if inp.node().kind() == "prim::GetAttr":
# GetAttr nodes always return None when we call scalarType() on it
name = inp.debugName()
assert name in outputs
if isinstance(outputs[name], _expr.Var):
in_types.append(outputs[name].type_annotation.dtype)
else:
# For quantized modules with parameters, here we would get
# "prim::GetAttr[name="_packed_params"]". Since the dtype corresponding to
# _packed_params is not needed by quantized ops, we return an arbitrary type.
in_types.append(default_dtype)
else:
in_types.append(_get_pytorch_value_type(inp.type(), default_dtype=default_dtype))

def _get_output_types(op_node, default_dtype="float32"):
"""Returns a TVM dtype for each input nodes derived from the torch type"""
return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype)
for i in op_node.outputs()]
return in_types


def _get_constant(node):
Expand Down Expand Up @@ -2575,7 +2586,8 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude, defau
outputs.update(zip(unpacked_names, loop_out))
else:
relay_op = convert_map[operator]
relay_out = relay_op(inputs, _get_input_types(op_node, default_dtype=default_dtype))
relay_out = relay_op(inputs, _get_input_types(op_node, outputs,
default_dtype=default_dtype))

if isinstance(relay_out, tuple):
# This is for torch operators that return multiple outputs
Expand Down
13 changes: 13 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,6 +2550,19 @@ def fn(t1, t2):
tensor2 = torch.randn(3, 4).to(dtype=dt)
verify_model(fn, input_data=[tensor1, tensor2])

class ModuleWithIntParameters(Module):
def __init__(self, arr):
super().__init__()
self.param = torch.nn.Parameter(torch.LongTensor(arr), requires_grad=False)

def forward(self, x):
return x.long() + self.param

shape = (10, 10)
param = torch.ones(shape, dtype=torch.long)
inp = torch.ones(shape, dtype=torch.int)
verify_model(ModuleWithIntParameters(param), input_data=inp)


def test_weight_names():
tm = torch.jit.trace(torch.nn.Linear(3, 4), [torch.randn(2, 3)])
Expand Down

0 comments on commit 470dfc3

Please sign in to comment.