From 9bd4e4f302fa4ec7899a09912a4bbfd4e326f535 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 24 Jun 2020 06:42:38 +0200 Subject: [PATCH] PyTorch frontend: fix handling of duplicate use of a model weight (#5897) This happens e.g. in shared input/output embeddings in BERT or siamese networks. Thank you @siju-samuel for reporting. --- python/tvm/relay/frontend/pytorch.py | 13 +++++++++---- tests/python/frontend/pytorch/test_forward.py | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 374e1c2651cc..92373036d2f2 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2335,6 +2335,7 @@ def convert_params(graph, state_dict): params = {} param_tensors = {} packed_param_map = {} + vars_by_name = {} seen = set() for node in getattr_nodes: @@ -2352,10 +2353,14 @@ def convert_params(graph, state_dict): assert full_attr in state_dict, err_msg packed_param_map[full_attr_node_name] = full_attr elif full_attr in state_dict: - torch_tensor = state_dict[full_attr] - tensor, var = _get_tensor_and_var(torch_tensor, - full_attr) - param_tensors[full_attr] = tensor + if full_attr in vars_by_name: + var = vars_by_name[full_attr] + else: + torch_tensor = state_dict[full_attr] + tensor, var = _get_tensor_and_var(torch_tensor, + full_attr) + param_tensors[full_attr] = tensor + vars_by_name[full_attr] = var params[full_attr_node_name] = var return params, param_tensors, packed_param_map diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d56496577176..12d1260a4a50 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2390,6 +2390,23 @@ def test_weight_names(): assert set(params.keys()) == set(n for n, p in tm.named_parameters()) +def test_duplicate_weight_use(): + # The test cases doesn't make any sense as a neural network, + # the issue popped up in shared input/output embeddings of bert, + # but this is quicker + class Test(Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(5, 3) + + def forward(self, x): + x = self.lin(x) + x = x @ self.lin.weight + return x + + verify_model(Test(), input_data=[torch.randn(5, 5)]) + + def test_forward_matmul(): torch.set_grad_enabled(False) @@ -2556,6 +2573,7 @@ def test_forward_pretrained_bert_base_uncased(): test_forward_traced_function() test_forward_dtypes() test_weight_names() + test_duplicate_weight_use() # Single operator tests test_forward_add()