From c9a2f3da5b4cda38829b933597c48d3bdff28083 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 11 Jun 2019 10:55:24 -0700 Subject: [PATCH] [RELAY] Pass infra cleanup (#3336) --- include/tvm/relay/transform.h | 5 +- python/tvm/relay/transform.py | 314 ++++++++++++------------ src/relay/pass/pass_manager.cc | 8 +- tests/python/relay/test_pass_manager.py | 7 + 4 files changed, 169 insertions(+), 165 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 793bc981ea61..f579f1c7ba91 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -202,7 +202,8 @@ class PassInfoNode : public RelayNode { v->Visit("required", &required); } - TVM_DLL static PassInfo make(int opt_level, std::string name, + TVM_DLL static PassInfo make(int opt_level, + std::string name, tvm::Array required); static constexpr const char* _type_key = "relay.PassInfo"; @@ -467,7 +468,7 @@ TVM_DLL Pass SimplifyInference(); * type information filled in, as well as it's checked type field * populated with the result type. * - * \return The pass. + * \return The pass. */ TVM_DLL Pass InferType(); diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 38079b010e7d..b76c2361605c 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -14,13 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return -# pylint: disable=unidiomatic-typecheck # pylint: disable=invalid-name """ -This file contains the pass manager for Relay which exposes different -granularity of interfaces for users to implement and use passes more -conveniently. +Relay pass transformation infrastructure. """ import types @@ -39,19 +35,19 @@ class PassInfo(RelayNode): Parameters ---------- - name : str - The pass name. - opt_level : int The optimization level of this pass. + name : str + The pass name. + required : List[str] The list of passes that are required by a certain pass. """ - def __init__(self, name, opt_level, required=None): - self.__init_handle_by_constructor__(_transform.PassInfo, name, opt_level, - required) + def __init__(self, opt_level, name, required=None): + self.__init_handle_by_constructor__( + _transform.PassInfo, opt_level, name, required) @register_relay_node @@ -194,7 +190,7 @@ class ModulePass(Pass): `module_pass`, because the design of the `module_pass` API is flexible enough to handle the creation of a module pass in different manners. In addition, all members of a module pass can be accessed from the base class. - The same rule applies to FunctionPass and Sequential as well. + The same rule applies to FunctionPass as well. """ @@ -250,153 +246,6 @@ def __init__(self, passes, opt_level, name, required) -def module_pass(pass_func=None, opt_level=None, name=None, required=None): - """Create a module pass. This function returns a callback when pass_func - is provided. Otherwise, it returns the created module level pass using the - given optimization function. - - Parameters - ---------- - pass_func : Optional[Callable[(Module/Function, PassContext) -> - Module/Function]] - The implemented optimization pass. - - opt_level : int - The optimization level of this module pass. - - name : Optional[str] - The name of the module pass. The name could be empty. In this case, the - name of the optimization function will be used as the pass name. - - required : Optional[List[str]] - The list of passes that the module pass is dependent on. - - Returns - ------- - create_module_pass : Union[Callable, ModulePass] - The callable that will create a module pass is returned when - pass_func is not passed in. Otherwise, a ModulePass object will be - directly created. - - Examples - -------- - The following code creates a module level pass and adds an abs function to - the module. - - .. code-block:: python - - @relay.transform.module_pass(opt_level=2) - def transform(mod, ctx): - tp = relay.TensorType((10,), "float32") - x = relay.var("x", tp) - gv = relay.GlobalVar("var") - func = relay.Function([x], relay.abs(x)) - new_mod = relay.Module({gv: func}) - new_mod.update(mod) - return new_mod - - module_pass = transform - assert isinstance(module_pass, transform.ModulePass) - assert module_pass.info.opt_level == 2 - - # Given a module m, the optimization could be invoked as the follwoing: - updated_mod = module_pass(m) - # Now a function abs should be added to the module m. - """ - - if opt_level is None: - raise ValueError("Please provide opt_level for the module pass.") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of " + - "list/tuple.") - - def create_module_pass(pass_func): - """Internal function that creates a module pass""" - if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): - raise TypeError("pass_func must be a callable for Module pass") - - return _transform.CreateModulePass( - pass_func, opt_level, name if name else pass_func.__name__, - required) - - if pass_func: - return create_module_pass(pass_func) - return create_module_pass - - -def function_pass(pass_func=None, opt_level=None, name=None, required=None): - """Create a function pass. This function returns a callback when pass_func - is provided. Otherwise, it returns the created function pass using the - given optimization function. - - Parameters - ---------- - pass_func : Optional[Callable[(Module/Function, PassContext) -> - Module/Function]] - The implemented optimization pass. - - opt_level : int - The optimization level of this module pass. - - name : Optional[str] - The name of the function pass. The name could be empty. In this case, the - name of the optimization function will be used as the pass name. - - required : Optional[List[str]] - The list of passes that the module pass is dependent on. - - Returns - ------- - create_function_pass : Union[Callable, FunctionPass] - The callable that will create a function pass is returned when - pass_func is not passed in. Otherwise, a FunctionPass object will be - created. - - Examples - -------- - The following code creates a function level pass that performs constant - folding. - - .. code-block:: python - - @relay.transform.function_pass(opt_level=2) - def transform(func, ctx): - return ir_pass.fold_constant(func) - - function_pass = transform - assert isinstance(function_pass, transform.FunctionPass) - assert function_pass.info.opt_level == 2 - - # Given a module m, the optimization could be invoked as the follwoing: - updated_mod = function_pass(m) - # Now constant folding should have been applied to every function in - # the provided module m. And the updated module will be returned. - """ - - if opt_level is None: - raise ValueError("Please provide opt_level for the funtion pass.") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of " + - "list/tuple.") - - def create_function_pass(pass_func): - """Internal function that creates a function pass""" - if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): - raise TypeError("pass_func must be a callable for Module pass") - - return _transform.CreateFunctionPass( - pass_func, opt_level, name if name else pass_func.__name__, - required) - - if pass_func: - return create_function_pass(pass_func) - return create_function_pass - - def InferType(): """Infer the type of an expr. @@ -593,3 +442,150 @@ def PartialEvaluate(): The registered pass that performs partial evaluation on an expression. """ return _transform.PartialEvaluate() + + +def module_pass(pass_func=None, opt_level=None, name=None, required=None): + """Create a module pass. This function returns a callback when pass_func + is provided. Otherwise, it returns the created module level pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(Module/Function, PassContext) -> + Module/Function]] + The implemented optimization pass. + + opt_level : int + The optimization level of this module pass. + + name : Optional[str] + The name of the module pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the module pass is dependent on. + + Returns + ------- + create_module_pass : Union[Callable, ModulePass] + The callable that will create a module pass is returned when + pass_func is not passed in. Otherwise, a ModulePass object will be + directly created. + + Examples + -------- + The following code creates a module level pass and adds an abs function to + the module. + + .. code-block:: python + + @relay.transform.module_pass(opt_level=2) + def transform(mod, ctx): + tp = relay.TensorType((10,), "float32") + x = relay.var("x", tp) + gv = relay.GlobalVar("var") + func = relay.Function([x], relay.abs(x)) + new_mod = relay.Module({gv: func}) + new_mod.update(mod) + return new_mod + + module_pass = transform + assert isinstance(module_pass, transform.ModulePass) + assert module_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = module_pass(m) + # Now a function abs should be added to the module m. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the module pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + + "list/tuple.") + + def create_module_pass(pass_func): + """Internal function that creates a module pass""" + if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + + fname = name if name else pass_func.__name__ + info = PassInfo(opt_level, fname, required) + return _transform.MakeModulePass(pass_func, info) + + if pass_func: + return create_module_pass(pass_func) + return create_module_pass + + +def function_pass(pass_func=None, opt_level=None, name=None, required=None): + """Create a function pass. This function returns a callback when pass_func + is provided. Otherwise, it returns the created function pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(Module/Function, PassContext) -> + Module/Function]] + The implemented optimization pass. + + opt_level : int + The optimization level of this module pass. + + name : Optional[str] + The name of the function pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the module pass is dependent on. + + Returns + ------- + create_function_pass : Union[Callable, FunctionPass] + The callable that will create a function pass is returned when + pass_func is not passed in. Otherwise, a FunctionPass object will be + created. + + Examples + -------- + The following code creates a function level pass that performs constant + folding. + + .. code-block:: python + + @relay.transform.function_pass(opt_level=2) + def transform(func, ctx): + return ir_pass.fold_constant(func) + + function_pass = transform + assert isinstance(function_pass, transform.FunctionPass) + assert function_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = function_pass(m) + # Now constant folding should have been applied to every function in + # the provided module m. And the updated module will be returned. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the funtion pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + + "list/tuple.") + + def create_function_pass(pass_func): + """Internal function that creates a function pass""" + if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + + fname = name if name else pass_func.__name__ + info = PassInfo(opt_level, fname, required) + return _transform.MakeFunctionPass(pass_func, info) + + if pass_func: + return create_function_pass(pass_func) + return create_function_pass diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 782bb6a5980f..500bdce742a0 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -465,8 +465,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(ModulePassNode); -TVM_REGISTER_API("relay._transform.CreateModulePass") -.set_body_typed(CreateModulePass); +TVM_REGISTER_API("relay._transform.MakeModulePass") +.set_body_typed(ModulePassNode::make); TVM_REGISTER_API("relay._transform.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -485,8 +485,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(FunctionPassNode); -TVM_REGISTER_API("relay._transform.CreateFunctionPass") -.set_body_typed(CreateFunctionPass); +TVM_REGISTER_API("relay._transform.MakeFunctionPass") +.set_body_typed(FunctionPassNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FunctionPassNode* node, diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 7fdef3fa8b9c..7505aa9ab981 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -259,6 +259,12 @@ def test_pass_run(): test_pass_run() +def test_pass_info(): + info = relay.transform.PassInfo(opt_level=1, name="xyz") + assert info.opt_level == 1 + assert info.name == "xyz" + + def test_sequential_pass(): shape = (10, ) dtype = 'float32' @@ -449,3 +455,4 @@ def expected(): test_function_pass() test_sequential_pass() test_sequential_with_scoping() + test_pass_info()