diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index cbd6a88e584e..ce8b592a0f6d 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -377,7 +377,7 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var); * * \return The pass. */ -TVM_DLL Pass PartitionGraph(); +TVM_DLL Pass PartitionGraph(runtime::PackedFunc foptimize = nullptr); /*! * \brief Inline the global functions marked as `inline` in a given Relay diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 8dfb3b7e0bf4..b42ca9faaefc 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -54,6 +54,16 @@ def partition_for_arm_compute_lib(mod, params=None): ------- ret : annotated and partitioned module. """ + + def optimize(mod): + foptimize = tvm._ffi.get_global_func("relay.ext.arm_compute_lib.optimize") + if foptimize is None: + raise RuntimeError( + "Failed to get the Arm compute library optimization pass. " + "Did you build with USE_ARM_COMPUTE_LIB=ON?" + ) + return foptimize(mod) + if params: mod["main"] = bind_params_by_name(mod["main"], params) @@ -62,7 +72,7 @@ def partition_for_arm_compute_lib(mod, params=None): transform.InferType(), transform.MergeComposite(arm_compute_lib_pattern_table()), transform.AnnotateTarget("arm_compute_lib"), - transform.PartitionGraph(), + transform.PartitionGraph(optimize), ] ) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 060547e4c4d7..f1bcfd467cae 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -627,7 +627,7 @@ def EliminateCommonSubexpr(fskip=None): Parameters ---------- - fskip: Callable + fskip: Optional[Callable] The callback function that decides whether an expression should be skipped. @@ -681,16 +681,21 @@ def LambdaLift(): return _ffi_api.LambdaLift() -def PartitionGraph(): +def PartitionGraph(foptimize=None): """Partition a Relay program into regions that can be executed on different backends. + Parameters + ---------- + foptimize: Optional[Callable] + The callback function that optimizes the partitioned Relay functions. + Returns ------- ret: tvm.transform.Pass The registered pass that partitions the Relay program. """ - return _ffi_api.PartitionGraph() + return _ffi_api.PartitionGraph(foptimize) def AnnotateTarget(targets): diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 75bc46387cc6..ccf0fdb7e204 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -113,7 +113,8 @@ struct RegionFuncMetadata { class Partitioner : public MixedModeMutator { public: - explicit Partitioner(const IRModule& module) : module_(module) { + explicit Partitioner(const IRModule& module, runtime::PackedFunc foptimize) + : module_(module), foptimize_(foptimize) { for (auto f : module->functions) { GlobalVar f_var = f.first; BaseFunc f_func = f.second; @@ -308,12 +309,12 @@ class Partitioner : public MixedModeMutator { if (!params_bind.empty()) { global_region_func = Downcast(relay::Bind(global_region_func, params_bind)); } - std::string ext_opt = "relay.ext." + target + ".optimize"; - auto pf = tvm::runtime::Registry::Get(ext_opt); - if (pf != nullptr) { + + // Optimize the partitioned function using user-specified optimization pass. + if (foptimize_ != nullptr) { auto mod = IRModule::FromExpr(global_region_func); mod = transform::InferType()(mod); - mod = (*pf)(mod); + mod = foptimize_(mod); global_region_func = Downcast(mod->Lookup("main")); } @@ -392,6 +393,8 @@ class Partitioner : public MixedModeMutator { /*!\brief The IRModule used for partitioning. */ IRModule module_; + /*!\brief The optimize pass for the partitioned functions. */ + const PackedFunc foptimize_; }; IRModule RemoveDefaultAnnotations(IRModule module) { @@ -484,7 +487,7 @@ IRModule FlattenTupleOutputs(IRModule module) { namespace transform { -Pass PartitionGraph() { +Pass PartitionGraph(const PackedFunc foptimize) { runtime::TypedPackedFunc flatten_tuples = [=](IRModule m, PassContext pc) { // There could be compiler_end annotations on tuples @@ -503,8 +506,10 @@ Pass PartitionGraph() { return partitioning::RemoveDefaultAnnotations(m); }; - runtime::TypedPackedFunc part_func = - [=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); }; + runtime::TypedPackedFunc part_func = [=](IRModule m, + PassContext pc) { + return partitioning::Partitioner(m, foptimize).Partition(); + }; auto flatten_tuples_pass = CreateModulePass(flatten_tuples, 0, "FlattenNestedTuples", {}); auto remove_default_pass = CreateModulePass(remove_defaults, 0, "RemoveDefaultAnnotations", {}); diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 2fd440e1c2c9..1e5fe133aca7 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -1327,8 +1327,6 @@ def test_extern_opt(): def Optimize(mod): return relay.transform.FoldConstant()(mod) - tvm.register_func("relay.ext.test_target.optimize", Optimize) - x = relay.var("x", shape=(2, 2)) y0 = relay.var("y0", shape=(2, 2)) y1 = relay.var("y1", shape=(2, 2)) @@ -1342,7 +1340,7 @@ def Optimize(mod): mod = tvm.IRModule() mod["main"] = f mod = transform.InferType()(mod) - mod = transform.PartitionGraph()(mod) + mod = transform.PartitionGraph(Optimize)(mod) try: t0 = mod["test_target_0"]