diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 374e1c2651cc1..78e9d41e1b5c1 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2335,6 +2335,7 @@ def convert_params(graph, state_dict): params = {} param_tensors = {} packed_param_map = {} + vars = {} seen = set() for node in getattr_nodes: @@ -2352,10 +2353,14 @@ def convert_params(graph, state_dict): assert full_attr in state_dict, err_msg packed_param_map[full_attr_node_name] = full_attr elif full_attr in state_dict: - torch_tensor = state_dict[full_attr] - tensor, var = _get_tensor_and_var(torch_tensor, - full_attr) - param_tensors[full_attr] = tensor + if full_attr in vars: + var = vars[full_attr] + else: + torch_tensor = state_dict[full_attr] + tensor, var = _get_tensor_and_var(torch_tensor, + full_attr) + param_tensors[full_attr] = tensor + vars[full_attr] = var params[full_attr_node_name] = var return params, param_tensors, packed_param_map diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d56496577176b..12d1260a4a50e 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2390,6 +2390,23 @@ def test_weight_names(): assert set(params.keys()) == set(n for n, p in tm.named_parameters()) +def test_duplicate_weight_use(): + # The test cases doesn't make any sense as a neural network, + # the issue popped up in shared input/output embeddings of bert, + # but this is quicker + class Test(Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(5, 3) + + def forward(self, x): + x = self.lin(x) + x = x @ self.lin.weight + return x + + verify_model(Test(), input_data=[torch.randn(5, 5)]) + + def test_forward_matmul(): torch.set_grad_enabled(False) @@ -2556,6 +2573,7 @@ def test_forward_pretrained_bert_base_uncased(): test_forward_traced_function() test_forward_dtypes() test_weight_names() + test_duplicate_weight_use() # Single operator tests test_forward_add()