Skip to content

Commit

Permalink
fix some pass docs (apache#3767)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and wweic committed Aug 16, 2019
1 parent 5daff57 commit a651e86
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 63 deletions.
21 changes: 0 additions & 21 deletions docs/api/python/relay/ir_pass.rst

This file was deleted.

44 changes: 44 additions & 0 deletions docs/api/python/relay/transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,50 @@ tvm.relay.transform

.. autofunction:: tvm.relay.transform.function_pass

.. autofunction:: tvm.relay.transform.InferType

.. autofunction:: tvm.relay.transform.FoldScaleAxis

.. autofunction:: tvm.relay.transform.BackwardFoldScaleAxis

.. autofunction:: tvm.relay.transform.ForwardFoldScaleAxis

.. autofunction:: tvm.relay.transform.SimplifyInference

.. autofunction:: tvm.relay.transform.CanonicalizeOps

.. autofunction:: tvm.relay.transform.DeadCodeElimination

.. autofunction:: tvm.relay.transform.FoldConstant

.. autofunction:: tvm.relay.transform.FuseOps

.. autofunction:: tvm.relay.transform.CombineParallelConv2D

.. autofunction:: tvm.relay.transform.AlterOpLayout

.. autofunction:: tvm.relay.transform.Legalize

.. autofunction:: tvm.relay.transform.RewriteAnnotatedOps

.. autofunction:: tvm.relay.transform.ToANormalForm

.. autofunction:: tvm.relay.transform.ToCPS

.. autofunction:: tvm.relay.transform.EtaExpand

.. autofunction:: tvm.relay.transform.ToGraphNormalForm

.. autofunction:: tvm.relay.transform.EliminateCommonSubexpr

.. autofunction:: tvm.relay.transform.PartialEvaluate

.. autofunction:: tvm.relay.transform.CanonicalizeCast

.. autofunction:: tvm.relay.transform.LambdaLift

.. autofunction:: tvm.relay.transform.PrintIR

.. autoclass:: tvm.relay.transform.Pass
:members:

Expand Down
6 changes: 4 additions & 2 deletions docs/dev/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ In this part of documentation, we share the rationale for the specific choices m

runtime
debugger
nnvm_json_spec
nnvm_overview
hybrid_script
relay_intro
relay_add_op
relay_pass_infra
relay_add_pass
virtual_machine
codebase_walkthrough
inferbound
nnvm_json_spec
nnvm_overview
51 changes: 14 additions & 37 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

@register_relay_node
class PassInfo(RelayNode):
"""The class that contains the meta data required by a pass. It is the
"""The class contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis.
This class can be extended by adding new members when more meta data is
needed.
Expand Down Expand Up @@ -132,11 +132,12 @@ def build_config(opt_level=2,
"SimplifyInference": 0,
"OpFusion": 1,
"FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
"CanonicalizeOps": 3,
"CanonicalizeCast": 3,
"EliminateCommonSubexpr": 3,
"CombineParallelConv2D": 4,
}
fallback_device : int, str, or tvm.TVMContext, optional
Expand Down Expand Up @@ -250,30 +251,6 @@ def __init__(self,
passes, opt_level, name, required)


def infer_type(expr, mod=None):
"""Infer the type of an expr.
Adding Function into a Module will change it's binding,
and some passes need type inference to work without binding modification.
However, InferType() work by putting stuff into a Module, thus changing all the binding.
This is an escape patch that allow type inference without binding changing.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The input module
Returns
-------
ret : tvm.relay.Expr
The output expression.
"""
return _transform.infer_type(expr, mod)


def InferType():
"""Infer the type of an expr.
Expand All @@ -297,7 +274,7 @@ def FoldScaleAxis():
Note
----
Internally, we will call backward_fold_scale_axis before using
forward_fold_scale_axis. As backward folding targets common conv-bn
forward_fold_scale_axis as backward folding targets the common conv->bn
pattern.
"""
return _transform.FoldScaleAxis()
Expand All @@ -314,8 +291,8 @@ def BackwardFoldScaleAxis():
Note
----
It is recommended to call backward_fold_scale_axis
before using forward_fold_scale_axis.
As backward folding targets common conv-bn pattern.
before using forward_fold_scale_axis as backward folding targets the common
conv->bn pattern.
"""
return _transform.BackwardFoldScaleAxis()

Expand All @@ -331,8 +308,8 @@ def ForwardFoldScaleAxis():
Note
----
It is recommended to call backward_fold_scale_axis
before using forward_fold_scale_axis.
As backward folding targets common conv-bn pattern.
before using forward_fold_scale_axis, as backward folding targets the
common conv->bn pattern.
"""
return _transform.ForwardFoldScaleAxis()

Expand All @@ -350,9 +327,9 @@ def SimplifyInference():


def CanonicalizeOps():
""" Canonicalize special operators to basic operators.
This can simplify followed analysis. (e.g. expanding bias_add to
expand_dims and broadcast_add.)
"""Canonicalize special operators to basic operators.
This can simplify followed analysis, e.g. expanding bias_add to
expand_dims and broadcast_add.
Returns
-------
Expand All @@ -363,7 +340,7 @@ def CanonicalizeOps():


def DeadCodeElimination(inline_once=False):
"""Remove expressions which does not effect the program result (dead code).
"""Remove expressions that do not have any users (dead code).
Parameters
----------
Expand All @@ -379,7 +356,7 @@ def DeadCodeElimination(inline_once=False):


def FoldConstant():
"""Fold the constant expression in expr.
"""Fold the constant expressions in a Relay program.
Returns
-------
Expand Down Expand Up @@ -513,7 +490,7 @@ def EtaExpand():


def ToGraphNormalForm():
"""Turn A Normal Form expression into Graph Normal Form expression
"""Turn a Relay program in A Normal Form into Graph Normal Form
Returns
-------
Expand Down
3 changes: 0 additions & 3 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -826,9 +826,6 @@ Function InferType(const Function& func,
return Downcast<Function>(func_ret);
}

TVM_REGISTER_API("relay._transform.infer_type")
.set_body_typed<Expr(Expr, Module)>([](Expr l, Module r) { return InferType(l, r); });

namespace transform {

Pass InferType() {
Expand Down

0 comments on commit a651e86

Please sign in to comment.