-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relay] Check match expressions for completeness (#3203)
- Loading branch information
1 parent
6e2c7ed
commit a698ad7
Showing
6 changed files
with
574 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* Copyright (c) 2019 by Contributors | ||
* \file match_exhaustion.cc | ||
* \brief Checking Relay match expression exhaustiveness. | ||
* | ||
* This file implements a function that checks whether a match | ||
* expression is exhaustive, that is, whether a given match clause | ||
* matches every possible case. This is important for ensuring | ||
* code correctness, since hitting an unmatched case results in a | ||
* dynamic error unless exhaustiveness is checked in advance. | ||
*/ | ||
#include <tvm/relay/adt.h> | ||
#include <tvm/relay/error.h> | ||
#include <tvm/relay/expr_functor.h> | ||
#include <tvm/relay/pattern_functor.h> | ||
#include <tvm/relay/pass.h> | ||
#include <stack> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
/*! \brief Possible pattern match results */ | ||
enum MatchResult : int { | ||
kMatch = 0, // pattern matches | ||
kClash = 1, // pattern conflicts | ||
kUnspecified = 2, // ambiguous: candidate needs more constructors specified | ||
}; | ||
|
||
class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const Pattern&)> { | ||
public: | ||
explicit CandidateChecker() {} | ||
|
||
MatchResult Check(const Pattern& pat, const Pattern& candidate) { | ||
return this->VisitPattern(pat, candidate); | ||
} | ||
|
||
// for a constructor pattern, we must ensure that the candidate is | ||
// a ConstructorPattern, that it has the same constructor, and | ||
// that its fields match the subpatterns. | ||
MatchResult VisitPattern_(const PatternConstructorNode* op, const Pattern& cand) override { | ||
auto* ctor_cand = cand.as<PatternConstructorNode>(); | ||
// attempting to match non-constructor to constructor pattern: need to specify | ||
if (ctor_cand == nullptr) { | ||
return MatchResult::kUnspecified; | ||
} | ||
|
||
// check that constructors match | ||
if (!op->constructor.same_as(ctor_cand->constructor)) { | ||
return MatchResult::kClash; | ||
} | ||
|
||
// now check that subpatterns match | ||
CHECK(op->patterns.size() == ctor_cand->patterns.size()); | ||
bool unspecified = false; | ||
for (size_t i = 0; i < op->patterns.size(); i++) { | ||
MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]); | ||
// if we have a clash anywhere, then we can return clash | ||
if (submatch == MatchResult::kClash) { | ||
return MatchResult::kClash; | ||
} | ||
if (submatch == MatchResult::kUnspecified) { | ||
unspecified = true; | ||
} | ||
} | ||
// only return unspecified if we have ruled out a clash | ||
if (unspecified) { | ||
return MatchResult::kUnspecified; | ||
} | ||
return MatchResult::kMatch; | ||
} | ||
|
||
// wildcard and var patterns always match | ||
MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override { | ||
return MatchResult::kMatch; | ||
} | ||
|
||
MatchResult VisitPattern_(const PatternVarNode*, const Pattern&) override { | ||
return MatchResult::kMatch; | ||
} | ||
}; | ||
|
||
// Returns list of arrays corresponding to Cartesian product of input list | ||
Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) { | ||
CHECK_NE(fields.size(), 0); | ||
Array<Pattern> field_vals = fields[fields.size() - 1]; | ||
Array<Array<Pattern>> ret; | ||
|
||
// base case: this is the last field left | ||
if (fields.size() == 1) { | ||
for (auto val : field_vals) { | ||
ret.push_back(Array<Pattern>{val}); | ||
} | ||
return ret; | ||
} | ||
|
||
// if we have more fields left, get the sub-candidates by getting | ||
// their cartesian product and appending the elements here onto those | ||
Array<Array<Pattern>> remaining_fields; | ||
for (size_t i = 0; i < fields.size() - 1; i++) { | ||
remaining_fields.push_back(fields[i]); | ||
} | ||
Array<Array<Pattern>> candidates = CartesianProduct(remaining_fields); | ||
for (auto val : field_vals) { | ||
for (auto candidate : candidates) { | ||
candidate.push_back(val); | ||
ret.push_back(candidate); | ||
} | ||
} | ||
return ret; | ||
} | ||
|
||
// Expands all wildcards in the candidate pattern once, using the pattern | ||
// to decide which constructors to insert. Returns a list of all possible expansions. | ||
Array<Pattern> ExpandWildcards(const Pattern& clause_pat, const Pattern& cand, | ||
const Module& mod) { | ||
auto ctor_cand = cand.as<PatternConstructorNode>(); | ||
PatternConstructor clause_ctor = Downcast<PatternConstructor>(clause_pat); | ||
auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to); | ||
|
||
// for a wildcard node, create constructor nodes with wildcards for all args | ||
if (!ctor_cand) { | ||
TypeData td = mod->LookupDef(gtv); | ||
// for each constructor add a candidate | ||
Array<Pattern> ret; | ||
for (auto constructor : td->constructors) { | ||
Array<Pattern> args; | ||
for (auto inp : constructor->inputs) { | ||
args.push_back(PatternWildcardNode::make()); | ||
} | ||
ret.push_back(PatternConstructorNode::make(constructor, args)); | ||
} | ||
return ret; | ||
} | ||
|
||
// for constructors, we will expand the wildcards in any field | ||
// that is an ADT | ||
Array<Array<Pattern>> values_by_field; | ||
for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) { | ||
auto* subpattern = clause_ctor->patterns[i].as<PatternConstructorNode>(); | ||
// for non-ADT fields, we can only have a wildcard for the value | ||
if (!subpattern) { | ||
values_by_field.push_back({PatternWildcardNode::make()}); | ||
continue; | ||
} | ||
|
||
// otherwise, recursively expand | ||
values_by_field.push_back(ExpandWildcards(GetRef<Pattern>(subpattern), | ||
ctor_cand->patterns[i], mod)); | ||
} | ||
|
||
// generate new candidates using a cartesian product | ||
auto all_subfields = CartesianProduct(values_by_field); | ||
Array<Pattern> ret; | ||
for (auto subfields : all_subfields) { | ||
ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields)); | ||
} | ||
return ret; | ||
} | ||
|
||
/*! | ||
* \brief Finds cases that the match expression does not catch, if any. | ||
* \return Returns a list of cases that are not handled by the match | ||
* expression. | ||
*/ | ||
Array<Pattern> UnmatchedCases(const Match& match, const Module& mod) { | ||
/* algorithm: | ||
* candidates = { Wildcard } | ||
* while candidates not empty { | ||
* cand = candidates.pop() | ||
* for clause in clauses { | ||
* if clause fails: next clause | ||
* if clause matches candidate: next candidate | ||
* if candidate is not specific enough: | ||
* candidates += expand_possible_wildcards(cand) | ||
* next candidate | ||
* } | ||
* failed_candidates += { cand } | ||
* } | ||
* return failed_candidates | ||
*/ | ||
std::stack<Pattern> candidates; | ||
candidates.push(PatternWildcardNode::make()); | ||
CandidateChecker checker; | ||
|
||
Array<Pattern> failures; | ||
|
||
while (!candidates.empty()) { | ||
Pattern cand = candidates.top(); | ||
candidates.pop(); | ||
|
||
bool failure = true; | ||
for (auto clause : match->clauses) { | ||
// if the check fails, we move on to the next | ||
MatchResult check = checker.Check(clause->lhs, cand); | ||
if (check == MatchResult::kClash) { | ||
continue; | ||
} | ||
|
||
// either success or we need to generate more candidates; | ||
// either way, we're done with this candidate | ||
failure = false; | ||
if (check == MatchResult::kUnspecified) { | ||
auto new_candidates = ExpandWildcards(clause->lhs, cand, mod); | ||
for (auto candidate : new_candidates) { | ||
candidates.push(candidate); | ||
} | ||
} | ||
break; | ||
} | ||
|
||
if (failure) { | ||
failures.push_back(cand); | ||
} | ||
} | ||
|
||
return failures; | ||
} | ||
|
||
// expose for testing only | ||
TVM_REGISTER_API("relay._ir_pass.unmatched_cases") | ||
.set_body_typed<Array<Pattern>(const Match&, | ||
const Module&)>([](const Match& match, | ||
const Module& mod_ref) { | ||
Module call_mod = mod_ref; | ||
if (!call_mod.defined()) { | ||
call_mod = ModuleNode::make({}, {}); | ||
} | ||
return UnmatchedCases(match, call_mod); | ||
}); | ||
} // namespace relay | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.