Skip to content

Commit

Permalink
add unit test case
Browse files Browse the repository at this point in the history
  • Loading branch information
tomoyazhang committed Jan 13, 2022
1 parent 9bafceb commit d76fba8
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def flatten_to_nd(x, x_shape, nd=3):
out = _op.reshape(x, fold_constant(newshape))
return out

print("a_rank={} b_rank={}".format(a_rank, b_rank))
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
Expand Down Expand Up @@ -278,11 +279,13 @@ def flatten_to_nd(x, x_shape, nd=3):
[
out_batch,
_op.strided_slice(
b_shape, [infer_shape(b_shape)[0] - 2], [infer_shape(a_shape)[0]]
b_shape, [infer_shape(b_shape)[0] - 2], [infer_shape(b_shape)[0]]
),
],
0,
)
print(fold_constant(a_broadcasted_shape).data.numpy())
print(fold_constant(b_broadcasted_shape).data.numpy())
a = _op.transform.broadcast_to(inputs[0], fold_constant(a_broadcasted_shape))
b = _op.transform.broadcast_to(inputs[1], fold_constant(b_broadcasted_shape))
# Convert a and b into 3 dimensional tensors.
Expand Down
1 change: 1 addition & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,7 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None):
verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4))
verify_batch_matmul((4, 32, 16), (16, 32), (4, 32, 32))
verify_batch_matmul((4, 32, 16, 32), (32, 16), (4, 32, 16, 16))
verify_batch_matmul((4, 32, 16, 32), (1, 32, 32, 16), (4, 32, 16, 16))
# Test transb=False
verify_batch_matmul(
(2, 3, 4, 3),
Expand Down

0 comments on commit d76fba8

Please sign in to comment.