diff --git a/docs/source/assets/kernel/k_vecs.png b/docs/source/assets/kernel/k_vecs.png new file mode 100644 index 0000000000000..4b7be1385aa2e Binary files /dev/null and b/docs/source/assets/kernel/k_vecs.png differ diff --git a/docs/source/assets/kernel/key.png b/docs/source/assets/kernel/key.png new file mode 100644 index 0000000000000..2059b608caeaa Binary files /dev/null and b/docs/source/assets/kernel/key.png differ diff --git a/docs/source/assets/kernel/logits_vec.png b/docs/source/assets/kernel/logits_vec.png new file mode 100644 index 0000000000000..373eea45c23ad Binary files /dev/null and b/docs/source/assets/kernel/logits_vec.png differ diff --git a/docs/source/assets/kernel/q_vecs.png b/docs/source/assets/kernel/q_vecs.png new file mode 100644 index 0000000000000..f55b3742f3c6a Binary files /dev/null and b/docs/source/assets/kernel/q_vecs.png differ diff --git a/docs/source/assets/kernel/query.png b/docs/source/assets/kernel/query.png new file mode 100644 index 0000000000000..e2d15ebbfe26e Binary files /dev/null and b/docs/source/assets/kernel/query.png differ diff --git a/docs/source/assets/kernel/v_vec.png b/docs/source/assets/kernel/v_vec.png new file mode 100644 index 0000000000000..bac3c10949f6c Binary files /dev/null and b/docs/source/assets/kernel/v_vec.png differ diff --git a/docs/source/assets/kernel/value.png b/docs/source/assets/kernel/value.png new file mode 100644 index 0000000000000..f585c77b2e144 Binary files /dev/null and b/docs/source/assets/kernel/value.png differ diff --git a/docs/source/dev/kernel/paged_attention.rst b/docs/source/dev/kernel/paged_attention.rst new file mode 100644 index 0000000000000..6fcadeeec27b6 --- /dev/null +++ b/docs/source/dev/kernel/paged_attention.rst @@ -0,0 +1,525 @@ +vLLM Paged Attention +==================== + +- Currently, vLLM utilizes its own implementation of a multi-head query + attention kernel (``csrc/attention/attention_kernels.cu``). + This kernel is designed to be compatible with + vLLM's paged KV caches, where the key and value cache are stored in + separate blocks (note that this block concept differs from the GPU + thread block. So in a later document, I will refer to vLLM paged + attention block as "block", while refer to GPU thread block as + "thread block"). +- To achieve high performance, this kernel relies on a specially + designed memory layout and access method, specifically when threads + read data from global memory to shared memory. The purpose of this + document is to provide a high-level explanation of the kernel + implementation step by step, aiding those who wish to learn about the + vLLM multi-head query attention kernel. After going through this + document, users will likely have a better understanding and feel easier + to follow the actual implementation. +- Please note that this document may not cover all details, such as how + to calculate the correct index for the corresponding data or the dot + multiplication implementation. However, after reading this document + and becoming familiar with the high-level logic flow, it should be + easier for you to read the actual code and understand the details. + +Inputs +------ + +- The kernel function takes a list of arguments for the current thread + to perform its assigned work. The three most important arguments are + the input pointers ``q``, ``k_cache``, and ``v_cache``, which point + to query, key, and value data on global memory that need to be read + and processed. The output pointer ``out`` points to global memory + where the result should be written. These four pointers actually + refer to multi-dimensional arrays, but each thread only accesses the + portion of data assigned to it. I have omitted all other runtime + parameters here for simplicity. + + .. code:: cpp + + template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS, + int PARTITION_SIZE = 0> + __device__ void paged_attention_kernel( + ... // Other side args. + const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + ... // Other side args. + ) + +- There are also a list of template arguments above the function + signature that are determined during compilation time. ``scalar_t`` + represents the data type of the query, key, and value data elements, + such as FP16. ``HEAD_SIZE`` indicates the number of elements in each + head. ``BLOCK_SIZE`` refers to the number of tokens in each block. + ``NUM_THREADS`` denotes the number of threads in each thread block. + ``PARTITION_SIZE`` represents the number of tensor parallel GPUs (For + simplicity, we assume this is 0 and tensor parallel is disabled). +- With these arguments, we need to perform a sequence of preparations. + This includes calculating the current head index, block index, and + other necessary variables. However, for now, we can ignore these + preparations and proceed directly to the actual calculations. It will + be easier to understand them once we grasp the entire flow. + +Concepts +-------- + +- Just before we dive into the calculation flow, I want to describe a + few concepts that are needed for later sections. However, you may + skip this section and return later if you encounter any confusing + terminologies. +- **Sequence**: A sequence represents a client request. For example, + the data pointed to by ``q`` has a shape of + ``[num_seqs, num_heads, head_size]``. That represents there are total + ``num_seqs`` of query sequence data are pointed by ``q``. Since this + kernel is a single query attention kernel, each sequence only has one + query token. Hence, the ``num_seqs`` equals the total number of tokens + that are processed in the batch. +- **Context**: The context consists of the generated tokens from the + sequence. For instance, ``["What", "is", "your"]`` are the context + tokens, and the input query token is ``"name"``. The model might + generate the token ``"?"``. +- **Vec**: The vec is a list of elements that are fetched and + calculated together. For query and key data, the vec size + (``VEC_SIZE``) is determined so that each thread group can fetch and + calculate 16 bytes of data at a time. For value data, the vec size + (``V_VEC_SIZE``) is determined so that each thread can fetch and + calculate 16 bytes of data at a time. For example, if the + ``scalar_t`` is FP16 (2 bytes) and ``THREAD_GROUP_SIZE`` is 2, the + ``VEC_SIZE`` will be 4, while the ``V_VEC_SIZE`` will be 8. +- **Thread group**: The thread group is a small group of + threads(\ ``THREAD_GROUP_SIZE``) that fetches and calculates one + query token and one key token at a time. Each thread handles only a + portion of the token data. The total number of elements processed by + one thread group is referred as ``x``. For example, if the thread + group contains 2 threads and the head size is 8, then thread 0 + handles the query and key elements at index 0, 2, 4, 6, while thread + 1 handles the elements at index 1, 3, 5, 7. +- **Block**: The key and value cache data in vLLM are split into + blocks. Each block stores data for a fixed number(\ ``BLOCK_SIZE``) + of tokens at one head. Each block may contain only a portion of the + whole context tokens. For example, if the block size is 16 and the + head size is 128, then for one head, one block can store 16 \* 128 = + 2048 elements. +- **Warp**: A warp is a group of 32 threads(\ ``WARP_SIZE``) that + execute simultaneously on a stream multiprocessor (SM). In this + kernel, each warp processes the calculation between one query token + and key tokens of one entire block at a time (it may process multiple + blocks in multiple iterations). For example, if there are 4 warps and + 6 blocks for one context, the assignment would be like warp 0 handles + the 0th, 4th blocks, warp 1 handles the 1st, 5th blocks, warp 2 + handles the 2nd block and warp 3 handles the 3rd block. +- **Thread block**: A thread block is a group of + threads(\ ``NUM_THREADS``) that can access the same shared memory. + Each thread block contains multiple warps(\ ``NUM_WARPS``), and in + this kernel, each thread block processes the calculation between one + query token and key tokens of a whole context. +- **Grid**: A grid is a collection of thread blocks and defines the + shape of the collection. In this kernel, the shape is + ``(num_heads, num_seqs, max_num_partitions)``. Therefore, each thread + block only handles the calculation for one head, one sequence, and + one partition. + +Query +----- + +- This section will introduce how query data is stored in memory and + fetched by each thread. As mentioned above, each thread group fetches + one query token data, while each thread itself only handles a part of + one query token data. Within each warp, every thread group will fetch + the same query token data, but will multiply it with different key + token data. + + .. code:: cpp + + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + + .. figure:: ../../assets/kernel/query.png + :alt: query + :width: 70% + :align: center + + Query data of one token at one head + +- Each thread defines its own ``q_ptr`` which points to the assigned + query token data on global memory. For example, if ``VEC_SIZE`` is 4 + and ``HEAD_SIZE`` is 128, the ``q_ptr`` points to data that contains + total of 128 elements divided into 128 / 4 = 32 vecs. + + .. figure:: ../../assets/kernel/q_vecs.png + :alt: q_vecs + :width: 70% + :align: center + + ``q_vecs`` for one thread group + + .. code:: cpp + + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; + +- Next, we need to read the global memory data pointed to by ``q_ptr`` + into shared memory as ``q_vecs``. It is important to note that each + vecs is assigned to a different row. For example, if the + ``THREAD_GROUP_SIZE`` is 2, thread 0 will handle the 0th row vecs, + while thread 1 handles the 1st row vecs. By reading the query data in + this way, neighboring threads like thread 0 and thread 1 can read + neighbor memory, achieving the memory coalescing to improve + performance. + +Key +--- + +- Similar to the "Query" section, this section introduces memory layout + and assignment for keys. While each thread group only handle one + query token one kernel run, it may handle multiple key tokens across + multiple iterations. Meanwhile, each warp will process multiple blocks + of key tokens in multiple iterations, ensuring that all context + tokens are processed by the entire thread group after the kernel run. + In this context, "handle" refers to performing the dot multiplication + between query data and key data. + + .. code:: cpp + + const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + physical_block_offset * x; + +- Unlike to ``q_ptr``, ``k_ptr`` in each thread will point to different + key token at different iterations. As shown above, that ``k_ptr`` + points to key token data based on ``k_cache`` at assigned block, + assigned head and assigned token. + + .. figure:: ../../assets/kernel/key.png + :alt: key + :width: 70% + :align: center + + Key data of all context tokens at one head + +- The diagram above illustrates the memory layout for key data. It + assumes that the ``BLOCK_SIZE`` is 16, ``HEAD_SIZE`` is 128, ``x`` is + 8, ``THREAD_GROUP_SIZE`` is 2, and there are a total of 4 warps. Each + rectangle represents all the elements for one key token at one head, + which will be processed by one thread group. The left half shows the + total 16 blocks of key token data for warp 0, while the right half + represents the remaining key token data for other warps or + iterations. Inside each rectangle, there are a total 32 vecs (128 + elements for one token) that will be processed by 2 threads (one + thread group) separately. + + .. figure:: ../../assets/kernel/k_vecs.png + :alt: k_vecs + :width: 70% + :align: center + + ``k_vecs`` for one thread + + .. code:: cpp + + K_vec k_vecs[NUM_VECS_PER_THREAD] + +- Next, we need to read the key token data from ``k_ptr`` and store + them on register memory as ``k_vecs``. We use register memory for + ``k_vecs`` because it will only be accessed by one thread once, + whereas ``q_vecs`` will be accessed by multiple threads multiple + times. Each ``k_vecs`` will contain multiple vectors for later + calculation. Each vec will be set at each inner iteration. The + assignment of vecs allows neighboring threads in a warp to read + neighboring memory together, which again promotes the memory + coalescing. For instance, thread 0 will read vec 0, while thread 1 + will read vec 1. In the next inner loop, thread 0 will read vec 2, + while thread 1 will read vec 3, and so on. +- You may still be a little confused about the overall flow. Don't + worry, please keep reading the next "QK" section. It will illustrate + the query and key calculation flow in a clearer and higher-level + manner. + +QK +--- + +- As shown the pseudo code below, before the entire for loop block, we + fetch the query data for one token and store it in ``q_vecs``. Then, + in the outer for loop, we iterate through different ``k_ptrs`` that + point to different tokens and prepare the ``k_vecs`` in the inner for + loop. Finally, we perform the dot multiplication between the + ``q_vecs`` and each ``k_vecs``. + + .. code:: cpp + + q_vecs = ... + for ... { + k_ptr = ... + for ... { + k_vecs[i] = ... + } + ... + float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + } + +- As mentioned before, for each thread, it only fetches part of the + query and key token data at a time. However, there will be a cross + thread group reduction happen in the ``Qk_dot<>::dot`` . So ``qk`` + returned here is not just between part of the query and key token dot + multiplication, but actually a full result between entire query and + key token data. +- For example, if the value of ``HEAD_SIZE`` is 128 and + ``THREAD_GROUP_SIZE`` is 2, each thread's ``k_vecs`` will contain + total 64 elements. However, the returned ``qk`` is actually the + result of dot multiplication between 128 query elements and 128 key + elements. If you want to learn more about the details of the dot + multiplication and reduction, you may refer to the implementation of + ``Qk_dot<>::dot``. However, for the sake of simplicity, I will not + cover it in this document. + +Softmax +------- + +- Next, we need to calculate the normalized softmax for all ``qk``\ s, + as shown above, where each :math:`x` represents a ``qk``. To do this, + we must obtain the reduced value of ``qk_max``\ (:math:`m(x)`) and + the ``exp_sum``\ (:math:`\ell(x)`) of all ``qk``\ s. The reduction + should be performed across the entire thread block, encompassing + results between the query token and all context key tokens. + + .. math:: + :nowrap: + + \begin{gather*} + m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\ + \quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)} + \end{gather*} + +``qk_max`` and ``logits`` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Just right after we get the ``qk`` result, we can set the temporary + ``logits`` result with ``qk`` (In the end, the ``logits`` should + store the normalized softmax result). Also we can compare and collect + the ``qk_max`` for all ``qk``\ s that are calculated by current + thread group. + + .. code:: cpp + + if (thread_group_offset == 0) { + const bool mask = token_idx >= context_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + +- Please note that the ``logits`` here is on shared memory, so each + thread group will set the fields for its own assigned context tokens. + Overall, the size of logits should be number of context tokens. + + .. code:: cpp + + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + +- Then we need to get the reduced ``qk_max`` across each warp. The main + idea is to make threads in warp to communicate with each other and + get the final max ``qk`` . + + .. code:: cpp + + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + qk_max = VLLM_SHFL_SYNC(qk_max, 0); + +- Finally, we can get the reduced ``qk_max`` from whole thread block by + compare the ``qk_max`` from all warps in this thread block. Then we + need to broadcast the final result to each thread. + +``exp_sum`` +~~~~~~~~~~~ + +- Similar to ``qk_max``, we need to get the reduced sum value from the + entire thread block too. + + .. code:: cpp + + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + ... + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + +- Firstly, sum all exp values from each thread group, and meanwhile, + convert each entry of ``logits`` from ``qk`` to ``exp(qk - qk_max)``. + Please note, the ``qk_max`` here is already the max ``qk`` across the + whole thread block. And then we can do reduction for ``exp_sum`` + across whole thread block just like the ``qk_max``. + + .. code:: cpp + + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + +- Finally, with the reduced ``qk_max`` and ``exp_sum``, we can obtain + the final normalized softmax result as ``logits``. This ``logits`` + variable will be used for dot multiplication with the value data in + later steps. Now, it should store the normalized softmax result of + ``qk`` for all assigned context tokens. + +Value +----- + +.. figure:: ../../assets/kernel/value.png + :alt: value + :width: 70% + :align: center + + Value data of all context tokens at one head + +.. figure:: ../../assets/kernel/logits_vec.png + :alt: logits_vec + :width: 50% + :align: center + + ``logits_vec`` for one thread + +.. figure:: ../../assets/kernel/v_vec.png + :alt: v_vec + :width: 70% + :align: center + + List of ``v_vec`` for one thread + +- Now we need to retrieve the value data and perform dot multiplication + with ``logits``. Unlike query and key, there is no thread group + concept for value data. As shown in diagram, different from key token + memory layout, elements from the same column correspond to the same + value token. For one block of value data, there are ``HEAD_SIZE`` of + rows and ``BLOCK_SIZE`` of columns that are split into multiple + ``v_vecs``. +- Each thread always fetches ``V_VEC_SIZE`` elements from the same + ``V_VEC_SIZE`` of tokens at a time. As a result, a single thread + retrieves multiple ``v_vec``\ s from different rows and the same + columns through multiple inner iterations. For each ``v_vec``, it + needs to be dot multiplied with the corresponding ``logits_vec``, + which is also ``V_VEC_SIZE`` elements from ``logits``. Overall, with + multiple inner iterations, each warp will process one block of value + tokens. And with multiple outer iterations, the whole context value + tokens are processd + + .. code:: cpp + + float accs[NUM_ROWS_PER_THREAD]; + for ... { // Iteration over different blocks. + logits_vec = ... + for ... { // Iteration over different rows. + v_vec = ... + ... + accs[i] += dot(logits_vec, v_vec); + } + } + +- As shown in the above pseudo code, in the outer loop, similar to + ``k_ptr``, ``logits_vec`` iterates over different blocks and reads + ``V_VEC_SIZE`` elements from ``logits``. In the inner loop, each + thread reads ``V_VEC_SIZE`` elements from the same tokens as a + ``v_vec`` and performs dot multiplication. It is important to note + that in each inner iteration, the thread fetches different head + position elements for the same tokens. The dot result is then + accumulated in ``accs``. Therefore, each entry of ``accs`` is mapped + to a head position assigned to the current thread. +- For example, if ``BLOCK_SIZE`` is 16 and ``V_VEC_SIZE`` is 8, each + thread fetches 8 value elements for 8 tokens at a time. Each element + is from different tokens at the same head position. If ``HEAD_SIZE`` + is 128 and ``WARP_SIZE`` is 32, for each inner loop, a warp needs to + fetch ``WARP_SIZE * V_VEC_SIZE = 256`` elements. This means there are + a total of 128 \* 16 / 256 = 8 inner iterations for a warp to handle + a whole block of value tokens. And each ``accs`` in each thread + contains 8 elements that accumulated at 8 different head positions. + For the thread 0, the ``accs`` variable will have 8 elements, which + are 0th, 16th … 112th elements of a value head that are accumulated + from all assigned 8 tokens. + +LV +--- +- Now, we need to perform reduction for ``accs`` within each warp. This + process allows each thread to accumulate the ``accs`` for the + assigned head positions of all tokens in one block. + + .. code:: cpp + + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += VLLM_SHFL_XOR_SYNC(acc, mask); + } + accs[i] = acc; + } + +- Next, we perform reduction for ``accs`` across all warps, allowing + each thread to have the accumulation of ``accs`` for the assigned + head positions of all context tokens. Please note that each ``accs`` + in every thread only stores the accumulation for a portion of + elements of the entire head for all context tokens. However, overall, + all results for output have been calculated but are just stored in + different thread register memory. + + .. code:: cpp + + float* out_smem = reinterpret_cast(shared_mem); + for (int i = NUM_WARPS; i > 1; i /= 2) { + // Upper warps write to shared memory. + ... + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + ... + dst[row_idx] = accs[i]; + } + + // Lower warps update the output. + const float* src = &out_smem[warp_idx * HEAD_SIZE]; + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + ... + accs[i] += src[row_idx]; + } + + // Write out the accs. + } + +Output +------ + +- Now we can write all of calculated result from local register memory + to final output global memory. + + .. code:: cpp + + scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + +- First, we need to define the ``out_ptr`` variable, which points to + the start address of the assigned sequence and assigned head. + + .. code:: cpp + + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + +- Finally, we need to iterate over different assigned head positions + and write out the corresponding accumulated result based on the + ``out_ptr``. diff --git a/docs/source/index.rst b/docs/source/index.rst index e90481845c4ff..c0250bf99f7ae 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -98,6 +98,7 @@ Documentation :caption: Developer Documentation dev/engine/engine_index + dev/kernel/paged_attention Indices and tables ==================