diff --git a/src/relay/op/tensor/math.cc b/src/relay/op/tensor/math.cc index 6d1dabb497e0..ef3ac8accbf2 100644 --- a/src/relay/op/tensor/math.cc +++ b/src/relay/op/tensor/math.cc @@ -59,6 +59,10 @@ bool EinsumRel(const Array& types, int num_inputs, const Attrs& attrs, } // Check the input tuple consists of tensors with consistent dtype. + if (tensor_tuple->fields[0].as()) { + return false; + } + ICHECK(tensor_tuple->fields[0].as()); const auto& first = Downcast(tensor_tuple->fields[0]); const DataType dtype = first->dtype; std::vector> input_shapes; diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index eefec17aa30a..739db69e10f1 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -18,7 +18,7 @@ import pytest import tvm from tvm import relay, tir -from tvm.relay.dataflow_pattern import TupleGetItemPattern, is_op, wildcard +from tvm.relay.dataflow_pattern import TuplePattern, TupleGetItemPattern, is_op, wildcard from tvm.relay.testing import run_opt_pass @@ -979,5 +979,50 @@ def _check_type_false(extract): check_result(pattern_table_true, before(), expected_true()) +def test_einsum_reshape_pattern(): + """Test MergeComposite does not cause error with einsum operator.""" + + def make_einsum_reshape_pattern(): + x = wildcard() + x = is_op("reshape")(x) | x + y = wildcard() + y = is_op("reshape")(y) | y + z = is_op("einsum")(TuplePattern([x, y])) + r = is_op("reshape")(z) | z + return r + + pattern_table = [ + ( + "einsum_reshape", + make_einsum_reshape_pattern(), + ) + ] + + def before(): + a = relay.var("a", shape=(10, 10)) + b = relay.var("b", shape=(10, 10)) + c = relay.reshape(a, [20, 5]) + d = relay.reshape(b, [20, 5]) + r = relay.einsum([c, d], "...ab,...cb->...ac") + return relay.Function([a, b], r) + + def expected(): + a = relay.var("a", shape=(10, 10)) + b = relay.var("b", shape=(10, 10)) + c = relay.reshape(a, [20, 5]) + d = relay.reshape(b, [20, 5]) + r = relay.einsum([c, d], "...ab,...cb->...ac") + func = relay.Function([a, b], r) + func = func.with_attr("Composite", "einsum_reshape") + func = func.with_attr("PartitionedFromPattern", "reshape_reshape_Tuple_einsum_") + + input0 = relay.var("a", shape=(10, 10)) + input1 = relay.var("b", shape=(10, 10)) + output = func(input0, input1) + return relay.Function([input0, input1], output) + + check_result(pattern_table, before(), expected()) + + if __name__ == "__main__": tvm.testing.main()