Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Fix dtype handling for modules with integer parameters #6311

Merged
merged 3 commits into from
Aug 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2130,6 +2130,7 @@ def _report_missing_conversion(op_names, convert_map):
msg = "The following operators are not implemented: {}".format(missing)
raise NotImplementedError(msg)


def _getattr_attr_name(node):
attribute_names = node.attributeNames()
assert len(attribute_names) == 1
Expand All @@ -2140,6 +2141,7 @@ def _getattr_attr_name(node):
def _getattr_full_name(getattrs):
return ".".join([_getattr_attr_name(node) for node in getattrs])


def _get_pytorch_value_type(typ, default_dtype="float32"):
kind = typ.kind()
if kind == 'TensorType':
Expand All @@ -2162,16 +2164,25 @@ def _get_pytorch_value_type(typ, default_dtype="float32"):
return 'UnsupportedType'


def _get_input_types(op_node, default_dtype="float32"):
def _get_input_types(op_node, outputs, default_dtype="float32"):
"""Returns a TVM dtype for each input nodes derived from the torch type"""
return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype)
for i in op_node.inputs()]

in_types = []
for inp in op_node.inputs():
if inp.node().kind() == "prim::GetAttr":
# GetAttr nodes always return None when we call scalarType() on it
name = inp.debugName()
assert name in outputs
if isinstance(outputs[name], _expr.Var):
in_types.append(outputs[name].type_annotation.dtype)
else:
# For quantized modules with parameters, here we would get
# "prim::GetAttr[name="_packed_params"]". Since the dtype corresponding to
# _packed_params is not needed by quantized ops, we return an arbitrary type.
in_types.append(default_dtype)
else:
in_types.append(_get_pytorch_value_type(inp.type(), default_dtype=default_dtype))

def _get_output_types(op_node, default_dtype="float32"):
"""Returns a TVM dtype for each input nodes derived from the torch type"""
return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype)
for i in op_node.outputs()]
return in_types


def _get_constant(node):
Expand Down Expand Up @@ -2575,7 +2586,8 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude, defau
outputs.update(zip(unpacked_names, loop_out))
else:
relay_op = convert_map[operator]
relay_out = relay_op(inputs, _get_input_types(op_node, default_dtype=default_dtype))
relay_out = relay_op(inputs, _get_input_types(op_node, outputs,
default_dtype=default_dtype))

if isinstance(relay_out, tuple):
# This is for torch operators that return multiple outputs
Expand Down
13 changes: 13 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,6 +2550,19 @@ def fn(t1, t2):
tensor2 = torch.randn(3, 4).to(dtype=dt)
verify_model(fn, input_data=[tensor1, tensor2])

class ModuleWithIntParameters(Module):
def __init__(self, arr):
super().__init__()
self.param = torch.nn.Parameter(torch.LongTensor(arr), requires_grad=False)

def forward(self, x):
return x.long() + self.param

shape = (10, 10)
param = torch.ones(shape, dtype=torch.long)
inp = torch.ones(shape, dtype=torch.int)
verify_model(ModuleWithIntParameters(param), input_data=inp)


def test_weight_names():
tm = torch.jit.trace(torch.nn.Linear(3, 4), [torch.randn(2, 3)])
Expand Down