From 50999fa634858508b4b15d96e4b866db967ccdf8 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 3 Sep 2024 10:30:08 +0800 Subject: [PATCH] [Doc] Deep Dive TensorIR This PR adds a new section in the documentation to introduce the TensorIR abstraction, its learning resources, and tutorials. --- docs/conf.py | 2 + docs/deep_dive/tensor_ir/abstraction.rst | 73 +++++ docs/deep_dive/tensor_ir/index.rst | 31 ++ docs/deep_dive/tensor_ir/learning.rst | 253 ++++++++++++++++ docs/deep_dive/tensor_ir/tutorials/README.txt | 2 + .../deep_dive/tensor_ir/tutorials/creation.py | 285 ++++++++++++++++++ .../tensor_ir/tutorials/transformation.py | 173 +++++++++++ docs/index.rst | 9 + 8 files changed, 828 insertions(+) create mode 100644 docs/deep_dive/tensor_ir/abstraction.rst create mode 100644 docs/deep_dive/tensor_ir/index.rst create mode 100644 docs/deep_dive/tensor_ir/learning.rst create mode 100644 docs/deep_dive/tensor_ir/tutorials/README.txt create mode 100644 docs/deep_dive/tensor_ir/tutorials/creation.py create mode 100644 docs/deep_dive/tensor_ir/tutorials/transformation.py diff --git a/docs/conf.py b/docs/conf.py index c933653233b1..e648dcd670a0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -424,6 +424,7 @@ def jupyter_notebook(script_blocks, gallery_conf, target_dir, real_func): # New tutorial structure under docs folder tvm_path.joinpath("docs", "get_started", "tutorials"), tvm_path.joinpath("docs", "how_to", "tutorials"), + tvm_path.joinpath("docs", "deep_dive", "tensor_ir", "tutorials"), ] gallery_dirs = [ @@ -442,6 +443,7 @@ def jupyter_notebook(script_blocks, gallery_conf, target_dir, real_func): # New tutorial structure under docs folder "get_started/tutorials/", "how_to/tutorials/", + "deep_dive/tensor_ir/tutorials/", ] diff --git a/docs/deep_dive/tensor_ir/abstraction.rst b/docs/deep_dive/tensor_ir/abstraction.rst new file mode 100644 index 000000000000..fc11d7f39156 --- /dev/null +++ b/docs/deep_dive/tensor_ir/abstraction.rst @@ -0,0 +1,73 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _tir-abstraction: + +Tensor Program Abstraction +-------------------------- +Before we dive into the details of TensorIR, let's first introduce what is a primitive tensor +function. Primitive tensor functions are functions that correspond to a single "unit" of +computational operation. For example, a convolution operation can be a primitive tensor function, +and a fused convolution + relu operation can also be a primitive tensor function. +Usually, a typical abstraction for primitive tensor function implementation contains the following +elements: multi-dimensional buffers, loop nests that drive the tensor computations, and finally, +the compute statements themselves. + +.. code:: python + + from tvm.script import tir as T + + @T.prim_func + def main( + A: T.Buffer((128,), "float32"), + B: T.Buffer((128,), "float32"), + C: T.Buffer((128,), "float32"), + ) -> None: + for i in range(128): + with T.block("C"): + vi = T.axis.spatial(128, i) + C[vi] = A[vi] + B[vi] + +Key Elements of Tensor Programs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The demonstrated primitive tensor function calculates the element-wise sum of two vectors. +The function: + +- Accepts three **multi-dimensional buffers** as parameters, and generates one **multi-dimensional + buffer** as output. +- Incorporates a solitary **loop nest** ``i`` that facilitates the computation. +- Features a singular **compute statement** that calculates the element-wise sum of the two + vectors. + +Extra Structure in TensorIR +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Crucially, we are unable to execute arbitrary transformations on the program, as certain +computations rely on the loop's sequence. Fortunately, the majority of primitive tensor +functions we focus on possess favorable properties, such as independence among loop iterations. +For instance, the aforementioned program includes block and iteration annotations: + +- The **block annotation** ``with T.block("C")`` signifies that the block is the fundamental + computation unit designated for scheduling. A block may encompass a single computation + statement, multiple computation statements with loops, or opaque intrinsics such as Tensor + Core instructions. +- The **iteration annotation** ``T.axis.spatial``, indicating that variable ``vi`` is mapped + to ``i``, and all iterations are independent. + +While this information isn't crucial for *executing* the specific program, it proves useful when +transforming the program. Consequently, we can confidently parallelize or reorder loops associated +with ``vi``, provided we traverse all the index elements from 0 to 128. diff --git a/docs/deep_dive/tensor_ir/index.rst b/docs/deep_dive/tensor_ir/index.rst new file mode 100644 index 000000000000..432d47116a3c --- /dev/null +++ b/docs/deep_dive/tensor_ir/index.rst @@ -0,0 +1,31 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _tensor-ir: + +TensorIR +======== +TensorIR is one of the core abstraction in Apache TVM Unity stack, which is used to +represent and optimize the primitive tensor functions. + +.. toctree:: + :maxdepth: 2 + + abstraction + learning + tutorials/creation + tutorials/transformation diff --git a/docs/deep_dive/tensor_ir/learning.rst b/docs/deep_dive/tensor_ir/learning.rst new file mode 100644 index 000000000000..7ca0a1514fbd --- /dev/null +++ b/docs/deep_dive/tensor_ir/learning.rst @@ -0,0 +1,253 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _tir-learning: + +Understand TensorIR Abstraction +=============================== +TensorIR is the tensor program abstraction in Apache TVM, which is one of the standard +machine learning compilation frameworks. The principal objective of tensor program abstraction +is to depict loops and associated hardware acceleration options, including threading, the +application of specialized hardware instructions, and memory access. + +To help our explanations, let us use the following sequence of tensor computations as +a motivating example. Specifically, for two :math:`128 \times 128` matrices ``A`` and ``B``, let us perform the +following two steps of tensor computations. + +.. math:: + + Y_{i, j} &= \sum_k A_{i, k} \times B_{k, j} \\ + C_{i, j} &= \mathbb{relu}(Y_{i, j}) = \mathbb{max}(Y_{i, j}, 0) + + +The above computations resemble a typical primitive tensor function commonly seen in neural networks, +a linear layer with relu activation. We use TensorIR to depict the above computations as follows. + +Before we invoke TensorIR, let's use native Python codes with NumPy to show the computation: + +.. code:: python + + def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray): + Y = np.empty((128, 128), dtype="float32") + for i in range(128): + for j in range(128): + for k in range(128): + if k == 0: + Y[i, j] = 0 + Y[i, j] = Y[i, j] + A[i, k] * B[k, j] + for i in range(128): + for j in range(128): + C[i, j] = max(Y[i, j], 0) + +With the low-level NumPy example in mind, now we are ready to introduce TensorIR. The code block +below shows a TensorIR implementation of ``mm_relu``. The particular code is implemented in a +language called TVMScript, which is a domain-specific dialect embedded in python AST. + +.. code:: python + + @tvm.script.ir_module + class MyModule: + @T.prim_func + def mm_relu(A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32")): + Y = T.alloc_buffer((128, 128), dtype="float32") + for i, j, k in T.grid(128, 128, 128): + with T.block("Y"): + vi = T.axis.spatial(128, i) + vj = T.axis.spatial(128, j) + vk = T.axis.reduce(128, k) + with T.init(): + Y[vi, vj] = T.float32(0) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi = T.axis.spatial(128, i) + vj = T.axis.spatial(128, j) + C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) + + +Next, let's invest the elements in the above TensorIR program. + +Function Parameters and Buffers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +**The function parameters correspond to the same set of parameters on the numpy function.** + +.. code:: python + + # TensorIR + def mm_relu(A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"]): + ... + # NumPy + def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray): + ... + +Here ``A``, ``B``, and ``C`` takes a type named ``T.Buffer``, which with shape +argument ``(128, 128)`` and data type ``float32``. This additional information +helps possible MLC process to generate code that specializes in the shape and data +type. + +**Similarly, TensorIR also uses a buffer type in intermediate result allocation.** + +.. code:: python + + # TensorIR + Y = T.alloc_buffer((128, 128), dtype="float32") + # NumPy + Y = np.empty((128, 128), dtype="float32") + +Loop Iterations +~~~~~~~~~~~~~~~ +**There are also direct correspondence of loop iterations.** + +``T.grid`` is a syntactic sugar in TensorIR for us to write multiple nested iterators. + +.. code:: python + + # TensorIR with `T.grid` + for i, j, k in T.grid(128, 128, 128): + ... + # TensorIR with `range` + for i in range(128): + for j in range(128): + for k in range(128): + ... + # NumPy + for i in range(128): + for j in range(128): + for k in range(128): + ... + +Computational Block +~~~~~~~~~~~~~~~~~~~ +A significant distinction lies in computational statements: +**TensorIR incorporates an additional construct termed** ``T.block``. + +.. code:: python + + # TensorIR + with T.block("Y"): + vi = T.axis.spatial(128, i) + vj = T.axis.spatial(128, j) + vk = T.axis.reduce(128, k) + with T.init(): + Y[vi, vj] = T.float32(0) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + # NumPy + vi, vj, vk = i, j, k + if vk == 0: + Y[vi, vj] = 0 + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + +A **block** represents a fundamental computation unit within TensorIR. Importantly, +a block encompasses more information than standard NumPy code. It comprises a set of block axes +``(vi, vj, vk)`` and the computations delineated around them. + +.. code:: python + + vi = T.axis.spatial(128, i) + vj = T.axis.spatial(128, j) + vk = T.axis.reduce(128, k) + +The above three lines declare the **key properties** about block axes in the following syntax. + +.. code:: python + + [block_axis] = T.axis.[axis_type]([axis_range], [mapped_value]) + +These three lines convey the following details: + +- They specify the binding of ``vi``, ``vj``, ``vk`` (in this instance, to ``i``, ``j``, ``k``). +- They declare the original range intended for ``vi``, ``vj``, ``vk`` + (the 128 in ``T.axis.spatial(128, i)``). +- They announce the properties of the iterators (spatial, reduce). + +Block Axis Properties +~~~~~~~~~~~~~~~~~~~~~ +Let's delve deeper into the properties of the block axis. These properties signify the axis's +relationship to the computation in progress. The block comprises three axes ``vi``, ``vj``, and +``vk``, meanwhile the block reads the buffer ``A[vi, vk]``, ``B[vk, vj]`` and writs the buffer +``Y[vi, vj]``. Strictly speaking, the block performs (reduction) updates to Y, which we label +as write for the time being, as we don't require the value of Y from another block. + +Significantly, for a fixed value of ``vi`` and ``vj``, the computation block yields a point +value at a spatial location of ``Y`` (``Y[vi, vj]``) that is independent of other locations in ``Y`` +(with different ``vi``, ``vj`` values). We can refer to ``vi``, ``vj`` as **spatial axes** since +they directly correspond to the start of a spatial region of buffers that the block writes to. +The axes involved in reduction (``vk``) are designated as **reduce axes**. + +Why Extra Information in Block +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +One crucial observation is that the additional information (block axis range and their properties) +makes the block to be **self-contained** when it comes to the iterations that it is supposed to +carry out independent from the external loop-nest ``i, j, k``. + +The block axis information also provides additional properties that help us to validate the correctness of the +external loops that are used to carry out the computation. For example, the above code block will result in an +error because the loop expects an iterator of size 128, but we only bound it to a for loop of size 127. + +.. code:: python + + # wrong program due to loop and block iteration mismatch + for i in range(127): + with T.block("C"): + vi = T.axis.spatial(128, i) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + error here due to iterator size mismatch + ... + +Sugars for Block Axes Binding +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +In situations where each of the block axes is directly mapped to an outer loop iterator, +we can use ``T.axis.remap`` to declare the block axis in a single line. + +.. code:: python + + # SSR means the properties of each axes are "spatial", "spatial", "reduce" + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + +which is equivalent to + +.. code:: python + + vi = T.axis.spatial(range_of_i, i) + vj = T.axis.spatial(range_of_j, j) + vk = T.axis.reduce (range_of_k, k) + +So we can also write the programs as follows. + +.. code:: python + + @tvm.script.ir_module + class MyModuleWithAxisRemapSugar: + @T.prim_func + def mm_relu(A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32")): + Y = T.alloc_buffer((128, 128), dtype="float32") + for i, j, k in T.grid(128, 128, 128): + with T.block("Y"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + Y[vi, vj] = T.float32(0) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) diff --git a/docs/deep_dive/tensor_ir/tutorials/README.txt b/docs/deep_dive/tensor_ir/tutorials/README.txt new file mode 100644 index 000000000000..bbbd7d3e5a20 --- /dev/null +++ b/docs/deep_dive/tensor_ir/tutorials/README.txt @@ -0,0 +1,2 @@ +Deep Dive: TensorIR +------------------- diff --git a/docs/deep_dive/tensor_ir/tutorials/creation.py b/docs/deep_dive/tensor_ir/tutorials/creation.py new file mode 100644 index 000000000000..51481fb2e325 --- /dev/null +++ b/docs/deep_dive/tensor_ir/tutorials/creation.py @@ -0,0 +1,285 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _tir-creation: + +TensorIR Creation +----------------- +In this section, we will introduce the methods to write a TensorIR function +in Apache TVM Unity. This tutorial presumes familiarity with the fundamental concepts of TensorIR. +If not already acquainted, please refer to :ref:`tir-learning` initially. + +.. note:: + + This tutorial concentrates on the construction of **standalone** TensorIR functions. The + techniques presented here are not requisite for end users to compile Relax models. + +""" + +###################################################################### +# Create TensorIR using TVMScript +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# The most straightforward way to create a TensorIR function via TVMScript. +# TVMScript is a TVM Python dialect that represents TensorIR in TVM. +# +# .. important:: +# +# While TVMScript employs Python syntax and AST, ensuring full compatibility +# with Python tools like auto-completion and linting, it is not a native Python +# language and cannot be executed by a Python interpreter. +# +# More precisely, the decorator **@tvm.script** extracts the Python AST from +# the decorated function, subsequently parsing it into TensorIR. +# +# Standard Format +# *************** +# Let's take an example of ``mm_relu`` from :ref:`tir-learning`. Here is the complete +# format of the ir_module and in TVMScript: + + +import numpy as np +import tvm +from tvm.script import ir as I +from tvm.script import tir as T + + +@I.ir_module +class MyModule: + @T.prim_func + def mm_relu( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), + ): + Y = T.alloc_buffer((128, 128), dtype="float32") + for i in range(128): + for j in range(128): + for k in range(128): + with T.block("Y"): + vi = T.axis.spatial(128, i) + vj = T.axis.spatial(128, j) + vk = T.axis.reduce(128, k) + T.reads(A[vi, vk], B[vk, vj]) + T.writes(Y[vi, vj]) + with T.init(): + Y[vi, vj] = T.float32(0) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + for i in range(128): + for j in range(128): + with T.block("C"): + vi = T.axis.spatial(128, i) + vj = T.axis.spatial(128, j) + T.reads(Y[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) + + +###################################################################### +# Concise with Syntactic Sugar +# **************************** +# For ease of writing, we can employ the following syntactic sugar to +# streamline the code: +# +# - Utilize ``T.grid`` to condense nested loops; +# - Employ ``T.axis.remap`` to abbreviate block iterator annotations; +# - Exclude ``T.reads`` and ``T.writes`` for blocks whose content can +# be inferred from the block body; + + +@I.ir_module +class ConciseModule: + @T.prim_func + def mm_relu( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), + ): + Y = T.alloc_buffer((128, 128), dtype="float32") + for i, j, k in T.grid(128, 128, 128): + with T.block("Y"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + Y[vi, vj] = T.float32(0) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) + + +###################################################################### +# We can use the following code to verify that the two modules are equivalent: + +print(tvm.ir.structural_equal(MyModule, ConciseModule)) + +###################################################################### +# Interactive with Python Variables +# ********************************* +# Despite TVMScript not being executed by a Python interpreter, limited +# interaction with Python is feasible. For instance, Python variables can +# be used to ascertain the shape and data type of a TensorIR. + +# Python variables +M = N = K = 128 +dtype = "float32" + + +# IRModule in TVMScript +@I.ir_module +class ConciseModuleFromPython: + @T.prim_func + def mm_relu( + A: T.Buffer((M, K), dtype), + B: T.Buffer((K, N), dtype), + C: T.Buffer((M, N), dtype), + ): + Y = T.alloc_buffer((M, N), dtype) + for i, j, k in T.grid(M, N, K): + with T.block("Y"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + Y[vi, vj] = T.cast(T.float32(0), dtype) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(M, N): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype)) + + +###################################################################### +# Check the equivalence: + +print(tvm.ir.structural_equal(ConciseModule, ConciseModuleFromPython)) + + +###################################################################### +# TensorIR Function with Dynamic Shapes +# ************************************* +# Despite TVMScript not being executed by a Python interpreter, limited +# interaction with Python is feasible. For instance, Python variables can +# be used to ascertain the shape and data type of a TensorIR. + + +@I.ir_module +class DynamicShapeModule: + @T.prim_func + def mm_relu(a: T.handle, b: T.handle, c: T.handle): + # Dynamic shape definition + M, N, K = T.int32(), T.int32(), T.int32() + + # Bind the input buffers with the dynamic shapes + A = T.match_buffer(a, [M, K], dtype) + B = T.match_buffer(b, [K, N], dtype) + C = T.match_buffer(c, [M, N], dtype) + Y = T.alloc_buffer((M, N), dtype) + for i, j, k in T.grid(M, N, K): + with T.block("Y"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + Y[vi, vj] = T.cast(T.float32(0), dtype) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(M, N): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = T.max(Y[vi, vj], T.cast(T.float32(0), dtype)) + + +###################################################################### +# Now let's check the runtime dynamic shape inference: + + +def evaluate_dynamic_shape(lib: tvm.runtime.Module, m: int, n: int, k: int): + A = tvm.nd.array(np.random.uniform(size=(m, k)).astype("float32")) + B = tvm.nd.array(np.random.uniform(size=(k, n)).astype("float32")) + C = tvm.nd.array(np.zeros((m, n), dtype="float32")) + lib(A, B, C) + return C.numpy() + + +# Compile lib only once +dyn_shape_lib = tvm.build(DynamicShapeModule, target="llvm") +# Able to handle different shapes +print(evaluate_dynamic_shape(dyn_shape_lib, m=4, n=4, k=4)) +print(evaluate_dynamic_shape(dyn_shape_lib, m=64, n=64, k=128)) + +###################################################################### +# Create TensorIR using Tensor Expression +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Often, the specifics of TensorIR are disregarded in favor of expressing the computation more +# succinctly, leading to the pragmatic generation of TensorIR. This is where Tensor Expression +# (TE) becomes relevant. +# +# Tensor Expression (TE) serves as a domain-specific language delineating a sequence of +# computations through an expression-like API. +# +# .. note:: +# +# Tensor Expression comprises two components within the TVM stack: the expression and the +# schedule. The expression is the domain-specific language embodying the computation pattern, +# precisely what we're addressing in this section. Conversely, the TE schedule is the legacy +# scheduling method, has been superseded by the TensorIR schedule in the TVM Unity stack. +# +# Create Static-Shape Functions +# ***************************** +# We use the same example of ``mm_relu`` from the last subsection to demonstrate the +# TE creation method. + +from tvm import te + +A = te.placeholder((128, 128), "float32", name="A") +B = te.placeholder((128, 128), "float32", name="B") +k = te.reduce_axis((0, 128), "k") +Y = te.compute((128, 128), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y") +C = te.compute((128, 128), lambda i, j: te.max(Y[i, j], 0), name="C") + +###################################################################### +# Here ``te.compute`` takes the signature ``te.compute(output_shape, fcompute)``. +# And the fcompute function describes how we want to compute the value of each +# element ``Y[i, j]`` for a given index: +# +# .. code:: python +# +# lambda i, j: te.sum(A[i, k] * B[k, j], axis=k) +# +# The aforementioned lambda expression encapsulates the computation: +# :math:`Y_{i, j} = \sum_k A_{i, k} \times B_{k, j}`. Upon defining the computation, +# we can formulate a TensorIR function by incorporating the pertinent parameters of interest. +# In this specific instance, we aim to construct a function with two input parameters **A, B** +# and one output parameter **C**. + +te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"}) +TEModule = tvm.IRModule({"mm_relu": te_func}) +TEModule.show() + +###################################################################### +# Create Dynamic-Shape Functions +# ****************************** +# We can also create a dynamic-shape function using Tensor Expression. The only difference +# is that we need to specify the shape of the input tensors as symbolic variables. + +# Declare symbolic variables +M, N, K = te.var("m"), te.var("n"), te.var("k") +A = te.placeholder((M, N), "float32", name="A") +B = te.placeholder((K, N), "float32", name="B") +k = te.reduce_axis((0, K), "k") +Y = te.compute((M, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="Y") +C = te.compute((M, N), lambda i, j: te.max(Y[i, j], 0), name="C") + +dyn_te_func = te.create_prim_func([A, B, C]).with_attr({"global_symbol": "mm_relu"}) +DynamicTEModule = tvm.IRModule({"mm_relu": dyn_te_func}) +DynamicTEModule.show() diff --git a/docs/deep_dive/tensor_ir/tutorials/transformation.py b/docs/deep_dive/tensor_ir/tutorials/transformation.py new file mode 100644 index 000000000000..1dcf8e7ab5c8 --- /dev/null +++ b/docs/deep_dive/tensor_ir/tutorials/transformation.py @@ -0,0 +1,173 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _tir-transform: + +Transformation +-------------- +In this section, we will get to the main ingredients of the compilation flows - +transformations of primitive tensor functions. +""" + +###################################################################### +# In the :ref:`previous section `, we have given an example of how to write +# ``mm_relu`` using TensorIR. In practice, there can be multiple ways to implement +# the same functionality, and each implementation can result in different performance. +# +# .. note:: +# This tutorial primarily illustrates the application of TensorIR Transformation, +# rather than delving into optimization techniques. +# +# First, let's take a look at the implementation of ``mm_relu`` in the previous section: + +import tvm +from tvm.script import ir as I +from tvm.script import tir as T + + +@I.ir_module +class MyModule: + @T.prim_func + def main( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + Y = T.alloc_buffer((128, 128)) + for i, j, k in T.grid(128, 128, 128): + with T.block("Y"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + Y[vi, vj] = T.float32(0) + Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) + + +###################################################################### +# Before we transform the function, let's first evaluate the performance of the +# original implementation. + +import numpy as np + +a_np = np.random.uniform(size=(128, 128)).astype("float32") +b_np = np.random.uniform(size=(128, 128)).astype("float32") +c_np = a_np @ b_np + +a_nd = tvm.nd.array(a_np) +b_nd = tvm.nd.array(b_np) +c_nd = tvm.nd.array(np.zeros((128, 128), dtype="float32")) + + +def evaluate(mod: tvm.IRModule): + lib = tvm.build(mod, target="llvm") + # check correctness + lib(a_nd, b_nd, c_nd) + np.testing.assert_allclose(c_nd.numpy(), c_np, rtol=1e-5) + # evaluate performance + f_timer = lib.time_evaluator("main", tvm.cpu()) + print(f_timer(a_nd, b_nd, c_nd)) + + +evaluate(MyModule) + +###################################################################### +# Initialization Schedule +# *********************** +# We initiate the process of code transformation by establishing a Schedule helper class, +# utilizing the provided **MyModule** as input. + +sch = tvm.tir.Schedule(MyModule) + +###################################################################### +# Loop Tiling +# *********** +# Subsequently, we execute the requisite operations to acquire a reference to +# block **Y** and its associated loops. + +block_Y = sch.get_block("Y") +i, j, k = sch.get_loops(block_Y) + +###################################################################### +# We now proceed to execute the transformations. The initial modification involves +# splitting loop ``j`` into two separate loops, with the inner loop possessing a +# length of 4. It is crucial to understand that the transformation process is procedural; +# thus, inadvertent execution of the block twice will yield an error stating the +# non-existence of variable ``j``. + +j0, j1 = sch.split(j, factors=[None, 8]) + +###################################################################### +# The outcome of the transformation can be examined, as it is retained within ``sch.mod``. + +sch.mod.show() + +###################################################################### +# Following the initial transformation phase, two supplementary loops, ``j_0`` and ``j_1``, +# have been generated with respective ranges of 32 and 4. The subsequent +# action involves reordering these two loops. + +sch.reorder(j0, k, j1) +sch.mod.show() +evaluate(sch.mod) + +###################################################################### +# Leverage Localities +# ******************* +# Subsequently, we will execute two additional transformation steps to achieve a different +# variant. First, we employ a primitive known as **reverse_compute_at** to relocate block +# **C** to an inner loop of **Y**. + +block_C = sch.get_block("C") +sch.reverse_compute_at(block_C, j0) +sch.mod.show() + +###################################################################### +# Rewrite Reduction +# ***************** +# Until now, the reduction initialization and update step have been maintained together +# within a single block body. This amalgamated form facilitates loop transformations, +# as the outer loops ``i``, ``j`` of initialization and updates generally need to remain +# synchronized. +# +# Following the loop transformations, we can segregate the initialization of Y's elements +# from the reduction update via the **decompose_reduction** primitive. + +sch.decompose_reduction(block_Y, k) +sch.mod.show() +evaluate(sch.mod) + +###################################################################### +# Trace the Transformation +# ************************ +# TensorIR schedule is a procedural language, and the transformation is executed in a +# step-by-step manner. We can trace the transformation by printing the schedule or the +# history of the schedule. +# +# We've already see the schedule by printing ``sch.mod``. We can also print the history +# of the schedule by ``sch.trace``. + +sch.trace.show() + +###################################################################### +# Alternatively, we can output the IRModule in conjunction with the historical trace. + +sch.show() diff --git a/docs/index.rst b/docs/index.rst index fdfaa56f7454..2ea00862ac1b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -43,6 +43,15 @@ driving its costs down. how_to/index +.. The Deep Dive content is comprehensive +.. we maintain a ``maxdepth`` of 2 to display more information on the main page. + +.. toctree:: + :maxdepth: 2 + :caption: Deep Dive + + deep_dive/tensor_ir/index + .. toctree:: :maxdepth: 1 :caption: API Reference