Skip to content

Commit

Permalink
[Fix] relay onnx frontend bug when [A, B, M, N] * [1, B, N, K] (#9911)
Browse files Browse the repository at this point in the history
* [Fix] relay onnx frontend bug when [A, B, M, N] * [1, B, N, K]

* fix line

Co-authored-by: tomoyazhang <[email protected]>
  • Loading branch information
willzhang4a58 and tomoyazhang authored Jan 15, 2022
1 parent 84ee90c commit 6eb4ed8
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 18 deletions.
54 changes: 36 additions & 18 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,24 +238,6 @@ def flatten_to_nd(x, x_shape, nd=3):
out = _op.reshape(x, fold_constant(newshape))
return out

b_type = infer_type(inputs[1])
# Convert to dense if the second matrix is 2d and non-dynamic
if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
a = flatten_to_nd(inputs[0], a_shape, 2)
b = _op.transpose(inputs[1])
output = _op.nn.dense(a, b, out_dtype=out_dtype)
else:
# Convert a and b into 3 dimensional tensors.
a = flatten_to_nd(inputs[0], a_shape, 3)
b = flatten_to_nd(inputs[1], b_shape, 3)
if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]:
# Transpose matrix dimensions of b.
bt = _op.transpose(b, [0, 2, 1])
# Perform a NT batch matmul.
output = _op.nn.batch_matmul(a, bt, out_dtype=out_dtype)
else:
# Perform a NN batch matmul.
output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False)
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
Expand All @@ -274,6 +256,42 @@ def flatten_to_nd(x, x_shape, nd=3):
],
0,
)

b_type = infer_type(inputs[1])
# Convert to dense if the second matrix is 2d and non-dynamic
if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
a = flatten_to_nd(inputs[0], a_shape, 2)
b = _op.transpose(inputs[1])
output = _op.nn.dense(a, b, out_dtype=out_dtype)
else:
# broadcast a and b
a_broadcasted_shape = _op.concatenate(
[
out_batch,
_op.strided_slice(a_shape, [a_rank - 2], [a_rank]),
],
0,
)
b_broadcasted_shape = _op.concatenate(
[
out_batch,
_op.strided_slice(b_shape, [b_rank - 2], [b_rank]),
],
0,
)
a = _op.transform.broadcast_to(inputs[0], fold_constant(a_broadcasted_shape))
b = _op.transform.broadcast_to(inputs[1], fold_constant(b_broadcasted_shape))
# Convert a and b into 3 dimensional tensors.
a = flatten_to_nd(a, shape_of(a), 3)
b = flatten_to_nd(b, shape_of(b), 3)
if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]:
# Transpose matrix dimensions of b.
bt = _op.transpose(b, [0, 2, 1])
# Perform a NT batch matmul.
output = _op.nn.batch_matmul(a, bt, out_dtype=out_dtype)
else:
# Perform a NN batch matmul.
output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False)
# Reshape output to original dimensions.
final_shape = _op.concatenate(
[
Expand Down
1 change: 1 addition & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,7 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None):
verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4))
verify_batch_matmul((4, 32, 16), (16, 32), (4, 32, 32))
verify_batch_matmul((4, 32, 16, 32), (32, 16), (4, 32, 16, 16))
verify_batch_matmul((4, 32, 16, 32), (1, 32, 32, 16), (4, 32, 16, 16))
# Test transb=False
verify_batch_matmul(
(2, 3, 4, 3),
Expand Down

0 comments on commit 6eb4ed8

Please sign in to comment.