Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY] Pass infra cleanup #3336

Merged
merged 1 commit into from
Jun 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ class PassInfoNode : public RelayNode {
v->Visit("required", &required);
}

TVM_DLL static PassInfo make(int opt_level, std::string name,
TVM_DLL static PassInfo make(int opt_level,
std::string name,
tvm::Array<tvm::Expr> required);

static constexpr const char* _type_key = "relay.PassInfo";
Expand Down Expand Up @@ -467,7 +468,7 @@ TVM_DLL Pass SimplifyInference();
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \return The pass.
* \return The pass.
*/
TVM_DLL Pass InferType();

Expand Down
314 changes: 155 additions & 159 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
# pylint: disable=invalid-name
"""
This file contains the pass manager for Relay which exposes different
granularity of interfaces for users to implement and use passes more
conveniently.
Relay pass transformation infrastructure.
"""
import types

Expand All @@ -39,19 +35,19 @@ class PassInfo(RelayNode):

Parameters
----------
name : str
The pass name.

opt_level : int
The optimization level of this pass.

name : str
The pass name.

required : List[str]
The list of passes that are required by a certain pass.
"""

def __init__(self, name, opt_level, required=None):
self.__init_handle_by_constructor__(_transform.PassInfo, name, opt_level,
required)
def __init__(self, opt_level, name, required=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we put opt_level first? I don't even think most passes should be on a standard opt-level, we end up in the current system we have where people arbitrarily assign pass numbers to passes and when passes are non-robust we just bump them to a higher pass number.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should focus on names and enabling passes by name in most cases imo.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the change just to be consistent with the rest of the positional argument, we can debate whether it is a good idea to put name or opt_level first. I am fine either way

self.__init_handle_by_constructor__(
_transform.PassInfo, opt_level, name, required)


@register_relay_node
Expand Down Expand Up @@ -194,7 +190,7 @@ class ModulePass(Pass):
`module_pass`, because the design of the `module_pass` API is flexible
enough to handle the creation of a module pass in different manners. In
addition, all members of a module pass can be accessed from the base class.
The same rule applies to FunctionPass and Sequential as well.
The same rule applies to FunctionPass as well.
"""


Expand Down Expand Up @@ -250,153 +246,6 @@ def __init__(self,
passes, opt_level, name, required)


def module_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a module pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created module level pass using the
given optimization function.

Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.

opt_level : int
The optimization level of this module pass.

name : Optional[str]
The name of the module pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.

required : Optional[List[str]]
The list of passes that the module pass is dependent on.

Returns
-------
create_module_pass : Union[Callable, ModulePass]
The callable that will create a module pass is returned when
pass_func is not passed in. Otherwise, a ModulePass object will be
directly created.

Examples
--------
The following code creates a module level pass and adds an abs function to
the module.

.. code-block:: python

@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), "float32")
x = relay.var("x", tp)
gv = relay.GlobalVar("var")
func = relay.Function([x], relay.abs(x))
new_mod = relay.Module({gv: func})
new_mod.update(mod)
return new_mod

module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
"""

if opt_level is None:
raise ValueError("Please provide opt_level for the module pass.")

required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")

def create_module_pass(pass_func):
"""Internal function that creates a module pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")

return _transform.CreateModulePass(
pass_func, opt_level, name if name else pass_func.__name__,
required)

if pass_func:
return create_module_pass(pass_func)
return create_module_pass


def function_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a function pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.

Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.

opt_level : int
The optimization level of this module pass.

name : Optional[str]
The name of the function pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.

required : Optional[List[str]]
The list of passes that the module pass is dependent on.

Returns
-------
create_function_pass : Union[Callable, FunctionPass]
The callable that will create a function pass is returned when
pass_func is not passed in. Otherwise, a FunctionPass object will be
created.

Examples
--------
The following code creates a function level pass that performs constant
folding.

.. code-block:: python

@relay.transform.function_pass(opt_level=2)
def transform(func, ctx):
return ir_pass.fold_constant(func)

function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
"""

if opt_level is None:
raise ValueError("Please provide opt_level for the funtion pass.")

required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")

def create_function_pass(pass_func):
"""Internal function that creates a function pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")

return _transform.CreateFunctionPass(
pass_func, opt_level, name if name else pass_func.__name__,
required)

if pass_func:
return create_function_pass(pass_func)
return create_function_pass


def InferType():
"""Infer the type of an expr.

Expand Down Expand Up @@ -593,3 +442,150 @@ def PartialEvaluate():
The registered pass that performs partial evaluation on an expression.
"""
return _transform.PartialEvaluate()


def module_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a module pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created module level pass using the
given optimization function.

Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.

opt_level : int
The optimization level of this module pass.

name : Optional[str]
The name of the module pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.

required : Optional[List[str]]
The list of passes that the module pass is dependent on.

Returns
-------
create_module_pass : Union[Callable, ModulePass]
The callable that will create a module pass is returned when
pass_func is not passed in. Otherwise, a ModulePass object will be
directly created.

Examples
--------
The following code creates a module level pass and adds an abs function to
the module.

.. code-block:: python

@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), "float32")
x = relay.var("x", tp)
gv = relay.GlobalVar("var")
func = relay.Function([x], relay.abs(x))
new_mod = relay.Module({gv: func})
new_mod.update(mod)
return new_mod

module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
"""

if opt_level is None:
raise ValueError("Please provide opt_level for the module pass.")

required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")

def create_module_pass(pass_func):
"""Internal function that creates a module pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")

fname = name if name else pass_func.__name__
info = PassInfo(opt_level, fname, required)
return _transform.MakeModulePass(pass_func, info)

if pass_func:
return create_module_pass(pass_func)
return create_module_pass


def function_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a function pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.

Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.

opt_level : int
The optimization level of this module pass.

name : Optional[str]
The name of the function pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.

required : Optional[List[str]]
The list of passes that the module pass is dependent on.

Returns
-------
create_function_pass : Union[Callable, FunctionPass]
The callable that will create a function pass is returned when
pass_func is not passed in. Otherwise, a FunctionPass object will be
created.

Examples
--------
The following code creates a function level pass that performs constant
folding.

.. code-block:: python

@relay.transform.function_pass(opt_level=2)
def transform(func, ctx):
return ir_pass.fold_constant(func)

function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
"""

if opt_level is None:
raise ValueError("Please provide opt_level for the funtion pass.")

required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")

def create_function_pass(pass_func):
"""Internal function that creates a function pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")

fname = name if name else pass_func.__name__
info = PassInfo(opt_level, fname, required)
return _transform.MakeFunctionPass(pass_func, info)

if pass_func:
return create_function_pass(pass_func)
return create_function_pass
8 changes: 4 additions & 4 deletions src/relay/pass/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)

TVM_REGISTER_NODE_TYPE(ModulePassNode);

TVM_REGISTER_API("relay._transform.CreateModulePass")
.set_body_typed(CreateModulePass);
TVM_REGISTER_API("relay._transform.MakeModulePass")
.set_body_typed(ModulePassNode::make);

TVM_REGISTER_API("relay._transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Expand All @@ -481,8 +481,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)

TVM_REGISTER_NODE_TYPE(FunctionPassNode);

TVM_REGISTER_API("relay._transform.CreateFunctionPass")
.set_body_typed(CreateFunctionPass);
TVM_REGISTER_API("relay._transform.MakeFunctionPass")
.set_body_typed(FunctionPassNode::make);

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
Expand Down
Loading