From 5243393c51ea53111c9e3e476824f0ca5772629f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Paul=20L=C3=BCcke?= Date: Thu, 10 Oct 2024 07:58:56 +0200 Subject: [PATCH] [tkw] Progress on compiler description (#207) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This also adds lots of notes, todos and general structure Signed-off-by: Martin Lücke --- iree/turbine/kernel/wave/docs/mlsys/tkw.tex | 76 +++++++++++++++++++-- 1 file changed, 72 insertions(+), 4 deletions(-) diff --git a/iree/turbine/kernel/wave/docs/mlsys/tkw.tex b/iree/turbine/kernel/wave/docs/mlsys/tkw.tex index 7035f20f..81cb5111 100644 --- a/iree/turbine/kernel/wave/docs/mlsys/tkw.tex +++ b/iree/turbine/kernel/wave/docs/mlsys/tkw.tex @@ -220,7 +220,8 @@ \subsection{Wave Programming Model} The advantages of the wave programming model are that it allows developers to write kernels at the same level of abstraction as the native hardware matrix multiply accumulate (MMA) instructions which operate at the granularity of waves giving them low-level control from a high-level abstraction. - +% Should we compare to the Triton programming model here? +% Possibly we could have roughly something like figure 3 of the Triton paper \subsection{Syntax \& Semantics} The Wave language partitions programs into two distinct regions as can be seen in Listing \ref{lst:gemm}. @@ -240,7 +241,7 @@ \subsubsection{Constraints} explicitly specify the canonical shape of the program in the \texttt{HardwareConstraint} by using the \texttt{vector\_shapes} keyword. For more examples of this, see Appendix \ref{appendix:samples}. \\ \\ -The constraints of the program serve multiple purposes. First, then separate out the distribution strategy from the kernel. +The constraints of the program serve multiple purposes. First, they separate out the distribution strategy from the kernel. Second, they result in much simpler kernels because kernel authors do not need to keep track the offsets to different memory locations as the compiler takes care of this. @@ -251,7 +252,7 @@ \subsubsection{Kernel} even though the inputs are specified as type \texttt{Memory} with shape \texttt{[M, K], [N, K], [M, N]}, the actual shape of the memory is determined by the constraints. In order to simplify the kernel, we write the kernel using the original symbolic shapes and let the compiler determine the actual shapes of the memory buffers. - +% Note: The question that comes up with this phrasing: Could these shapes then just be left out. In reality they are at least important to figure out the indexing dimensions. \begin{lstlisting}[language=Wave, frame=single, breaklines, caption={Mixed-precision $C = A \times B^{T}$ expressed in Wave.}, captionpos=b, label={lst:gemm}] constraints = [WorkgroupConstraint(M, BLOCK_M, 0)] @@ -283,6 +284,8 @@ \subsubsection{Kernel} write(loop, c, elements_per_thread=ELEMS) \end{lstlisting} +\newpage + \section{Wave Compiler} \label{section:wave_compiler} The Wave compiler is a Python-based compiler designed to process and optimize kernels written in the Wave language. It leverages symbolic types to represent memory access patterns and perform reasoning about them. The compilation process involves several key steps: @@ -341,7 +344,72 @@ \subsection{Tracing with torch.fx} The Wave compiler utilizes torch.fx~\cite{reed2022torch} for symbolic tracing of the kernel. This process involves executing the Python kernel program with special \emph{Proxy} objects that act as placeholders for actual values. As the program runs, torch.fx records all definitions and function calls, effectively capturing the computational structure of the kernel. The result is a comprehensive representation of the kernel's logic in the form of the torch.fx graph IR. By leveraging torch.fx, Wave strikes a balance between the flexibility of a Python-embedded DSL and the power of a custom compiler, allowing for specialized optimizations while maintaining an accessible and extensible programming model. \subsection{Intermediate Representation} -We extended the torch.fx intermediate representation (IR) with custom primitives and types to accurately model the Wave programming model. Our enhancements include custom primitives representing Wave-specific constructs such as wave-level operations and memory accesses. +We extended the torch.fx intermediate representation with custom primitives and types to accurately model the Wave programming model. Our enhancements include custom primitives representing Wave-specific constructs such as wave-level operations and memory accesses. Representing them explicitly in the IR simplifies type inference and enables easier transformations on the IR. At the same time we keep compatibility with torch.fx IR in order to reuse existing tooling for e.g. visualization. + + +\subsection{Lowering Wave programming to thread programming} +% Note: A description of the programming model will already be in the previous chapter. So I can take that as a given. +A wave kernel is expressed following an SPMD programming model in the granularity of a single wave. While the target programming model for the output IR follows SPMD as well, it operates on the granularity of a single thread with explicit data movement and synchronization. +In consequence the computation graph needs to be expanded according to the input sizes and the distribution of data to threads to model the instructions for each thread. +\smallskip +As each thread has to load different data depending on its position in the launch grid %or do we call this differently? +we prepend the IR with operations to get the \texttt{\footnotesize thread\_id} and \texttt{\footnotesize workgroup\_id} for the relevant dimensions. + +First, we determine to which level each dimension has to expand according to the input sizes and constraints. +% small example? +We start at the leaf nodes of the kernel and follow def-uses upward until we reach the kernel inputs. We each node we reach we: +\begin{enumerate} + \item Determine the dimensions this node needs to index in + \item \ldots +\end{enumerate} +% TODO: Preliminary, maybe better to just express this as an algorihm? + +% TODO: Can we produce a figure (or graph) of pre-expansion and post-expansion? Maybe only expansion in a single dimension if this gets too large otherwise? + + +% Thought: +% We name this programming model \emph{SIWT} (Single Instruction, Wave of Threads) denoting that a single instruction is executed by a wave of threads. This is inspired by the SIMT execution model where instructions are executed by all threads in lockstep. + + +\subsection{Instruction Scheduling} +% Describe the deep type of decisions we can take already on this level with the vast information we still have available + + +\subsection{Optimization Passes} +After tracing, the compiler executes a series of optimization passes on the computational graph, such as: +\textbf{barrier insertion} % We could give more specifics here on when we insert barriers +\ldots + +\subsection{Lowering to MLIR} + +% TODO: mention symbolic types, their optimization and lowering here +% Also mention for which sizes we use them: Tensor shapes, ... +% basically the sympy walker we have. + +The final stage of the compilation process involves lowering the optimized computational graph to several MLIR dialects: % TODO: possibly present this differently, for now a list is fine. +%We target the amdgpu and gpu dialects to model GPU + +% Possibly simplify if this is too specific +\begin{itemize} + \item Intrinsics used in the kernel are directly mapped to operations of the \texttt{\footnotesize amdgpu} dialect + \item The \texttt{\footnotesize gpu} dialect is used to model general GPU concepts, such as \texttt{\footnotesize thread\_id} + \item The \texttt{\footnotesize scf} dialect is used to model loops + \item The \texttt{\footnotesize llvm} dialect is used to emit scheduling barriers that preserve our instruction scheduling decisions + \item We use the \texttt{\footnotesize scf}, \texttt{\footnotesize arith}, and \texttt{\footnotesize math} dialects to model loops and arithmetic. + \item Furthermore we use the \texttt{\footnotesize memref} and \texttt{\footnotesize func} dialects +\end{itemize} + + + +% mention LLVM scheduling barriers + +\subsection{Integration} +TODO: briefly describe integration into Pytorch (compile to vmfb + call with torch tensors) \& IREE (\ldots)? + + + +\bigskip +In summary, the Wave compiler combines symbolic computation, multi-stage lowering, and GPU-specific optimizations to translate high-level Wave kernels into efficient, low-level code suitable for execution on GPUs. Its use of symbolic types and constraint-based programming model allows for powerful optimizations while maintaining a high level of abstraction for kernel developers. \section{Numerical Experiments} \label{section:numerical_experiments}