diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index e94ef411d29d..db808eb3e70c 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -95,6 +95,7 @@ PatternWildcard = adt.PatternWildcard PatternVar = adt.PatternVar PatternConstructor = adt.PatternConstructor +PatternTuple = adt.PatternTuple Constructor = adt.Constructor TypeData = adt.TypeData Clause = adt.Clause diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index d340d7df3ff3..0c1a059bd852 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -246,9 +246,9 @@ def define_list_zip(self): t1 = Var("t1") t2 = Var("t2") cons_case = Clause(PatternTuple([PatternConstructor(self.cons, - [PatternVar(h2), PatternVar(t2)]), + [PatternVar(h1), PatternVar(t1)]), PatternConstructor(self.cons, - [PatternVar(h1), PatternVar(t1)])]), + [PatternVar(h2), PatternVar(t2)])]), self.cons(Tuple([h1, h2]), self.zip(t1, t2))) nil_case = Clause(PatternWildcard(), self.nil()) self.mod[self.zip] = Function([l1, l2], Match(Tuple([l1, l2]), [cons_case, nil_case]), diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index c003fe788a11..d661be73ad02 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -311,14 +311,18 @@ def create_match_check(self, pattern: Pattern, data): if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)): return NameConstant(True) - # constructor patterns check whether the constructors match - # and also the matches of any nested patterns + conds = [] - # equiv: (arg.tag == patern_constructor.tag) - conds = [ast.Compare(ast.Attribute(data, 'tag', Load()), - [ast.Eq()], - [ast.Num(pattern.constructor.tag)])] + if isinstance(pattern, relay.PatternConstructor): + # constructor patterns check whether the constructors match + # and also the matches of any nested patterns + # equiv: (arg.tag == patern_constructor.tag) + conds.append(ast.Compare(ast.Attribute(data, 'tag', Load()), + [ast.Eq()], + [ast.Num(pattern.constructor.tag)])) + + assert isinstance(pattern, (relay.PatternConstructor, relay.PatternTuple)) # now check for any nested patterns for i in range(len(pattern.patterns)): nested_pat = pattern.patterns[i] diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 11a3d0403067..9841bd839488 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -210,19 +210,27 @@ TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data, auto pattern = GetRef(pat); auto cond = std::make_shared(pattern->var, data); return TreeBranchNode::Make(cond, then_branch, else_branch); - } else { - auto pat = pattern.as(); - auto pattern = GetRef(pat); - auto tag = pattern->constructor->tag; + } else if (auto pcn = pattern.as()) { + auto tag = pcn->constructor->tag; size_t field_index = 0; - for (auto& p : pattern->patterns) { + for (auto& p : pcn->patterns) { auto d = std::make_shared(data, field_index); then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch); field_index++; } auto cond = std::make_shared(data, tag); return TreeBranchNode::Make(cond, then_branch, else_branch); + } else { + auto pt = pattern.as(); + CHECK(pt) << "unhandled case: " << pattern; + size_t field_index = 0; + for (auto& p : pt->patterns) { + auto d = std::make_shared(data, field_index); + then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch); + field_index++; + } + return then_branch; } }