From 15e2d135348b06e78614e48baa1474901b015621 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 3 Aug 2021 19:06:33 +0000 Subject: [PATCH] fix --- python/tvm/relay/frontend/pytorch.py | 9 ++++++++- tests/python/frontend/pytorch/test_forward.py | 3 +++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b528fff8d455..73c3caffbb6f 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1445,7 +1445,14 @@ def linear(self, inputs, input_types): # 0 - input # 1 - weight bias = inputs[2] - mm_out = self.matmul([inputs[0], _op.transpose(inputs[1], axes=(1, 0))], input_types[:2]) + a_shape = self.infer_shape_with_prelude(inputs[0]) + b_shape = self.infer_shape_with_prelude(inputs[1]) + if len(a_shape) == 2 and len(b_shape) == 2: + mm_out = _op.nn.dense(inputs[0], inputs[1]) + elif len(b_shape) == 1: + mm_out = self.matmul([inputs[0], inputs[1]], input_types[:2]) + else: + mm_out = self.matmul([inputs[0], _op.transpose(inputs[1], axes=(1, 0))], input_types[:2]) if isinstance(bias, _expr.Expr): bias_ndims = len(self.infer_shape_with_prelude(bias)) if bias_ndims == 1: diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index c981c0552cd7..e58575266414 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1569,6 +1569,7 @@ def forward(self, input, weight): return F.linear(input, weight) input2d = torch.rand([2, 2]).float() + input3d = torch.rand([4, 3, 2]).float() weight1d = torch.rand([2]).float() weight2d = torch.rand([2, 2]).float() weight3x2 = torch.rand([3, 2]).float() @@ -1584,6 +1585,8 @@ def forward(self, input, weight): # 2D input, 1D weight, 1D bias is not supported by torch.linear() # 2D input, 1D weight, no bias verify_model(LinearNoBias(), input_data=[input2d, weight1d]) + # 3D input, 2D weight, no bias + verify_model(LinearNoBias(), input_data=[input3d, weight3x2]) # TODO: Add the following cases when matmul(1D, _) is supported by TVM # 1D input, 2D weight, 1D bias # 1D input, 2D weight, no bias