Skip to content

Commit

Permalink
[PassManager] Implement pass manager tracing API (apache#4782)
Browse files Browse the repository at this point in the history
* Implement pass tracing API

* Set is_before correctly

* Add docs for trace function

* Fix lint

* Remove PDB

* Ensure trace_func is set before calling

* Fix conditional
  • Loading branch information
jroesch authored and alexwong committed Feb 28, 2020
1 parent 2086d92 commit 652b21c
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 5 deletions.
20 changes: 20 additions & 0 deletions docs/dev/relay_pass_infra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,26 @@ By inserting the ``PrintIR`` pass after ``FoldConstant``, the pass infra will
dump out the module IR when ``FoldConstant`` is done. Users can plug in this
pass after any pass they want to debug for viewing the optimization effect.

There is a more flexible debugging mechanism also exposed by the build configuration
object. One can pass a tracing function which can be used to execute arbitrary code
before and/or after each pass. A tracing function will receive a ``IRModule``, ``PassInfo``,
and a boolean indicating whether you are executing before, or after a pass.
An example is below.

.. code:: python
def print_ir(mod, info, is_before):
"""Print the name of the pass, the IR, only before passes execute."""
if is_before:
print(f"Running pass: {}", info)
print(mod)
with relay.build_config(opt_level=3, trace=print_ir):
with tvm.target.create("llvm"):
# Perform the optimizations.
mod = seq(mod)
For more pass infra related examples in Python and C++, please refer to
`tests/python/relay/test_pass_manager.py`_ and
`tests/cpp/relay_transform_sequential.cc`_, respectively.
Expand Down
22 changes: 22 additions & 0 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@
namespace tvm {
namespace transform {

// Forward declare for TraceFunc.
class PassInfo;

/*! \brief A callback for tracing passes, useful for debugging and logging.
*
*/
using TraceFunc =
runtime::TypedPackedFunc<void(const IRModule& ir_module,
const PassInfo& ctx,
bool is_before)>;

/*!
* \brief PassContextNode contains the information that a pass can rely on,
* such as analysis results.
Expand All @@ -88,6 +99,8 @@ class PassContextNode : public Object {
/*! \brief The list of disabled passes. */
Array<PrimExpr> disabled_pass;

TraceFunc trace_func;

PassContextNode() = default;

void VisitAttrs(AttrVisitor* v) {
Expand All @@ -101,6 +114,7 @@ class PassContextNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
};


/*!
* \brief PassContext that is used to configure the pass behavior.
*
Expand Down Expand Up @@ -146,6 +160,14 @@ class PassContext : public ObjectRef {
*/
TVM_DLL static PassContext Current();

/*!
* \brief Apply the tracing functions of the context to the module, with the info.
* \param module The IRModule to trace.
* \param info The pass information.
* \param is_before Indicated whether the tracing is before or after a pass.
*/
TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const;

// accessor.
using ContainerType = PassContextNode;
class Internal;
Expand Down
13 changes: 9 additions & 4 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def __init__(self,
opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None):
disabled_pass=None,
trace=None):
if isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, TVMContext):
Expand All @@ -99,7 +100,7 @@ def __init__(self,

self.__init_handle_by_constructor__(_transform.PassContext, opt_level,
fallback_device, required,
disabled)
disabled, trace)

def __enter__(self):
_transform.EnterPassContext(self)
Expand All @@ -117,7 +118,8 @@ def current():
def build_config(opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None):
disabled_pass=None,
trace=None):
"""Configure the build behavior by setting config variables.
Parameters
Expand Down Expand Up @@ -151,13 +153,16 @@ def build_config(opt_level=2,
disabled_pass: set of str, optional
Optimization passes to be disabled during optimization.
trace: Callable[[IRModule, PassInfo, bool], None]
A tracing function for debugging or introspection.
Returns
-------
pass_context: PassContext
The pass context for optimizations.
"""
return PassContext(opt_level, fallback_device, required_pass,
disabled_pass)
disabled_pass, trace)


@register_relay_node
Expand Down
11 changes: 11 additions & 0 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ PassContext PassContext::Create() {
return PassContext(make_object<PassContextNode>());
}

void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const {
auto pass_ctx_node = this->operator->();
if (pass_ctx_node->trace_func != nullptr) {
pass_ctx_node->trace_func(module, info, is_before);
}
}

class ModulePass;

/*!
Expand Down Expand Up @@ -231,8 +238,10 @@ IRModule ModulePassNode::operator()(const IRModule& mod,
<< " with opt level: "
<< pass_info->opt_level;
CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true);
IRModule updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined());
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod;
}

Expand Down Expand Up @@ -414,10 +423,12 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext")
int fallback_device = args[1];
tvm::Array<tvm::PrimExpr> required = args[2];
tvm::Array<tvm::PrimExpr> disabled = args[3];
TraceFunc trace_func = args[4];
pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device;
pctx->required_pass = std::move(required);
pctx->disabled_pass = std::move(disabled);
pctx->trace_func = std::move(trace_func);
*ret = pctx;
});

Expand Down
3 changes: 2 additions & 1 deletion src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
<< pass_info->name
<< " with opt level: "
<< pass_info->opt_level;

pass_ctx.Trace(mod, pass_info, true);
// Execute the pass function and return a new module.
IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates;
Expand All @@ -134,6 +134,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
for (const auto& pair : updates) {
updated_mod->Add(pair.first, pair.second, true);
}
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod;
}

Expand Down
30 changes: 30 additions & 0 deletions tests/python/relay/test_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,36 @@ def test_print_ir(capfd):
assert "Dumping the module IR" in out
assert "multiply" in out

__TRACE_COUNTER__ = 0

def _tracer(module, info, is_before):
global __TRACE_COUNTER__
if bool(is_before):
__TRACE_COUNTER__ += 1

def test_print_debug_callback():
global __TRACE_COUNTER__
shape = (1, 2, 3)
tp = relay.TensorType(shape, "float32")
x = relay.var("x", tp)
y = relay.add(x, x)
y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y)

seq = _transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.DeadCodeElimination()
])

assert __TRACE_COUNTER__ == 0
mod = relay.Module({"main": func})

with relay.build_config(opt_level=3, trace=_tracer):
mod = seq(mod)

assert __TRACE_COUNTER__ == 4


if __name__ == "__main__":
pytest.main()

0 comments on commit 652b21c

Please sign in to comment.