Skip to content

Commit

Permalink
[PyTorch] Fix in matmul function that enables working with all sizes …
Browse files Browse the repository at this point in the history
…of input tensors
  • Loading branch information
padreofthegame committed Feb 8, 2023
1 parent c36ae1c commit 94a75fa
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 75 deletions.
119 changes: 63 additions & 56 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,70 +1871,77 @@ def baddbmm(self, inputs, _):
return beta * input + alpha * _op.nn.batch_matmul(batch1, batch2, transpose_b=False)

def matmul(self, inputs, input_types):
assert len(inputs) == 2, "Two tensors to be multiplied are expected."

inputs_0 = inputs[0]
inputs_1 = inputs[1]
a = inputs[0]
b = inputs[1]

# Need to check input shape as batch matmul must be supported.
a_shape = self.infer_shape_with_prelude(inputs_0)
b_shape = self.infer_shape_with_prelude(inputs_1)

# When performing a batch matmul, we need to properly handle N-dim shapes.
if len(a_shape) > 2 and len(b_shape) > 2:
# Convert a into a 3 dimensional tensors.
need_reshape_output = False
if len(a_shape) != 3:
a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
need_reshape_output = True
else:
a = inputs_0

# Transpose matrix dimensions of b.
trans_axes = list(range(len(b_shape)))
trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
b = _op.transpose(inputs_1, trans_axes)

# Convert b into a 3 dimensional tensor. Note that the last two dimensions
# are transposed.
if len(b_shape) != 3:
b = _op.reshape(b, [-1, b_shape[-1], b_shape[-2]])

# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)

# Reshape output to original dimensions.
if need_reshape_output:
return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
return output
elif len(a_shape) > 2:
inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
elif len(a_shape) == 1:
return _op.squeeze(_op.nn.matmul(_op.expand_dims(inputs_0, axis=0), inputs_1), axis=[0])

if len(b_shape) > 2:
trans_axes = list(range(len(b_shape)))
trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, b_shape[-2]])
elif len(b_shape) == 2:
input_1 = _op.transpose(inputs_1, axes=(1, 0))
elif len(b_shape) == 1:
input_1 = _op.expand_dims(inputs_1, 0, 1)
a_shape = self.infer_shape_with_prelude(a)
b_shape = self.infer_shape_with_prelude(b)

a_ndims = len(a_shape)
b_ndims = len(b_shape)

out = _op.nn.dense(inputs_0, input_1)
# Check if both tensors are at least 1D.
if a_ndims == 0 or b_ndims == 0:
msg = "Both arguments to matmul must be at least 1D."
raise AssertionError(msg)

# Check if tensors can be multiplied.
b_mulaxis = b_shape[-2] if b_ndims > 1 else b_shape[0]
if a_shape[-1] != b_mulaxis:
msg = "Tensors being multiplied do not have compatible shapes."
raise AssertionError(msg)

if len(b_shape) == 1:
out = _op.squeeze(out, axis=[-1])
# If 1D, remember axis that should be deleted at the end
squeeze_dims = []
if a_ndims == 1:
a = _op.expand_dims(a, axis=0)
squeeze_dims += [-2]
a_ndims = 2
a_shape = (1,) + a_shape

if b_ndims == 1:
b = _op.expand_dims(b, axis=1)
squeeze_dims += [-1]
b_ndims = 2
b_shape = b_shape + (1,)

# Compute result
if a_ndims == 2 and b_ndims == 2:
# Result is obtained using matmul
out = _op.nn.dense(a, _op.transpose(b))
else:
# Result is obtained using batch_matmul
batch_shape = [1] * (max(a_ndims, b_ndims) - 2)

# Reshape output into a N dimensional tensor when a or b dim > 2
if len(a_shape) > 2:
out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
elif len(b_shape) > 2:
out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
out = _op.reshape(
_op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], b_shape[-1]]
for i, j in enumerate(reversed(a_shape[:-2])):
batch_shape[i] = j

for i, j in enumerate(reversed(b_shape[:-2])):
# Need to check if axis can be broadcasted
if batch_shape[i] == 1 or j == 1 or batch_shape[i] == j:
batch_shape[i] = max(batch_shape[i], j)
else:
msg = "Batch dimensions are not broadcastable."
raise AssertionError(msg)

batch_shape = batch_shape[::-1]

a = _op.broadcast_to(a, batch_shape + list(a_shape[-2:]))
b = _op.broadcast_to(b, batch_shape + list(b_shape[-2:]))

out = _op.nn.batch_matmul(
_op.reshape(a, [-1, *a_shape[-2:]]),
_op.reshape(b, [-1, *b_shape[-2:]]),
transpose_b=False,
)

return out
out_shape = batch_shape + [a_shape[-2]] + [b_shape[-1]]
out = _op.reshape(out, out_shape)

return _op.squeeze(out, axis=squeeze_dims)

def expand(self, inputs, input_types):
data_in = inputs[0]
Expand Down
77 changes: 58 additions & 19 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3977,42 +3977,81 @@ class MatMul1(Module):
def forward(self, *args):
return torch.matmul(args[0], args[1])

# matrix x vector
tensor1 = torch.randn(3, 4)
# vector x vector - 1D x 1D
tensor1 = torch.randn(4)
tensor2 = torch.randn(4)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])

# vector x matrix
# vector x matrix - 1D x 2D
tensor1 = torch.randn(4)
tensor2 = torch.randn(4, 3)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])

# vector x batched_matrix - 1D x ND
tensor1 = torch.randn(5)
tensor2 = torch.randn(2, 3, 5, 4)
verify_model(
MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"]
)

# matrix x matrix
# matrix x vector - 2D - 1D
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])

# matrix x matrix - 2D x 2D
tensor1 = torch.randn(10, 4)
tensor2 = torch.randn(4, 10)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])

# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
# broadcasted matrix x batched matrix - 2D x ND
tensor1 = torch.randn(10, 4)
tensor2 = torch.randn(2, 3, 4, 5)
verify_model(
MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"]
)

# batched matrix x broadcasted matrix
# batched matrix x vector - ND x 1D
tensor1 = torch.randn(2, 3, 4, 5)
tensor2 = torch.randn(5)
verify_model(
MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"]
)

# batched matrix x broadcasted matrix - ND x 2D
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])
verify_model(
MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"]
)

# broadcasted matrix x batched matrix
tensor1 = torch.randn(10, 4)
tensor2 = torch.randn(3, 4, 5)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])
# batched matrix x batched matrix - ND x ND
tensor1 = torch.randn(2, 10, 3, 4)
tensor2 = torch.randn(2, 10, 4, 5)
verify_model(
MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"]
)

# batched matrix x batched matrix
tensor1 = torch.randn(1, 12, 14, 64)
tensor2 = torch.randn(1, 12, 64, 14)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
# batched matrix x broadcasted matrix - ND x ND
tensor1 = torch.randn(2, 5, 3, 4)
tensor2 = torch.randn(2, 1, 4, 5)
verify_model(
MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"]
)

# broadcasted matrix x batched matrix - ND x ND
tensor1 = torch.randn(2, 1, 5, 4)
tensor2 = torch.randn(2, 5, 4, 3)
verify_model(
MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"]
)

# broadcasted matrix x broadcasted matrix - ND x ND
tensor1 = torch.randn(3, 2, 3, 1, 5, 4)
tensor2 = torch.randn(2, 1, 5, 4, 3)
verify_model(
MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"]
)


def test_forward_index():
Expand Down

0 comments on commit 94a75fa

Please sign in to comment.