Skip to content

Commit

Permalink
[Relay][Pass] Fix bug in re-processing call node in MergeComposite pa…
Browse files Browse the repository at this point in the history
…ss (#4879)

* Fix bug in re-processing call node

* Add test

* Add to main

* temp changes to work from another machine

* fix rest of tests

* fix test_reuse_call_merge

* fix merge

Co-authored-by: Jon Soifer <[email protected]>
  • Loading branch information
soiferj and jonso4 authored Feb 17, 2020
1 parent 0b2d11a commit 27a0284
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 9 deletions.
25 changes: 16 additions & 9 deletions src/relay/pass/merge_composite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class MergeCompositeWrapper : public ExprMutator {
* a new Relay expression ready to be wrapped into a composite function.
*/
Expr ExtractPattern(const Call& pattern, const Call& root,
Map<std::string, Array<Expr>>* var_map) {
Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
// check to make sure both calls are to operators (not functions)
if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
return Expr();
Expand All @@ -99,14 +99,20 @@ class MergeCompositeWrapper : public ExprMutator {
for (const auto& arg : pattern->args) {
Expr new_arg;
if (arg->IsInstance<CallNode>()) {
// fail if the root argument is not also a call node
if (!root->args[i]->IsInstance<CallNode>()) {
return Expr();
// if we've already processed this call node, return the previous result
if (call_map->find(arg) != call_map->end()) {
new_arg = (*call_map)[arg];
} else {
// fail if the root argument is not also a call node
if (!root->args[i]->IsInstance<CallNode>()) {
return Expr();
}
// if it's a call node, recursively call this function
new_arg = ExtractPattern(Downcast<Call>(arg),
Downcast<Call>(root->args[i]),
var_map, call_map);
call_map->Set(arg, new_arg);
}
// if it's a call node, recursively call this function
new_arg = ExtractPattern(Downcast<Call>(arg),
Downcast<Call>(root->args[i]),
var_map);
} else if (arg->IsInstance<VarNode>()) {
// if there's a var in the pattern, it must be a free var
// so call the function to update the var_map
Expand Down Expand Up @@ -155,7 +161,8 @@ class MergeCompositeWrapper : public ExprMutator {
Call pattern = Downcast<Call>(pattern_);
CHECK(pattern.defined());
Map<std::string, Array<Expr>> args_map;
auto extract = ExtractPattern(pattern, call, &args_map);
Map<Expr, Expr> call_map;
auto extract = ExtractPattern(pattern, call, &args_map, &call_map);
if (extract.defined()) {
auto free_vars = FreeVars(extract);
// make the composite function
Expand Down
82 changes: 82 additions & 0 deletions tests/python/relay/test_pass_merge_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,26 @@ def make_conv_bias_relu_pattern():
return r


def make_add_add_add_pattern():
"""Create a pattern to match the following graph.
Useful for testing re-using a call node.
x y
/ \ /
| add
\ | \
add |
| /
add
"""
x = relay.var('x')
y = relay.var('y')
add_node = relay.add(x, y)
add_node_1 = relay.add(x, add_node)
r = relay.add(add_node_1, add_node)
return r


def test_simple_merge():
"""Test composite function is correctly produced from simple graph.
Expand Down Expand Up @@ -239,6 +259,67 @@ def expected():
assert relay.analysis.alpha_equal(result, expected)


def test_reuse_call_merge():
"""Test composite function is correctly produced from simple graph
which re-uses call nodes.
We could expect the pattern `make_add_add_add` to be merged
into a single op `add_add_add`.
x y
\ / \
sub | x y
/ | / \ / |
| add ====> sub |
\ | \ | /
add | add_add_add
| /
add
"""
pattern_table = [
("add_add_add", make_add_add_add_pattern())
]

def before():
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))
sub_node = relay.subtract(a, b)

# pattern
add_node = relay.add(sub_node, b)
add_node_1 = relay.add(sub_node, add_node)
r = relay.add(add_node_1, add_node)

return relay.Function([a, b], r)

def expected():
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))

# add_relu_add function
in_1 = relay.var('in_1', shape=(10, 10))
in_2 = relay.var('in_2', shape=(10, 10))
add_node = relay.add(in_1, in_2)
add_node_1 = relay.add(in_1, add_node)
add_node_2 = relay.add(add_node_1, add_node)
add_add_add = relay.Function([in_1, in_2], add_node_2)
add_add_add = add_add_add.set_attribute("Primitive",
tir.IntImm("int32", 1))
add_add_add = add_add_add.set_attribute("Composite",
tir.StringImm("add_add_add"))

# merged function
sub_node = relay.subtract(a, b)
call = relay.Call(add_add_add, [sub_node, b])
return relay.Function([a, b], call)

result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)


def test_multiple_patterns():
"""Test different patterns are merged correctly in the graph.
Expand Down Expand Up @@ -608,3 +689,4 @@ def after_B():
test_merge_order()
test_parallel_merge()
test_multiple_input_subgraphs()
test_reuse_call_merge()

0 comments on commit 27a0284

Please sign in to comment.