Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PassManager] Implement pass manager tracing API #4782

Merged
merged 7 commits into from
Jan 28, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@
namespace tvm {
namespace transform {

// Forward declare for TraceFunc.
class PassInfo;

/*! \brief A callback for tracing passes, useful for debugging and logging.
*
*/
using TraceFunc = runtime::TypedPackedFunc<void(const IRModule& ir_module, const PassInfo& ctx, bool is_before)>;

/*!
* \brief PassContextNode contains the information that a pass can rely on,
* such as analysis results.
Expand All @@ -88,6 +96,8 @@ class PassContextNode : public Object {
/*! \brief The list of disabled passes. */
Array<PrimExpr> disabled_pass;

TraceFunc trace_func;

PassContextNode() = default;

void VisitAttrs(AttrVisitor* v) {
Expand All @@ -101,6 +111,7 @@ class PassContextNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
};


/*!
* \brief PassContext that is used to configure the pass behavior.
*
Expand Down Expand Up @@ -146,6 +157,14 @@ class PassContext : public ObjectRef {
*/
TVM_DLL static PassContext Current();

/*!
* \brief Apply the tracing functions of the context to the module, with the info.
* \param module The IRModule to trace.
* \param info The pass information.
* \param is_before Indicated whether the tracing is before or after a pass.
*/
TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const;

// accessor.
using ContainerType = PassContextNode;
class Internal;
Expand Down
13 changes: 9 additions & 4 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def __init__(self,
opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None):
disabled_pass=None,
trace=None):
if isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, TVMContext):
Expand All @@ -99,7 +100,7 @@ def __init__(self,

self.__init_handle_by_constructor__(_transform.PassContext, opt_level,
fallback_device, required,
disabled)
disabled, trace)

def __enter__(self):
_transform.EnterPassContext(self)
Expand All @@ -117,7 +118,8 @@ def current():
def build_config(opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None):
disabled_pass=None,
trace=None):
"""Configure the build behavior by setting config variables.

Parameters
Expand Down Expand Up @@ -151,13 +153,16 @@ def build_config(opt_level=2,
disabled_pass: set of str, optional
Optimization passes to be disabled during optimization.

trace: Callable[[IRModule, PassInfo, bool], None]
A tracing function for debugging or introspection.

Returns
-------
pass_context: PassContext
The pass context for optimizations.
"""
return PassContext(opt_level, fallback_device, required_pass,
disabled_pass)
disabled_pass, trace)


@register_relay_node
Expand Down
8 changes: 8 additions & 0 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ PassContext PassContext::Create() {
return PassContext(make_object<PassContextNode>());
}

void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const {
this->operator->()->trace_func(module, info, is_before);
}

class ModulePass;

/*!
Expand Down Expand Up @@ -231,8 +235,10 @@ IRModule ModulePassNode::operator()(const IRModule& mod,
<< " with opt level: "
<< pass_info->opt_level;
CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true);
IRModule updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined());
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod;
}

Expand Down Expand Up @@ -414,10 +420,12 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext")
int fallback_device = args[1];
tvm::Array<tvm::PrimExpr> required = args[2];
tvm::Array<tvm::PrimExpr> disabled = args[3];
TraceFunc trace_func = args[4];
pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device;
pctx->required_pass = std::move(required);
pctx->disabled_pass = std::move(disabled);
pctx->trace_func = std::move(trace_func);
*ret = pctx;
});

Expand Down
3 changes: 2 additions & 1 deletion src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
<< pass_info->name
<< " with opt level: "
<< pass_info->opt_level;

pass_ctx.Trace(mod, pass_info, true);
// Execute the pass function and return a new module.
IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates;
Expand All @@ -134,6 +134,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
for (const auto& pair : updates) {
updated_mod->Add(pair.first, pair.second, true);
}
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod;
}

Expand Down
31 changes: 31 additions & 0 deletions tests/python/relay/test_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,37 @@ def test_print_ir(capfd):
assert "Dumping the module IR" in out
assert "multiply" in out

__TRACE_COUNTER__ = 0

def _tracer(module, info, is_before):
global __TRACE_COUNTER__
import pdb; pdb.set_trace()
if bool(is_before):
__TRACE_COUNTER__ += 1

def test_print_debug_callback():
global __TRACE_COUNTER__
shape = (1, 2, 3)
tp = relay.TensorType(shape, "float32")
x = relay.var("x", tp)
y = relay.add(x, x)
y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y)

seq = _transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.DeadCodeElimination()
])

assert __TRACE_COUNTER__ == 0
mod = relay.Module({"main": func})

with relay.build_config(opt_level=3, trace=_tracer):
mod = seq(mod)

assert __TRACE_COUNTER__ == 4


if __name__ == "__main__":
pytest.main()