From 94a75fab518d311845ad9d248dfcd08f44395eac Mon Sep 17 00:00:00 2001 From: padreofthegame Date: Mon, 6 Feb 2023 14:40:40 +0100 Subject: [PATCH] [PyTorch] Fix in matmul function that enables working with all sizes of input tensors --- python/tvm/relay/frontend/pytorch.py | 119 +++++++++--------- tests/python/frontend/pytorch/test_forward.py | 77 +++++++++--- 2 files changed, 121 insertions(+), 75 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index fde2bfb26356..919ac65f504a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1871,70 +1871,77 @@ def baddbmm(self, inputs, _): return beta * input + alpha * _op.nn.batch_matmul(batch1, batch2, transpose_b=False) def matmul(self, inputs, input_types): + assert len(inputs) == 2, "Two tensors to be multiplied are expected." - inputs_0 = inputs[0] - inputs_1 = inputs[1] + a = inputs[0] + b = inputs[1] # Need to check input shape as batch matmul must be supported. - a_shape = self.infer_shape_with_prelude(inputs_0) - b_shape = self.infer_shape_with_prelude(inputs_1) - - # When performing a batch matmul, we need to properly handle N-dim shapes. - if len(a_shape) > 2 and len(b_shape) > 2: - # Convert a into a 3 dimensional tensors. - need_reshape_output = False - if len(a_shape) != 3: - a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]]) - need_reshape_output = True - else: - a = inputs_0 - - # Transpose matrix dimensions of b. - trans_axes = list(range(len(b_shape))) - trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2] - b = _op.transpose(inputs_1, trans_axes) - - # Convert b into a 3 dimensional tensor. Note that the last two dimensions - # are transposed. - if len(b_shape) != 3: - b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]]) - - # Perform a batch matmul. - output = _op.nn.batch_matmul(a, b) - - # Reshape output to original dimensions. - if need_reshape_output: - return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]]) - return output - elif len(a_shape) > 2: - inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]]) - elif len(a_shape) == 1: - return _op.squeeze(_op.nn.matmul(_op.expand_dims(inputs_0, axis=0), inputs_1), axis=[0]) - - if len(b_shape) > 2: - trans_axes = list(range(len(b_shape))) - trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2] - input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, b_shape[-2]]) - elif len(b_shape) == 2: - input_1 = _op.transpose(inputs_1, axes=(1, 0)) - elif len(b_shape) == 1: - input_1 = _op.expand_dims(inputs_1, 0, 1) + a_shape = self.infer_shape_with_prelude(a) + b_shape = self.infer_shape_with_prelude(b) + + a_ndims = len(a_shape) + b_ndims = len(b_shape) - out = _op.nn.dense(inputs_0, input_1) + # Check if both tensors are at least 1D. + if a_ndims == 0 or b_ndims == 0: + msg = "Both arguments to matmul must be at least 1D." + raise AssertionError(msg) + + # Check if tensors can be multiplied. + b_mulaxis = b_shape[-2] if b_ndims > 1 else b_shape[0] + if a_shape[-1] != b_mulaxis: + msg = "Tensors being multiplied do not have compatible shapes." + raise AssertionError(msg) - if len(b_shape) == 1: - out = _op.squeeze(out, axis=[-1]) + # If 1D, remember axis that should be deleted at the end + squeeze_dims = [] + if a_ndims == 1: + a = _op.expand_dims(a, axis=0) + squeeze_dims += [-2] + a_ndims = 2 + a_shape = (1,) + a_shape + + if b_ndims == 1: + b = _op.expand_dims(b, axis=1) + squeeze_dims += [-1] + b_ndims = 2 + b_shape = b_shape + (1,) + + # Compute result + if a_ndims == 2 and b_ndims == 2: + # Result is obtained using matmul + out = _op.nn.dense(a, _op.transpose(b)) + else: + # Result is obtained using batch_matmul + batch_shape = [1] * (max(a_ndims, b_ndims) - 2) - # Reshape output into a N dimensional tensor when a or b dim > 2 - if len(a_shape) > 2: - out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]]) - elif len(b_shape) > 2: - out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]]) - out = _op.reshape( - _op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], b_shape[-1]] + for i, j in enumerate(reversed(a_shape[:-2])): + batch_shape[i] = j + + for i, j in enumerate(reversed(b_shape[:-2])): + # Need to check if axis can be broadcasted + if batch_shape[i] == 1 or j == 1 or batch_shape[i] == j: + batch_shape[i] = max(batch_shape[i], j) + else: + msg = "Batch dimensions are not broadcastable." + raise AssertionError(msg) + + batch_shape = batch_shape[::-1] + + a = _op.broadcast_to(a, batch_shape + list(a_shape[-2:])) + b = _op.broadcast_to(b, batch_shape + list(b_shape[-2:])) + + out = _op.nn.batch_matmul( + _op.reshape(a, [-1, *a_shape[-2:]]), + _op.reshape(b, [-1, *b_shape[-2:]]), + transpose_b=False, ) - return out + out_shape = batch_shape + [a_shape[-2]] + [b_shape[-1]] + out = _op.reshape(out, out_shape) + + return _op.squeeze(out, axis=squeeze_dims) def expand(self, inputs, input_types): data_in = inputs[0] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 82992d287ace..d5cf05309060 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3977,42 +3977,81 @@ class MatMul1(Module): def forward(self, *args): return torch.matmul(args[0], args[1]) - # matrix x vector - tensor1 = torch.randn(3, 4) + # vector x vector - 1D x 1D + tensor1 = torch.randn(4) tensor2 = torch.randn(4) - verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) + verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"]) - # vector x matrix + # vector x matrix - 1D x 2D tensor1 = torch.randn(4) tensor2 = torch.randn(4, 3) - verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) + verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"]) + + # vector x batched_matrix - 1D x ND + tensor1 = torch.randn(5) + tensor2 = torch.randn(2, 3, 5, 4) + verify_model( + MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"] + ) - # matrix x matrix + # matrix x vector - 2D - 1D + tensor1 = torch.randn(3, 4) + tensor2 = torch.randn(4) + verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"]) + + # matrix x matrix - 2D x 2D tensor1 = torch.randn(10, 4) tensor2 = torch.randn(4, 10) verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"]) - # batched matrix x batched matrix - tensor1 = torch.randn(10, 3, 4) - tensor2 = torch.randn(10, 4, 5) + # broadcasted matrix x batched matrix - 2D x ND + tensor1 = torch.randn(10, 4) + tensor2 = torch.randn(2, 3, 4, 5) verify_model( MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"] ) - # batched matrix x broadcasted matrix + # batched matrix x vector - ND x 1D + tensor1 = torch.randn(2, 3, 4, 5) + tensor2 = torch.randn(5) + verify_model( + MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"] + ) + + # batched matrix x broadcasted matrix - ND x 2D tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(4, 5) - verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"]) + verify_model( + MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"] + ) - # broadcasted matrix x batched matrix - tensor1 = torch.randn(10, 4) - tensor2 = torch.randn(3, 4, 5) - verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"]) + # batched matrix x batched matrix - ND x ND + tensor1 = torch.randn(2, 10, 3, 4) + tensor2 = torch.randn(2, 10, 4, 5) + verify_model( + MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"] + ) - # batched matrix x batched matrix - tensor1 = torch.randn(1, 12, 14, 64) - tensor2 = torch.randn(1, 12, 64, 14) - verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) + # batched matrix x broadcasted matrix - ND x ND + tensor1 = torch.randn(2, 5, 3, 4) + tensor2 = torch.randn(2, 1, 4, 5) + verify_model( + MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"] + ) + + # broadcasted matrix x batched matrix - ND x ND + tensor1 = torch.randn(2, 1, 5, 4) + tensor2 = torch.randn(2, 5, 4, 3) + verify_model( + MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"] + ) + + # broadcasted matrix x broadcasted matrix - ND x ND + tensor1 = torch.randn(3, 2, 3, 1, 5, 4) + tensor2 = torch.randn(2, 1, 5, 4, 3) + verify_model( + MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"] + ) def test_forward_index():