diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index fbfc1d0ff28f..3d9f0b777f46 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -9,7 +9,6 @@ for users to implement and use passes more conveniently. """ import types -from abc import abstractmethod from . import _ir_pass from . import _make @@ -57,15 +56,14 @@ def __init__(self): @register_relay_node class Pass(RelayNode): - """The base class of all passes. This class is designed as a pure virtual - class that will be implemented by the subclasses. + """The base class of all passes. All methods here are just simple wrappers + that are implemented in the backend. They are defined for users to + conveniently interact with the base class. """ - @abstractmethod def set_pass_context(self, pass_ctx): """Setup the pass context for analysis and optimizations. This context - could be shared by different passes for sequential passes. It is an - abstract method that will be implemented by each subclass. + could be shared by different passes for sequential passes. Parameters ---------- @@ -73,18 +71,17 @@ def set_pass_context(self, pass_ctx): The context that is used to help perform a certain pass or a series of passes. """ - raise NotImplementedError("Pure virtual function is not implemented.") + if not isinstance(pass_ctx, PassContext): + raise TypeError("pass_ctx is expected to be the PassContext type") + _ir_pass.SetContext(self, pass_ctx) - @abstractmethod def get_pass_info(self): - """Get the pass meta. It is an abstract method that will be implemented - by each subclass.""" - raise NotImplementedError("Pure virtual function is not implemented.") + """Get the pass meta.""" + return _ir_pass.GetPassInfo(self) - @abstractmethod def __call__(self, mod): - """Execute the pass. It is an abstract function that will be - implemented by subclasses. + """Execute the pass. Note that for sequential pass, the dependency among + different passes will be resolved in the backend. Parameters ---------- @@ -96,7 +93,7 @@ def __call__(self, mod): mod : tvm.relay.Module The updated module after applying this pass. """ - raise NotImplementedError("Pure virtual function is not implemented.") + return _ir_pass.RunPass(self, mod) @register_relay_node @@ -123,37 +120,6 @@ def __init__(self, name, opt_level, pass_func, required=None): name, opt_level, required, pass_func) - def set_pass_context(self, pass_ctx): - """Setup the pass context for analysis and optimizations. - - Parameters - ---------- - pass_ctx : PassContext - The context that is used to help perform a certain module pass. - """ - if not isinstance(pass_ctx, PassContext): - raise TypeError("pass_ctx is expected to be the PassContext type") - _ir_pass.SetContext(self, pass_ctx) - - def get_pass_info(self): - """Get the meta data for module pass.""" - return _ir_pass.GetPassInfo(self) - - def __call__(self, mod): - """Execute a module pass. - - Parameters - ---------- - mod : tvm.relay.Module - The module that the module pass is executed on. - - Returns - ------- - ret : tvm.relay.Module - The updated module. - """ - return _ir_pass.RunModulePass(self, mod) - @register_relay_node class FunctionPass(Pass): @@ -180,37 +146,6 @@ def __init__(self, name, opt_level, pass_func, required=None): name, opt_level, required, pass_func) - def set_pass_context(self, pass_ctx): - """Setup the pass context for analysis and optimizations. - - Parameters - ---------- - pass_ctx : PassContext - The context that is used to help perform the function pass. - """ - if not isinstance(pass_ctx, PassContext): - raise TypeError("pass_ctx is expected to be the PassContext type") - _ir_pass.SetContext(self, pass_ctx) - - def get_pass_info(self): - """Get the meta data for function pass.""" - return _ir_pass.GetPassInfo(self) - - def __call__(self, mod): - """Execute a function pass. - - Parameters - ---------- - mod : tvm.relay.Module - The module that the function pass is executed on. - - Returns - ------- - ret : tvm.relay.Module - The updated module. - """ - return _ir_pass.RunFunctionPass(self, mod) - @register_relay_node class SequentialPass(Pass): @@ -241,37 +176,6 @@ def __init__(self, name, opt_level, passes, required=None, disabled=None): self.__init_handle_by_constructor__(_ir_pass.CreateSequentialPass, name, opt_level, passes, required, disabled) - def set_pass_context(self, pass_ctx): - """Setup the pass context for analysis and optimizations. This context - could be shared by different passes for sequential passes. - - Parameters - ---------- - pass_ctx : PassContext - The context that is used to help perform a series of passes. - """ - if not isinstance(pass_ctx, PassContext): - raise TypeError("pass_ctx is expected to be the PassContext type") - _ir_pass.SetContext(self, pass_ctx) - - def get_pass_info(self): - """Get the meta data for sequential pass.""" - return _ir_pass.GetPassInfo(self) - - def __call__(self, mod): - """Execute a sequence of passes. - - Parameters - ---------- - mod : tvm.relay.Module - The module that the function pass is executed on. - - Returns - ------- - ret : tvm.relay.Module - The updated module. - """ - return _ir_pass.RunSequentialPass(self, mod) def create_module_pass(pass_name, opt_level, pass_func, required=None): diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 5478dcd6ff6c..052436506404 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -248,7 +248,7 @@ PassInfo PassInfoNode::make(std::string name, int opt_level, tvm::Array required) { auto pass_info = make_node(); pass_info->name = std::move(name); - pass_info->opt_level = std::move(opt_level); + pass_info->opt_level = opt_level; pass_info->required = std::move(required); return PassInfo(pass_info); } @@ -446,12 +446,12 @@ TVM_REGISTER_API("relay._ir_pass.CreateModulePass") *ret = CreateModulePass(name, opt_level, required, pass_func); }); -TVM_REGISTER_API("relay._ir_pass.RunModulePass") +TVM_REGISTER_API("relay._ir_pass.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { - ModulePass pass = args[0]; + Pass pass = args[0]; Module mod = args[1]; CHECK(pass.defined()) - << "Running a pass on undefined ModulePass is not allowed." + << "Running an undefined pass is not allowed." << "\n"; const auto* pn = pass.operator->(); @@ -477,17 +477,6 @@ TVM_REGISTER_API("relay._ir_pass.CreateFunctionPass") *ret = CreateFunctionPass(name, opt_level, required, pass_func); }); -TVM_REGISTER_API("relay._ir_pass.RunFunctionPass") -.set_body([](TVMArgs args, TVMRetValue* ret) { - FunctionPass pass = args[0]; - Module mod = args[1]; - CHECK(pass.defined()) - << "Running a pass on undefined ModulePass is not allowed." - << "\n"; - const auto* pn = pass.operator->(); - *ret = (*pn)(mod); -}); - TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FunctionPassNode* node, tvm::IRPrinter* p) { @@ -509,17 +498,6 @@ TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass") *ret = SequentialPassNode::make(pass_info, passes, disabled); }); -TVM_REGISTER_API("relay._ir_pass.RunSequentialPass") -.set_body([](TVMArgs args, TVMRetValue* ret) { - SequentialPass pass = args[0]; - Module mod = args[1]; - CHECK(pass.defined()) - << "Running passes on undefined SequentialPass is not allowed." - << "\n"; - const auto* pn = pass.operator->(); - *ret = (*pn)(mod); -}); - TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const SequentialPassNode* node, tvm::IRPrinter* p) {