From 111bf2867fa840471b8771616efc22e7703d0b2d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Jan 2020 03:25:52 -0800 Subject: [PATCH] [PassManager] Implement pass manager tracing API (#4782) * Implement pass tracing API * Set is_before correctly * Add docs for trace function * Fix lint * Remove PDB * Ensure trace_func is set before calling * Fix conditional --- docs/dev/relay_pass_infra.rst | 20 +++++++++++++++++ include/tvm/ir/transform.h | 22 ++++++++++++++++++ python/tvm/relay/transform.py | 13 +++++++---- src/ir/transform.cc | 11 +++++++++ src/relay/ir/transform.cc | 3 ++- tests/python/relay/test_pass_manager.py | 30 +++++++++++++++++++++++++ 6 files changed, 94 insertions(+), 5 deletions(-) diff --git a/docs/dev/relay_pass_infra.rst b/docs/dev/relay_pass_infra.rst index 60d2b72968880..b4f3f6b0b7c9e 100644 --- a/docs/dev/relay_pass_infra.rst +++ b/docs/dev/relay_pass_infra.rst @@ -621,6 +621,26 @@ By inserting the ``PrintIR`` pass after ``FoldConstant``, the pass infra will dump out the module IR when ``FoldConstant`` is done. Users can plug in this pass after any pass they want to debug for viewing the optimization effect. +There is a more flexible debugging mechanism also exposed by the build configuration +object. One can pass a tracing function which can be used to execute arbitrary code +before and/or after each pass. A tracing function will receive a ``IRModule``, ``PassInfo``, +and a boolean indicating whether you are executing before, or after a pass. +An example is below. + +.. code:: python + + def print_ir(mod, info, is_before): + """Print the name of the pass, the IR, only before passes execute.""" + if is_before: + print(f"Running pass: {}", info) + print(mod) + + with relay.build_config(opt_level=3, trace=print_ir): + with tvm.target.create("llvm"): + # Perform the optimizations. + mod = seq(mod) + + For more pass infra related examples in Python and C++, please refer to `tests/python/relay/test_pass_manager.py`_ and `tests/cpp/relay_transform_sequential.cc`_, respectively. diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index c606b348f0991..2afcb17a59f49 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -65,6 +65,17 @@ namespace tvm { namespace transform { +// Forward declare for TraceFunc. +class PassInfo; + +/*! \brief A callback for tracing passes, useful for debugging and logging. + * + */ +using TraceFunc = + runtime::TypedPackedFunc; + /*! * \brief PassContextNode contains the information that a pass can rely on, * such as analysis results. @@ -88,6 +99,8 @@ class PassContextNode : public Object { /*! \brief The list of disabled passes. */ Array disabled_pass; + TraceFunc trace_func; + PassContextNode() = default; void VisitAttrs(AttrVisitor* v) { @@ -101,6 +114,7 @@ class PassContextNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); }; + /*! * \brief PassContext that is used to configure the pass behavior. * @@ -146,6 +160,14 @@ class PassContext : public ObjectRef { */ TVM_DLL static PassContext Current(); + /*! + * \brief Apply the tracing functions of the context to the module, with the info. + * \param module The IRModule to trace. + * \param info The pass information. + * \param is_before Indicated whether the tracing is before or after a pass. + */ + TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const; + // accessor. using ContainerType = PassContextNode; class Internal; diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index c4fbde60a6eb9..26b20e01c6236 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -78,7 +78,8 @@ def __init__(self, opt_level=2, fallback_device=_nd.cpu(), required_pass=None, - disabled_pass=None): + disabled_pass=None, + trace=None): if isinstance(fallback_device, str): fallback_device = _nd.context(fallback_device).device_type elif isinstance(fallback_device, TVMContext): @@ -99,7 +100,7 @@ def __init__(self, self.__init_handle_by_constructor__(_transform.PassContext, opt_level, fallback_device, required, - disabled) + disabled, trace) def __enter__(self): _transform.EnterPassContext(self) @@ -117,7 +118,8 @@ def current(): def build_config(opt_level=2, fallback_device=_nd.cpu(), required_pass=None, - disabled_pass=None): + disabled_pass=None, + trace=None): """Configure the build behavior by setting config variables. Parameters @@ -151,13 +153,16 @@ def build_config(opt_level=2, disabled_pass: set of str, optional Optimization passes to be disabled during optimization. + trace: Callable[[IRModule, PassInfo, bool], None] + A tracing function for debugging or introspection. + Returns ------- pass_context: PassContext The pass context for optimizations. """ return PassContext(opt_level, fallback_device, required_pass, - disabled_pass) + disabled_pass, trace) @register_relay_node diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 1da010c5979d6..14bd063b0169d 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -84,6 +84,13 @@ PassContext PassContext::Create() { return PassContext(make_object()); } +void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->trace_func != nullptr) { + pass_ctx_node->trace_func(module, info, is_before); + } +} + class ModulePass; /*! @@ -231,8 +238,10 @@ IRModule ModulePassNode::operator()(const IRModule& mod, << " with opt level: " << pass_info->opt_level; CHECK(mod.defined()); + pass_ctx.Trace(mod, pass_info, true); IRModule updated_mod = pass_func(mod, pass_ctx); CHECK(updated_mod.defined()); + pass_ctx.Trace(updated_mod, pass_info, false); return updated_mod; } @@ -414,10 +423,12 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext") int fallback_device = args[1]; tvm::Array required = args[2]; tvm::Array disabled = args[3]; + TraceFunc trace_func = args[4]; pctx->opt_level = opt_level; pctx->fallback_device = fallback_device; pctx->required_pass = std::move(required); pctx->disabled_pass = std::move(disabled); + pctx->trace_func = std::move(trace_func); *ret = pctx; }); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index ac0f36cf2205f..d5cd5c91ff374 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -116,7 +116,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, << pass_info->name << " with opt level: " << pass_info->opt_level; - + pass_ctx.Trace(mod, pass_info, true); // Execute the pass function and return a new module. IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports()); std::vector > updates; @@ -134,6 +134,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, for (const auto& pair : updates) { updated_mod->Add(pair.first, pair.second, true); } + pass_ctx.Trace(updated_mod, pass_info, false); return updated_mod; } diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index e02e917dbb627..bd055eebdbde0 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -522,6 +522,36 @@ def test_print_ir(capfd): assert "Dumping the module IR" in out assert "multiply" in out +__TRACE_COUNTER__ = 0 + +def _tracer(module, info, is_before): + global __TRACE_COUNTER__ + if bool(is_before): + __TRACE_COUNTER__ += 1 + +def test_print_debug_callback(): + global __TRACE_COUNTER__ + shape = (1, 2, 3) + tp = relay.TensorType(shape, "float32") + x = relay.var("x", tp) + y = relay.add(x, x) + y = relay.multiply(y, relay.const(2, "float32")) + func = relay.Function([x], y) + + seq = _transform.Sequential([ + relay.transform.InferType(), + relay.transform.FoldConstant(), + relay.transform.DeadCodeElimination() + ]) + + assert __TRACE_COUNTER__ == 0 + mod = relay.Module({"main": func}) + + with relay.build_config(opt_level=3, trace=_tracer): + mod = seq(mod) + + assert __TRACE_COUNTER__ == 4 + if __name__ == "__main__": pytest.main()