Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hzfan committed Aug 3, 2021
1 parent ecdd872 commit 15e2d13
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
9 changes: 8 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,7 +1445,14 @@ def linear(self, inputs, input_types):
# 0 - input
# 1 - weight
bias = inputs[2]
mm_out = self.matmul([inputs[0], _op.transpose(inputs[1], axes=(1, 0))], input_types[:2])
a_shape = self.infer_shape_with_prelude(inputs[0])
b_shape = self.infer_shape_with_prelude(inputs[1])
if len(a_shape) == 2 and len(b_shape) == 2:
mm_out = _op.nn.dense(inputs[0], inputs[1])
elif len(b_shape) == 1:
mm_out = self.matmul([inputs[0], inputs[1]], input_types[:2])
else:
mm_out = self.matmul([inputs[0], _op.transpose(inputs[1], axes=(1, 0))], input_types[:2])
if isinstance(bias, _expr.Expr):
bias_ndims = len(self.infer_shape_with_prelude(bias))
if bias_ndims == 1:
Expand Down
3 changes: 3 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,7 @@ def forward(self, input, weight):
return F.linear(input, weight)

input2d = torch.rand([2, 2]).float()
input3d = torch.rand([4, 3, 2]).float()
weight1d = torch.rand([2]).float()
weight2d = torch.rand([2, 2]).float()
weight3x2 = torch.rand([3, 2]).float()
Expand All @@ -1584,6 +1585,8 @@ def forward(self, input, weight):
# 2D input, 1D weight, 1D bias is not supported by torch.linear()
# 2D input, 1D weight, no bias
verify_model(LinearNoBias(), input_data=[input2d, weight1d])
# 3D input, 2D weight, no bias
verify_model(LinearNoBias(), input_data=[input3d, weight3x2])
# TODO: Add the following cases when matmul(1D, _) is supported by TVM
# 1D input, 2D weight, 1D bias
# 1D input, 2D weight, no bias
Expand Down

0 comments on commit 15e2d13

Please sign in to comment.