diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 500e0dcb0c9a1..b3e749ba680ad 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -40,35 +40,44 @@ using namespace backend; * purpose. Only several binary options are covered. Users * may need to extend them to cover more operators. */ -class CodegenC : public ExprVisitor, public CodegenCBase { +class CodegenC : public relay::ExprFunctor(const Expr&)>, + public CodegenCBase { public: explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; } - void VisitExpr_(const VarNode* node) final { + std::vector VisitExpr(const Expr& expr) final { + if (visited_.count(expr)) return visited_.at(expr); + + std::vector output; + if (expr.as()) { + output = VisitExpr_(expr.as()); + } else if (expr.as()) { + output = VisitExpr_(expr.as()); + } else if (expr.as()) { + output = VisitExpr_(expr.as()); + } else { + LOG(FATAL) << "DNNL codegen doesn't support: " << expr->GetTypeKey(); + } + visited_[expr] = output; + return output; + } + + std::vector VisitExpr_(const VarNode* node) final { ext_func_args_.push_back(GetRef(node)); - out_.clear(); Output output; output.name = node->name_hint(); - out_.push_back(output); + return {output}; } - void VisitExpr_(const ConstantNode* cn) final { - Constant constant = GetRef(cn); - if (visited_.count(constant)) { - // Note this is for demostration purpose. ConstantNode doesn't necessarily - // belong to calls. We need to revisit this when tuples come into play. - out_.push_back(visited_[constant]); - return; - } + std::vector VisitExpr_(const ConstantNode* cn) final { + // Note this is for demonstration purpose. ConstantNode doesn't necessarily + // belong to calls. We need to revisit this when tuples come into play. std::ostringstream decl_stream; std::ostringstream buf_stream; - out_.clear(); Output output; output.name = "const_" + std::to_string(const_idx_++); - out_.push_back(output); - visited_[constant] = output; runtime::NDArray array = cn->data; const auto& shape = array.Shape(); @@ -99,9 +108,11 @@ class CodegenC : public ExprVisitor, public CodegenCBase { } buf_stream << "};"; ext_func_body.insert(ext_func_body.begin(), buf_stream.str()); + + return {output}; } - void VisitExpr_(const CallNode* call) final { + std::vector VisitExpr_(const CallNode* call) final { std::ostringstream macro_stream; std::ostringstream decl_stream; std::ostringstream buf_stream; @@ -138,8 +149,8 @@ class CodegenC : public ExprVisitor, public CodegenCBase { bool first = true; decl_stream << func_name << "("; for (size_t i = 0; i < call->args.size(); ++i) { - VisitExpr(call->args[i]); - for (auto out : out_) { + auto res = VisitExpr(call->args[i]); + for (auto out : res) { if (!first) { decl_stream << ", "; } @@ -162,13 +173,14 @@ class CodegenC : public ExprVisitor, public CodegenCBase { ext_func_body.push_back(decl_stream.str()); // Update output buffer - out_.clear(); + // Note C codegen only handles TensorType. Therefore, we don't flatten + // tuples and only return a single vaule. Output output; output.name = out; output.dtype = dtype; output.need_copy = true; output.size = out_size; - out_.push_back(output); + return {output}; } /*! @@ -176,12 +188,12 @@ class CodegenC : public ExprVisitor, public CodegenCBase { * * \return The emitted code. */ - std::string JIT() { + std::string JIT(const std::vector& out) { // Write function macros for (auto decl : func_decl_) { code_stream_ << decl << "\n"; } - return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_); + return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out); } private: @@ -202,9 +214,7 @@ class CodegenC : public ExprVisitor, public CodegenCBase { /*! \brief The declaration statements of buffers. */ std::vector buf_decl_; /*! \brief The name and index pairs for output. */ - std::vector out_; - /*! \brief The cached expressions. */ - std::unordered_map visited_; + std::unordered_map, ObjectHash, ObjectEqual> visited_; }; class CSourceCodegen : public CSourceModuleCodegenBase { @@ -216,8 +226,8 @@ class CSourceCodegen : public CSourceModuleCodegenBase { auto sid = GetExtSymbol(func); CodegenC builder(sid); - builder.VisitExpr(func->body); - code_stream_ << builder.JIT(); + auto out = builder.VisitExpr(func->body); + code_stream_ << builder.JIT(out); } runtime::Module CreateCSourceModule(const ObjectRef& ref) override { diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 1b953f3c44671..be8fcfff4fcfc 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -165,9 +165,11 @@ class CodegenCBase { /*! * \brief Emit the code for external runtime. * + * \param out The outputs. + * * \return The code string. */ - virtual std::string JIT() = 0; + virtual std::string JIT(const std::vector& out) = 0; /*! * \brief A common interface that is used by various external runtime to diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 7f3aabf6e0160..58b85d28d7a0b 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -128,42 +128,50 @@ std::vector Add(const CallNode* call) { // TODO(@zhiics, @comaniac): This is a basic implementation. We should implement // all utilities and make a base class for users to implement. -class CodegenDNNL : public ExprVisitor, public CodegenCBase { +class CodegenDNNL : public relay::ExprFunctor(const Expr&)>, + public CodegenCBase { public: explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; } - void VisitExpr_(const VarNode* node) final { + std::vector VisitExpr(const Expr& expr) final { + if (visited_.count(expr)) return visited_.at(expr); + + std::vector output; + if (expr.as()) { + output = VisitExpr_(expr.as()); + } else if (expr.as()) { + output = VisitExpr_(expr.as()); + } else if (expr.as()) { + output = VisitExpr_(expr.as()); + } else if (expr.as()) { + output = VisitExpr_(expr.as()); + } else { + LOG(FATAL) << "DNNL codegen doesn't support: " << expr->GetTypeKey(); + } + visited_[expr] = output; + return output; + } + + std::vector VisitExpr_(const VarNode* node) final { ext_func_args_.push_back(GetRef(node)); - out_.clear(); Output output; output.name = node->name_hint(); - out_.push_back(output); + return {output}; } - void VisitExpr_(const TupleGetItemNode* op) final { - VisitExpr(op->tuple); - CHECK(out_.size() > static_cast(op->index)); + std::vector VisitExpr_(const TupleGetItemNode* op) final { + auto res = VisitExpr(op->tuple); + CHECK_GT(res.size(), static_cast(op->index)); // Only keep the item we want for the child node. // FIXME(@comaniac): The other items should still be requried for the primary outputs. - auto item = out_[op->index]; - out_.clear(); - out_.push_back(item); + return {res[op->index]}; } - void VisitExpr_(const ConstantNode* cn) final { - Constant constant = GetRef(cn); - if (visited_.count(constant)) { - out_.push_back(visited_[constant]); - return; - } - - out_.clear(); + std::vector VisitExpr_(const ConstantNode* cn) final { Output output; output.name = "const_" + std::to_string(const_idx_++); output.dtype = "float"; - out_.push_back(output); - visited_[constant] = output; runtime::NDArray array = cn->data; @@ -176,16 +184,23 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now."; std::ostringstream buf_stream; - buf_stream << "float* " << output.name << " = (float*)std::malloc(4 * " << num_elems << ");\n"; const float* ptr = static_cast(array.ToDLPack()->dl_tensor.data); - for (int64_t i = 0; i < num_elems; i++) { - buf_stream << " " << output.name << "[" << i << "] = " << ptr[i] << ";\n"; + + // Allocate large arrays on the static section to avoids stakc overflow. + // Note that this would probably increase compilation time as the source + // file could be really large. + buf_stream << "static float " << output.name << "[" << num_elems <<"] = {"; + for (int64_t i = 0; i < num_elems - 1; i++) { + buf_stream << ptr[i] << ","; } + if (num_elems > 0) buf_stream << ptr[num_elems - 1]; + buf_stream << "};\n"; ext_func_body.insert(ext_func_body.begin(), buf_stream.str()); + return {output}; } - void VisitExpr_(const CallNode* call) final { + std::vector VisitExpr_(const CallNode* call) final { GenerateBodyOutput ret; if (const auto* func = call->op.as()) { ret = GenerateCompositeFunctionCall(func, call); @@ -193,16 +208,14 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { ret = GenerateOpCall(call); } - out_.clear(); - for (size_t i = 0; i < ret.outputs.size(); ++i) { - buf_decl_.push_back(ret.buffers[i]); - out_.push_back(ret.outputs[i]); - } + buf_decl_.insert(buf_decl_.end(), ret.buffers.begin(), ret.buffers.end()); + std::vector out = ret.outputs; ext_func_body.push_back(ret.decl); + return ret.outputs; } - std::string JIT(void) { - return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_); + std::string JIT(const std::vector& out) { + return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out); } private: @@ -215,8 +228,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { std::vector GetArgumentNames(const CallNode* call) { std::vector arg_names; for (size_t i = 0; i < call->args.size(); ++i) { - VisitExpr(call->args[i]); - for (auto out : out_) { + auto res = VisitExpr(call->args[i]); + for (const auto& out : res) { arg_names.push_back(out.name); } } @@ -331,17 +344,15 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { */ int buf_idx_{0}; /*! \brief The index of global constants. */ - int const_idx_ = 0; + int const_idx_{0}; /*! \brief The arguments used by a wrapped function that calls DNNL kernels. */ Array ext_func_args_; /*! \brief statement of the function that will be compiled using DNNL kernels. */ std::vector ext_func_body; /*! \brief The declaration of intermeidate buffers. */ std::vector buf_decl_; - /*! \brief The name of the the outputs. */ - std::vector out_; /*! \brief The cached expressions. */ - std::unordered_map visited_; + std::unordered_map, ObjectHash, ObjectEqual> visited_; }; /*! @@ -361,8 +372,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { auto sid = GetExtSymbol(func); CodegenDNNL builder(sid); - builder.VisitExpr(func->body); - code_stream_ << builder.JIT(); + auto out = builder.VisitExpr(func->body); + code_stream_ << builder.JIT(out); } /*! diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index fa9c8c4f40a25..d5b653d66ea6f 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -148,25 +148,42 @@ class Partitioner : public ExprMutator { CHECK_EQ(call->args.size(), 1U); // Traverse the rest graph. - auto input_expr = VisitExpr(call->args[0]); + Expr parent = call->args[0]; + auto input_expr = VisitExpr(parent); + + // Backtrace the parent to find the LCA node that is not a begin/ end op + while (const auto* parent_call = parent.as()) { + if (parent_call->op == compiler_begin_op || + parent_call->op == compiler_end_op) { + parent = parent_call->args[0]; + } else { + break; + } + } AnnotatedRegion sg = GetRegion(GetRef(call)); int index = GetArgIdx(sg, GetRef(call)); CHECK_NE(index, -1); - // The type of the created variable is the same as the compiler_begin - // node. - std::string target = call->attrs.as()->compiler; - std::string varname = - target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index); - auto var = Var(varname, GetRef(call)->checked_type_); - - auto cand = std::make_pair(var, input_expr); - if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) == - region_args[sg].end()) { - region_args[sg].push_back(cand); - } - return std::move(var); + if (shared_output_.count(parent) && shared_output_[parent].count(sg)) { + return shared_output_[parent][sg]; + } else { + // The type of the created variable is the same as the compiler_begin + // node. + std::string target = call->attrs.as()->compiler; + std::string varname = + target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index); + auto var = Var(varname, GetRef(call)->checked_type_); + + std::pair cand = std::make_pair(var, input_expr); + + if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) == + region_args[sg].end()) { + region_args[sg].push_back(cand); + } + shared_output_[parent][sg] = var; + return std::move(var); + } } else { CHECK_EQ(call->op, compiler_end_op); // The annotation node is inserted on edge so it must have only one @@ -474,6 +491,12 @@ class Partitioner : public ExprMutator { * belongs to */ std::unordered_map regions_sets_; + + /*!\brief Cache the output that is shared by different nodes. */ + using RegionOutputMap = std::unordered_map; + std::unordered_map shared_output_; + + /*!\brief The IRModule used for partitioning. */ IRModule module_; }; diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index c7d9626931d0e..1272c352af787 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -300,6 +300,14 @@ def visit_call(self, call): check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data) +def set_func_attr(func, compile_name, symbol_name): + func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Compiler", compile_name) + func = func.with_attr("global_symbol", symbol_name) + return func + + def test_extern_ccompiler_default_ops(): def expected(): mod = tvm.IRModule() @@ -310,10 +318,7 @@ def expected(): add = x0 + y0 # Function that uses C compiler func = relay.Function([x0, y0], add) - func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Compiler", "ccompiler") - func = func.with_attr("global_symbol", "ccompiler_0") + func = set_func_attr(func, "ccompiler", "ccompiler_0") glb_0 = relay.GlobalVar("ccompiler_0") mod[glb_0] = func add_call = relay.Call(glb_0, [x, y]) @@ -380,32 +385,28 @@ def test_extern_dnnl(): def expected(): data0 = relay.var("data", shape=(ishape), dtype=dtype) - input0 = relay.var("input0", shape=(w1shape), dtype=dtype) - input1 = relay.var("input1", shape=(w1shape), dtype=dtype) + input0 = relay.var("input", shape=(w1shape), dtype=dtype) depthwise_conv2d_1 = relay.nn.conv2d(data0, input0, kernel_size=(3, 3), padding=(1, 1), groups=32) depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1, - input1, + input0, kernel_size=(3, 3), padding=(1, 1), groups=32) out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) - func = relay.Function([data0, input0, input1], out) - func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Compiler", "dnnl") - func = func.with_attr("global_symbol", "dnnl_0") + func = relay.Function([data0, input0], out) + func = set_func_attr(func, "dnnl", "dnnl_0") glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() mod[glb_var] = func data = relay.var("data", shape=(ishape), dtype=dtype) weight = relay.var("input", shape=(w1shape), dtype=dtype) - main_f = relay.Function([data, weight], glb_var(data, weight, weight)) + main_f = relay.Function([data, weight], glb_var(data, weight)) mod["main"] = main_f return mod @@ -444,7 +445,7 @@ def get_func(): check_result(mod, {"data": i_data, "weight1": w1_data}, (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5) -@pytest.mark.skip(reason="fix constant node before opening this case") + def test_extern_dnnl_mobilenet(): if not tvm.get_global_func("relay.ext.dnnl", True): print("skip because DNNL codegen is not available") @@ -521,10 +522,7 @@ def expected(): bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn.astuple()) - func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Compiler", "test_compiler") - func0 = func0.with_attr("global_symbol", "test_compiler_0") + func0 = set_func_attr(func0, "test_compiler", "test_compiler_0") gv0 = relay.GlobalVar("test_compiler_0") mod[gv0] = func0 @@ -538,10 +536,7 @@ def expected(): channels=16, padding=(1, 1)) func1 = relay.Function([data1, weight1], conv) - func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func1 = func1.with_attr("Compiler", "test_compiler") - func1 = func1.with_attr("global_symbol", "test_compiler_1") + func1 = set_func_attr(func1, "test_compiler", "test_compiler_1") gv1 = relay.GlobalVar("test_compiler_1") mod[gv1] = func1 @@ -610,10 +605,7 @@ def expected(): bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn.astuple()) - func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Compiler", "test_compiler") - func0 = func0.with_attr("global_symbol", "test_compiler_0") + func0 = set_func_attr(func0, "test_compiler", "test_compiler_0") # main function data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32")) @@ -645,10 +637,7 @@ def expected(): add = x0 + y0 # Function that uses C compiler func = relay.Function([y0], add) - func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Compiler", "ccompiler") - func = func.with_attr("global_symbol", "ccompiler_0") + func = set_func_attr(func, "ccompiler", "ccompiler_0") glb_0 = relay.GlobalVar("ccompiler_0") mod[glb_0] = func add_call = relay.Call(glb_0, [y]) @@ -745,10 +734,7 @@ def expected(): func0 = relay.Function([data, weight, bn_gamma, bn_beta, bn_mean, bn_var], tuple_o) - func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Compiler", "test_target") - func0 = func0.with_attr("global_symbol", "test_target_2") + func0 = set_func_attr(func0, "test_target", "test_target_2") gv0 = relay.GlobalVar("test_target_2") mod[gv0] = func0 @@ -810,11 +796,7 @@ def expected(): f1_O_2 = relay.nn.relu(f1_O_1) f1_out = relay.Tuple((f1_O_2, f1_O_1)) func1 = relay.Function([f1_cb1], f1_out) - - func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func1 = func1.with_attr("Compiler", "test_target") - func1 = func1.with_attr("global_symbol", "test_target_1") + func1 = set_func_attr(func1, "test_target", "test_target_1") gv1 = relay.GlobalVar("test_target_1") mod[gv1] = func1 @@ -823,11 +805,7 @@ def expected(): f2_cb4 = relay.var('test_target_0_i1', shape=(10, 10)) f2_O_3 = relay.add(f2_cb3, f2_cb4) func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3) - - func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Compiler", "test_target") - func0 = func0.with_attr("global_symbol", "test_target_0") + func0 = set_func_attr(func0, "test_target", "test_target_0") gv0 = relay.GlobalVar("test_target_0") mod[gv0] = func0 @@ -973,6 +951,93 @@ def test_exec(mod, params, ref_mod, ref_params, out_shape): # test_exec(mod, params, ref_mod, ref_params, (1, 1000)) +def test_multiple_use_of_an_output(): + def expected_same_output_region(): + mod = tvm.IRModule() + x = relay.var("x", shape=(8, 8)) + y = relay.var("y", shape=(8, 8)) + z = relay.var("z", shape=(8, 8)) + x0 = relay.var("x0", shape=(8, 8)) + y0 = relay.var("y0", shape=(8, 8)) + log = relay.log(x0) + sub = x0 - y0 + mul = log * sub + # The partitioned graph contains log, subtract, and multiply + func = relay.Function([x0, y0], mul) + func = set_func_attr(func, "ccompiler", "ccompiler_0") + glb_0 = relay.GlobalVar("ccompiler_0") + mod[glb_0] = func + add = x + y + call = relay.Call(glb_0, [add, z]) + main = relay.Function([x, y, z], call) + mod["main"] = main + return mod + + def expected_different_output_region(): + mod = tvm.IRModule() + x = relay.var("x", shape=(8, 8)) + y = relay.var("y", shape=(8, 8)) + z = relay.var("z", shape=(8, 8)) + + # The partitioned graph contains log + i0 = relay.var("i0", shape=(8, 8)) + log = relay.log(i0) + func = relay.Function([i0], log) + func = set_func_attr(func, "ccompiler", "ccompiler_0") + glb_0 = relay.GlobalVar("ccompiler_0") + mod[glb_0] = func + + # The partitioned graph contains subtract + x0 = relay.var("x0", shape=(8, 8)) + y0 = relay.var("y0", shape=(8, 8)) + sub = x0 - y0 + func = relay.Function([x0, y0], sub) + func = set_func_attr(func, "ccompiler", "ccompiler_1") + glb_1 = relay.GlobalVar("ccompiler_1") + mod[glb_1] = func + + add = x + y + call_log = relay.Call(glb_0, [add]) + call_sub = relay.Call(glb_1, [add, z]) + main = relay.Function([x, y, z], call_log * call_sub) + mod["main"] = main + return mod + + def get_mod(): + x = relay.var("x", shape=(8, 8)) + y = relay.var("y", shape=(8, 8)) + z = relay.var("z", shape=(8, 8)) + add = x + y + sub = add - z + log = relay.log(add) + sub1 = log * sub + f = relay.Function([x, y, z], sub1) + mod = tvm.IRModule() + mod["main"] = f + return mod + + def test_same_output_region(): + mod = get_mod() + mod = WhiteListAnnotator(["subtract", "log", "multiply"], "ccompiler")(mod) + mod = transform.MergeCompilerRegions()(mod) + mod = transform.PartitionGraph()(mod) + + expected_mod = expected_same_output_region() + assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True) + + def test_different_output_region(): + mod = get_mod() + mod = WhiteListAnnotator(["subtract", "log"], "ccompiler")(mod) + mod = transform.MergeCompilerRegions()(mod) + mod = transform.PartitionGraph()(mod) + + expected_mod = expected_different_output_region() + assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True) + + test_same_output_region() + test_different_output_region() + + if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op()