diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 60319d682f0c..234beec244ba 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -238,24 +238,6 @@ def flatten_to_nd(x, x_shape, nd=3): out = _op.reshape(x, fold_constant(newshape)) return out - b_type = infer_type(inputs[1]) - # Convert to dense if the second matrix is 2d and non-dynamic - if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type): - a = flatten_to_nd(inputs[0], a_shape, 2) - b = _op.transpose(inputs[1]) - output = _op.nn.dense(a, b, out_dtype=out_dtype) - else: - # Convert a and b into 3 dimensional tensors. - a = flatten_to_nd(inputs[0], a_shape, 3) - b = flatten_to_nd(inputs[1], b_shape, 3) - if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]: - # Transpose matrix dimensions of b. - bt = _op.transpose(b, [0, 2, 1]) - # Perform a NT batch matmul. - output = _op.nn.batch_matmul(a, bt, out_dtype=out_dtype) - else: - # Perform a NN batch matmul. - output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False) # Determine the output batch dimension. if a_rank > b_rank: out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) @@ -274,6 +256,42 @@ def flatten_to_nd(x, x_shape, nd=3): ], 0, ) + + b_type = infer_type(inputs[1]) + # Convert to dense if the second matrix is 2d and non-dynamic + if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type): + a = flatten_to_nd(inputs[0], a_shape, 2) + b = _op.transpose(inputs[1]) + output = _op.nn.dense(a, b, out_dtype=out_dtype) + else: + # broadcast a and b + a_broadcasted_shape = _op.concatenate( + [ + out_batch, + _op.strided_slice(a_shape, [a_rank - 2], [a_rank]), + ], + 0, + ) + b_broadcasted_shape = _op.concatenate( + [ + out_batch, + _op.strided_slice(b_shape, [b_rank - 2], [b_rank]), + ], + 0, + ) + a = _op.transform.broadcast_to(inputs[0], fold_constant(a_broadcasted_shape)) + b = _op.transform.broadcast_to(inputs[1], fold_constant(b_broadcasted_shape)) + # Convert a and b into 3 dimensional tensors. + a = flatten_to_nd(a, shape_of(a), 3) + b = flatten_to_nd(b, shape_of(b), 3) + if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]: + # Transpose matrix dimensions of b. + bt = _op.transpose(b, [0, 2, 1]) + # Perform a NT batch matmul. + output = _op.nn.batch_matmul(a, bt, out_dtype=out_dtype) + else: + # Perform a NN batch matmul. + output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False) # Reshape output to original dimensions. final_shape = _op.concatenate( [ diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 2bdc5f77cc7b..2e0d9274a2a7 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1273,6 +1273,7 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None): verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4)) verify_batch_matmul((4, 32, 16), (16, 32), (4, 32, 32)) verify_batch_matmul((4, 32, 16, 32), (32, 16), (4, 32, 16, 16)) + verify_batch_matmul((4, 32, 16, 32), (1, 32, 32, 16), (4, 32, 16, 16)) # Test transb=False verify_batch_matmul( (2, 3, 4, 3),