Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame committed Jul 23, 2019
1 parent bb2790d commit 9008e71
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 13 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
PatternWildcard = adt.PatternWildcard
PatternVar = adt.PatternVar
PatternConstructor = adt.PatternConstructor
PatternTuple = adt.PatternTuple
Constructor = adt.Constructor
TypeData = adt.TypeData
Clause = adt.Clause
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,9 @@ def define_list_zip(self):
t1 = Var("t1")
t2 = Var("t2")
cons_case = Clause(PatternTuple([PatternConstructor(self.cons,
[PatternVar(h2), PatternVar(t2)]),
[PatternVar(h1), PatternVar(t1)]),
PatternConstructor(self.cons,
[PatternVar(h1), PatternVar(t1)])]),
[PatternVar(h2), PatternVar(t2)])]),
self.cons(Tuple([h1, h2]), self.zip(t1, t2)))
nil_case = Clause(PatternWildcard(), self.nil())
self.mod[self.zip] = Function([l1, l2], Match(Tuple([l1, l2]), [cons_case, nil_case]),
Expand Down
16 changes: 10 additions & 6 deletions python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,14 +311,18 @@ def create_match_check(self, pattern: Pattern, data):
if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)):
return NameConstant(True)

# constructor patterns check whether the constructors match
# and also the matches of any nested patterns
conds = []

# equiv: (arg.tag == patern_constructor.tag)
conds = [ast.Compare(ast.Attribute(data, 'tag', Load()),
[ast.Eq()],
[ast.Num(pattern.constructor.tag)])]
if isinstance(pattern, relay.PatternConstructor):
# constructor patterns check whether the constructors match
# and also the matches of any nested patterns

# equiv: (arg.tag == patern_constructor.tag)
conds.append(ast.Compare(ast.Attribute(data, 'tag', Load()),
[ast.Eq()],
[ast.Num(pattern.constructor.tag)]))

assert isinstance(pattern, (relay.PatternConstructor, relay.PatternTuple))
# now check for any nested patterns
for i in range(len(pattern.patterns)):
nested_pat = pattern.patterns[i]
Expand Down
18 changes: 13 additions & 5 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,19 +210,27 @@ TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data,
auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pat = pattern.as<PatternConstructorNode>();
auto pattern = GetRef<PatternConstructor>(pat);
auto tag = pattern->constructor->tag;
} else if (auto pcn = pattern.as<PatternConstructorNode>()) {
auto tag = pcn->constructor->tag;

size_t field_index = 0;
for (auto& p : pattern->patterns) {
for (auto& p : pcn->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pt = pattern.as<PatternTupleNode>();
CHECK(pt) << "unhandled case: " << pattern;
size_t field_index = 0;
for (auto& p : pt->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
return then_branch;
}
}

Expand Down

0 comments on commit 9008e71

Please sign in to comment.