Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MergeComposite] Fix InferType when module contains Prelude #5797

Merged
merged 1 commit into from
Jun 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm import te
import tvm.relay as relay
import tvm.relay.op as op
from tvm.relay import Prelude


from . import mlp
Expand All @@ -44,9 +45,11 @@
from .py_converter import to_python, run_as_python
from ..transform import gradient

def run_opt_pass(expr, opt_pass):
def run_opt_pass(expr, opt_pass, import_prelude=False):
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
if import_prelude:
Prelude(mod)
mod = opt_pass(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
Expand Down
12 changes: 7 additions & 5 deletions src/relay/transforms/merge_composite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,24 @@ namespace tvm {
namespace relay {
namespace merge_composite {

Function InferType(const Function& expr) {
auto mod = IRModule::FromExpr(expr);
Function InferType(const Function& expr, const IRModule& m) {
IRModule mod(m);
mod->Update(mod->GetGlobalVar("main"), expr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since now we update the original module with new transformed function, should we update the corresponding function instead of main?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm, you're right, reading the pass infrastructure a little more today.

The FunctionPass, however, doesn't seem to pass the information on what Function we're see down to the passes: https://github.com/apache/incubator-tvm/blob/8578096853eec5711bfcc9a3a68145fd12a135cb/src/relay/ir/transform.cc#L123-L132

I guess we can either change that API (which touches a lot of passes), or maybe invert this Map https://github.com/apache/incubator-tvm/blob/4347b41a5e64a2a453297b371232d6e101051b3c/include/tvm/ir/module.h#L53, find the global var, and store that in the class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe we can rewrite MergeComposite to be a module pass so that we can iterate functions by ourselves.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since now we update the original module with new transformed function, should we update the corresponding function instead of main?

For now, other functions in module don't call 'main', so it is safe to replace 'main'. If we are infering function which is mutated from 'main', that's just what we want, if we are infering other functions, it will be a duplicated function but it doesn't harm. And, with only a mutated function, it seems we can't find it in original global_var_map_ because we don't know its crossponding global variable's name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, although I still prefer to avoid dummy modules. I think it's fine to let his PR in first, and maybe I can fix that later on.

mod = transform::InferType()(mod);
return Downcast<Function>(mod->Lookup("main"));
}

Expr MergeComposite(const Function& func, const Array<runtime::String>& pattern_names,
const Array<DFPattern>& patterns, const std::vector<PackedFunc>& checks) {
const Array<DFPattern>& patterns, const std::vector<PackedFunc>& checks,
const IRModule& m) {
CHECK_EQ(pattern_names.size(), patterns.size());
Function merged_func = func;
// merge the patterns one-by-one in order
for (size_t i = 0; i < patterns.size(); i++) {
Map<String, ObjectRef> attrs;
attrs.Set("Composite", pattern_names[i]);
merged_func = Downcast<Function>(PartitionPattern(patterns[i], merged_func, attrs, checks[i]));
merged_func = InferType(merged_func);
merged_func = InferType(merged_func, m);
}
return std::move(merged_func);
}
Expand All @@ -65,7 +67,7 @@ Pass MergeComposite(const tvm::Array<runtime::String>& pattern_names,
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(
relay::merge_composite::MergeComposite(f, pattern_names, patterns, checks));
relay::merge_composite::MergeComposite(f, pattern_names, patterns, checks, m));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {});
return func_pass;
Expand Down
6 changes: 3 additions & 3 deletions tests/python/relay/test_pass_merge_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def make_bn_relu_pattern():
r = is_op('nn.relu')(tuple_get_item_node)
return r

def check_result(pattern_table, graph, expected_graph):
def check_result(pattern_table, graph, expected_graph, import_prelude=False):
"""Utility function to check merge composite results."""
result = run_opt_pass(graph, relay.transform.MergeComposite(pattern_table))
result = run_opt_pass(graph, relay.transform.MergeComposite(pattern_table), import_prelude=import_prelude)
assert not relay.analysis.free_vars(result), \
"Found free vars in the result graph: {0}".format(str(result))
expected = run_opt_pass(expected_graph, relay.transform.InferType())
Expand Down Expand Up @@ -213,7 +213,7 @@ def expected():
r = relay.Call(add_relu, [a, b])
return relay.Function([a, b], r)

check_result(pattern_table, before(), expected())
check_result(pattern_table, before(), expected(), import_prelude=True)


def test_branch_merge():
Expand Down