Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relay][Transform] Support Dumping IR to help debugging
Browse files Browse the repository at this point in the history
zhiics committed Jul 7, 2019
1 parent 287078c commit bfeff9f
Showing 4 changed files with 126 additions and 0 deletions.
7 changes: 7 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
@@ -540,6 +540,13 @@ TVM_DLL Pass CanonicalizeCast();
*/
TVM_DLL Pass EtaExpand();

/*!
* \brief Print the IR for a module to help debugging.
*
* \return the pass.
*/
TVM_DLL Pass DebugPrint();

} // namespace transform

/*!
12 changes: 12 additions & 0 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
@@ -529,6 +529,18 @@ def CanonicalizeCast():
return _transform.CanonicalizeCast()


def DebugPrint():
"""
Print the IR for a module to help debugging.
Returns
-------
ret : tvm.relay.Pass
The registered pass that prints the module IR.
"""
return _transform.DebugPrint()


def gradient(expr, mod=None, mode='higher_order'):
"""
Transform the input function,
50 changes: 50 additions & 0 deletions src/relay/pass/dump_ir.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
*
* \file src/relay/pass/debug_print.cc
*
* \brief Print the module IR to help debugging.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>

namespace tvm {
namespace relay {

namespace transform {

Pass DebugPrint() {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
LOG(INFO) << "Dumping the module IR: " << std::endl << AsText(m);
return m;
};
return CreateModulePass(pass_func, 0, "DebugPrint", {});
}

TVM_REGISTER_API("relay._transform.DebugPrint")
.set_body_typed(DebugPrint);

} // namespace transform

} // namespace relay
} // namespace tvm
57 changes: 57 additions & 0 deletions tests/python/relay/test_pass_manager.py
Original file line number Diff line number Diff line change
@@ -504,6 +504,62 @@ def expected():
assert analysis.alpha_equal(zz, zexpected)


def test_debug_print():
shape = (1, 2, 3)
tp = relay.TensorType(shape, "float32")
x = relay.var("x", tp)
y = relay.add(x, x)
y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y)

seq = _transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.DebugPrint(),
relay.transform.DeadCodeElimination()
])

def redirect_output(call):
"""Redirect the C++ logging info."""
import sys
import os
import threading
stderr_fileno = sys.stderr.fileno()
stderr_save = os.dup(stderr_fileno)
stderr_pipe = os.pipe()
os.dup2(stderr_pipe[1], stderr_fileno)
os.close(stderr_pipe[1])
output = ''

def record():
nonlocal output
while True:
data = os.read(stderr_pipe[0], 1024)
if not data:
break
output += data.decode("utf-8")

t = threading.Thread(target=record)
t.start()
call()
os.close(stderr_fileno)
t.join()
os.close(stderr_pipe[0])
os.dup2(stderr_save, stderr_fileno)
os.close(stderr_save)

return output

def run_pass():
mod = relay.Module({"main": func})
with relay.build_config(opt_level=3):
mod = seq(mod)

out = redirect_output(run_pass)
assert "Dumping the module IR" in out
assert "multiply" in out


if __name__ == "__main__":
test_function_class_pass()
test_module_class_pass()
@@ -512,3 +568,4 @@ def expected():
test_sequential_pass()
test_sequential_with_scoping()
test_pass_info()
test_debug_print()

0 comments on commit bfeff9f

Please sign in to comment.