From efc0c9b423855e2b8d83ffe369b3cea13c4d8d82 Mon Sep 17 00:00:00 2001 From: Jian Sheng Date: Mon, 16 Aug 2021 12:29:20 -0700 Subject: [PATCH] Cherry-pick PT frontend bug fix: https://github.com/apache/tvm/pull/8622 --- python/tvm/relay/frontend/pytorch.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index acc33d73e826..f56243d83e88 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1444,7 +1444,16 @@ def linear(self, inputs, input_types): # 0 - input # 1 - weight bias = inputs[2] - mm_out = self.matmul(inputs[:2], input_types[:2]) + a_shape = self.infer_shape_with_prelude(inputs[0]) + b_shape = self.infer_shape_with_prelude(inputs[1]) + if len(a_shape) == 2 and len(b_shape) == 2: + mm_out = _op.nn.dense(inputs[0], inputs[1]) + elif len(b_shape) == 1: + mm_out = self.matmul([inputs[0], inputs[1]], input_types[:2]) + else: + mm_out = self.matmul( + [inputs[0], _op.transpose(inputs[1], axes=(1, 0))], input_types[:2] + ) if isinstance(bias, _expr.Expr): bias_ndims = len(self.infer_shape_with_prelude(bias)) if bias_ndims == 1: