Skip to content

Commit

Permalink
[Fix] relay onnx frontend bug when [A, B, M, N] * [1, B, N, K]
Browse files Browse the repository at this point in the history
  • Loading branch information
tomoyazhang committed Jan 12, 2022
1 parent 0d2340c commit 5a3ea3b
Showing 1 changed file with 34 additions and 12 deletions.
46 changes: 34 additions & 12 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,6 @@ def flatten_to_nd(x, x_shape, nd=3):
out = _op.reshape(x, fold_constant(newshape))
return out

b_type = infer_type(inputs[1])
# Convert to dense if the second matrix is 2d and non-dynamic
if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
a = flatten_to_nd(inputs[0], a_shape, 2)
b = _op.transpose(inputs[1])
output = _op.nn.dense(a, b, out_dtype=out_dtype)
else:
# Convert a and b into 3 dimensional tensors.
a = flatten_to_nd(inputs[0], a_shape, 3)
b = flatten_to_nd(inputs[1], b_shape, 3)
# Perform a NN batch matmul.
output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False)
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
Expand All @@ -268,6 +256,40 @@ def flatten_to_nd(x, x_shape, nd=3):
],
0,
)

b_type = infer_type(inputs[1])
# Convert to dense if the second matrix is 2d and non-dynamic
if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
a = flatten_to_nd(inputs[0], a_shape, 2)
b = _op.transpose(inputs[1])
output = _op.nn.dense(a, b, out_dtype=out_dtype)
else:
# broadcast a and b
a_broadcasted_shape = _op.concatenate(
[
out_batch,
_op.strided_slice(
a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0]]
),
],
0,
)
b_broadcasted_shape = _op.concatenate(
[
out_batch,
_op.strided_slice(
b_shape, [infer_shape(b_shape)[0] - 2], [infer_shape(a_shape)[0]]
),
],
0,
)
a = _op.transform.broadcast_to(inputs[0], fold_constant(a_broadcasted_shape))
b = _op.transform.broadcast_to(inputs[1], fold_constant(b_broadcasted_shape))
# Convert a and b into 3 dimensional tensors.
a = flatten_to_nd(a, shape_of(a), 3)
b = flatten_to_nd(b, shape_of(b), 3)
# Perform a NN batch matmul.
output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False)
# Reshape output to original dimensions.
final_shape = _op.concatenate(
[
Expand Down

0 comments on commit 5a3ea3b

Please sign in to comment.