Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PassManager] Implement pass manager tracing API #4782

Merged
merged 7 commits into from
Jan 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/dev/relay_pass_infra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,26 @@ By inserting the ``PrintIR`` pass after ``FoldConstant``, the pass infra will
dump out the module IR when ``FoldConstant`` is done. Users can plug in this
pass after any pass they want to debug for viewing the optimization effect.

There is a more flexible debugging mechanism also exposed by the build configuration
object. One can pass a tracing function which can be used to execute arbitrary code
before and/or after each pass. A tracing function will receive a ``IRModule``, ``PassInfo``,
and a boolean indicating whether you are executing before, or after a pass.
An example is below.

.. code:: python

def print_ir(mod, info, is_before):
"""Print the name of the pass, the IR, only before passes execute."""
if is_before:
print(f"Running pass: {}", info)
print(mod)

with relay.build_config(opt_level=3, trace=print_ir):
with tvm.target.create("llvm"):
# Perform the optimizations.
mod = seq(mod)


For more pass infra related examples in Python and C++, please refer to
`tests/python/relay/test_pass_manager.py`_ and
`tests/cpp/relay_transform_sequential.cc`_, respectively.
Expand Down
22 changes: 22 additions & 0 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@
namespace tvm {
namespace transform {

// Forward declare for TraceFunc.
class PassInfo;

/*! \brief A callback for tracing passes, useful for debugging and logging.
*
*/
using TraceFunc =
runtime::TypedPackedFunc<void(const IRModule& ir_module,
const PassInfo& ctx,
bool is_before)>;

/*!
* \brief PassContextNode contains the information that a pass can rely on,
* such as analysis results.
Expand All @@ -88,6 +99,8 @@ class PassContextNode : public Object {
/*! \brief The list of disabled passes. */
Array<PrimExpr> disabled_pass;

TraceFunc trace_func;

PassContextNode() = default;

void VisitAttrs(AttrVisitor* v) {
Expand All @@ -101,6 +114,7 @@ class PassContextNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
};


/*!
* \brief PassContext that is used to configure the pass behavior.
*
Expand Down Expand Up @@ -146,6 +160,14 @@ class PassContext : public ObjectRef {
*/
TVM_DLL static PassContext Current();

/*!
* \brief Apply the tracing functions of the context to the module, with the info.
* \param module The IRModule to trace.
* \param info The pass information.
* \param is_before Indicated whether the tracing is before or after a pass.
*/
TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const;

// accessor.
using ContainerType = PassContextNode;
class Internal;
Expand Down
13 changes: 9 additions & 4 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def __init__(self,
opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None):
disabled_pass=None,
trace=None):
if isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, TVMContext):
Expand All @@ -99,7 +100,7 @@ def __init__(self,

self.__init_handle_by_constructor__(_transform.PassContext, opt_level,
fallback_device, required,
disabled)
disabled, trace)

def __enter__(self):
_transform.EnterPassContext(self)
Expand All @@ -117,7 +118,8 @@ def current():
def build_config(opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None):
disabled_pass=None,
trace=None):
"""Configure the build behavior by setting config variables.

Parameters
Expand Down Expand Up @@ -151,13 +153,16 @@ def build_config(opt_level=2,
disabled_pass: set of str, optional
Optimization passes to be disabled during optimization.

trace: Callable[[IRModule, PassInfo, bool], None]
A tracing function for debugging or introspection.

Returns
-------
pass_context: PassContext
The pass context for optimizations.
"""
return PassContext(opt_level, fallback_device, required_pass,
disabled_pass)
disabled_pass, trace)


@register_relay_node
Expand Down
11 changes: 11 additions & 0 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ PassContext PassContext::Create() {
return PassContext(make_object<PassContextNode>());
}

void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const {
auto pass_ctx_node = this->operator->();
if (pass_ctx_node->trace_func != nullptr) {
pass_ctx_node->trace_func(module, info, is_before);
}
}

class ModulePass;

/*!
Expand Down Expand Up @@ -231,8 +238,10 @@ IRModule ModulePassNode::operator()(const IRModule& mod,
<< " with opt level: "
<< pass_info->opt_level;
CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true);
IRModule updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined());
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod;
}

Expand Down Expand Up @@ -414,10 +423,12 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext")
int fallback_device = args[1];
tvm::Array<tvm::PrimExpr> required = args[2];
tvm::Array<tvm::PrimExpr> disabled = args[3];
TraceFunc trace_func = args[4];
pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device;
pctx->required_pass = std::move(required);
pctx->disabled_pass = std::move(disabled);
pctx->trace_func = std::move(trace_func);
*ret = pctx;
});

Expand Down
3 changes: 2 additions & 1 deletion src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
<< pass_info->name
<< " with opt level: "
<< pass_info->opt_level;

pass_ctx.Trace(mod, pass_info, true);
// Execute the pass function and return a new module.
IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates;
Expand All @@ -134,6 +134,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
for (const auto& pair : updates) {
updated_mod->Add(pair.first, pair.second, true);
}
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod;
}

Expand Down
30 changes: 30 additions & 0 deletions tests/python/relay/test_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,36 @@ def test_print_ir(capfd):
assert "Dumping the module IR" in out
assert "multiply" in out

__TRACE_COUNTER__ = 0

def _tracer(module, info, is_before):
global __TRACE_COUNTER__
if bool(is_before):
__TRACE_COUNTER__ += 1

def test_print_debug_callback():
global __TRACE_COUNTER__
shape = (1, 2, 3)
tp = relay.TensorType(shape, "float32")
x = relay.var("x", tp)
y = relay.add(x, x)
y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y)

seq = _transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.DeadCodeElimination()
])

assert __TRACE_COUNTER__ == 0
mod = relay.Module({"main": func})

with relay.build_config(opt_level=3, trace=_tracer):
mod = seq(mod)

assert __TRACE_COUNTER__ == 4


if __name__ == "__main__":
pytest.main()