diff --git a/docs/dev/relay_add_pass.rst b/docs/dev/relay_add_pass.rst index 1a264ef9c012..8c2fc9a21469 100644 --- a/docs/dev/relay_add_pass.rst +++ b/docs/dev/relay_add_pass.rst @@ -22,17 +22,15 @@ Adding a Compiler Pass to Relay Compiler passes are the primary interface for both extending Relay's feature set and for performing optimizations on Relay programs. By writing a compiler -pass, you can then modify the AST and/or collect information about the AST, -depending on your goal. Indeed, some of Relay's most important "built-in" -features (e.g., autodiff and type inference) are nothing more than compiler -passes. +pass, you can modify the AST or collect information about the AST, +depending on your goal. Indeed, some of Relay's most important built-in +features (e.g., autodiff and type inference) are nothing more than "standard" +compiler passes. -At a high level, there are three key components to writing a pass: +At a high level, there are two key components to writing a pass: - Creating one or more C++ classes that traverse the program -- Registering an API endpoint (a TVM packed function) with the - ``TVM_REGISTER_API`` macro that performs the pass -- Wrapping the Python API hook in a neater interface +- Wrapping the traversal implementation and its metadata in the pass manager API so it can neatly interface with the :ref:`relay-pass-infra` To begin, we'll give an overview of the key mechanisms for writing a compiler pass. Then, we'll walk through a concrete example of the constant-folding @@ -183,9 +181,9 @@ Example: Constant Folding ------------------------- In order to better understand the process of writing a pass, we will look at -the constant folding pass (found in ``src/relay/pass/fold_constant.cc`` and -in ``python/tvm/relay/ir_pass.py``) as a guide, because it is a relatively -simple pass that incorporates both types of traversals. +the constant folding pass (found in `src/relay/pass/fold_constant.cc`_) +as a guide, because it is a relatively simple pass that incorporates +both types of traversals. Constant folding involves evaluating expressions in the program that only involve constant values, then replacing those expressions with the result @@ -327,32 +325,82 @@ all of the arguments are constant (using ``ConstantChecker``). Evaluating the call produces a **value**, so we use a helper method ``ValueToExpr`` to allow us to place the evaluated expression back into the AST. -Now, we construct the public interface ``FoldConstant`` to our constant -folder, which is a standalone function outside of the ``ConstantFolder`` -class. ``FoldConstant`` takes an expression and internally creates and uses a +Now, we construct a more convenient interface ``FoldConstant`` for our constant +folder. ``FoldConstant`` is a standalone function outside of the ``ConstantFolder`` +class that takes an expression and internally creates and uses a ``ConstantFolder`` instance (the full definition can be found in -``include/tvm/relay/pass.h``). +`src/relay/pass/fold_constant.cc`_). -To allow other C++ modules to use our pass, we declare the public interface -in ``src/relay/pass/pass.h``: + +Registering a Pass with the Pass Manager +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +*Note: please see the documentation on the :ref:`relay-pass-infra` for more specific detail on this subject.* + +With the AST traversers written, the pass can be registered to become a TVM +API endpoint with the following code: .. code:: c - TVM_DLL Expr FoldConstant(const Expr& expr); + namespace transform { -Registering an API Endpoint -~~~~~~~~~~~~~~~~~~~~~~~~~~~ + Pass FoldConstant() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(FoldConstant(f)); + }; + return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); + } -With the AST traversers written, the pass can be registered to become a TVM -API endpoint with the following code snippet: + } // namespace transform + +If the ``Pass`` object produced by the above code is given to the pass infrastructure, +it will ensure that the AST traversal is applied to every function in the +given Relay module, which is the behavior one would expect for a constant folding +pass (it should fold all constants where possible). + +The function ``CreateFunctionPass`` +allows for registering the optimization level of the pass (in this case, 2), which can +be used to group together passes based on their general utility, a name for the pass, +and any dependencies for the pass. A pass's dependencies are given as a list of any passes +whose results are necessary to be able to run the current pass. ``FoldConstant`` does not +have any dependencies, but many Relay passes do depend on having type information, +so ``InferType`` is a common dependency; others may depend on the program's being in +A-normal form, via the ``ToANormalForm`` pass. + +Note that the ``PassContext`` object contains information a pass uses for +error reporting and configuration options; ``FoldConstant`` does not need +this information but other passes may reference their ``PassContext`` objects. + +The pass can now be invoked via the pass infrastructure, though it's a good idea to +also add a Python binding for the pass, as in this code snippet: .. code:: c - TVM_REGISTER_API("relay._ir_pass.FoldConstant") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = FoldConstant(args[0]); - }); + TVM_REGISTER_API("relay._transform.FoldConstant") + .set_body_typed(FoldConstant); + +Once ``Pass`` objects are defined in the above fashion, they can be invoked using the +pass infrastructure's ``Sequential`` construct, which takes a list of passes and applies +them in sequence to a Relay module, obtaining a transformed module as a result. For example, +the below code applies both the ``FoldConstant`` and ``ToANormalForm`` passes +(one after the other) to each function in ``mod`` and obtains a new module. + +.. code:: python + + seq = transform.Sequential([ + relay.transform.FoldConstant(), + relay.transform.ToANormalForm() + ]) + new_mod = seq(mod) + +More detail about registration can be found in :ref:`tvm-runtime-system` and more +information about the pass manager interface can be found in :ref:`relay-pass-infra`. +Relay's standard passes are listed in `include/tvm/relay/transform.h`_ and implemented +in `src/relay/pass/`_. + +.. _include/tvm/relay/transform.h: https://github.com/dmlc/tvm/blob/master/include/tvm/relay/transform.h + +.. _src/relay/pass: https://github.com/dmlc/tvm/tree/master/src/relay/pass -And the pass can now be used in C++ and Python, though it's a good idea to -wrap the API in Python, as described in :ref:`relay-add-op`. More detail -about registration can be found in :ref:`tvm-runtime-system`. +.. _src/relay/pass/fold_constant.cc: https://github.com/dmlc/tvm/blob/master/src/relay/pass/fold_constant.cc diff --git a/docs/dev/relay_pass_infra.rst b/docs/dev/relay_pass_infra.rst index 4c699b12fd78..b358e957b258 100644 --- a/docs/dev/relay_pass_infra.rst +++ b/docs/dev/relay_pass_infra.rst @@ -15,8 +15,10 @@ specific language governing permissions and limitations under the License. -Relay Pass Infra -================================== +.. _relay-pass-infra: + +Relay Pass Infrastructure +========================= Relay features a series of optimization passes which improve performance metrics of models such as mean inference, memory footprint, or power consumption for