From 1f3849378884be54c16efcd23c78f168c5c99ecd Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 27 Jan 2020 16:35:45 -0800 Subject: [PATCH 1/7] Implement pass tracing API --- include/tvm/ir/transform.h | 19 ++++++++++++++++ python/tvm/relay/transform.py | 13 +++++++---- src/ir/transform.cc | 8 +++++++ src/relay/ir/transform.cc | 3 ++- tests/python/relay/test_pass_manager.py | 30 +++++++++++++++++++++++++ 5 files changed, 68 insertions(+), 5 deletions(-) diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index c606b348f099..03aba40628a3 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -65,6 +65,14 @@ namespace tvm { namespace transform { +// Forward declare for TraceFunc. +class PassInfo; + +/*! \brief A callback for tracing passes, useful for debugging and logging. + * + */ +using TraceFunc = runtime::TypedPackedFunc; + /*! * \brief PassContextNode contains the information that a pass can rely on, * such as analysis results. @@ -88,6 +96,8 @@ class PassContextNode : public Object { /*! \brief The list of disabled passes. */ Array disabled_pass; + TraceFunc trace_func; + PassContextNode() = default; void VisitAttrs(AttrVisitor* v) { @@ -101,6 +111,7 @@ class PassContextNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); }; + /*! * \brief PassContext that is used to configure the pass behavior. * @@ -146,6 +157,14 @@ class PassContext : public ObjectRef { */ TVM_DLL static PassContext Current(); + /*! + * \brief Apply the tracing functions of the context to the module, with the info. + * \param module The IRModule to trace. + * \param info The pass information. + * \param is_before Indicated whether the tracing is before or after a pass. + */ + TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const; + // accessor. using ContainerType = PassContextNode; class Internal; diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index c4fbde60a6eb..26b20e01c623 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -78,7 +78,8 @@ def __init__(self, opt_level=2, fallback_device=_nd.cpu(), required_pass=None, - disabled_pass=None): + disabled_pass=None, + trace=None): if isinstance(fallback_device, str): fallback_device = _nd.context(fallback_device).device_type elif isinstance(fallback_device, TVMContext): @@ -99,7 +100,7 @@ def __init__(self, self.__init_handle_by_constructor__(_transform.PassContext, opt_level, fallback_device, required, - disabled) + disabled, trace) def __enter__(self): _transform.EnterPassContext(self) @@ -117,7 +118,8 @@ def current(): def build_config(opt_level=2, fallback_device=_nd.cpu(), required_pass=None, - disabled_pass=None): + disabled_pass=None, + trace=None): """Configure the build behavior by setting config variables. Parameters @@ -151,13 +153,16 @@ def build_config(opt_level=2, disabled_pass: set of str, optional Optimization passes to be disabled during optimization. + trace: Callable[[IRModule, PassInfo, bool], None] + A tracing function for debugging or introspection. + Returns ------- pass_context: PassContext The pass context for optimizations. """ return PassContext(opt_level, fallback_device, required_pass, - disabled_pass) + disabled_pass, trace) @register_relay_node diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 1da010c5979d..d14a5b472f6b 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -84,6 +84,10 @@ PassContext PassContext::Create() { return PassContext(make_object()); } +void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { + this->operator->()->trace_func(module, info, is_before); +} + class ModulePass; /*! @@ -231,8 +235,10 @@ IRModule ModulePassNode::operator()(const IRModule& mod, << " with opt level: " << pass_info->opt_level; CHECK(mod.defined()); + pass_ctx.Trace(mod, pass_info, true); IRModule updated_mod = pass_func(mod, pass_ctx); CHECK(updated_mod.defined()); + pass_ctx.Trace(updated_mod, pass_info, true); return updated_mod; } @@ -414,10 +420,12 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext") int fallback_device = args[1]; tvm::Array required = args[2]; tvm::Array disabled = args[3]; + TraceFunc trace_func = args[4]; pctx->opt_level = opt_level; pctx->fallback_device = fallback_device; pctx->required_pass = std::move(required); pctx->disabled_pass = std::move(disabled); + pctx->trace_func = std::move(trace_func); *ret = pctx; }); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index ac0f36cf2205..516103fc48ce 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -116,7 +116,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, << pass_info->name << " with opt level: " << pass_info->opt_level; - + pass_ctx.Trace(mod, pass_info, true); // Execute the pass function and return a new module. IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports()); std::vector > updates; @@ -134,6 +134,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, for (const auto& pair : updates) { updated_mod->Add(pair.first, pair.second, true); } + pass_ctx.Trace(updated_mod, pass_info, true); return updated_mod; } diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index e02e917dbb62..d9e17a3ae62d 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -522,6 +522,36 @@ def test_print_ir(capfd): assert "Dumping the module IR" in out assert "multiply" in out +__TRACE_COUNTER__ = 0 + +def _tracer(module, info, is_before): + global __TRACE_COUNTER__ + if is_before: + __TRACE_COUNTER__ += 1 + +def test_print_debug_callback(): + global __TRACE_COUNTER__ + shape = (1, 2, 3) + tp = relay.TensorType(shape, "float32") + x = relay.var("x", tp) + y = relay.add(x, x) + y = relay.multiply(y, relay.const(2, "float32")) + func = relay.Function([x], y) + + seq = _transform.Sequential([ + relay.transform.InferType(), + relay.transform.FoldConstant(), + relay.transform.DeadCodeElimination() + ]) + + assert __TRACE_COUNTER__ == 0 + mod = relay.Module({"main": func}) + + with relay.build_config(opt_level=3, trace=_tracer): + mod = seq(mod) + + assert __TRACE_COUNTER__ == 4 + if __name__ == "__main__": pytest.main() From 8b5a85f6dced9fd05f3105fc9bd6473a7c6f36bd Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 27 Jan 2020 16:42:53 -0800 Subject: [PATCH 2/7] Set is_before correctly --- src/ir/transform.cc | 2 +- src/relay/ir/transform.cc | 2 +- tests/python/relay/test_pass_manager.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ir/transform.cc b/src/ir/transform.cc index d14a5b472f6b..87a3c807306f 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -238,7 +238,7 @@ IRModule ModulePassNode::operator()(const IRModule& mod, pass_ctx.Trace(mod, pass_info, true); IRModule updated_mod = pass_func(mod, pass_ctx); CHECK(updated_mod.defined()); - pass_ctx.Trace(updated_mod, pass_info, true); + pass_ctx.Trace(updated_mod, pass_info, false); return updated_mod; } diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 516103fc48ce..d5cd5c91ff37 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -134,7 +134,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, for (const auto& pair : updates) { updated_mod->Add(pair.first, pair.second, true); } - pass_ctx.Trace(updated_mod, pass_info, true); + pass_ctx.Trace(updated_mod, pass_info, false); return updated_mod; } diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index d9e17a3ae62d..b06fe3308182 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -526,7 +526,8 @@ def test_print_ir(capfd): def _tracer(module, info, is_before): global __TRACE_COUNTER__ - if is_before: + import pdb; pdb.set_trace() + if bool(is_before): __TRACE_COUNTER__ += 1 def test_print_debug_callback(): From 7669b002df4a1992fe2e31449a2f537011a33a7b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 27 Jan 2020 16:52:24 -0800 Subject: [PATCH 3/7] Add docs for trace function --- docs/dev/relay_pass_infra.rst | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docs/dev/relay_pass_infra.rst b/docs/dev/relay_pass_infra.rst index 60d2b7296888..b4f3f6b0b7c9 100644 --- a/docs/dev/relay_pass_infra.rst +++ b/docs/dev/relay_pass_infra.rst @@ -621,6 +621,26 @@ By inserting the ``PrintIR`` pass after ``FoldConstant``, the pass infra will dump out the module IR when ``FoldConstant`` is done. Users can plug in this pass after any pass they want to debug for viewing the optimization effect. +There is a more flexible debugging mechanism also exposed by the build configuration +object. One can pass a tracing function which can be used to execute arbitrary code +before and/or after each pass. A tracing function will receive a ``IRModule``, ``PassInfo``, +and a boolean indicating whether you are executing before, or after a pass. +An example is below. + +.. code:: python + + def print_ir(mod, info, is_before): + """Print the name of the pass, the IR, only before passes execute.""" + if is_before: + print(f"Running pass: {}", info) + print(mod) + + with relay.build_config(opt_level=3, trace=print_ir): + with tvm.target.create("llvm"): + # Perform the optimizations. + mod = seq(mod) + + For more pass infra related examples in Python and C++, please refer to `tests/python/relay/test_pass_manager.py`_ and `tests/cpp/relay_transform_sequential.cc`_, respectively. From baa7c8ddba2ed3ed7a726b2a7245278e0ea867e9 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 27 Jan 2020 16:53:39 -0800 Subject: [PATCH 4/7] Fix lint --- include/tvm/ir/transform.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 03aba40628a3..2afcb17a59f4 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -71,7 +71,10 @@ class PassInfo; /*! \brief A callback for tracing passes, useful for debugging and logging. * */ -using TraceFunc = runtime::TypedPackedFunc; +using TraceFunc = + runtime::TypedPackedFunc; /*! * \brief PassContextNode contains the information that a pass can rely on, From 418756ddafd4c27f37e38593d369dbc9bb798819 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 27 Jan 2020 16:56:01 -0800 Subject: [PATCH 5/7] Remove PDB --- tests/python/relay/test_pass_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index b06fe3308182..bd055eebdbde 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -526,7 +526,6 @@ def test_print_ir(capfd): def _tracer(module, info, is_before): global __TRACE_COUNTER__ - import pdb; pdb.set_trace() if bool(is_before): __TRACE_COUNTER__ += 1 From e769f0536a06c46050e72fa107562d1d04b921c3 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Jan 2020 00:10:40 -0800 Subject: [PATCH 6/7] Ensure trace_func is set before calling --- src/ir/transform.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 87a3c807306f..13448f21a2ca 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -85,7 +85,10 @@ PassContext PassContext::Create() { } void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { - this->operator->()->trace_func(module, info, is_before); + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->trace_func == nullptr) { + pass_ctx_node->trace_func(module, info, is_before); + } } class ModulePass; From e702dd3460fedbb231031b605d1d0e384ce0c6af Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Jan 2020 00:14:33 -0800 Subject: [PATCH 7/7] Fix conditional --- src/ir/transform.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 13448f21a2ca..14bd063b0169 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -86,7 +86,7 @@ PassContext PassContext::Create() { void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { auto pass_ctx_node = this->operator->(); - if (pass_ctx_node->trace_func == nullptr) { + if (pass_ctx_node->trace_func != nullptr) { pass_ctx_node->trace_func(module, info, is_before); } }