diff --git a/src/relay/pass/merge_composite.cc b/src/relay/pass/merge_composite.cc index 28bf8fa8c33a..4e1094b617e9 100644 --- a/src/relay/pass/merge_composite.cc +++ b/src/relay/pass/merge_composite.cc @@ -87,7 +87,7 @@ class MergeCompositeWrapper : public ExprMutator { * a new Relay expression ready to be wrapped into a composite function. */ Expr ExtractPattern(const Call& pattern, const Call& root, - Map>* var_map) { + Map>* var_map, Map* call_map) { // check to make sure both calls are to operators (not functions) if (!pattern->op->IsInstance() || !root->op->IsInstance()) return Expr(); @@ -99,14 +99,20 @@ class MergeCompositeWrapper : public ExprMutator { for (const auto& arg : pattern->args) { Expr new_arg; if (arg->IsInstance()) { - // fail if the root argument is not also a call node - if (!root->args[i]->IsInstance()) { - return Expr(); + // if we've already processed this call node, return the previous result + if (call_map->find(arg) != call_map->end()) { + new_arg = (*call_map)[arg]; + } else { + // fail if the root argument is not also a call node + if (!root->args[i]->IsInstance()) { + return Expr(); + } + // if it's a call node, recursively call this function + new_arg = ExtractPattern(Downcast(arg), + Downcast(root->args[i]), + var_map, call_map); + call_map->Set(arg, new_arg); } - // if it's a call node, recursively call this function - new_arg = ExtractPattern(Downcast(arg), - Downcast(root->args[i]), - var_map); } else if (arg->IsInstance()) { // if there's a var in the pattern, it must be a free var // so call the function to update the var_map @@ -155,7 +161,8 @@ class MergeCompositeWrapper : public ExprMutator { Call pattern = Downcast(pattern_); CHECK(pattern.defined()); Map> args_map; - auto extract = ExtractPattern(pattern, call, &args_map); + Map call_map; + auto extract = ExtractPattern(pattern, call, &args_map, &call_map); if (extract.defined()) { auto free_vars = FreeVars(extract); // make the composite function diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index 4f5acc707a52..b96a89b1f483 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -110,6 +110,26 @@ def make_conv_bias_relu_pattern(): return r +def make_add_add_add_pattern(): + """Create a pattern to match the following graph. + Useful for testing re-using a call node. + + x y + / \ / + | add + \ | \ + add | + | / + add + """ + x = relay.var('x') + y = relay.var('y') + add_node = relay.add(x, y) + add_node_1 = relay.add(x, add_node) + r = relay.add(add_node_1, add_node) + return r + + def test_simple_merge(): """Test composite function is correctly produced from simple graph. @@ -239,6 +259,67 @@ def expected(): assert relay.analysis.alpha_equal(result, expected) +def test_reuse_call_merge(): + """Test composite function is correctly produced from simple graph + which re-uses call nodes. + + We could expect the pattern `make_add_add_add` to be merged + into a single op `add_add_add`. + + x y + \ / \ + sub | x y + / | / \ / | + | add ====> sub | + \ | \ | / + add | add_add_add + | / + add + + """ + pattern_table = [ + ("add_add_add", make_add_add_add_pattern()) + ] + + def before(): + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + sub_node = relay.subtract(a, b) + + # pattern + add_node = relay.add(sub_node, b) + add_node_1 = relay.add(sub_node, add_node) + r = relay.add(add_node_1, add_node) + + return relay.Function([a, b], r) + + def expected(): + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + + # add_relu_add function + in_1 = relay.var('in_1', shape=(10, 10)) + in_2 = relay.var('in_2', shape=(10, 10)) + add_node = relay.add(in_1, in_2) + add_node_1 = relay.add(in_1, add_node) + add_node_2 = relay.add(add_node_1, add_node) + add_add_add = relay.Function([in_1, in_2], add_node_2) + add_add_add = add_add_add.set_attribute("Primitive", + tir.IntImm("int32", 1)) + add_add_add = add_add_add.set_attribute("Composite", + tir.StringImm("add_add_add")) + + # merged function + sub_node = relay.subtract(a, b) + call = relay.Call(add_add_add, [sub_node, b]) + return relay.Function([a, b], call) + + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result) + expected = run_opt_pass(expected(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, expected) + + def test_multiple_patterns(): """Test different patterns are merged correctly in the graph. @@ -608,3 +689,4 @@ def after_B(): test_merge_order() test_parallel_merge() test_multiple_input_subgraphs() + test_reuse_call_merge() \ No newline at end of file