Skip to content

Commit

Permalink
Update documentation (#199)
Browse files Browse the repository at this point in the history
This PR adds more information about the language.

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod authored and saienduri committed Oct 28, 2024
1 parent 14ef6fe commit e5e77e9
Showing 1 changed file with 183 additions and 22 deletions.
205 changes: 183 additions & 22 deletions iree/turbine/kernel/wave/docs/mlsys/tkw.tex
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,62 @@
% Use the following line for the initial blind version submitted for review:
\usepackage{mlsys2024}

% For code listings
\usepackage{sourcecodepro}
\usepackage[T1]{fontenc}
\usepackage{listings}
\usepackage[dvipsnames]{xcolor}
\definecolor{commentgreen}{RGB}{2,112,10}
\definecolor{eminence}{RGB}{108,48,130}
\definecolor{weborange}{RGB}{255,165,0}
\definecolor{frenchplum}{RGB}{129,20,83}
%Define Colors
\definecolor{gray}{RGB}{102,102,102} %#666666
\definecolor{lightblue}{RGB}{0,102,153} %#006699
\definecolor{lightgreen}{RGB}{102,153,0} %#669900
\definecolor{bluegreen}{RGB}{51,153,126} %#33997e
\definecolor{magenta}{RGB}{217,74,122} %#d94a7a
\definecolor{orange}{RGB}{226,102,26} %#e2661a
\definecolor{purple}{RGB}{125,71,147} %#7d4793
\definecolor{green}{RGB}{113,138,98} %#718a62

\usepackage[framemethod=tikz]{mdframed}
\lstdefinelanguage{Wave}{
language=Python,
classoffset=1,
morekeywords={WorkgroupConstraint, TilingConstraint, WaveConstraint, HardwareConstraint},
keywordstyle=\color{lightblue},
classoffset=2,
morekeywords={Memory, Register},
keywordstyle=\color{lightgreen},
classoffset=3,
morekeywords={reduction, read, write, mma},
keywordstyle=\color{magenta},
classoffset=4,
morekeywords={@wave, @reduction},
keywordstyle=\color{orange},
sensitive=false, % keywords are not case-sensitive
}

\lstset{
language={Wave},
basicstyle={\scriptsize\ttfamily},
identifierstyle={\scriptsize\ttfamily},
commentstyle={\scriptsize\itshape\ttfamily},
keywordstyle={\scriptsize\bfseries\ttfamily},
ndkeywordstyle={\scriptsize\ttfamily},
stringstyle={\scriptsize\ttfamily},
frame={tb},
breaklines=true,
columns=[l]{fullflexible},
xrightmargin=0em,
xleftmargin=0em,
numberstyle={\scriptsize},
stepnumber=1,
numbersep=1em,
lineskip=-0.5ex,
}

% If accepted, instead use the following line for the camera-ready submission:
% \usepackage[accepted]{mlsys2024}

Expand All @@ -30,7 +86,7 @@
\begin{document}

\twocolumn[
\mlsystitle{Wave : A Python DSL for High Performance Machine Learning}
\mlsystitle{Wave : A Symbolic Python DSL and Compiler for High Performance Machine Learning}

% It is OKAY to include author information, even for blind
% submissions: the style file will automatically remove it for you
Expand Down Expand Up @@ -116,38 +172,140 @@ \section{Introduction}
on vendor-specific libraries such as cuDNN \cite{chetlur_cudnn_2014} to achieve high performance.
These libraries are performant but are black boxes consisting of
hand-written kernels and often do not support the full set of
fused operators encountered in machine learning models.
operators encountered in machine learning models.
To address these limitations, recent work has focused on developing
Python domain specific languages (DSL) that allow developers to get high performance
while reducing the kernel complexity. Triton \cite{tillet_triton_2019}.
is a popular Python DSL that exposes a workgroup level programming
model and allows developers to author high performance kernels.
However, Triton kernels often get quite complex and start to
resemble hand-written kernels as the kernel complexity grows.
Furthermore, fusion of Triton kernels is limited to a few operators
and remains an open problem.

In this paper, we introduce Wave, a Python DSL for high performance machine learning.
Wave exposes a subgroup (wave or warp) level programming model that allows
for much simpler kernels compared to Triton. Through the use of constraints, Wave forces developers to
come up with the distribution strategy for their kernel -
which dimensions are parallel and which are sequential and how to distribute those
dimensions across the memory and compute hierarchy of the GPU. This allows for a separation
between the kernel and the distribution strategy and makes the kernel simpler.
Wave also embraces symbolic data types using sympy to represent the shapes and
memory access patterns of tensors in the kernel.
In the programmability versus performance tradeoff, Triton demonstrated that it is possible to
achieve high performance while maintaining a high level of programmability. However,
Triton kernels often get quite complex as the kernel complexity grows. Most of this complexity
comes from exposing a pointer based approach to access and manipulate memory.
\\ \\
In this paper, we introduce Wave, a Python DSL and compiler for high performance machine learning.
Wave exposes a subgroup (wave or warp) level programming model and uses constraints
to specify the distribution strategy for the kernel. This allows for a separation between
the kernel and distribution strategy and results in simpler kernels. The language
and compiler make extensive use of symbolic data types to represent tensor shapes and memory access patterns
that make it easier to reason about the kernel. We demonstrate that Wave can achieve competitive performance
with Triton and hand-tuned kernels on core machine learning operators such as matrix multiplication,
convolutions and attention.
In summary, the contributions of this paper are as follows:
\begin{itemize}
\item \textbf{Wave language} (Section \ref{section:wave_language}): A Python DSL that exposes a subgroup programming model for GPUs. The language
defines constraints that separate distribution strategies from the description of the core computation. Tensor shapes and address spaces
are represented using symbolic types (using sympy).
\item \textbf{Wave compiler} (Section \ref{section:wave_compiler}): A Python compiler that uses symbolic types
to represent memory access patterns and reason about them. The compiler uses torch.fx to trace the kernel,
then runs a series of compiler optimization passes and finally lowers the computation graph to MLIR and LLVM.
\item \textbf{Numerical Experiments} (Section \ref{section:numerical_experiments}): Numerical experiments on
matrix multiplication, convolutions and attention that demonstrate the performance of Wave kernels and show
that it is on par with existing DSLs and hand-tuned libraries.

\end{itemize}

\section{Wave Language}
In this section, we will go through the Wave language and its features using matrix multiplication as an example. See Listing \ref{lst:gemm} for the full code listing.

\label{section:wave_language}
\subsection{Wave Programming Model}
Wave programs follow the single-program multiple data (SPMD) programming model where the kernel is written
at the level of execution of a single wave or warp. While waves are core to how programs are executed on
GPUs, most GPU languages do not expose them directly to the developer. CUDA, for example, allows workgroup
and thread level programming but does not expose the wave level while Triton only exposes workgroup level programming.
The advantages of the wave programming model are that it allows developers to write kernels at the same level of
abstraction as the native hardware matrix multiply accumulate (MMA) instructions which operate at the granularity of waves
giving them low-level control from a high-level abstraction.


\subsection{Syntax \& Semantics}
The Wave language partitions programs into two distinct regions as can be seen in Listing \ref{lst:gemm}.
The first part of the program consists of constraints which are new constructs introduced by the language.

\subsubsection{Constraints}
Constraints are used to represent the distribution strategy of a kernel. Each constraint operates on a particular
dimension and specifies how that dimension is to be distributed. In the matrix multiplication example of Listing \ref{lst:gemm},
the \texttt{WorkgroupConstraint} on symbolic dimension \texttt{M/N} states that the \texttt{M/N} dimension is distributed among work group dimension 0/1
with a tile size of \texttt{BLOCK\_M/BLOCK\_N}. The \texttt{WaveConstraint} on the \texttt{M/N} dimension states that the $M/N$ dimension is then further distributed among waves
with a tile size of \texttt{BLOCK\_M / 2 / BLOCK\_N / 2}. The \texttt{TilingConstraint} on \texttt{K} specifies that the \texttt{K} dimension is tiled with a tile size of \texttt{BLOCK\_K}
in a sequential for loop. Finally, the \texttt{HardwareConstraint} specifies hardware specific parameters such as the number of threads per wave, the number of waves per block and
the canonical shape of the program.
\\ \\
The canonical shape of the program specifies the minimum granularity of the operations in the program. In the matrix multiplication example and for programs using MMA instructions, the canonical shape is
the shape of the MMA instruction which is \texttt{M = 16, N = 16, K = 16}. For programs that do not use MMA instructions, users can
explicitly specify the canonical shape of the program in the \texttt{HardwareConstraint} by using the \texttt{vector\_shapes} keyword. For more examples of this,
see Appendix \ref{appendix:samples}.
\\ \\
The constraints of the program serve multiple purposes. First, then separate out the distribution strategy from the kernel.
Second, they result in much simpler kernels because kernel authors do not need to keep track the offsets to different memory locations
as the compiler takes care of this.

\subsubsection{Kernel}
The kernel is the second part of the program and consists of the core computation. It is annotated with the \texttt{@wave} decorator which
is used by the compiler to trace the kernel. In the matrix multiplication example, the inputs to the kernel are of type \texttt{Memory} which
represents a memory buffer with a symbolic shape and address space (shared memory or global memory). In the kernel,
even though the inputs are specified as type \texttt{Memory} with shape \texttt{[M, K], [N, K], [M, N]}, the actual shape of the memory
is determined by the constraints. In order to simplify the kernel, we write the kernel using the original symbolic shapes and let the compiler
determine the actual shapes of the memory buffers.


\begin{lstlisting}[language=Wave, frame=single, breaklines, caption={Mixed-precision $C = A \times B^{T}$ expressed in Wave.}, captionpos=b, label={lst:gemm}]
constraints = [WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [TilingConstraint(K, BLOCK_K)]
constraints += [WaveConstraint(M, BLOCK_M / 2)]
constraints += [WaveConstraint(N, BLOCK_N / 2)]
constraints += [
HardwareConstraint(threads_per_wave=64,
waves_per_block=(2, 2, 1),
mma_type=MMAType.F32_16x16x16_F16
]

@wave(constraints)
def gemm(
a: Memory[M, K, ADDRESS_SPACE, f16],
b: Memory[N, K, ADDRESS_SPACE, f16],
c: Memory[M, N, GLOBAL_ADDRESS_SPACE, f32],
):
c_reg = Register[M, N, f32](0.0)

@reduction(K, init_args=[c_reg])
def loop(acc: Register[M, N, f32]) -> Register[M, N, f32]:
a_reg = read(a, elements_per_thread=ELEMS)
b_reg = read(b, elements_per_thread=ELEMS)
acc = mma(a_reg, b_reg, acc)
return acc

write(loop, c, elements_per_thread=ELEMS)
\end{lstlisting}

\section{Wave Compiler}
\label{section:wave_compiler}

\section{Numerical Experiments}
\label{section:numerical_experiments}

\section{Related Work}
\label{section:related_work}

\section{Conclusions \& Future Work}
\label{section:conclusions}

\section{Acknowledgements}
\label{section:acknowledgements}

\section{Appendix: Sample Wave Programs}
\label{section:samples}


\iffalse
It has a Python based compiler that uses torch.fx tracing to define
and trace operators written in the language. The torch.fx graphs are then run through a series of optimization passes
on the computation graph and are finally lowered to MLIR and subsequently LLVM. This code generation flow allows compiler writers
to blend high productivity in Python with high performance from the MLIR and LLVM
code generation flow.
\\ \\
In summary, the contributions of this paper are as follows:
\begin{itemize}
\item A novel subgroup programming model for GPU with a Python DSL that separates distribution strategies from the core kernel allowing for simpler kernels,
\item A symbolic data type system that allows for reasoning about tensor shapes and memory access patterns in the kernel,
\item A Python compiler that leverages torch.fx for tracing and maps torch.fx graphs to MLIR and LLVM for high performance code generation.
\end{itemize}


\section{Memory Access Patterns}
Expand All @@ -167,6 +325,9 @@ \section{Memory Access Patterns}
access patterns with the determination of which access pattern to use based on
the minimization of an appropriate metric across the entire graph (see Section 3).

\fi

\newpage

\iffalse
\section{Electronic Submission}
Expand Down

0 comments on commit e5e77e9

Please sign in to comment.