From 5701d2e78dff96f3a39cd390a1ce0c54122aa38a Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 13 Jun 2019 09:02:26 -0700 Subject: [PATCH] [Relay] Check match expressions for completeness (#3203) --- include/tvm/relay/pass.h | 31 +- python/tvm/relay/ir_pass.py | 18 ++ python/tvm/relay/prelude.py | 2 - src/relay/pass/match_exhaustion.cc | 250 ++++++++++++++++ src/relay/pass/type_infer.cc | 9 + .../python/relay/test_pass_unmatched_cases.py | 267 ++++++++++++++++++ 6 files changed, 574 insertions(+), 3 deletions(-) create mode 100644 src/relay/pass/match_exhaustion.cc create mode 100644 tests/python/relay/test_pass_unmatched_cases.py diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 81587339f2ad..977bb6793bb5 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -122,6 +122,24 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2); */ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2); +/*! + * \brief Compare two patterns for structural equivalence. + * + * This comparison operator respects scoping and compares + * patterns without regard to variable choice. + * + * For example: `A(x, _, y)` is equal to `A(z, _, a)`. + * + * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + * for more details. + * + * \param t1 The left hand pattern. + * \param t2 The right hand pattern. + * + * \return true if equal, otherwise false + */ +TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2); + /*! * \brief Add abstraction over a function * @@ -400,8 +418,19 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod); TVM_DLL Expr ToGraphNormalForm(const Expr& e); /*! - * \brief Aggressive constant propagation/constant folding/inlining. + * \brief Finds cases that the given match expression does not catch, if any. + * + * \param match the match expression to test + * + * \param mod The module used for accessing global type var definitions, can be None. * + * \return Returns a list of cases (as patterns) that are not handled by the match + * expression. + */ +TVM_DLL Array UnmatchedCases(const Match& match, const Module& mod); + +/*! + * \brief Aggressive constant propagation/constant folding/inlining. * It will do as much computation in compile time as possible. * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). * As a side effect, code size will explode. diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index ea34c6b1958b..8f1ceded76dd 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -652,3 +652,21 @@ def partial_evaluate(expr): The output expression. """ return _ir_pass.partial_evaluate(expr) + +def unmatched_cases(match, mod=None): + """ + Finds cases that the match expression does not catch, if any. + + Parameters + ---------- + match : tvm.relay.Match + The match expression + mod : Optional[tvm.relay.Module] + The module (defaults to an empty module) + + Returns + ------- + missing_patterns : [tvm.relay.Pattern] + Patterns that the match expression does not catch. + """ + return _ir_pass.unmatched_cases(match, mod) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index da75b9d00e13..17df61750afd 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -39,7 +39,6 @@ def define_list_adt(self): self.cons = Constructor("cons", [a, self.l(a)], self.l) self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons]) - def define_list_hd(self): """Defines a function to get the head of a list. Assume the list has at least one element. @@ -54,7 +53,6 @@ def define_list_hd(self): cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y) self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a]) - def define_list_tl(self): """Defines a function to get the tail of a list. diff --git a/src/relay/pass/match_exhaustion.cc b/src/relay/pass/match_exhaustion.cc new file mode 100644 index 000000000000..173d6eacf528 --- /dev/null +++ b/src/relay/pass/match_exhaustion.cc @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file match_exhaustion.cc + * \brief Checking Relay match expression exhaustiveness. + * + * This file implements a function that checks whether a match + * expression is exhaustive, that is, whether a given match clause + * matches every possible case. This is important for ensuring + * code correctness, since hitting an unmatched case results in a + * dynamic error unless exhaustiveness is checked in advance. + */ +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Possible pattern match results */ +enum MatchResult : int { + kMatch = 0, // pattern matches + kClash = 1, // pattern conflicts + kUnspecified = 2, // ambiguous: candidate needs more constructors specified +}; + +class CandidateChecker : public PatternFunctor { + public: + explicit CandidateChecker() {} + + MatchResult Check(const Pattern& pat, const Pattern& candidate) { + return this->VisitPattern(pat, candidate); + } + + // for a constructor pattern, we must ensure that the candidate is + // a ConstructorPattern, that it has the same constructor, and + // that its fields match the subpatterns. + MatchResult VisitPattern_(const PatternConstructorNode* op, const Pattern& cand) override { + auto* ctor_cand = cand.as(); + // attempting to match non-constructor to constructor pattern: need to specify + if (ctor_cand == nullptr) { + return MatchResult::kUnspecified; + } + + // check that constructors match + if (!op->constructor.same_as(ctor_cand->constructor)) { + return MatchResult::kClash; + } + + // now check that subpatterns match + CHECK(op->patterns.size() == ctor_cand->patterns.size()); + bool unspecified = false; + for (size_t i = 0; i < op->patterns.size(); i++) { + MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]); + // if we have a clash anywhere, then we can return clash + if (submatch == MatchResult::kClash) { + return MatchResult::kClash; + } + if (submatch == MatchResult::kUnspecified) { + unspecified = true; + } + } + // only return unspecified if we have ruled out a clash + if (unspecified) { + return MatchResult::kUnspecified; + } + return MatchResult::kMatch; + } + + // wildcard and var patterns always match + MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override { + return MatchResult::kMatch; + } + + MatchResult VisitPattern_(const PatternVarNode*, const Pattern&) override { + return MatchResult::kMatch; + } +}; + +// Returns list of arrays corresponding to Cartesian product of input list +Array> CartesianProduct(Array> fields) { + CHECK_NE(fields.size(), 0); + Array field_vals = fields[fields.size() - 1]; + Array> ret; + + // base case: this is the last field left + if (fields.size() == 1) { + for (auto val : field_vals) { + ret.push_back(Array{val}); + } + return ret; + } + + // if we have more fields left, get the sub-candidates by getting + // their cartesian product and appending the elements here onto those + Array> remaining_fields; + for (size_t i = 0; i < fields.size() - 1; i++) { + remaining_fields.push_back(fields[i]); + } + Array> candidates = CartesianProduct(remaining_fields); + for (auto val : field_vals) { + for (auto candidate : candidates) { + candidate.push_back(val); + ret.push_back(candidate); + } + } + return ret; +} + +// Expands all wildcards in the candidate pattern once, using the pattern +// to decide which constructors to insert. Returns a list of all possible expansions. +Array ExpandWildcards(const Pattern& clause_pat, const Pattern& cand, + const Module& mod) { + auto ctor_cand = cand.as(); + PatternConstructor clause_ctor = Downcast(clause_pat); + auto gtv = Downcast(clause_ctor->constructor->belong_to); + + // for a wildcard node, create constructor nodes with wildcards for all args + if (!ctor_cand) { + TypeData td = mod->LookupDef(gtv); + // for each constructor add a candidate + Array ret; + for (auto constructor : td->constructors) { + Array args; + for (auto inp : constructor->inputs) { + args.push_back(PatternWildcardNode::make()); + } + ret.push_back(PatternConstructorNode::make(constructor, args)); + } + return ret; + } + + // for constructors, we will expand the wildcards in any field + // that is an ADT + Array> values_by_field; + for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) { + auto* subpattern = clause_ctor->patterns[i].as(); + // for non-ADT fields, we can only have a wildcard for the value + if (!subpattern) { + values_by_field.push_back({PatternWildcardNode::make()}); + continue; + } + + // otherwise, recursively expand + values_by_field.push_back(ExpandWildcards(GetRef(subpattern), + ctor_cand->patterns[i], mod)); + } + + // generate new candidates using a cartesian product + auto all_subfields = CartesianProduct(values_by_field); + Array ret; + for (auto subfields : all_subfields) { + ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields)); + } + return ret; +} + +/*! + * \brief Finds cases that the match expression does not catch, if any. + * \return Returns a list of cases that are not handled by the match + * expression. + */ +Array UnmatchedCases(const Match& match, const Module& mod) { + /* algorithm: + * candidates = { Wildcard } + * while candidates not empty { + * cand = candidates.pop() + * for clause in clauses { + * if clause fails: next clause + * if clause matches candidate: next candidate + * if candidate is not specific enough: + * candidates += expand_possible_wildcards(cand) + * next candidate + * } + * failed_candidates += { cand } + * } + * return failed_candidates + */ + std::stack candidates; + candidates.push(PatternWildcardNode::make()); + CandidateChecker checker; + + Array failures; + + while (!candidates.empty()) { + Pattern cand = candidates.top(); + candidates.pop(); + + bool failure = true; + for (auto clause : match->clauses) { + // if the check fails, we move on to the next + MatchResult check = checker.Check(clause->lhs, cand); + if (check == MatchResult::kClash) { + continue; + } + + // either success or we need to generate more candidates; + // either way, we're done with this candidate + failure = false; + if (check == MatchResult::kUnspecified) { + auto new_candidates = ExpandWildcards(clause->lhs, cand, mod); + for (auto candidate : new_candidates) { + candidates.push(candidate); + } + } + break; + } + + if (failure) { + failures.push_back(cand); + } + } + + return failures; +} + +// expose for testing only +TVM_REGISTER_API("relay._ir_pass.unmatched_cases") +.set_body_typed(const Match&, + const Module&)>([](const Match& match, + const Module& mod_ref) { + Module call_mod = mod_ref; + if (!call_mod.defined()) { + call_mod = ModuleNode::make({}, {}); + } + return UnmatchedCases(match, call_mod); + }); +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 3fde3c7e7b36..4b126e5299cf 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -293,6 +293,15 @@ class TypeInferencer : private ExprFunctor, GetType(c->rhs), op->span); } + + // check completness + Match match = GetRef(op); + Array unmatched_cases = UnmatchedCases(match, this->mod_); + if (unmatched_cases.size() != 0) { + LOG(WARNING) << "Match clause " << match << " does not handle the following cases: " + << unmatched_cases; + } + return rtype; } diff --git a/tests/python/relay/test_pass_unmatched_cases.py b/tests/python/relay/test_pass_unmatched_cases.py new file mode 100644 index 000000000000..4f2bb20ad7d6 --- /dev/null +++ b/tests/python/relay/test_pass_unmatched_cases.py @@ -0,0 +1,267 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relay +from tvm.relay.prelude import Prelude +from tvm.relay.ir_pass import unmatched_cases + +def test_empty_match_block(): + # empty match block will not match anything, so it should return a wildcard pattern + v = relay.Var('v') + match = relay.Match(v, []) + + unmatched = unmatched_cases(match) + assert len(unmatched) == 1 + assert isinstance(unmatched[0], relay.PatternWildcard) + + +def test_trivial_matches(): + # a match clause with a wildcard will match anything + v = relay.Var('v') + match = relay.Match(v, [ + relay.Clause(relay.PatternWildcard(), v) + ]) + assert len(unmatched_cases(match)) == 0 + + # same with a pattern var + w = relay.Var('w') + match = relay.Match(v, [ + relay.Clause(relay.PatternVar(w), w) + ]) + assert len(unmatched_cases(match)) == 0 + + +def test_single_constructor_adt(): + mod = relay.Module() + box = relay.GlobalTypeVar('box') + a = relay.TypeVar('a') + box_ctor = relay.Constructor('box', [a], box) + box_data = relay.TypeData(box, [a], [box_ctor]) + mod[box] = box_data + + v = relay.Var('v') + match = relay.Match(v, [ + relay.Clause(relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), v) + ]) + + # with one constructor, having one pattern constructor case is exhaustive + assert len(unmatched_cases(match, mod)) == 0 + + # this will be so if we nest the constructors too + nested_pattern = relay.Match(v, [ + relay.Clause( + relay.PatternConstructor( + box_ctor, + [relay.PatternConstructor(box_ctor, + [relay.PatternConstructor( + box_ctor, + [relay.PatternWildcard()])])]), v) + ]) + assert len(unmatched_cases(nested_pattern, mod)) == 0 + + +def test_too_specific_match(): + mod = relay.Module() + p = Prelude(mod) + + v = relay.Var('v') + match = relay.Match(v, [ + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternWildcard()])]), v) + ]) + + unmatched = unmatched_cases(match, mod) + + # will not match nil or a list of length 1 + nil_found = False + single_length_found = False + assert len(unmatched) == 2 + for case in unmatched: + assert isinstance(case, relay.PatternConstructor) + if case.constructor == p.nil: + nil_found = True + if case.constructor == p.cons: + assert isinstance(case.patterns[1], relay.PatternConstructor) + assert case.patterns[1].constructor == p.nil + single_length_found = True + assert nil_found and single_length_found + + # if we add a wildcard, this should work + new_match = relay.Match(v, [ + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternWildcard()])]), v), + relay.Clause(relay.PatternWildcard(), v) + ]) + assert len(unmatched_cases(new_match, mod)) == 0 + + +def test_multiple_constructor_clauses(): + mod = relay.Module() + p = Prelude(mod) + + v = relay.Var('v') + match = relay.Match(v, [ + # list of length exactly 1 + relay.Clause( + relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.nil, [])]), v), + # list of length exactly 2 + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.nil, []) + ])]), v), + # empty list + relay.Clause( + relay.PatternConstructor(p.nil, []), v), + # list of length 2 or more + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternWildcard()])]), v) + ]) + assert len(unmatched_cases(match, mod)) == 0 + + +def test_missing_in_the_middle(): + mod = relay.Module() + p = Prelude(mod) + + v = relay.Var('v') + match = relay.Match(v, [ + # list of length exactly 1 + relay.Clause( + relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternConstructor(p.nil, [])]), v), + # empty list + relay.Clause( + relay.PatternConstructor(p.nil, []), v), + # list of length 3 or more + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternWildcard(), + relay.PatternConstructor( + p.cons, + [relay.PatternWildcard(), + relay.PatternConstructor( + p.cons, + [relay.PatternWildcard(), + relay.PatternWildcard()])])]), + v) + ]) + + # fails to match a list of length exactly two + unmatched = unmatched_cases(match, mod) + assert len(unmatched) == 1 + assert isinstance(unmatched[0], relay.PatternConstructor) + assert unmatched[0].constructor == p.cons + assert isinstance(unmatched[0].patterns[1], relay.PatternConstructor) + assert unmatched[0].patterns[1].constructor == p.cons + assert isinstance(unmatched[0].patterns[1].patterns[1], relay.PatternConstructor) + assert unmatched[0].patterns[1].patterns[1].constructor == p.nil + + +def test_mixed_adt_constructors(): + mod = relay.Module() + box = relay.GlobalTypeVar('box') + a = relay.TypeVar('a') + box_ctor = relay.Constructor('box', [a], box) + box_data = relay.TypeData(box, [a], [box_ctor]) + mod[box] = box_data + + p = Prelude(mod) + + v = relay.Var('v') + box_of_lists_inc = relay.Match(v, [ + relay.Clause( + relay.PatternConstructor( + box_ctor, + [relay.PatternConstructor(p.cons, [ + relay.PatternWildcard(), relay.PatternWildcard()])]), v) + ]) + + # will fail to match a box containing an empty list + unmatched = unmatched_cases(box_of_lists_inc, mod) + assert len(unmatched) == 1 + assert isinstance(unmatched[0], relay.PatternConstructor) + assert unmatched[0].constructor == box_ctor + assert len(unmatched[0].patterns) == 1 and unmatched[0].patterns[0].constructor == p.nil + + box_of_lists_comp = relay.Match(v, [ + relay.Clause( + relay.PatternConstructor( + box_ctor, [relay.PatternConstructor(p.nil, [])]), v), + relay.Clause( + relay.PatternConstructor( + box_ctor, [relay.PatternConstructor(p.cons, [ + relay.PatternWildcard(), relay.PatternWildcard()])]), v) + ]) + assert len(unmatched_cases(box_of_lists_comp, mod)) == 0 + + list_of_boxes_inc = relay.Match(v, [ + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), + relay.PatternWildcard()]), v) + ]) + + # fails to match empty list of boxes + unmatched = unmatched_cases(list_of_boxes_inc, mod) + assert len(unmatched) == 1 + assert isinstance(unmatched[0], relay.PatternConstructor) + assert unmatched[0].constructor == p.nil + + list_of_boxes_comp = relay.Match(v, [ + # exactly one box + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), + relay.PatternConstructor(p.nil, [])]), v), + # exactly two boxes + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), + relay.PatternConstructor(p.cons, [ + relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), + relay.PatternConstructor(p.nil, []) + ])]), v), + # exactly three boxes + relay.Clause( + relay.PatternConstructor( + p.cons, [relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), + relay.PatternConstructor(p.cons, [ + relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), + relay.PatternConstructor(p.cons, [ + relay.PatternConstructor(box_ctor, [relay.PatternWildcard()]), + relay.PatternConstructor(p.nil, []) + ])])]), v), + # one or more boxes + relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternWildcard(), + relay.PatternWildcard()]), v), + # no boxes + relay.Clause(relay.PatternConstructor(p.nil, []), v) + ]) + assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0