diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 93129cf57a27..575b57360af5 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -540,6 +540,13 @@ TVM_DLL Pass CanonicalizeCast(); */ TVM_DLL Pass EtaExpand(); +/*! + * \brief Print the IR for a module to help debugging. + * + * \return the pass. + */ +TVM_DLL Pass DebugPrint(); + } // namespace transform /*! diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 2805e0b429fa..b59ee2771824 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -529,6 +529,18 @@ def CanonicalizeCast(): return _transform.CanonicalizeCast() +def DebugPrint(): + """ + Print the IR for a module to help debugging. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that prints the module IR. + """ + return _transform.DebugPrint() + + def gradient(expr, mod=None, mode='higher_order'): """ Transform the input function, diff --git a/src/relay/pass/dump_ir.cc b/src/relay/pass/dump_ir.cc new file mode 100644 index 000000000000..0070994aab17 --- /dev/null +++ b/src/relay/pass/dump_ir.cc @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file src/relay/pass/debug_print.cc + * + * \brief Print the module IR to help debugging. + */ +#include +#include + +namespace tvm { +namespace relay { + +namespace transform { + +Pass DebugPrint() { + runtime::TypedPackedFunc pass_func = + [=](Module m, PassContext pc) { + LOG(INFO) << "Dumping the module IR: " << std::endl << AsText(m); + return m; + }; + return CreateModulePass(pass_func, 0, "DebugPrint", {}); +} + +TVM_REGISTER_API("relay._transform.DebugPrint") +.set_body_typed(DebugPrint); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 930dbe045198..91a3f343ece9 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -504,6 +504,62 @@ def expected(): assert analysis.alpha_equal(zz, zexpected) +def test_debug_print(): + shape = (1, 2, 3) + tp = relay.TensorType(shape, "float32") + x = relay.var("x", tp) + y = relay.add(x, x) + y = relay.multiply(y, relay.const(2, "float32")) + func = relay.Function([x], y) + + seq = _transform.Sequential([ + relay.transform.InferType(), + relay.transform.FoldConstant(), + relay.transform.DebugPrint(), + relay.transform.DeadCodeElimination() + ]) + + def redirect_output(call): + """Redirect the C++ logging info.""" + import sys + import os + import threading + stderr_fileno = sys.stderr.fileno() + stderr_save = os.dup(stderr_fileno) + stderr_pipe = os.pipe() + os.dup2(stderr_pipe[1], stderr_fileno) + os.close(stderr_pipe[1]) + output = '' + + def record(): + nonlocal output + while True: + data = os.read(stderr_pipe[0], 1024) + if not data: + break + output += data.decode("utf-8") + + t = threading.Thread(target=record) + t.start() + call() + os.close(stderr_fileno) + t.join() + os.close(stderr_pipe[0]) + os.dup2(stderr_save, stderr_fileno) + os.close(stderr_save) + + return output + + def run_pass(): + mod = relay.Module({"main": func}) + with relay.build_config(opt_level=3): + mod = seq(mod) + + out = redirect_output(run_pass) + assert "Dumping the module IR" in out + assert "multiply" in out + + if __name__ == "__main__": test_function_class_pass() test_module_class_pass() @@ -512,3 +568,4 @@ def expected(): test_sequential_pass() test_sequential_with_scoping() test_pass_info() + test_debug_print()