Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][RELAY] Remove re-exports of tvm.transform #5337

Merged
merged 1 commit into from
Apr 15, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[RELAY] Remove re-exports of tvm.transform
tqchen committed Apr 14, 2020
commit 416c614d9a575fe837f1936f6003c01411c05479
8 changes: 8 additions & 0 deletions docs/api/python/ir.rst
Original file line number Diff line number Diff line change
@@ -21,3 +21,11 @@ tvm.ir
:members:
:imported-members:
:autosummary:


tvm.transform
-------------
.. automodule:: tvm.transform
:members:
:imported-members:
:autosummary:
2 changes: 1 addition & 1 deletion docs/dev/convert_layout.rst
Original file line number Diff line number Diff line change
@@ -227,7 +227,7 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r
# Convert the layout to NCHW
# RemoveUnunsedFunctions is used to clean up the graph.
seq = relay.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout('NCHW')])
with relay.transform.PassContext(opt_level=3):
mod = seq(mod)
4 changes: 2 additions & 2 deletions docs/dev/relay_pass_infra.rst
Original file line number Diff line number Diff line change
@@ -582,7 +582,7 @@ using ``Sequential`` associated with other types of passes.
func = relay.Function([x], z2)
# Customize the optimization pipeline.
seq = _transform.Sequential([
seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
@@ -609,7 +609,7 @@ sequential pass example could be like the following to enable IR dumping for

.. code:: python
seq = _transform.Sequential([
seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.PrintIR(),
4 changes: 3 additions & 1 deletion include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
@@ -361,9 +361,11 @@ TVM_DLL Pass CreateModulePass(

/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \param header The header to be attached to the output.
* \param show_meta_data Whether should we show meta data.
* \return The pass.
*/
TVM_DLL Pass PrintIR(std::string header);
TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false);

} // namespace transform
} // namespace tvm
2 changes: 1 addition & 1 deletion python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
@@ -106,7 +106,7 @@ def _update_global_key(item, _):
"relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequantial": _rename("transform.Sequantial"),
"relay.Sequential": _rename("transform.Sequential"),
# TIR
"Variable": _update_tir_var("tir.Var"),
"SizeVar": _update_tir_var("tir.SizeVar"),
7 changes: 5 additions & 2 deletions python/tvm/ir/transform.py
Original file line number Diff line number Diff line change
@@ -329,16 +329,19 @@ def create_module_pass(pass_arg):
return create_module_pass


def PrintIR(header):
def PrintIR(header="", show_meta_data=False):
"""A special trace pass that prints the header and IR.
Parameters
----------
header : str
The header to be displayed along with the dump.
show_meta_data : bool
A boolean flag to indicate if meta data should be printed.
Returns
--------
The pass
"""
return _ffi_transform_api.PrintIR(header)
return _ffi_transform_api.PrintIR(header, show_meta_data)
11 changes: 0 additions & 11 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
@@ -128,20 +128,9 @@
# Scope builder
ScopeBuilder = scope_builder.ScopeBuilder

module_pass = transform.module_pass
function_pass = transform.function_pass

# Parser
fromtext = parser.fromtext

# Param Serialization
save_param_dict = param_dict.save_param_dict
load_param_dict = param_dict.load_param_dict

# Pass manager
PassInfo = transform.PassInfo
PassContext = transform.PassContext
Pass = transform.Pass
ModulePass = transform.ModulePass
FunctionPass = transform.FunctionPass
Sequential = transform.Sequential
8 changes: 4 additions & 4 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
@@ -210,10 +210,10 @@ def optimize(self):
opt_mod : tvm.IRModule
The optimized module.
"""
seq = transform.Sequential([transform.SimplifyInference(),
transform.FuseOps(0),
transform.ToANormalForm(),
transform.InferType()])
seq = tvm.transform.Sequential([transform.SimplifyInference(),
transform.FuseOps(0),
transform.ToANormalForm(),
transform.InferType()])
return seq(self.mod)

def _make_executor(self, expr=None):
4 changes: 2 additions & 2 deletions python/tvm/relay/qnn/transform.py
Original file line number Diff line number Diff line change
@@ -60,7 +60,7 @@ def @main(%quantized_data: Tensor[(200), int32]) -> Tensor[(200), int8] {
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that canonicalizes QNN ops to Relay ops.
"""

@@ -108,7 +108,7 @@ def Legalize():
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that legalizes QNN ops.
"""

33 changes: 18 additions & 15 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
#pylint: disable=unused-argument, not-context-manager
"""Automatic quantization toolkit."""
import tvm.ir
import tvm
from tvm.runtime import Object

from . import _quantize
@@ -240,7 +241,7 @@ def partition():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass for VTA rewrite.
"""
return _quantize.QuantizePartition()
@@ -253,7 +254,7 @@ def annotate():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass for quantization annotation.
"""
return _quantize.QuantizeAnnotate()
@@ -267,7 +268,7 @@ def realize():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass for quantization realization.
"""
return _quantize.QuantizeRealize()
@@ -298,11 +299,12 @@ def prerequisite_optimize(mod, params=None):
""" Prerequisite optimization passes for quantization. Perform
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization. """
optimize = _transform.Sequential([_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])
optimize = tvm.transform.Sequential(
[_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])

if params:
mod['main'] = _bind_params(mod['main'], params)
@@ -336,19 +338,20 @@ def quantize(mod, params=None, dataset=None):
"""
mod = prerequisite_optimize(mod, params)

calibrate_pass = _transform.module_pass(calibrate(dataset), opt_level=1,
name="QuantizeCalibrate")
calibrate_pass = tvm.transform.module_pass(
calibrate(dataset), opt_level=1,
name="QuantizeCalibrate")
quant_passes = [partition(),
annotate(),
calibrate_pass]
if not current_qconfig().do_simulation:
quant_passes.append(realize())
quant_passes.append(_transform.FoldConstant())
quantize_seq = _transform.Sequential(quant_passes)
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
quantize_seq = tvm.transform.Sequential(quant_passes)
with tvm.transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
with quantize_context():
mod = quantize_seq(mod)

2 changes: 1 addition & 1 deletion python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@
from ..transform import gradient

def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
4 changes: 2 additions & 2 deletions python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
@@ -95,8 +95,8 @@ def optimize(self, prog: Expr):

# necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
# and fusion (to get primitive functions)
opts = relay.transform.Sequential([relay.transform.SimplifyInference(),
relay.transform.FuseOps(fuse_opt_level=0)])
opts = tvm.transform.Sequential([relay.transform.SimplifyInference(),
relay.transform.FuseOps(fuse_opt_level=0)])
mod = opts(mod)
optimized = mod['main']
return optimized if isinstance(unwrapped, Function) else optimized.body
Loading