diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 625d0285da732..22990d7d36cb8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -238,6 +238,7 @@ def flatten_to_nd(x, x_shape, nd=3): out = _op.reshape(x, fold_constant(newshape)) return out + print("a_rank={} b_rank={}".format(a_rank, b_rank)) # Determine the output batch dimension. if a_rank > b_rank: out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) @@ -278,11 +279,13 @@ def flatten_to_nd(x, x_shape, nd=3): [ out_batch, _op.strided_slice( - b_shape, [infer_shape(b_shape)[0] - 2], [infer_shape(a_shape)[0]] + b_shape, [infer_shape(b_shape)[0] - 2], [infer_shape(b_shape)[0]] ), ], 0, ) + print(fold_constant(a_broadcasted_shape).data.numpy()) + print(fold_constant(b_broadcasted_shape).data.numpy()) a = _op.transform.broadcast_to(inputs[0], fold_constant(a_broadcasted_shape)) b = _op.transform.broadcast_to(inputs[1], fold_constant(b_broadcasted_shape)) # Convert a and b into 3 dimensional tensors. diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 701906e4be40d..1a68d46d3a80a 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1272,6 +1272,7 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None): verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4)) verify_batch_matmul((4, 32, 16), (16, 32), (4, 32, 32)) verify_batch_matmul((4, 32, 16, 32), (32, 16), (4, 32, 16, 16)) + verify_batch_matmul((4, 32, 16, 32), (1, 32, 32, 16), (4, 32, 16, 16)) # Test transb=False verify_batch_matmul( (2, 3, 4, 3),