diff --git a/tests/python/relay/test_pass_unmatched_cases.py b/tests/python/relay/test_pass_unmatched_cases.py index eb26f08e23056..4f2bb20ad7d68 100644 --- a/tests/python/relay/test_pass_unmatched_cases.py +++ b/tests/python/relay/test_pass_unmatched_cases.py @@ -146,6 +146,44 @@ def test_multiple_constructor_clauses(): assert len(unmatched_cases(match, mod)) == 0 +def test_missing_in_the_middle(): + mod = relay.Module() + p = Prelude(mod) + + v = relay.Var('v') + match = relay.Match(v, [ + # list of length exactly 1 + relay.Clause( + relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.nil, [])]), v), + # empty list + relay.Clause( + relay.PatternConstructor(p.nil, []), v), + # list of length 3 or more + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternWildcard(), + relay.PatternConstructor( + p.cons, + [relay.PatternWildcard(), + relay.PatternConstructor( + p.cons, + [relay.PatternWildcard(), + relay.PatternWildcard()])])]), + v) + ]) + + # fails to match a list of length exactly two + unmatched = unmatched_cases(match, mod) + assert len(unmatched) == 1 + assert isinstance(unmatched[0], relay.PatternConstructor) + assert unmatched[0].constructor == p.cons + assert isinstance(unmatched[0].patterns[1], relay.PatternConstructor) + assert unmatched[0].patterns[1].constructor == p.cons + assert isinstance(unmatched[0].patterns[1].patterns[1], relay.PatternConstructor) + assert unmatched[0].patterns[1].patterns[1].constructor == p.nil + + def test_mixed_adt_constructors(): mod = relay.Module() box = relay.GlobalTypeVar('box')