Skip to content

Commit

Permalink
[RELAY] Pass infra cleanup (#3336)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Jun 11, 2019
1 parent d6c4aba commit c9a2f3d
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 165 deletions.
5 changes: 3 additions & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ class PassInfoNode : public RelayNode {
v->Visit("required", &required);
}

TVM_DLL static PassInfo make(int opt_level, std::string name,
TVM_DLL static PassInfo make(int opt_level,
std::string name,
tvm::Array<tvm::Expr> required);

static constexpr const char* _type_key = "relay.PassInfo";
Expand Down Expand Up @@ -467,7 +468,7 @@ TVM_DLL Pass SimplifyInference();
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \return The pass.
* \return The pass.
*/
TVM_DLL Pass InferType();

Expand Down
314 changes: 155 additions & 159 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
# pylint: disable=invalid-name
"""
This file contains the pass manager for Relay which exposes different
granularity of interfaces for users to implement and use passes more
conveniently.
Relay pass transformation infrastructure.
"""
import types

Expand All @@ -39,19 +35,19 @@ class PassInfo(RelayNode):
Parameters
----------
name : str
The pass name.
opt_level : int
The optimization level of this pass.
name : str
The pass name.
required : List[str]
The list of passes that are required by a certain pass.
"""

def __init__(self, name, opt_level, required=None):
self.__init_handle_by_constructor__(_transform.PassInfo, name, opt_level,
required)
def __init__(self, opt_level, name, required=None):
self.__init_handle_by_constructor__(
_transform.PassInfo, opt_level, name, required)


@register_relay_node
Expand Down Expand Up @@ -194,7 +190,7 @@ class ModulePass(Pass):
`module_pass`, because the design of the `module_pass` API is flexible
enough to handle the creation of a module pass in different manners. In
addition, all members of a module pass can be accessed from the base class.
The same rule applies to FunctionPass and Sequential as well.
The same rule applies to FunctionPass as well.
"""


Expand Down Expand Up @@ -250,153 +246,6 @@ def __init__(self,
passes, opt_level, name, required)


def module_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a module pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created module level pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the module pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_module_pass : Union[Callable, ModulePass]
The callable that will create a module pass is returned when
pass_func is not passed in. Otherwise, a ModulePass object will be
directly created.
Examples
--------
The following code creates a module level pass and adds an abs function to
the module.
.. code-block:: python
@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), "float32")
x = relay.var("x", tp)
gv = relay.GlobalVar("var")
func = relay.Function([x], relay.abs(x))
new_mod = relay.Module({gv: func})
new_mod.update(mod)
return new_mod
module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
"""

if opt_level is None:
raise ValueError("Please provide opt_level for the module pass.")

required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")

def create_module_pass(pass_func):
"""Internal function that creates a module pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")

return _transform.CreateModulePass(
pass_func, opt_level, name if name else pass_func.__name__,
required)

if pass_func:
return create_module_pass(pass_func)
return create_module_pass


def function_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a function pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the function pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_function_pass : Union[Callable, FunctionPass]
The callable that will create a function pass is returned when
pass_func is not passed in. Otherwise, a FunctionPass object will be
created.
Examples
--------
The following code creates a function level pass that performs constant
folding.
.. code-block:: python
@relay.transform.function_pass(opt_level=2)
def transform(func, ctx):
return ir_pass.fold_constant(func)
function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
"""

if opt_level is None:
raise ValueError("Please provide opt_level for the funtion pass.")

required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")

def create_function_pass(pass_func):
"""Internal function that creates a function pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")

return _transform.CreateFunctionPass(
pass_func, opt_level, name if name else pass_func.__name__,
required)

if pass_func:
return create_function_pass(pass_func)
return create_function_pass


def InferType():
"""Infer the type of an expr.
Expand Down Expand Up @@ -593,3 +442,150 @@ def PartialEvaluate():
The registered pass that performs partial evaluation on an expression.
"""
return _transform.PartialEvaluate()


def module_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a module pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created module level pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the module pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_module_pass : Union[Callable, ModulePass]
The callable that will create a module pass is returned when
pass_func is not passed in. Otherwise, a ModulePass object will be
directly created.
Examples
--------
The following code creates a module level pass and adds an abs function to
the module.
.. code-block:: python
@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), "float32")
x = relay.var("x", tp)
gv = relay.GlobalVar("var")
func = relay.Function([x], relay.abs(x))
new_mod = relay.Module({gv: func})
new_mod.update(mod)
return new_mod
module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
"""

if opt_level is None:
raise ValueError("Please provide opt_level for the module pass.")

required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")

def create_module_pass(pass_func):
"""Internal function that creates a module pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")

fname = name if name else pass_func.__name__
info = PassInfo(opt_level, fname, required)
return _transform.MakeModulePass(pass_func, info)

if pass_func:
return create_module_pass(pass_func)
return create_module_pass


def function_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a function pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the function pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_function_pass : Union[Callable, FunctionPass]
The callable that will create a function pass is returned when
pass_func is not passed in. Otherwise, a FunctionPass object will be
created.
Examples
--------
The following code creates a function level pass that performs constant
folding.
.. code-block:: python
@relay.transform.function_pass(opt_level=2)
def transform(func, ctx):
return ir_pass.fold_constant(func)
function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
"""

if opt_level is None:
raise ValueError("Please provide opt_level for the funtion pass.")

required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")

def create_function_pass(pass_func):
"""Internal function that creates a function pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")

fname = name if name else pass_func.__name__
info = PassInfo(opt_level, fname, required)
return _transform.MakeFunctionPass(pass_func, info)

if pass_func:
return create_function_pass(pass_func)
return create_function_pass
8 changes: 4 additions & 4 deletions src/relay/pass/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)

TVM_REGISTER_NODE_TYPE(ModulePassNode);

TVM_REGISTER_API("relay._transform.CreateModulePass")
.set_body_typed(CreateModulePass);
TVM_REGISTER_API("relay._transform.MakeModulePass")
.set_body_typed(ModulePassNode::make);

TVM_REGISTER_API("relay._transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Expand All @@ -485,8 +485,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)

TVM_REGISTER_NODE_TYPE(FunctionPassNode);

TVM_REGISTER_API("relay._transform.CreateFunctionPass")
.set_body_typed(CreateFunctionPass);
TVM_REGISTER_API("relay._transform.MakeFunctionPass")
.set_body_typed(FunctionPassNode::make);

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
Expand Down
Loading

0 comments on commit c9a2f3d

Please sign in to comment.