diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 6f6c66ad21079..0e65758a2e1c9 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -498,7 +498,9 @@ class IncompleteTypeNode : public TypeNode { } bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const { - return equal(kind, other->kind); + return + equal(kind, other->kind) && + equal.FreeVarEqualImpl(this, other); } void SHashReduce(SHashReducer hash_reduce) const { diff --git a/tests/python/relay/test_ir_structural_equal_hash.py b/tests/python/relay/test_ir_structural_equal_hash.py index cf626d702f4c6..5295e17d2e8b3 100644 --- a/tests/python/relay/test_ir_structural_equal_hash.py +++ b/tests/python/relay/test_ir_structural_equal_hash.py @@ -107,7 +107,7 @@ def test_func_type_sequal(): ft = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1, tvm.runtime.convert([tp1, tp3]), tvm.runtime.convert([tr1])) - translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1, + translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp2, tvm.runtime.convert([tp2, tp4]), tvm.runtime.convert([tr2])) assert ft == translate_vars