Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC] Configurable optimize pass for PartitionGraph #6777

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
*
* \return The pass.
*/
TVM_DLL Pass PartitionGraph();
TVM_DLL Pass PartitionGraph(runtime::PackedFunc foptimize = nullptr);

/*!
* \brief Inline the global functions marked as `inline` in a given Relay
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/relay/op/contrib/arm_compute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ def partition_for_arm_compute_lib(mod, params=None):
-------
ret : annotated and partitioned module.
"""

def optimize(mod):
foptimize = tvm._ffi.get_global_func("relay.ext.arm_compute_lib.optimize")
if foptimize is None:
raise RuntimeError(
"Failed to get the Arm compute library optimization pass. "
"Did you build with USE_ARM_COMPUTE_LIB=ON?"
)
return foptimize(mod)

if params:
mod["main"] = bind_params_by_name(mod["main"], params)

Expand All @@ -62,7 +72,7 @@ def partition_for_arm_compute_lib(mod, params=None):
transform.InferType(),
transform.MergeComposite(arm_compute_lib_pattern_table()),
transform.AnnotateTarget("arm_compute_lib"),
transform.PartitionGraph(),
transform.PartitionGraph(optimize),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why cant we just run the optimize pass before partition graph here ? Is there a reason why it has to be inside PartitionGraph() ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass applies to each partitioned function, so it has to be called after the partitioned function has been created. See #6068

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, then why not after PartitionGraph ? (using kCompiler as a filter) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As being said in that PR, we can definitely do that. In fact we are planning an RFC right now, so depending on one the RFC, we probably don't need this PR anymore.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds good :). I think PartitionGraph should just perform partitioning unless its really unavoidable to perform the needed post-processing after the pass.

]
)

Expand Down
11 changes: 8 additions & 3 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def EliminateCommonSubexpr(fskip=None):

Parameters
----------
fskip: Callable
fskip: Optional[Callable]
The callback function that decides whether an expression should be
skipped.

Expand Down Expand Up @@ -681,16 +681,21 @@ def LambdaLift():
return _ffi_api.LambdaLift()


def PartitionGraph():
def PartitionGraph(foptimize=None):
"""Partition a Relay program into regions that can be executed on different
backends.

Parameters
----------
foptimize: Optional[Callable]
The callback function that optimizes the partitioned Relay functions.

Returns
-------
ret: tvm.transform.Pass
The registered pass that partitions the Relay program.
"""
return _ffi_api.PartitionGraph()
return _ffi_api.PartitionGraph(foptimize)


def AnnotateTarget(targets):
Expand Down
21 changes: 13 additions & 8 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ struct RegionFuncMetadata {

class Partitioner : public MixedModeMutator {
public:
explicit Partitioner(const IRModule& module) : module_(module) {
explicit Partitioner(const IRModule& module, runtime::PackedFunc foptimize)
: module_(module), foptimize_(foptimize) {
for (auto f : module->functions) {
GlobalVar f_var = f.first;
BaseFunc f_func = f.second;
Expand Down Expand Up @@ -308,12 +309,12 @@ class Partitioner : public MixedModeMutator {
if (!params_bind.empty()) {
global_region_func = Downcast<Function>(relay::Bind(global_region_func, params_bind));
}
std::string ext_opt = "relay.ext." + target + ".optimize";
auto pf = tvm::runtime::Registry::Get(ext_opt);
if (pf != nullptr) {

// Optimize the partitioned function using user-specified optimization pass.
if (foptimize_ != nullptr) {
auto mod = IRModule::FromExpr(global_region_func);
mod = transform::InferType()(mod);
mod = (*pf)(mod);
mod = foptimize_(mod);
global_region_func = Downcast<Function>(mod->Lookup("main"));
}

Expand Down Expand Up @@ -392,6 +393,8 @@ class Partitioner : public MixedModeMutator {

/*!\brief The IRModule used for partitioning. */
IRModule module_;
/*!\brief The optimize pass for the partitioned functions. */
const PackedFunc foptimize_;
};

IRModule RemoveDefaultAnnotations(IRModule module) {
Expand Down Expand Up @@ -484,7 +487,7 @@ IRModule FlattenTupleOutputs(IRModule module) {

namespace transform {

Pass PartitionGraph() {
Pass PartitionGraph(const PackedFunc foptimize) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> flatten_tuples = [=](IRModule m,
PassContext pc) {
// There could be compiler_end annotations on tuples
Expand All @@ -503,8 +506,10 @@ Pass PartitionGraph() {
return partitioning::RemoveDefaultAnnotations(m);
};

runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); };
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = [=](IRModule m,
PassContext pc) {
return partitioning::Partitioner(m, foptimize).Partition();
};

auto flatten_tuples_pass = CreateModulePass(flatten_tuples, 0, "FlattenNestedTuples", {});
auto remove_default_pass = CreateModulePass(remove_defaults, 0, "RemoveDefaultAnnotations", {});
Expand Down
4 changes: 1 addition & 3 deletions tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,8 +1327,6 @@ def test_extern_opt():
def Optimize(mod):
return relay.transform.FoldConstant()(mod)

tvm.register_func("relay.ext.test_target.optimize", Optimize)

x = relay.var("x", shape=(2, 2))
y0 = relay.var("y0", shape=(2, 2))
y1 = relay.var("y1", shape=(2, 2))
Expand All @@ -1342,7 +1340,7 @@ def Optimize(mod):
mod = tvm.IRModule()
mod["main"] = f
mod = transform.InferType()(mod)
mod = transform.PartitionGraph()(mod)
mod = transform.PartitionGraph(Optimize)(mod)

try:
t0 = mod["test_target_0"]
Expand Down