diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index e173bc32734f6..d9555570d9fa0 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -261,37 +261,26 @@ class Partitioner : public MixedModeMutator { } /*! - * \brief Create a function and its function call for the given region. If the function has - * multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes - * will be created to serve output consumers. + * \brief Check if an expr is a constant or a tuple that only contain constants. */ - void CreateFunction(AnnotatedRegion region, const CallNode* end_node) { - // Create fields which is a unique list of outputs. - Array fields; - std::unordered_map out_expr_to_idx; - int out_idx = 0; - for (auto region_end_node : region->GetOutputs()) { - auto ret_node = Downcast(region_end_node)->args[0]; - // Don't duplicate outputs. - if (!out_expr_to_idx.count(ret_node)) { - auto ret_expr = MixedModeMutator::VisitExpr(ret_node); - fields.push_back(ret_expr); - out_expr_to_idx[ret_node] = out_idx++; - } - } + bool IsConstant(const Expr& expr) const { + if (expr->IsInstance()) return true; + if (!expr->IsInstance()) return false; + const auto* tn = expr.as(); + return std::all_of(tn->fields.begin(), tn->fields.end(), + [](const Expr& e) { return e->IsInstance(); }); + } + /*! + * \brief Create a call to the function that represents a region. + * \note The customized optimization pipeline will be invoked as well to + * optimize each function that is handled by external codegen. + */ + Call CreateRegionCall(AnnotatedRegion region, const Array& fields, + const CallNode* end_node) { Array params; Array param_expr; Map params_bind; - - auto IsConstant = [](const Expr& expr) { - if (expr->IsInstance()) return true; - if (!expr->IsInstance()) return false; - const auto* tn = expr.as(); - return std::all_of(tn->fields.begin(), tn->fields.end(), - [](const Expr& e) { return e->IsInstance(); }); - }; - for (auto pair : region_func_meta_[region].args) { params.push_back(pair.first); if (IsConstant(pair.second)) { @@ -314,6 +303,18 @@ class Partitioner : public MixedModeMutator { std::string target = end_node->attrs.as()->compiler; std::string name = target + "_" + std::to_string(region->GetID()); + // Constant propagation + if (!params_bind.empty()) { + global_region_func = Downcast(relay::Bind(global_region_func, params_bind)); + } + std::string ext_opt = "relay.ext." + target + ".optimize"; + auto pf = tvm::runtime::Registry::Get(ext_opt); + if (pf != nullptr) { + auto mod = IRModule::FromExpr(global_region_func); + mod = (*pf)(mod); + global_region_func = Downcast(mod->Lookup("main")); + } + global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, runtime::String(name)); global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); @@ -321,11 +322,6 @@ class Partitioner : public MixedModeMutator { WithAttr(std::move(global_region_func), attr::kCompiler, tvm::runtime::String(target)); global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); - // Constant propagation - if (!params_bind.empty()) { - global_region_func = Downcast(relay::Bind(global_region_func, params_bind)); - } - std::string fname = name; CHECK(!module_->ContainGlobalVar(fname)) << "Global function " << fname << " already exists"; // Create a global function and add it to the IRModule for the region. @@ -340,6 +336,31 @@ class Partitioner : public MixedModeMutator { auto call = Call(glob_func, param_expr); region_func_meta_[region].func_call = call; + return call; + } + + /*! + * \brief Create a function and its function call for the given region. If the function has + * multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes + * will be created to serve output consumers. + */ + void CreateFunction(AnnotatedRegion region, const CallNode* end_node) { + // Create fields which is a unique list of outputs. + Array fields; + std::unordered_map out_expr_to_idx; + int out_idx = 0; + for (auto region_end_node : region->GetOutputs()) { + auto ret_node = Downcast(region_end_node)->args[0]; + // Don't duplicate outputs. + if (!out_expr_to_idx.count(ret_node)) { + auto ret_expr = MixedModeMutator::VisitExpr(ret_node); + fields.push_back(ret_expr); + out_expr_to_idx[ret_node] = out_idx++; + } + } + + Call call = CreateRegionCall(region, fields, end_node); + // Create output expr(s) for the function call. if (out_expr_to_idx.size() == 1) { // Single output direcly uses the call node as the output expr. diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 8dc5344b00be5..84474f63ade91 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -1286,6 +1286,37 @@ def test_tuple_output_exec(): [(10, 10), (10, 10)], [(a_data + b_data), (a_data - b_data)]) +def test_extern_opt(): + def Optimize(mod): + return relay.transform.FoldConstant()(mod) + + tvm.register_func("relay.ext.test_target.optimize", Optimize) + + x = relay.var('x', shape=(2, 2)) + y0 = relay.var('y0', shape=(2, 2)) + y1 = relay.var('y1', shape=(2, 2)) + yy0 = relay.annotation.compiler_begin(y0, 'test_target') + yy1 = relay.annotation.compiler_begin(y1, 'test_target') + z = yy0 + yy1 + end = relay.annotation.compiler_end(z, 'test_target') + f = relay.Function([x, y0, y1], end * x) + c = np.ones(shape=(2, 2), dtype="float32") + f = bind_params_by_name(f, {"y0": tvm.nd.array(c), "y1": tvm.nd.array(c)}) + mod = tvm.IRModule() + mod["main"] = f + mod = transform.PartitionGraph()(mod) + + try: + t0 = mod["test_target_0"] + except: + raise KeyError("test_target_0 not found") + + assert isinstance(t0.body, relay.Constant) + expected = np.empty([2, 2]) + expected.fill(2) + tvm.testing.assert_allclose(t0.body.data.asnumpy(), expected, rtol=1e-5, + atol=1e-5) + if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op() @@ -1305,3 +1336,4 @@ def test_tuple_output_exec(): test_constant_tuples() test_flatten_tuple_output() test_tuple_output_exec() + test_extern_opt()