diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index b55a599676a2d..6f44cfbffc16a 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -109,9 +109,7 @@ class PassNode : public RelayNode { virtual std::vector Required() const = 0; /*! - * \brief Execute the optimization pass using a functor. This functor invokes - * the `run` method to perform a real optimization on a certain type - * of node. + * \brief Execute the optimization pass using a functor. * * \param mod The module that an optimization pass runs on. * @@ -145,7 +143,7 @@ class Pass : public NodeRef { * * \param name The name of the module pass. * \param opt_level The optimization level of the module pass. - * \param pass_func The curried packed function that contains the optimization. + * \param pass_func The packed function that contains the optimization. * * \return The created module pass. */ @@ -158,7 +156,7 @@ Pass CreateModulePass( * * \param name The name of the function pass. * \param opt_level The optimization level of the function pass. - * \param pass_func The curried packed function that contains the optimization. + * \param pass_func The packed function that contains the optimization. * * \return The created function pass. */ diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 860926ba01efa..174e91638a546 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -23,7 +23,7 @@ class PassContext(RelayNode): """The basis where a Relay optimization/analysis runs on. Each pass context contains a number of auxiliary information that is used to help an optimization pass. Such information includes the error reporter - to record the errors of during the performing the optimization, etc. + to record the errors of during the optimization, etc. """ def __init__(self): @@ -61,7 +61,7 @@ def set_pass_context(self, pass_ctx): """ if not isinstance(pass_ctx, PassContext): raise TypeError("pass_ctx is expected to be the PassContext type") - return _ir_pass.SetContext(self, pass_ctx) + _ir_pass.SetContext(self, pass_ctx) def __call__(self, mod): """Execute the pass. It is an abstract function that will be @@ -92,7 +92,7 @@ class ModulePass(Pass): opt_level : int The optimization level of this pass. - pass_func : Callable[PassContext: tvm.relay.Module -> tvm.relay.Module] + pass_func : Callable[(tvm.relay.Module, PassContext) -> tvm.relay.Module] The callback function that sketches a certain optimization. """ @@ -128,7 +128,8 @@ class FunctionPass(Pass): opt_level : int The optimization level of this pass. - pass_func : Callable[PassContext: tvm.relay.Function -> tvm.relay.Function] + pass_func : Callable[(tvm.relay.Function, PassContext) -> + tvm.relay.Function] The callback function that sketches a certain optimization. """ @@ -205,7 +206,7 @@ def create_module_pass(pass_name, opt_level, pass_func): opt_level : int The optimization level of this pass. - pass_func : Optional[Callable[PassContext: Module/Function -> + pass_func : Optional[Callable[(Module/Function, PassContext) -> Module/Function]] The implemented optimization pass. @@ -232,7 +233,7 @@ def create_function_pass(pass_name, opt_level, pass_func): opt_level : int The optimization level of this pass. - pass_func : Optional[Callable[PassContext: Module/Function -> + pass_func : Optional[Callable[(Module/Function, PassContext) -> Module/Function]] The implemented optimization pass. diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index f89aae5d786af..ce613543aed5f 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -37,7 +37,7 @@ class ModulePassNode : public PassNode { } /*! - * \brief Run a function pass on a certain module. + * \brief Run a module pass on a certain module. * * \param mod The module that an optimization pass runs on. *