Skip to content

Commit

Permalink
[Relay][Pass] Add pass to remove unused functions in relay module (ap…
Browse files Browse the repository at this point in the history
…ache#4334)

* [Relay][Pass] Add pass to remove unused functions in relay module

* Add tests

* Fix lint

* Fix visit order

* Add pass argument

* Fix
  • Loading branch information
wweic authored and zhiics committed Nov 20, 2019
1 parent 36f691f commit 39a6e7e
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 0 deletions.
16 changes: 16 additions & 0 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,22 @@ def BackwardFoldScaleAxis():
"""
return _transform.BackwardFoldScaleAxis()

def RemoveUnusedFunctions(entry_functions=None):
"""Remove unused global relay functions in a relay module.
Parameters
----------
entry_functions: list[string]
The set of entry functions to start from.
Returns
-------
ret : tvm.relay.Pass
The registered pass to remove unused functions.
"""
if entry_functions is None:
entry_functions = ['main']
return _transform.RemoveUnusedFunctions(entry_functions)

def ForwardFoldScaleAxis():
"""Fold the scaling of axis into weights of conv2d/dense.
Expand Down
3 changes: 3 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ namespace transform {

Pass LambdaLift();
Pass InlinePrimitives();
Pass RemoveUnusedFunctions(Array<tvm::Expr> entry_functions);

Pass ManifestAlloc(Target target_host) {
auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
Expand Down Expand Up @@ -864,6 +865,8 @@ void VMCompiler::Compile(Module mod,

Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
Array<Pass> pass_seqs;
Array<tvm::Expr> entry_functions{tvm::Expr{"main"}};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());

Expand Down
134 changes: 134 additions & 0 deletions src/relay/backend/vm/removed_unused_funcs.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file tvm/relay/backend/vm/remove_unused_funcs.cc
* \brief Remove unused global relay functions in a relay module.
*/

#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/logging.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <iostream>
#include <unordered_set>
#include <vector>

namespace tvm {
namespace relay {
namespace vm {

/**
* \brief Detects all the functions that can be possibly called by entry function.
*/
struct CallTracer : ExprVisitor {
Module module_;

// Record the names of all encountered functions
std::unordered_set<std::string> called_funcs_;

// Record the expressions that are being visited
std::unordered_set<Expr, NodeHash, NodeEqual> visiting_;

explicit CallTracer(const Module& module)
: module_{module},
called_funcs_{},
visiting_{} {}

void VisitExpr_(const CallNode* call_node) final {
Expr op = call_node->op;
for (auto param : call_node->args) {
VisitExpr(param);
}
if (auto func_node = op.as<FunctionNode>()) {
auto func = GetRef<Function>(func_node);
auto it = visiting_.find(func);
if (it != visiting_.end()) {
return;
}
visiting_.insert(func);
VisitExpr(func);
} else if (auto global = op.as<GlobalVarNode>()) {
called_funcs_.insert(global->name_hint);
auto func = module_->Lookup(global->name_hint);
auto it = visiting_.find(func);
if (it != visiting_.end()) {
return;
}
visiting_.insert(func);
VisitExpr(func);
}
}

std::unordered_set<std::string> Trace(const std::string& entry) {
called_funcs_.insert(entry);
auto main_func = module_->Lookup(entry);
VisitExpr(main_func);
return called_funcs_;
}
};

/*!
* \brief Remove functions that are not used.
*
* \param module The Relay module.
* \param entry_funcs The set of functions that can be entry function.
*
* \return The module with dead functions removed.
*/
Module RemoveUnusedFunctions(const Module& module,
Array<tvm::Expr> entry_funcs) {
std::unordered_set<std::string> called_funcs{};
for (auto entry : entry_funcs) {
auto* str_name = entry.as<ir::StringImm>();
auto funcs = CallTracer(module).Trace(str_name->value);
called_funcs.insert(funcs.cbegin(), funcs.cend());
}
auto existing_functions = module->functions;
for (auto f : existing_functions) {
auto it = called_funcs.find(f.first->name_hint);
if (it == called_funcs.end()) {
module->Remove(f.first);
}
}
return module;
}

} // namespace vm

namespace transform {

Pass RemoveUnusedFunctions(Array<tvm::Expr> entry_functions) {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
return relay::vm::RemoveUnusedFunctions(m, entry_functions);
};
return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {});
}

TVM_REGISTER_API("relay._transform.RemoveUnusedFunctions")
.set_body_typed(RemoveUnusedFunctions);

} // namespace transform

} // namespace relay
} // namespace tvm
75 changes: 75 additions & 0 deletions tests/python/relay/test_pass_remove_unused_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.prelude import Prelude

def test_remove_all_prelude_functions():
mod = relay.Module()
p = Prelude(mod)
x = relay.var("x", shape=(1, 16))
mod["main"] = relay.Function([x], x)
mod = relay.transform.RemoveUnusedFunctions()(mod)
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['main'])

def test_remove_all_prelude_functions_but_referenced_functions():
mod = relay.Module()
p = Prelude(mod)
x = relay.var("x", shape=(1, 16))
id_func = relay.Function([x], x)
id_name = relay.GlobalVar('id_func')
mod[id_name] = id_func

mod["main"] = relay.Function([x], id_name(x))
mod = relay.transform.RemoveUnusedFunctions()(mod)
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['id_func', 'main'])

def test_keep_only_referenced_prelude_functions():
mod = relay.Module()
p = Prelude(mod)
l = p.nil()
for i in [4, 3, 2, 1, 0]:
l = p.cons(relay.const(i), l)
body = p.hd(p.tl(p.tl(l)))
mod["main"] = relay.Function([], body)
mod = relay.transform.RemoveUnusedFunctions()(mod)
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['tl', 'hd', 'main'])

def test_multiple_entry_functions():
mod = relay.Module()
p = Prelude(mod)
l = p.nil()
for i in [4, 3, 2, 1, 0]:
l = p.cons(relay.const(i), l)
body = p.hd(p.tl(p.tl(l)))
mod["main1"] = relay.Function([], body)

x = relay.var("x", shape=(1, 16))
id_func = relay.Function([x], x)
id_name = relay.GlobalVar('id_func')
mod[id_name] = id_func
mod["main2"] = relay.Function([x], id_name(x))
mod = relay.transform.RemoveUnusedFunctions(['main1', 'main2'])(mod)
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['tl', 'hd', 'main2', 'id_func', 'main1'])

if __name__ == '__main__':
pytest.main()

0 comments on commit 39a6e7e

Please sign in to comment.