diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index 94c7621e60af..103ddcb31111 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -86,32 +86,69 @@ AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) { return *ret.first; } -class AnnotatedRegionSet::Creator : public ExprVisitor { +class AnnotatedRegionSet::Creator : protected MixedModeVisitor { public: Creator(const Op& region_begin_op, const Op& region_end_op) : begin_op_(region_begin_op), end_op_(region_end_op) {} + AnnotatedRegionSet Create(const Expr& expr) { + VisitExpr(expr); + return std::move(region_set_); + } + + void AddToArgRegion(Expr expr, Array args) { + // Merge argument regions and add itself to the region. + + // Find the first open region. + AnnotatedRegion region; + for (auto arg : args) { + const CallNode* end = arg.as(); + if (end && end->op == end_op_) { // Ignore closed regions. + continue; + } + + region = region_set_->GetRegion(arg); + if (region.defined()) { + break; + } + } + + // Try to merge open regions. + for (auto arg : args) { + const CallNode* end = arg.as(); + if (end && end->op == end_op_) { // Ignore closed regions. + continue; + } + + auto arg_region = region_set_->GetRegion(arg); + CHECK_EQ(region.defined(), arg_region.defined()) + << "Arg regions are inconsistent: " << AsText(expr); + if (region.defined() && region != arg_region) { + region_set_->MergeRegions(arg_region, region); + } + } + if (region.defined()) { + region_set_->AddToRegion(region, expr); + } + } + void VisitExpr_(const CallNode* call) { auto op_node = call->op.as(); if (op_node == nullptr || call->attrs.as() == nullptr) { - // Propagate region to arguments - auto region = region_set_->GetRegion(GetRef(call)); - if (region.defined()) { - for (auto arg : call->args) { - region_set_->AddToRegion(region, arg); - } - } + AddToArgRegion(GetRef(call), call->args); } else if (call->op == begin_op_) { // The annotation node is inserted on edge so it must have only one argument. CHECK_EQ(call->args.size(), 1U); + std::string target = call->attrs.as()->compiler; + // Check if the argument already belongs to a region auto region = region_set_->GetRegion(GetRef(call)); - if (!region.defined()) { - throw Error(ErrorBuilder() - << "Cannot find the corresponding region for start annotation:\n" - << AsText(GetRef(call), false)); - } + CHECK(!region.defined()); + + // Create a new region. + region = region_set_->MakeRegion(target); + region->nodes_.insert(GetRef(call)); region->ins_.push_back(GetRef(call)); } else { CHECK_EQ(call->op, end_op_); @@ -122,9 +159,8 @@ class AnnotatedRegionSet::Creator : public ExprVisitor { // Check if the argument already belongs to a region auto region = region_set_->GetRegion(call->args[0]); if (!region.defined()) { - // Create a new region if the argument is not belonged to any regions yet. - region = region_set_->MakeRegion(target); - region->nodes_.insert(call->args[0]); + throw Error(ErrorBuilder() << "Cannot find the corresponding region for end annotation:\n" + << AsText(GetRef(call), false)); } else { // If the argument is belonged to a region, it must have the same target. // Otherwise we should see a region_begin op. @@ -133,83 +169,44 @@ class AnnotatedRegionSet::Creator : public ExprVisitor { region->nodes_.insert(GetRef(call)); region->outs_.push_back(GetRef(call)); } - ExprVisitor::VisitExpr_(call); - } - - AnnotatedRegionSet Create(const Expr& expr) { - VisitExpr(expr); - return std::move(region_set_); } void VisitExpr_(const TupleNode* op) { - auto region = region_set_->GetRegion(GetRef(op)); - if (region.defined()) { - for (auto field : op->fields) { - region_set_->AddToRegion(region, field); - } - } - ExprVisitor::VisitExpr_(op); + AddToArgRegion(GetRef(op), op->fields); } void VisitExpr_(const TupleGetItemNode* g) { - auto region = region_set_->GetRegion(GetRef(g)); - if (region.defined()) { - region_set_->AddToRegion(region, g->tuple); - } - ExprVisitor::VisitExpr_(g); - } - - void VisitExpr_(const FunctionNode* op) { - auto region = region_set_->GetRegion(GetRef(op)); - if (region.defined()) { - for (auto param : op->params) { - region_set_->AddToRegion(region, param); - } - } - ExprVisitor::VisitExpr_(op); + Array args = {g->tuple}; + AddToArgRegion(GetRef(g), args); } void VisitExpr_(const LetNode* op) { - auto region = region_set_->GetRegion(GetRef(op)); - if (region.defined()) { - region_set_->AddToRegion(region, op->var); - region_set_->AddToRegion(region, op->value); - region_set_->AddToRegion(region, op->body); - } + Array args = {op->var, op->value, op->body}; + AddToArgRegion(GetRef(op), args); ExprVisitor::VisitExpr_(op); } void VisitExpr_(const IfNode* op) { - auto region = region_set_->GetRegion(GetRef(op)); - if (region.defined()) { - region_set_->AddToRegion(region, op->cond); - region_set_->AddToRegion(region, op->true_branch); - region_set_->AddToRegion(region, op->false_branch); - } + Array args = {op->cond, op->true_branch, op->false_branch}; + AddToArgRegion(GetRef(op), args); ExprVisitor::VisitExpr_(op); } void VisitExpr_(const RefCreateNode* op) { - auto region = region_set_->GetRegion(GetRef(op)); - if (region.defined()) { - region_set_->AddToRegion(region, op->value); - } + Array args = {op->value}; + AddToArgRegion(GetRef(op), args); ExprVisitor::VisitExpr_(op); } void VisitExpr_(const RefReadNode* op) { - auto region = region_set_->GetRegion(GetRef(op)); - if (region.defined()) { - region_set_->AddToRegion(region, op->ref); - } + Array args = {op->ref}; + AddToArgRegion(GetRef(op), args); ExprVisitor::VisitExpr_(op); } void VisitExpr_(const RefWriteNode* op) { - auto region = region_set_->GetRegion(GetRef(op)); - if (region.defined()) { - region_set_->AddToRegion(region, op->ref); - } + Array args = {op->ref}; + AddToArgRegion(GetRef(op), args); ExprVisitor::VisitExpr_(op); } diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index bc6b4b993ae8..4caac042f016 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -42,9 +42,9 @@ const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._mak // A helper class to insert annotation boundaries for a program region that will // be handled by a specific compiler. -class AnnotateTargetWrapper : public ExprMutator { +class AnnotateTargetRewriter : public ExprRewriter { public: - explicit AnnotateTargetWrapper(Array targets) : targets_(std::move(targets)) {} + explicit AnnotateTargetRewriter(Array targets) : targets_(std::move(targets)) {} /*! * \brief This function annotates a compiler end and a compiler begin to all arguments. @@ -108,29 +108,29 @@ class AnnotateTargetWrapper : public ExprMutator { return new_op; } - Expr VisitExpr_(const CallNode* cn) final { + Expr Rewrite_(const CallNode* pre, const Expr& post) final { // Supported targets for this node. The order implies the priority. std::vector supported_targets; - auto op_node = cn->op.as(); + auto op_node = pre->op.as(); // This graph has annotations, meaning that this is not the first time running this pass. - if (op_node && cn->op == compiler_begin_op) { + if (op_node && pre->op == compiler_begin_op) { // Bypass compiler begin due to lack of target information. It will be processed // when the following op handling arguments. - CHECK_EQ(cn->args.size(), 1U); - return VisitExpr(cn->args[0]); - } else if (op_node && cn->op == compiler_end_op) { + CHECK_EQ(pre->args.size(), 1U); + return post.as()->args[0]; + } else if (op_node && pre->op == compiler_end_op) { // Override compiler end with the new target. - CHECK_EQ(cn->args.size(), 1U); - auto input_expr = VisitExpr(cn->args[0]); + CHECK_EQ(pre->args.size(), 1U); + auto input_expr = post.as()->args[0]; CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end()); return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op); } // Peek the first argument. If it is compiler begin then this node had annotated by // another target before, so we also consider that target as a supported target. - const CallNode* first_arg_call = cn->args[0].as(); + const CallNode* first_arg_call = pre->args[0].as(); if (first_arg_call && first_arg_call->op == compiler_begin_op) { std::string arg_target = first_arg_call->attrs.as()->compiler; if (arg_target != "default") { @@ -142,21 +142,21 @@ class AnnotateTargetWrapper : public ExprMutator { if (op_node) { // TVM operators: Check target specific op checking function and add to supported_targets // if it is supported. - Op op = Downcast(cn->op); + Op op = Downcast(pre->op); CHECK(op.defined()); for (const auto& target : this->targets_) { if (!Op::HasAttr("target." + std::string(target))) { continue; } auto fannotate = Op::GetAttr("target." + std::string(target)); - if (fannotate.count(op) && fannotate[op](cn->attrs, cn->args)) { + if (fannotate.count(op) && fannotate[op](pre->attrs, pre->args)) { supported_targets.push_back(target); } } - } else if (cn->op->IsInstance()) { + } else if (pre->op->IsInstance()) { // Composite function: Add the target of a composite function to supported_targets // if it is in the target list. - Function func = Downcast(cn->op); + Function func = Downcast(pre->op); CHECK(func.defined()); if (auto comp_name = func->GetAttr(attr::kComposite)) { @@ -181,23 +181,22 @@ class AnnotateTargetWrapper : public ExprMutator { std::string target = supported_targets[0]; // Visit and mutate arguments after the target of this op has been determined. - auto new_call = Downcast(ExprMutator::VisitExpr_(cn)); + Call post_call = Downcast(post); // Add annotations to each arg. - auto target_n_args = AnnotateArgs(new_call->args, target); + auto target_n_args = AnnotateArgs(post_call->args, target); Array compiler_begins = std::get<1>(target_n_args); - Call call = Call(new_call->op, compiler_begins, new_call->attrs); - call->checked_type_ = cn->checked_type_; + Call new_call = Call(post_call->op, compiler_begins, post_call->attrs); + new_call->checked_type_ = pre->checked_type_; // Update the target map. - op_expr_to_target_[call] = target; + op_expr_to_target_[new_call] = target; - return std::move(call); + return std::move(new_call); } - Expr VisitExpr_(const TupleNode* op) final { - auto new_e = ExprMutator::VisitExpr_(op); - auto expr = Downcast(new_e); + Expr Rewrite_(const TupleNode* op, const Expr& post) final { + auto expr = Downcast(post); auto target_n_args = AnnotateArgs(expr->fields); auto new_expr = Tuple(std::get<1>(target_n_args)); @@ -205,9 +204,8 @@ class AnnotateTargetWrapper : public ExprMutator { return std::move(new_expr); } - Expr VisitExpr_(const TupleGetItemNode* op) final { - auto new_e = ExprMutator::VisitExpr_(op); - auto expr = Downcast(new_e); + Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final { + auto expr = Downcast(post); auto target_n_args = AnnotateArgs(Array({expr->tuple})); auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index); @@ -215,7 +213,7 @@ class AnnotateTargetWrapper : public ExprMutator { return std::move(new_expr); } - Expr VisitExpr_(const FunctionNode* fn) final { + Expr Rewrite_(const FunctionNode* fn, const Expr& post) final { Function func; Expr new_body; // don't step into composite functions @@ -223,8 +221,7 @@ class AnnotateTargetWrapper : public ExprMutator { func = GetRef(fn); new_body = func->body; } else { - auto new_e = ExprMutator::VisitExpr_(fn); - func = Downcast(new_e); + func = Downcast(post); new_body = func->body; if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) { new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op); @@ -234,9 +231,8 @@ class AnnotateTargetWrapper : public ExprMutator { return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs); } - Expr VisitExpr_(const LetNode* op) final { - auto new_e = ExprMutator::VisitExpr_(op); - auto let = Downcast(new_e); + Expr Rewrite_(const LetNode* op, const Expr& post) final { + auto let = Downcast(post); auto target_n_args = AnnotateArgs({let->value, let->body}); auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); @@ -244,9 +240,8 @@ class AnnotateTargetWrapper : public ExprMutator { return std::move(new_expr); } - Expr VisitExpr_(const IfNode* op) final { - auto new_e = ExprMutator::VisitExpr_(op); - auto expr = Downcast(new_e); + Expr Rewrite_(const IfNode* op, const Expr& post) final { + auto expr = Downcast(post); auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch}); CHECK_EQ(std::get<1>(target_n_args).size(), 3U); @@ -256,9 +251,8 @@ class AnnotateTargetWrapper : public ExprMutator { return std::move(new_expr); } - Expr VisitExpr_(const RefCreateNode* op) final { - auto new_e = ExprMutator::VisitExpr_(op); - auto expr = Downcast(new_e); + Expr Rewrite_(const RefCreateNode* op, const Expr& post) final { + auto expr = Downcast(post); auto target_n_args = AnnotateArgs(Array({expr->value})); auto new_expr = RefCreate(std::get<1>(target_n_args)[0]); @@ -266,9 +260,8 @@ class AnnotateTargetWrapper : public ExprMutator { return std::move(new_expr); } - Expr VisitExpr_(const RefReadNode* op) final { - auto new_e = ExprMutator::VisitExpr_(op); - auto expr = Downcast(new_e); + Expr Rewrite_(const RefReadNode* op, const Expr& post) final { + auto expr = Downcast(post); auto target_n_args = AnnotateArgs(Array({expr->ref})); auto new_expr = RefRead(std::get<1>(target_n_args)[0]); @@ -276,9 +269,8 @@ class AnnotateTargetWrapper : public ExprMutator { return std::move(new_expr); } - Expr VisitExpr_(const RefWriteNode* op) final { - auto new_e = ExprMutator::VisitExpr_(op); - auto expr = Downcast(new_e); + Expr Rewrite_(const RefWriteNode* op, const Expr& post) final { + auto expr = Downcast(post); auto target_n_args = AnnotateArgs(Array({expr->ref, expr->value})); auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); @@ -294,7 +286,8 @@ class AnnotateTargetWrapper : public ExprMutator { }; Expr AnnotateTarget(const Expr& expr, const Array& targets) { - return AnnotateTargetWrapper(targets).Mutate(expr); + auto rewriter = AnnotateTargetRewriter(targets); + return PostOrderRewrite(expr, &rewriter); } } // namespace annotate_target diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc index 601be0f96bc4..6fbd0d513e79 100644 --- a/src/relay/transforms/merge_compiler_regions.cc +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -53,7 +53,7 @@ namespace merge_compiler_region { static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); -class RegionMerger : public ExprVisitor { +class RegionMerger : public MixedModeVisitor { public: explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {} @@ -131,7 +131,6 @@ class RegionMerger : public ExprVisitor { } merged_regions_.insert(region->GetID()); } - ExprVisitor::VisitExpr_(call); } private: @@ -140,11 +139,11 @@ class RegionMerger : public ExprVisitor { std::unordered_map> region_restrictions_; }; -class MergeAnnotations : public ExprMutator { +class MergeAnnotations : public ExprRewriter { public: explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {} - Expr VisitExpr_(const CallNode* call) final { + Expr Rewrite_(const CallNode* call, const Expr& post) final { // Merge annotations which are now internal to a region. // This happens if we see a compiler begin next to a // compiler end and they're both in the same region. @@ -154,11 +153,12 @@ class MergeAnnotations : public ExprMutator { auto region1 = regions_->GetRegion(GetRef(call)); auto region2 = regions_->GetRegion(arg); if (region1 == region2) { - return VisitExpr(arg->args[0]); + auto post_arg = post.as()->args[0]; + return post_arg.as()->args[0]; } } } - return ExprMutator::VisitExpr_(call); + return post; } private: @@ -175,7 +175,7 @@ Expr MergeCompilerRegions(const Expr& expr) { // Remove annotations that are not in the region boundaries. MergeAnnotations merge_anno(regions); - return merge_anno.Mutate(expr); + return PostOrderRewrite(expr, &merge_anno); } } // namespace merge_compiler_region diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 14d57a92f106..2a4fd31041d7 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -522,8 +522,8 @@ def expected(): bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn.astuple()) - func0 = set_func_attr(func0, "test_compiler", "test_compiler_0") - gv0 = relay.GlobalVar("test_compiler_0") + func0 = set_func_attr(func0, "test_compiler", "test_compiler_2") + gv0 = relay.GlobalVar("test_compiler_2") mod[gv0] = func0 # function for conv2d @@ -536,8 +536,8 @@ def expected(): channels=16, padding=(1, 1)) func1 = relay.Function([data1, weight1], conv) - func1 = set_func_attr(func1, "test_compiler", "test_compiler_1") - gv1 = relay.GlobalVar("test_compiler_1") + func1 = set_func_attr(func1, "test_compiler", "test_compiler_0") + gv1 = relay.GlobalVar("test_compiler_0") mod[gv1] = func1 # main function @@ -630,7 +630,6 @@ def test_constant_propagation(): def expected(): mod = tvm.IRModule() - x = relay.const(ones) y = relay.var("y", shape=(8, 8)) x0 = relay.const(ones) y0 = relay.var("y0", shape=(8, 8)) @@ -712,12 +711,12 @@ def expected(): mod = tvm.IRModule() # function 0 - data = relay.var("test_target_2_i0", relay.TensorType((1, 3, 224, 224), "float32")) - weight = relay.var("test_target_2_i1", relay.TensorType((16, 3, 3, 3), "float32")) - bn_gamma = relay.var("test_target_2_i2", relay.TensorType((16, ), "float32")) - bn_beta = relay.var("test_target_2_i3", relay.TensorType((16, ), "float32")) - bn_mean = relay.var("test_target_2_i4", relay.TensorType((16, ), "float32")) - bn_var = relay.var("test_target_2_i5", relay.TensorType((16, ), "float32")) + data = relay.var("test_target_0_i0", relay.TensorType((1, 3, 224, 224), "float32")) + weight = relay.var("test_target_0_i1", relay.TensorType((16, 3, 3, 3), "float32")) + bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16, ), "float32")) + bn_beta = relay.var("test_target_0_i3", relay.TensorType((16, ), "float32")) + bn_mean = relay.var("test_target_0_i4", relay.TensorType((16, ), "float32")) + bn_var = relay.var("test_target_0_i5", relay.TensorType((16, ), "float32")) conv_o = relay.nn.conv2d( data=data, @@ -730,12 +729,12 @@ def expected(): bn_var) relu_o = relay.nn.relu(bn_o[0]) - tuple_o = relay.Tuple((bn_o[2], bn_o[1], relu_o)) + tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2])) func0 = relay.Function([data, weight, bn_gamma, bn_beta, bn_mean, bn_var], tuple_o) - func0 = set_func_attr(func0, "test_target", "test_target_2") - gv0 = relay.GlobalVar("test_target_2") + func0 = set_func_attr(func0, "test_target", "test_target_0") + gv0 = relay.GlobalVar("test_target_0") mod[gv0] = func0 # body @@ -747,9 +746,9 @@ def expected(): bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32")) f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var) - f0_relu_o = relay.TupleGetItem(f0_o, 2) + f0_relu_o = relay.TupleGetItem(f0_o, 0) f0_mean_o = relay.TupleGetItem(f0_o, 1) - f0_var_o = relay.TupleGetItem(f0_o, 0) + f0_var_o = relay.TupleGetItem(f0_o, 2) f0_mean_abs = relay.abs(f0_mean_o) f0_var_abs = relay.abs(f0_var_o) @@ -791,22 +790,22 @@ def expected(): mod = tvm.IRModule() # function 1 - f1_cb1 = relay.var('test_target_1_i0', shape=(10, 10)) + f1_cb1 = relay.var('test_target_0_i0', shape=(10, 10)) f1_O_1 = relay.abs(f1_cb1) f1_O_2 = relay.nn.relu(f1_O_1) f1_out = relay.Tuple((f1_O_2, f1_O_1)) func1 = relay.Function([f1_cb1], f1_out) - func1 = set_func_attr(func1, "test_target", "test_target_1") - gv1 = relay.GlobalVar("test_target_1") + func1 = set_func_attr(func1, "test_target", "test_target_0") + gv1 = relay.GlobalVar("test_target_0") mod[gv1] = func1 # function 0 - f2_cb3 = relay.var('test_target_0_i0', shape=(10, 10)) - f2_cb4 = relay.var('test_target_0_i1', shape=(10, 10)) + f2_cb3 = relay.var('test_target_1_i0', shape=(10, 10)) + f2_cb4 = relay.var('test_target_1_i1', shape=(10, 10)) f2_O_3 = relay.add(f2_cb3, f2_cb4) func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3) - func0 = set_func_attr(func0, "test_target", "test_target_0") - gv0 = relay.GlobalVar("test_target_0") + func0 = set_func_attr(func0, "test_target", "test_target_1") + gv0 = relay.GlobalVar("test_target_1") mod[gv0] = func0 # body @@ -1109,22 +1108,22 @@ def expected(): mod = tvm.IRModule() # function 0 - f0_i0 = relay.var(target+"_1_i0", shape=(10, 10)) - f0_i1 = relay.var(target+"_1_i1") - f0_i2 = relay.var(target+"_1_i2") - f0_i3 = relay.var(target+"_1_i3") - f0_i4 = relay.var(target+"_1_i4") + f0_i0 = relay.var(target + "_0_i0", shape=(10, 10)) + f0_i1 = relay.var(target + "_0_i1") + f0_i2 = relay.var(target + "_0_i2") + f0_i3 = relay.var(target + "_0_i3") + f0_i4 = relay.var(target + "_0_i4") f0_n0 = relay.nn.batch_norm(f0_i0, f0_i1, f0_i2, f0_i3, f0_i4) f0_n1 = f0_n0[1] f0_n2 = relay.nn.relu(f0_n0[0]) - f0_o0 = relay.Tuple([f0_n1, f0_n2]) + f0_o0 = relay.Tuple([f0_n2, f0_n1]) func0 = relay.Function([f0_i0, f0_i1, f0_i2, f0_i3, f0_i4], f0_o0) func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", target) - func0 = func0.with_attr("global_symbol", target+"_1") - gv0 = relay.GlobalVar(target+"_1") + func0 = func0.with_attr("global_symbol", target + "_0") + gv0 = relay.GlobalVar(target + "_0") mod[gv0] = func0 # body @@ -1136,9 +1135,9 @@ def expected(): function_out = gv0(data, bn_gamma, bn_beta, bn_mmean, bn_mvar) get_out0 = relay.TupleGetItem(function_out, 0) get_out1 = relay.TupleGetItem(function_out, 1) - out_2 = relay.tanh(get_out0) - out_3 = relay.log(get_out0) - out = relay.Tuple([get_out1, out_2, out_3]) + out_2 = relay.tanh(get_out1) + out_3 = relay.log(get_out1) + out = relay.Tuple([get_out0, out_2, out_3]) func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out) mod["main"] = func return mod