Skip to content

Commit

Permalink
[Bugfix] [Relay] fix a bug caused by IncompleteTypeNode in EinsumRel …
Browse files Browse the repository at this point in the history
…while doing MergeComposite (#14556)

* [Bugfix] fix a bug caused by IncompleteTypeNode in EinsumRel while doing MergeComposite

Co-authored-by: Rui Wang <[email protected]>
  • Loading branch information
kfeng123 and Rui Wang authored Apr 11, 2023
1 parent f622e7f commit 8554e7a
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/relay/op/tensor/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ bool EinsumRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}

// Check the input tuple consists of tensors with consistent dtype.
if (tensor_tuple->fields[0].as<IncompleteTypeNode>()) {
return false;
}
ICHECK(tensor_tuple->fields[0].as<TensorTypeNode>());
const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
const DataType dtype = first->dtype;
std::vector<Array<PrimExpr>> input_shapes;
Expand Down
47 changes: 46 additions & 1 deletion tests/python/relay/test_pass_merge_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
import tvm
from tvm import relay, tir
from tvm.relay.dataflow_pattern import TupleGetItemPattern, is_op, wildcard
from tvm.relay.dataflow_pattern import TuplePattern, TupleGetItemPattern, is_op, wildcard
from tvm.relay.testing import run_opt_pass


Expand Down Expand Up @@ -979,5 +979,50 @@ def _check_type_false(extract):
check_result(pattern_table_true, before(), expected_true())


def test_einsum_reshape_pattern():
"""Test MergeComposite does not cause error with einsum operator."""

def make_einsum_reshape_pattern():
x = wildcard()
x = is_op("reshape")(x) | x
y = wildcard()
y = is_op("reshape")(y) | y
z = is_op("einsum")(TuplePattern([x, y]))
r = is_op("reshape")(z) | z
return r

pattern_table = [
(
"einsum_reshape",
make_einsum_reshape_pattern(),
)
]

def before():
a = relay.var("a", shape=(10, 10))
b = relay.var("b", shape=(10, 10))
c = relay.reshape(a, [20, 5])
d = relay.reshape(b, [20, 5])
r = relay.einsum([c, d], "...ab,...cb->...ac")
return relay.Function([a, b], r)

def expected():
a = relay.var("a", shape=(10, 10))
b = relay.var("b", shape=(10, 10))
c = relay.reshape(a, [20, 5])
d = relay.reshape(b, [20, 5])
r = relay.einsum([c, d], "...ab,...cb->...ac")
func = relay.Function([a, b], r)
func = func.with_attr("Composite", "einsum_reshape")
func = func.with_attr("PartitionedFromPattern", "reshape_reshape_Tuple_einsum_")

input0 = relay.var("a", shape=(10, 10))
input1 = relay.var("b", shape=(10, 10))
output = func(input0, input1)
return relay.Function([input0, input1], output)

check_result(pattern_table, before(), expected())


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 8554e7a

Please sign in to comment.