diff --git a/include/tvm/topi/detail/ravel_unravel.h b/include/tvm/topi/detail/ravel_unravel.h index dd7bcac09a041..e91d6afb666a5 100644 --- a/include/tvm/topi/detail/ravel_unravel.h +++ b/include/tvm/topi/detail/ravel_unravel.h @@ -44,7 +44,9 @@ using namespace tvm::te; */ inline PrimExpr RavelIndex(Array indices, Array shape) { ICHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size"; - ICHECK_GT(indices.size(), 0) << "indices must not be empty"; + if (indices.size() == 0U) { + return 0; + } PrimExpr idx; for (size_t i = 0; i < indices.size(); ++i) { if (i == 0) { diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index acc33d73e826b..f56243d83e88e 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1444,7 +1444,16 @@ def linear(self, inputs, input_types): # 0 - input # 1 - weight bias = inputs[2] - mm_out = self.matmul(inputs[:2], input_types[:2]) + a_shape = self.infer_shape_with_prelude(inputs[0]) + b_shape = self.infer_shape_with_prelude(inputs[1]) + if len(a_shape) == 2 and len(b_shape) == 2: + mm_out = _op.nn.dense(inputs[0], inputs[1]) + elif len(b_shape) == 1: + mm_out = self.matmul([inputs[0], inputs[1]], input_types[:2]) + else: + mm_out = self.matmul( + [inputs[0], _op.transpose(inputs[1], axes=(1, 0))], input_types[:2] + ) if isinstance(bias, _expr.Expr): bias_ndims = len(self.infer_shape_with_prelude(bias)) if bias_ndims == 1: diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index be4d74ed205ac..ae2bedac0b29b 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1569,8 +1569,10 @@ def forward(self, input, weight): return F.linear(input, weight) input2d = torch.rand([2, 2]).float() + input3d = torch.rand([4, 3, 2]).float() weight1d = torch.rand([2]).float() weight2d = torch.rand([2, 2]).float() + weight3x2 = torch.rand([3, 2]).float() bias1d = torch.rand([2]).float() bias2d = torch.rand([2, 2]).float() # 2D input, 2D weight, 1D bias @@ -1579,9 +1581,12 @@ def forward(self, input, weight): verify_model(Linear(), input_data=[input2d, weight2d, bias2d]) # 2D input, 2D weight, no bias verify_model(LinearNoBias(), input_data=[input2d, weight2d]) + verify_model(LinearNoBias(), input_data=[input2d, weight3x2]) # 2D input, 1D weight, 1D bias is not supported by torch.linear() # 2D input, 1D weight, no bias verify_model(LinearNoBias(), input_data=[input2d, weight1d]) + # 3D input, 2D weight, no bias + verify_model(LinearNoBias(), input_data=[input3d, weight3x2]) # TODO: Add the following cases when matmul(1D, _) is supported by TVM # 1D input, 2D weight, 1D bias # 1D input, 2D weight, no bias @@ -3939,6 +3944,7 @@ def test_fn(is_sorted, return_inverse, return_counts): test_forward_logsoftmax() test_forward_sigmoid() test_forward_dense() + test_forward_linear() test_forward_avgpool1d() test_forward_avgpool2d() test_forward_avgpool3d() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 65e8b01fd2b3d..51dadb21b2a25 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -293,6 +293,7 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((2, 3, 4), (-3, -2), (6, 4)) verify_reshape((2, 3, 4), (-4, 1, 2, -2), (1, 2, 3, 4)) verify_reshape((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4)) + verify_reshape((1,), (), ()) def test_reshape_fail():