Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pass] Fix bug in re-processing call node in MergeComposite pass #4879

Merged
merged 10 commits into from
Feb 17, 2020
25 changes: 16 additions & 9 deletions src/relay/pass/merge_composite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class MergeCompositeWrapper : public ExprMutator {
* a new Relay expression ready to be wrapped into a composite function.
*/
Expr ExtractPattern(const Call& pattern, const Call& root,
Map<std::string, Array<Expr>>* var_map) {
Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
// check to make sure both calls are to operators (not functions)
if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
return Expr();
Expand All @@ -99,14 +99,20 @@ class MergeCompositeWrapper : public ExprMutator {
for (const auto& arg : pattern->args) {
Expr new_arg;
if (arg->IsInstance<CallNode>()) {
// fail if the root argument is not also a call node
if (!root->args[i]->IsInstance<CallNode>()) {
return Expr();
// if we've already processed this call node, return the previous result
if (call_map->find(arg) != call_map->end()) {
new_arg = (*call_map)[arg];
} else {
// fail if the root argument is not also a call node
if (!root->args[i]->IsInstance<CallNode>()) {
return Expr();
}
// if it's a call node, recursively call this function
new_arg = ExtractPattern(Downcast<Call>(arg),
Downcast<Call>(root->args[i]),
var_map, call_map);
call_map->Set(arg, new_arg);
}
// if it's a call node, recursively call this function
new_arg = ExtractPattern(Downcast<Call>(arg),
Downcast<Call>(root->args[i]),
var_map);
} else if (arg->IsInstance<VarNode>()) {
// if there's a var in the pattern, it must be a free var
// so call the function to update the var_map
Expand Down Expand Up @@ -155,7 +161,8 @@ class MergeCompositeWrapper : public ExprMutator {
Call pattern = Downcast<Call>(pattern_);
CHECK(pattern.defined());
Map<std::string, Array<Expr>> args_map;
auto extract = ExtractPattern(pattern, call, &args_map);
Map<Expr, Expr> call_map;
auto extract = ExtractPattern(pattern, call, &args_map, &call_map);
if (extract.defined()) {
auto free_vars = FreeVars(extract);
// make the composite function
Expand Down