Skip to content

Commit

Permalink
[PYTORCH]Matmul fix for batch_matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed May 15, 2020
1 parent 482e341 commit 094dfc4
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 8 deletions.
50 changes: 42 additions & 8 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,13 +1242,47 @@ def _impl(inputs, input_types):
return chunks
return _impl

def _matmul():
def _impl(inputs, input_types):
data0 = inputs[0]
data1 = inputs[1]
data1_t = _op.transpose(data1, axes=(1, 0))
def _matmul(prelude):
def _impl(inputs, input_types):

inputs_0 = inputs[0]
inputs_1 = inputs[1]

# Need to check input shape as batch matmul must be supported.
a_shape = _infer_shape(inputs_0, prelude.mod)
b_shape = _infer_shape(inputs_1, prelude.mod)

# When performing a batch matmul, we need to properly handle N-dim shapes.
if len(a_shape) > 2 or len(b_shape) > 2:
# Convert a and b into 3 dimensional tensors.
a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]])
# Broadcast b to match batch size of a
new_b_shape = list(_infer_shape(b, prelude.mod))
new_a_shape = _infer_shape(a, prelude.mod)
if new_a_shape[0] > new_b_shape[0]:
new_b_shape[0] = new_a_shape[0]
b = _op.broadcast_to(b, new_b_shape)
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)
# Reshape output to original dimensions.
return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])

# Otherwise a simple dense op will get the job done.
if len(b_shape) == 1:
input_1 = _op.expand_dims(inputs_1, 0, 1)
else:
input_1 = _op.transpose(inputs_1, axes=(1, 0))

out = _op.nn.dense(inputs_0, input_1)

if len(b_shape) == 1:
out = _op.squeeze(out, axis=[-1])

return out

return _op.nn.dense(data0, data1_t)
return _impl


Expand Down Expand Up @@ -1695,7 +1729,7 @@ def _get_convert_map(prelude):
"aten::alpha_dropout" : _dropout(),
"aten::mean" : _mean(),
"aten::chunk" : _chunk(prelude),
"aten::matmul" : _matmul(),
"aten::matmul" : _matmul(prelude),
"aten::expand" : _expand(),
"aten::Int" : _int(),
"prim::NumToTensor" : _numtotensor(),
Expand Down Expand Up @@ -1755,7 +1789,7 @@ def _get_convert_map(prelude):
"aten::rsub" : _rsub(),
"aten::embedding" : _embedding(),
"aten::one_hot" : _one_hot(),
"aten::mm" : _matmul(),
"aten::mm" : _matmul(prelude),
"relay::tensor_array_stack" : _tensor_array_stack(prelude),
"aten::add" : _add(prelude),
"aten::add_" : _add(prelude),
Expand Down
34 changes: 34 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,11 +2064,45 @@ def forward(self, *args):
verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2])


def test_forward_matmul():
torch.set_grad_enabled(False)

class MatMul1(Module):
def forward(self, *args):
return torch.matmul(args[0], args[1])

# matrix x vector
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])

# matrix x matrix
tensor1 = torch.randn(10, 4)
tensor2 = torch.randn(4, 10)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])

# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])

# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])

# batched matrix x batched matrix
tensor1 = torch.randn(1, 12, 14, 64)
tensor2 = torch.randn(1, 12, 64, 14)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])


if __name__ == "__main__":
# Single operator tests
test_forward_add()
test_forward_subtract()
test_forward_multiply()
test_forward_matmul()
test_forward_rsub()
test_forward_onehot()
test_forward_embedding()
Expand Down

0 comments on commit 094dfc4

Please sign in to comment.