Skip to content

Commit

Permalink
[Relay][Pytorch] Fix bug when converting models with torch.nn.Paramet…
Browse files Browse the repository at this point in the history
…erList (apache#16180)

* include index in the attribute name if the node is torch.nn.ParameterList

* add test
  • Loading branch information
mshr-h authored Dec 1, 2023
1 parent 1994f40 commit e9a3b60
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
15 changes: 11 additions & 4 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5134,11 +5134,18 @@ def convert_params(graph, state_dict, source_map, use_parser_friendly_name=False

full_attr = _getattr_full_name(getattrs, attr_name_sep)
full_attr_node_name = _get_output_name(getattrs[-1])
# set variable name by concatenating first consumer's name with full attribute

# check if the node is a torch.nn.ParameterList, and if so, include the index in
# the attribute name as well
# e.g. "weights.1"
if re.search(attr_name_sep + r"\d+$", full_attr):
attr_name = full_attr.split(attr_name_sep)[-2:]
else:
attr_name = [full_attr.split(attr_name_sep)[-1]]

# set variable name by concatenating first consumer's name with attribute name
# e.g. "aten::batch_norm_5.running_mean"
var_name = attr_name_sep.join(
[source_map[_get_users(getattrs[-1])[0]], full_attr.split(attr_name_sep)[-1]]
)
var_name = attr_name_sep.join([source_map[_get_users(getattrs[-1])[0]]] + attr_name)

if full_attr.endswith("_packed_params"): # for quantized models
packed_param_map[full_attr_node_name] = full_attr
Expand Down
19 changes: 19 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5571,6 +5571,25 @@ def test_fn(attn_mask=None, is_causal=False):
verify_model(test_fn(), [query_3d, key_3d, value_3d])


def test_parameterlist():
"""test_parameterlist"""
torch.set_grad_enabled(False)

class ParamListModel(torch.nn.Module):
def __init__(self, num_layer=2):
super().__init__()
self.biases = torch.nn.ParameterList([torch.randn(10)] * num_layer)
self.weights = torch.nn.ParameterList([torch.randn(10, 10)] * num_layer)

def forward(self, x):
for i in range(len(self.weights) - 1):
x = torch.addmm(self.biases[i], x, self.weights[i])
return torch.addmm(self.biases[-1], x, self.weights[-1])

input_data = torch.randn(20, 10)
verify_model(ParamListModel().float().eval(), input_data=input_data)


class TestSetSpan:
"""test structural equal between translated / hand-crafted relay IR with span tagged."""

Expand Down

0 comments on commit e9a3b60

Please sign in to comment.