diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 93129cf57a279..da7e3ffceb3f1 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -98,6 +98,8 @@ class PassContextNode : public RelayNode { tvm::Array required_pass; /*! \brief The list of disabled passes. */ tvm::Array disabled_pass; + /*! \brief The list of passes that will be dump for debugging. */ + tvm::Array dump_pass; PassContextNode() = default; @@ -106,6 +108,7 @@ class PassContextNode : public RelayNode { v->Visit("fallback_device", &fallback_device); v->Visit("required_pass", &required_pass); v->Visit("disabled_pass", &disabled_pass); + v->Visit("dump_pass", &dump_pass); } static constexpr const char* _type_key = "relay.PassContext"; diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 2805e0b429fa0..0fb24c566045b 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -71,33 +71,43 @@ class PassContext(RelayNode): disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of passes that are disabled. + + dump_pass : Optional[Union[List[str], Set[str], Tuple[str]]] + The list of passes that will be dupmed to help debugging. """ def __init__(self, opt_level=2, fallback_device=_nd.cpu(), required_pass=None, - disabled_pass=None): + disabled_pass=None, + dump_pass=None): + + def _check_type(var, ty): + if not isinstance(var, ty): + var_name = [k for k, v in locals().items() if v == var][0] + raise TypeError(var_name + " is expected to be the type of " + str(ty)) + if isinstance(fallback_device, str): fallback_device = _nd.context(fallback_device).device_type elif isinstance(fallback_device, TVMContext): fallback_device = fallback_device.device_type - if not isinstance(fallback_device, int): - raise TypeError("required_pass is expected to be the type of " + - "int/str/TVMContext.") + _check_type(fallback_device, (str, int, TVMContext)) - required = list(required_pass) if required_pass else [] - if not isinstance(required, (list, tuple)): - raise TypeError("required_pass is expected to be the type of " + - "list/tuple/set.") + required_pass = required_pass if required_pass else [] + _check_type(required_pass, (list, tuple)) + + disabled_pass = disabled_pass if disabled_pass else [] + _check_type(disabled_pass, (list, tuple)) - disabled = list(disabled_pass) if disabled_pass else [] - if not isinstance(disabled, (list, tuple)): - raise TypeError("disabled_pass is expected to be the type of " + - "list/tuple/set.") + dump_pass = dump_pass if dump_pass else [] + _check_type(dump_pass, (list, tuple)) - self.__init_handle_by_constructor__(_transform.PassContext, opt_level, - fallback_device, required, - disabled) + self.__init_handle_by_constructor__(_transform.PassContext, + opt_level, + fallback_device, + required_pass, + disabled_pass, + dump_pass) def __enter__(self): _transform.EnterPassContext(self) @@ -115,12 +125,13 @@ def current(): def build_config(opt_level=2, fallback_device=_nd.cpu(), required_pass=None, - disabled_pass=None): + disabled_pass=None, + dump_pass=None): """Configure the build behavior by setting config variables. Parameters ---------- - opt_level: int, optional + opt_level: Optional[int] Optimization level. The optimization pass name and level are as the following: @@ -137,23 +148,28 @@ def build_config(opt_level=2, "EliminateCommonSubexpr": 3, } - fallback_device : int, str, or tvm.TVMContext, optional + fallback_device : Optional[Union[Int, String, tvm.TVMContext]] The fallback device. It is also used as the default device for operators without specified device during heterogeneous execution. - required_pass: set of str, optional + required_pass: Optional[List[String]] Optimization passes that are required regardless of optimization level. - disabled_pass: set of str, optional + disabled_pass: Optional[List[String]] Optimization passes to be disabled during optimization. + dump_pass: Optional[List[String]] + Optimization passes that will be dumped to help debugging. Users can + provide the interested pass to make a debugging dump. Or they can + simply provide "All" to dump the module IR after each individual pass. + Returns ------- pass_context: PassContext The pass context for optimizations. """ return PassContext(opt_level, fallback_device, required_pass, - disabled_pass) + disabled_pass, dump_pass) @register_relay_node diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index d63d9121fe27e..80419ed32914b 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -392,6 +392,7 @@ bool SequentialNode::PassEnabled(const PassInfo& info) const { return ctx->opt_level >= info->opt_level; } +// Get a pass from the registry. Pass GetPass(const std::string& pass_name) { using tvm::runtime::Registry; std::string fpass_name = "relay._transform." + pass_name; @@ -401,6 +402,21 @@ Pass GetPass(const std::string& pass_name) { return (*f)(); } +// A helper function to apply a pass. +Module ApplyPass(const Pass& pass, + const PassContext& pass_ctx, + Module mod) { + mod = pass(mod, pass_ctx); + if (PassArrayContains(pass_ctx->dump_pass, pass->Info()->name) || + PassArrayContains(pass_ctx->dump_pass, "All") || + PassArrayContains(pass_ctx->dump_pass, "all")) { + LOG(INFO) << "Dumping the module IR after applying: " + << pass->Info()->name << std::endl + << AsText(mod) << std::endl; + } + return mod; +} + // TODO(zhiics): we currenlty only sequentially execute each pass in // a Sequential without the consideration of their orders. The phase // ordering problem needs to be handled in the future. @@ -410,14 +426,19 @@ Module SequentialNode::operator()(const Module& module, for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; const PassInfo& pass_info = pass->Info(); - if (!PassEnabled(pass_info)) continue; + if (!PassEnabled(pass_info)) { + if (PassArrayContains(pass_ctx->dump_pass, pass_info->name)) { + LOG(INFO) << "Skip dumping IR for disabled pass " << pass_info->name << "\n"; + } + continue; + } // resolve dependencies for (const auto& it : pass_info->required) { const auto* name = it.as(); CHECK(name); - mod = GetPass(name->value)(mod, pass_ctx); + mod = ApplyPass(GetPass(name->value), pass_ctx, mod); } - mod = pass(mod, pass_ctx); + mod = ApplyPass(pass, pass_ctx, mod); } return mod; } @@ -533,10 +554,12 @@ TVM_REGISTER_API("relay._transform.PassContext") int fallback_device = args[1]; tvm::Array required = args[2]; tvm::Array disabled = args[3]; + tvm::Array dump = args[4]; pctx->opt_level = opt_level; pctx->fallback_device = fallback_device; pctx->required_pass = std::move(required); pctx->disabled_pass = std::move(disabled); + pctx->dump_pass = std::move(dump); *ret = pctx; }); @@ -549,17 +572,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << runtime::DeviceName(node->fallback_device) << "\n"; - p->stream << "\trequired passes: [" << node->opt_level; + p->stream << "\trequired passes: ["; for (const auto& it : node->required_pass) { p->stream << it << " "; } p->stream << "]\n"; - p->stream << "\tdisabled passes: [" << node->opt_level; + p->stream << "\tdisabled passes: ["; for (const auto& it : node->disabled_pass) { p->stream << it << " "; } p->stream << "]"; + + p->stream << "\tdumping IR for passes: ["; + for (const auto& it : node->dump_pass) { + p->stream << it << " "; + } + p->stream << "]"; }); class PassContext::Internal { diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 930dbe0451983..cd45ff24c2886 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -504,6 +504,72 @@ def expected(): assert analysis.alpha_equal(zz, zexpected) +def test_dump_pass(): + shape = (1, 2, 3) + tp = relay.TensorType(shape, "float32") + x = relay.var("x", tp) + y = relay.add(x, x) + y = relay.multiply(y, relay.const(2, "float32")) + func = relay.Function([x], y) + + seq = _transform.Sequential([ + relay.transform.InferType(), + relay.transform.FoldConstant(), + relay.transform.DeadCodeElimination() + ]) + + def redirect_output(call): + """Redirect the C++ logging info.""" + import sys + import os + import threading + stderr_fileno = sys.stderr.fileno() + stderr_save = os.dup(stderr_fileno) + stderr_pipe = os.pipe() + os.dup2(stderr_pipe[1], stderr_fileno) + os.close(stderr_pipe[1]) + output = '' + + def record(): + nonlocal output + while True: + data = os.read(stderr_pipe[0], 1024) + if not data: + break + output += data.decode("utf-8") + + t = threading.Thread(target=record) + t.start() + call() + os.close(stderr_fileno) + t.join() + os.close(stderr_pipe[0]) + os.dup2(stderr_save, stderr_fileno) + os.close(stderr_save) + + return output + + def test_dump_one(): + def run_pass(): + mod = relay.Module({"main": func}) + with relay.build_config(opt_level=3, dump_pass=["FoldConstant"]): + mod = seq(mod) + out = redirect_output(run_pass) + assert "FoldConstant" in out + + def test_dump_all(): + def run_pass(): + mod = relay.Module({"main": func}) + with relay.build_config(opt_level=3, dump_pass=["All"]): + mod = seq(mod) + out = redirect_output(run_pass) + assert "InferType" in out + assert "FoldConstant" in out + assert "DeadCodeElimination" in out + + test_dump_one() + test_dump_all() + if __name__ == "__main__": test_function_class_pass() test_module_class_pass()