diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 64f2278c31036..1dcf957426c0c 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -561,6 +561,8 @@ constexpr const char* kParams = "__params__"; constexpr const char* kExternalSymbol = "ExternalSymbol"; /*! \brief Mark if the function should be avoided being optimized. */ constexpr const char* kSkipOptimization = "SkipOptimization"; +/*! \brief Treat the function as a composite operator. */ +constexpr const char* kComposite = "Composite"; } // namespace attr } // namespace relay diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 26b20e01c6236..cfca4a6ed3b2c 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -513,6 +513,31 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"): return _transform.Legalize(legalize_map_attr_name) +def MergeComposite(pattern_table): + """Merge multiple operators into a single composite relay function. + + Parameters + ---------- + pattern_table : list(tuple) + A list of (pattern_name, pattern) tuples. + The order of the patterns in the list will determine the order + of priority in which they are matched. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that merges operators into a single composite + relay function. + """ + pattern_names = [] + patterns = [] + for pattern_name, pattern in pattern_table: + pattern_names.append(pattern_name) + patterns.append(pattern) + + return _transform.MergeComposite(pattern_names, patterns) + + def RewriteAnnotatedOps(fallback_device): """Rewrite the annotated program where annotation operators, e.g. `on_deivce`, mark which device an expression should be scheduled to. diff --git a/src/relay/pass/merge_composite.cc b/src/relay/pass/merge_composite.cc new file mode 100644 index 0000000000000..28bf8fa8c33ab --- /dev/null +++ b/src/relay/pass/merge_composite.cc @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/pass/merge_composite.cc + * \brief Merges expressions matching patterns into functions marked + * as 'composite'. This is primarily intended to be used alongside the + * external codegen infrastructure to support the case where multiple + * Relay operators map to a single external operator. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace merge_composite { + +class MergeCompositeWrapper : public ExprMutator { + public: + explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern) + : pattern_name_(pattern_name), pattern_(pattern) {} + + Expr ExtractPattern(const Var& pattern, const Expr& root, + Map>* var_map) { + if (var_map->find(pattern->name_hint()) == var_map->end()) { + // if we haven't encountered this var yet, make a new free var and associate + // it with the value at 'root' + auto free_var = VarNode::make(pattern->name_hint(), Type()); + var_map->Set(pattern->name_hint(), Array({free_var, root})); + return std::move(free_var); + } else { + // if we have encountered this var already, return the free var that was created + auto vars = (*var_map)[pattern->name_hint()]; + auto free_var = vars[0]; + auto graph_expr = vars[1]; + // make sure to first check they both map to the same node in the graph + if (graph_expr != root) { + return Expr(); + } + return (*var_map)[pattern->name_hint()][0]; + } + } + + Expr ExtractPattern(const Constant& pattern, const Expr& root, + Map>* var_map) { + return root; + } + + /*! + * \brief Try and extract a given pattern from a graph as a subgraph. + * \param pattern The pattern to extract. + * \param root The graph to extract from. + * \param var_map A map between free vars in the subgraph and nodes in the graph. + * \return The extracted subgraph. + * + * \note How does this work? + * + * A pattern consists of Relay expression containing only operator call nodes, constants + * and free variables. The free variables indicate where the pattern can 'attach' in your + * graph. This function takes the final call node of the pattern and the call node currently + * being traversed in the Relay graph. It traverses through the pattern in lockstep with call node + * from the graph (referred to as the 'root' node here) to check they're identical. If at any point + * they differ, an empty expression is returned to signify the extract failed. If a free var is + * reached in the pattern, the corresponding value in the root is associated with the name of the + * free var (via the var_map) so that when we construct the composite function, the inputs match + * up correctly with the rest of the graph. The return value of this function when successful is + * a new Relay expression ready to be wrapped into a composite function. + */ + Expr ExtractPattern(const Call& pattern, const Call& root, + Map>* var_map) { + // check to make sure both calls are to operators (not functions) + if (!pattern->op->IsInstance() || !root->op->IsInstance()) + return Expr(); + if (pattern->op.as()->name != root->op.as()->name) + return Expr(); + + unsigned int i = 0; + Array new_args; + for (const auto& arg : pattern->args) { + Expr new_arg; + if (arg->IsInstance()) { + // fail if the root argument is not also a call node + if (!root->args[i]->IsInstance()) { + return Expr(); + } + // if it's a call node, recursively call this function + new_arg = ExtractPattern(Downcast(arg), + Downcast(root->args[i]), + var_map); + } else if (arg->IsInstance()) { + // if there's a var in the pattern, it must be a free var + // so call the function to update the var_map + new_arg = ExtractPattern(Downcast(arg), + root->args[i], + var_map); + } else if (arg->IsInstance()) { + // if there's a constant, simply get the corresponding + // value of the constant from the root + new_arg = ExtractPattern(Downcast(arg), + root->args[i], + var_map); + } + if (!new_arg.defined()) { + return Expr(); + } + new_args.push_back(new_arg); + i++; + } + return CallNode::make(root->op, new_args, root->attrs); + } + + Expr VisitExpr_(const CallNode* cn) { + Call call = GetRef(cn); + if (call->op->IsInstance()) { + Function func = Downcast(call->op); + CHECK(func.defined()); + const auto name_node = FunctionGetAttr(func, attr::kComposite).as(); + // don't step into existing composite functions + if (name_node && name_node->value != "") { + tvm::Array new_args; + for (const auto& arg : call->args) { + auto new_e = this->Mutate(arg); + new_args.push_back(new_e); + } + return CallNode::make(call->op, new_args, call->attrs); + } + } + + Expr expr = ExprMutator::VisitExpr_(cn); + call = Downcast(expr); + if (!call->op->IsInstance()) + return std::move(call); + + // only call patterns are supported + Call pattern = Downcast(pattern_); + CHECK(pattern.defined()); + Map> args_map; + auto extract = ExtractPattern(pattern, call, &args_map); + if (extract.defined()) { + auto free_vars = FreeVars(extract); + // make the composite function + auto f = FunctionNode::make(free_vars, extract, call->checked_type_, {}, Attrs()); + f = FunctionSetAttr(f, attr::kComposite, tir::StringImmNode::make(pattern_name_)); + f = FunctionSetAttr(f, attr::kPrimitive, tvm::Integer(1)); + // find the expressions associated with the free vars using the args_map + // this tells us which expressions should be given as inputs to the composite function + Array args; + for (const auto& free_var : free_vars) { + args.push_back(args_map[free_var->name_hint()][1]); + } + auto new_call = CallNode::make(f, args); + return std::move(new_call); + } + return std::move(call); + } + + private: + /*! \brief The name of the pattern to match */ + std::string pattern_name_; + /*! \brief The pattern to match */ + Expr pattern_; +}; + +Expr MergeComposite(const Expr& expr, + const Array& pattern_names, const Array& patterns) { + CHECK_EQ(pattern_names.size(), patterns.size()); + Expr merged_expr = expr; + // merge the patterns one-by-one in order + for (size_t i = 0; i < patterns.size(); i++) { + std::string pattern_name = pattern_names[i]->value; + Expr pattern = patterns[i]; + merged_expr = MergeCompositeWrapper(pattern_name, pattern).Mutate(merged_expr); + } + return merged_expr; +} + +} // namespace merge_composite + +namespace transform { + +Pass MergeComposite(const tvm::Array& pattern_names, + const tvm::Array& patterns) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast( + relay::merge_composite::MergeComposite(f, pattern_names, patterns)); + }; + auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {}); + return func_pass; +} + +TVM_REGISTER_GLOBAL("relay._transform.MergeComposite") +.set_body_typed(MergeComposite); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py new file mode 100644 index 0000000000000..4f785d7c915ec --- /dev/null +++ b/tests/python/relay/test_pass_merge_composite.py @@ -0,0 +1,609 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for merge composite.""" +from tvm import expr +from tvm import relay +from tvm.relay.testing import run_opt_pass + +""" +The merge composite pass is designed to merge multiple relay operators, that +match a given pattern, and combine them into a single relay function. + +For example suppose we have the graph: + + conv2d + | (merge composite pass) + bias_add ====> conv2d_bias_relu + | (our target) + relu + +Our Relay IR before the pass: + fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32], + %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] { + %0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1]) + /* ty=Tensor[(1, 256, 28, 28), float32] */; + %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */; + nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */ + } + +Our Relay IR after the pass: + fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32], + %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] { + %2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32], + %z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") -> + Tensor[(1, 256, 28, 28), float32] { + %0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */; + %1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */; + nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */ + }; + %2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */ + } + +As you can see in the second relay example, the pattern we specified has been wrapped +in a function. The function is then called, producing the same result as the first relay +example. + +One convenient use for this pass is to offload multiple operators to a single external +codegen function. +""" + + +def make_add_sub_mul_pattern(): + """Create a pattern to match the following graph. + + add sub + \ / + \ / + mul + """ + x = relay.var('x') + y = relay.var('y') + add_node = relay.add(x, y) + sub_node = relay.subtract(x, y) + mul_node = relay.multiply(add_node, sub_node) + return mul_node + + +def make_add_relu_pattern(): + """Create a pattern to match the following graph. + + add + | + relu + """ + x = relay.var('x') + y = relay.var('y') + add_node = relay.add(x, y) + r = relay.nn.relu(add_node) + return r + + +def make_conv_bias_relu_pattern(): + """Create a pattern to match the following graph. + + conv2d + | + bias_add + | + relu + """ + x = relay.var('x') + y = relay.var('y') + z = relay.var('z') + conv_node = relay.nn.conv2d(x, y) + bias_node = relay.nn.bias_add(conv_node, z) + r = relay.nn.relu(bias_node) + return r + + +def test_simple_merge(): + """Test composite function is correctly produced from simple graph. + + We could expect the pattern `make_add_relu_pattern` to be merged + into a single op `add_relu`. + + a b + \ / a b + add ====> \ / + | add_relu + relu + + """ + pattern_table = [ + ("add_relu", make_add_relu_pattern()) + ] + + def before(): + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + add_node = relay.add(a, b) + r = relay.nn.relu(add_node) + return relay.Function([a, b], r) + + def expected(): + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + + # add_relu function + in_1 = relay.var('in_1', shape=(10, 10)) + in_2 = relay.var('in_2', shape=(10, 10)) + add_node = relay.add(in_1, in_2) + relu_node = relay.nn.relu(add_node) + add_relu = relay.Function([in_1, in_2], relu_node) + + # merged function + r = relay.Call(add_relu, [a, b]) + return relay.Function([a, b], r) + + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result) + expected = run_opt_pass(expected(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, expected) + + +def test_branch_merge(): + """Test composite function is correctly produced from branching graph. + + We would expect the pattern `make_add_sub_mul_pattern` to be merged + into a single op `add_sub_mul`. + + a b a b + \/ \/ + add sub a b + \ / \/ + \ / add_sub_mul + mul c | + / \ \ | + c / c | ====> add_sub_mul + \/ \/ | + add sub | + \ / relu + \ / + mul + | + | + relu + """ + + pattern_table = [ + ("add_sub_mul", make_add_sub_mul_pattern()) + ] + + def before(): + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + c = relay.var('c', shape=(10, 10)) + add_node = relay.add(a, b) + sub_node = relay.subtract(a, b) + mul_node = relay.multiply(add_node, sub_node) + add_node_2 = relay.add(c, mul_node) + sub_node_2 = relay.subtract(c, mul_node) + mul_node_2 = relay.multiply(add_node_2, sub_node_2) + r = relay.nn.relu(mul_node_2) + return relay.Function([a, b, c], r) + + def expected(): + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + c = relay.var('c', shape=(10, 10)) + + # add_sub_mul function + in_1 = relay.var('in_1', shape=(10, 10)) + in_2 = relay.var('in_2', shape=(10, 10)) + add_node = relay.add(in_1, in_2) + sub_node = relay.subtract(in_1, in_2) + mul_node = relay.multiply(add_node, sub_node) + add_sub_mul = relay.Function([in_1, in_2], mul_node) + + # merged function + add_sub_mul_1 = relay.Call(add_sub_mul, [a, b]) + add_sub_mul_2 = relay.Call(add_sub_mul, [c, add_sub_mul_1]) + r = relay.nn.relu(add_sub_mul_2) + return relay.Function([a, b, c], r) + + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result) + expected = run_opt_pass(expected(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, expected) + + +def test_multiple_patterns(): + """Test different patterns are merged correctly in the graph. + + We would expect the pattern `make_conv_bias_relu_pattern` to be merged + into a single op `conv_bias_relu`. We would also expect `make_add_relu_pattern` + to be merged into a single op `add_relu`. + + data kernel + \ / + \ / + conv2d data kernel bias + | \ | / + | bias conv2d_bias_relu + | / | + bias_add ====> | a + | | / + relu a add_relu + \ / | + add | b + | | / + relu b mul + | / + mul + """ + pattern_table = [ + ("conv2d_bias_relu", make_conv_bias_relu_pattern()), + ("add_relu", make_add_relu_pattern()) + ] + + def before(): + data = relay.var('data', shape=(1, 512, 28, 28)) + kernel = relay.var('kernel', shape=(256, 512, 1, 1)) + bias = relay.var('bias', shape=(256,)) + a = relay.var('a', shape=(1, 256, 28, 28)) + b = relay.var('b', shape=(1, 256, 28, 28)) + + conv_node = relay.nn.conv2d(data, + kernel, + kernel_size=(1, 1), + padding=(0, 0), + strides=(1, 1)) + + bias_node = relay.nn.bias_add(conv_node, bias) + relu_node = relay.nn.relu(bias_node) + add_node = relay.add(relu_node, a) + relu_node_2 = relay.nn.relu(add_node) + r = relay.multiply(relu_node_2, b) + return relay.Function([data, kernel, bias, a, b], r) + + def expected(): + data = relay.var('data', shape=(1, 512, 28, 28)) + kernel = relay.var('kernel', shape=(256, 512, 1, 1)) + bias = relay.var('bias', shape=(256,)) + a = relay.var('a', shape=(1, 256, 28, 28)) + b = relay.var('b', shape=(1, 256, 28, 28)) + + # conv_bias_relu function + in_1 = relay.var('in_1', shape=(1, 512, 28, 28)) + in_2 = relay.var('in_2', shape=(256, 512, 1, 1)) + in_3 = relay.var('in_3', shape=(256,)) + + conv_node = relay.nn.conv2d(in_1, + in_2, + kernel_size=(1, 1), + padding=(0, 0), + strides=(1, 1)) + + bias_node = relay.nn.bias_add(conv_node, in_3) + r = relay.nn.relu(bias_node) + conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r) + + # add_relu function + in_4 = relay.var('in_4', shape=(1, 256, 28, 28)) + in_5 = relay.var('in_5', shape=(1, 256, 28, 28)) + add_node = relay.add(in_4, in_5) + r = relay.nn.relu(add_node) + add_relu = relay.Function([in_4, in_5], r) + + # merged function + conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias]) + add_relu_1 = relay.Call(add_relu, [conv_bias_add_relu_1, a]) + r = relay.multiply(add_relu_1, b) + return relay.Function([data, kernel, bias, a, b], r) + + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result) + expected = run_opt_pass(expected(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, expected) + + +def test_merge_order(): + """Test that patterns are merged in the order they exist in the pattern table. + + There can be cases where one pattern is a subgraph of another, in which case + it is not clear which match should take priority. The priority should come + from the order in which the patterns are declared in the pattern table. The + first patterns will be merged with highest priority and the last with lowest. + + A: B: C: + add add abs + | | | + abs abs relu + | + relu + + """ + + def pattern_A(): + x = relay.var('x') + y = relay.var('y') + out = relay.add(x, y) + out = relay.abs(out) + out = relay.nn.relu(out) + return out + + def pattern_B(): + x = relay.var('x') + y = relay.var('y') + out = relay.add(x, y) + out = relay.abs(out) + return out + + def pattern_C(): + x = relay.var('x') + out = relay.abs(x) + out = relay.nn.relu(x) + return out + + def before(): + input_1 = relay.var('input_1', shape=(10, 10)) + input_2 = relay.var('input_2', shape=(10, 10)) + out = relay.add(input_1, input_2) + out = relay.abs(out) + out = relay.nn.relu(out) + return relay.Function([input_1, input_2], out) + + def after_A_priority(): + input_1 = relay.var('input_1', shape=(10, 10)) + input_2 = relay.var('input_2', shape=(10, 10)) + x = relay.var('x') + y = relay.var('y') + out = relay.add(x, y) + out = relay.abs(out) + out = relay.nn.relu(out) + merged_func = relay.Function([x, y], out) + merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1)) + merged_func = merged_func.set_attribute('Composite', expr.StringImm('A')) + ret = relay.Call(merged_func, [input_1, input_2]) + return relay.Function([input_1, input_2], ret) + + def after_B_priority(): + input_1 = relay.var('input_1', shape=(10, 10)) + input_2 = relay.var('input_2', shape=(10, 10)) + x = relay.var('x') + y = relay.var('y') + out = relay.add(x, y) + out = relay.abs(out) + merged_func = relay.Function([x, y], out) + merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1)) + merged_func = merged_func.set_attribute('Composite', expr.StringImm('B')) + merged_call = relay.Call(merged_func, [input_1, input_2]) + ret = relay.nn.relu(merged_call) + return relay.Function([input_1, input_2], ret) + + def after_C_priority(): + input_1 = relay.var('input_1', shape=(10, 10)) + input_2 = relay.var('input_2', shape=(10, 10)) + add = relay.add(input_1, input_2) + x = relay.var('x') + out = relay.abs(x) + out = relay.nn.relu(out) + merged_func = relay.Function([x], out) + merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1)) + merged_func = merged_func.set_attribute('Composite', expr.StringImm('C')) + ret = relay.Call(merged_func, [add]) + return relay.Function([input_1, input_2], ret) + + # check A highest priority + pattern_table = [ + ("A", pattern_A()), + ("B", pattern_B()), + ("C", pattern_C()), + ] + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result) + expected = run_opt_pass(after_A_priority(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, expected) + + # check B highest priority + pattern_table = [ + ("B", pattern_A()), + ("C", pattern_B()), + ("A", pattern_C()), + ] + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result) + expected = run_opt_pass(after_A_priority(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, expected) + + # check C highest priority + pattern_table = [ + ("C", pattern_A()), + ("A", pattern_B()), + ("B", pattern_C()), + ] + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result) + expected = run_opt_pass(after_A_priority(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, expected) + + +def test_parallel_merge(): + """Tests that parallel patterns relying on the same inputs are correctly merged. + + The test graph is difficult to draw out as ascii art. It is essentially two parallel + add-sub-mul units which both consume input_1 and input_2 with their results being multiplied + to give the output. We expect both parallel branches should get merged and both should still + consume the same input variables, input_1 and input_2.""" + + def before(): + input_1 = relay.var('input_1', shape=(10, 10)) + input_2 = relay.var('input_2', shape=(10, 10)) + branch_1_add = relay.add(input_1, input_2) + branch_1_sub = relay.subtract(input_1, input_2) + branch_1 = relay.multiply(branch_1_add, branch_1_sub) + branch_2_add = relay.add(input_1, input_2) + branch_2_sub = relay.subtract(input_1, input_2) + branch_2 = relay.multiply(branch_2_add, branch_2_sub) + out = relay.multiply(branch_1, branch_2) + return relay.Function([input_1, input_2], out) + + def after(): + input_1 = relay.var('input_1', shape=(10, 10)) + input_2 = relay.var('input_2', shape=(10, 10)) + x = relay.var('x') + y = relay.var('y') + branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y)) + func_1 = relay.Function([x, y], branch_1) + call_1 = relay.Call(func_1, [input_1, input_2]) + x1 = relay.var('x1') + y1 = relay.var('y1') + branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1)) + func_2 = relay.Function([x1, y1], branch_2) + call_2 = relay.Call(func_2, [input_1, input_2]) + out = relay.multiply(call_1, call_2) + return relay.Function([input_1, input_2], out) + + pattern_table = [ + ("add_sub_mul", make_add_sub_mul_pattern()) + ] + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result) + expected = run_opt_pass(after(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, expected) + + +def test_multiple_input_subgraphs(): + """Test the case when multiple input subgraphs feed into another subgraph. + + (1) (2) (3) (4) + add add add add + | | | | + relu relu relu relu + \ / \ / + \ / \ / + add sub + \ / + \ / + \ / + mul + + ----> When 1=3 and 2=4 (Case 'A') + + add_relu add_relu + \ / + \ / + add_sub_mul + + ----> When 1!=3 and 2!=4 (Case 'B') + + add_relu add_relu add_relu add_relu + \ / \ / + \ / \ / + add sub + \ / + -------- ----- + \ / + mul + + The difference in behaviour comes from the fact that add_sub_mul expects that the + inputs to add and sub are identical (the same two relay expressions). So when you + have 4 independent inputs, the pattern should not be merged. + """ + + def before(): + before_funcs = {} + inputs = [relay.var('input_' + str(i), shape=(10, 10)) for i in range(8)] + add_relu_1 = relay.add(inputs[0], inputs[1]) + add_relu_1 = relay.nn.relu(add_relu_1) + add_relu_2 = relay.add(inputs[2], inputs[3]) + add_relu_2 = relay.nn.relu(add_relu_2) + add_relu_3 = relay.add(inputs[4], inputs[5]) + add_relu_3 = relay.nn.relu(add_relu_3) + add_relu_4 = relay.add(inputs[6], inputs[7]) + add_relu_4 = relay.nn.relu(add_relu_4) + add = relay.add(add_relu_1, add_relu_2) + sub = relay.subtract(add_relu_3, add_relu_4) + out = relay.multiply(add, sub) + before_funcs['B'] = relay.Function(inputs, out) + sub = relay.subtract(add_relu_1, add_relu_2) + out = relay.multiply(add, sub) + before_funcs['A'] = relay.Function(inputs[:4], out) + return before_funcs + + def after_A(): + inputs = [relay.var('input_' + str(i), shape=(10, 10)) for i in range(4)] + x = relay.var('x') + y = relay.var('y') + add_relu_1 = relay.add(x, y) + add_relu_1 = relay.nn.relu(add_relu_1) + add_relu_1 = relay.Function([x, y], add_relu_1) + add_relu_1 = add_relu_1.set_attribute('Primitive', expr.IntImm('int32', 1)) + add_relu_1 = add_relu_1.set_attribute('Composite', expr.StringImm('add_relu')) + add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]]) + x1 = relay.var('x1') + y1 = relay.var('y1') + add_relu_2 = relay.add(x1, y1) + add_relu_2 = relay.nn.relu(add_relu_2) + add_relu_2 = relay.Function([x1, y1], add_relu_2) + add_relu_2 = add_relu_2.set_attribute('Primitive', expr.IntImm('int32', 1)) + add_relu_2 = add_relu_2.set_attribute('Composite', expr.StringImm('add_relu')) + add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]]) + x2 = relay.var('x2') + y2 = relay.var('y2') + add = relay.add(x2, y2) + sub = relay.subtract(x2, y2) + add_sub_mul = relay.multiply(add, sub) + add_sub_mul = relay.Function([x2, y2], add_sub_mul) + add_sub_mul = add_sub_mul.set_attribute('Primitive', expr.IntImm('int32', 1)) + add_sub_mul = add_sub_mul.set_attribute('Composite', expr.StringImm('add_sub_mul')) + add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2]) + return relay.Function(inputs, add_sub_mul_call) + + def after_B(): + inputs = [relay.var('input_' + str(i), shape=(10, 10)) for i in range(8)] + add_relu_calls = [] + for i in range(4): + x = relay.var('x' + str(i)) + y = relay.var('x' + str(i)) + add_relu = relay.add(x, y) + add_relu = relay.nn.relu(add_relu) + add_relu = relay.Function([x, y], add_relu) + add_relu = add_relu.set_attribute('Primitive', expr.IntImm('int32', 1)) + add_relu = add_relu.set_attribute('Composite', expr.StringImm('add_relu')) + add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]]) + add_relu_calls.append(add_relu_call) + + add = relay.add(add_relu_calls[0], add_relu_calls[1]) + sub = relay.subtract(add_relu_calls[2], add_relu_calls[3]) + out = relay.multiply(add, sub) + return relay.Function(inputs, out) + + pattern_table = [ + ("add_sub_mul", make_add_sub_mul_pattern()), + ("add_relu", make_add_relu_pattern()) + ] + # check case 'A' + result = run_opt_pass(before()['A'], relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result) + expected = run_opt_pass(after_A(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, expected) + + # check case 'B' + result = run_opt_pass(before()['B'], relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result) + expected = run_opt_pass(after_B(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, expected) + + +if __name__ == "__main__": + test_simple_merge() + test_branch_merge() + test_multiple_patterns() + test_merge_order() + test_parallel_merge() + test_multiple_input_subgraphs() \ No newline at end of file