From 8ce79b95706a2efe769183ac5f360da3fefadb0e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 10 Oct 2023 15:10:03 -0400 Subject: [PATCH] [Unity] Paged KV Cache as LM Support This PR introduces the PagedKVCache object to Relax runtime for the KV cache value management in batching settings. One test file is included. Note that this file does not contain the test of attention function/kernel. That part will be uploaded and tested separately. --- src/runtime/relax_vm/paged_kv_cache.cc | 661 ++++++++++++++++++ ...untime_builtin_paged_attention_kv_cache.py | 421 +++++++++++ 2 files changed, 1082 insertions(+) create mode 100644 src/runtime/relax_vm/paged_kv_cache.cc create mode 100644 tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc new file mode 100644 index 0000000000000..f3ee9b7acf086 --- /dev/null +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -0,0 +1,661 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/runtime/relax_vm/paged_kv_cache.cc + * \brief Runtime paged KV cache object for language models. + */ +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +//------------------------------------------- +// We keep the implementation private as +// they may subject to future changes. +// +// Users can interact with it through the +// runtime API function calls +//------------------------------------------- + +/*! + * \brief The paged KV cache for attention. + * - It supports managing the K/V data of **multiple sequences**. + * - It manages K/V values by doing paging along the sequence-length + * dimension with a configured page size. + * - To add a sequence to the cache, use AddSequence. + * - The basic example use of the paged KV cache after initialization + * in each round of model forwarding is the following: + * - step 1. use `ResetAppendLengths` to reset the appending information + * for preparation, + * - step 2. use `ReserveExtraLengthForAppend` to specify the length + * of K/V data to be appended for each sequence, + * - step 3. use `SyncAuxArrayToDevice` to synchronize auxiliary arrays + * to device for append/attention computation, + * - step 4. for each layer, use `Append` to append the K/V data to the + * cache, and then use `Attention` to compute attention results with + * Q data. + */ +class PagedAttentionKVCacheObj : public Object { + private: + /*! \brief The total number of sequences managed in the KV cache. */ + int64_t num_total_seqs_ = 0; + /*! \brief The number of pages that are in use by the sequences. */ + int64_t num_pages_in_use_ = 0; + /*! + * \brief The number of allocated pages, including the in-use pages + * and the pages released due to sequence removal. + */ + int64_t num_pages_allocated_ = 0; + + /********************* Configuration *********************/ + + /*! \brief The page size (the sequence length each page manages) of the cache. */ + const int64_t page_size_; + /*! \brief The number of layers in the model. */ + const int64_t num_layers_; + /*! \brief The number of heads in the model. */ + const int64_t num_heads_; + /*! \brief The number of features each head has. */ + const int64_t head_dim_; + + /*! \brief We fix int32 to be the index dtype of auxiliary data. */ + const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1)); + + /********************* Page Structures *********************/ + + /*! + * \brief The KV data managed by the KV cache. + * It has layout (num_pages, num_layers, 2, num_heads, page_size, head_dim). + * Along on the "2" dimension, index 0 stands for K and 1 stands for V. + */ + NDArray pages_; + /*! \brief The list of ids of released pages for page reuse. */ + std::vector free_page_ids_; + + /*! \brief The list of page ids assigned for each sequence in the cache. */ + std::vector> page_table_; + /*! \brief The lengths of each sequence in the cache. */ + std::vector seq_lengths_; + + /********************* Current Batch Info *********************/ + + /*! + * \brief The current lengths to append for each sequence. + * - The new K/V data appended to the cache must have the same length + * as stored in this array. + * - The Q data passed in for attention must also have the same length + * as stored. + * \note Invoke "ResetAppendLengths" to reset this array to all-zero. + */ + std::vector cur_append_lengths_; + + /********************* Auxiliary Arrays on Device *********************/ + //------------------------------------------- + // The following fields are auxiliary arrays on device. + // All of them are directly derivable from the fields above. + // We store them for efficient execution of attentions, + // cache append, etc. + //------------------------------------------- + /*! + * \brief A boolean flag indicating if the auxiliary arrays are dirty. + * If it is dirty, an explicit "SyncAuxArrayToDevice" should be invoked. + */ + bool dirty_aux_data_device_ = false; + /*! + * \brief The page table indptr array on device. + * \note Since page table is a ragged data structure, we represent it + * in CSR format (which uses indptr and values below) on device. + */ + NDArray page_table_indptr_device_; + /*! \brief The page table value array on device. */ + NDArray page_table_values_device_; + /*! + * \brief The array storing "the number of used slots in the last page + * of each sequence" on device. Its values range in (0, page_size_]. + */ + NDArray last_page_offset_device_; + /*! + * \brief The append_length indptr array on device. + * \note Since the Q/K/V data may have raggedness in terms of lengths, + * we represent the the append lengths in CSR format. + */ + NDArray cur_append_length_indptr_device_; + /*! + * \brief The corresponding sequence id for each position along the + * length dimension of K/V data. It is used for efficient computation. + */ + NDArray cur_pos2seqid_device_; + /*! \brief A temporary buffer for efficient attention computation. */ + NDArray attn_tmp_buffer_; + + //------------------------------------------- + // For efficient memory management, the actual sizes of the arrays + // above are over allocated. + // We create a view for the actual shapes of each of the arrays + // after each synchronization and pass these views as input for + // attention/append. + //------------------------------------------- + NDArray page_table_indptr_view_; + NDArray page_table_values_view_; + NDArray last_page_offset_view_; + NDArray cur_append_length_indptr_view_; + NDArray cur_pos2seqid_view_; + + public: + /*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */ + explicit PagedAttentionKVCacheObj(int64_t page_size, int64_t num_layers, int64_t num_heads, + int64_t head_dim, int64_t reserved_num_seqs, + int64_t reserved_num_pages, DLDataType dtype, DLDevice device) + : page_size_(page_size), num_layers_(num_layers), num_heads_(num_heads), head_dim_(head_dim) { + pages_ = NDArray::Empty({reserved_num_pages, num_layers, 2, num_heads, page_size, head_dim}, + dtype, device); + page_table_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); + page_table_values_device_ = NDArray::Empty({reserved_num_pages}, dtype_aux_, device); + last_page_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device); + cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); + cur_pos2seqid_device_ = NDArray::Empty({reserved_num_pages * page_size}, dtype_aux_, device); + attn_tmp_buffer_ = NDArray::Empty({8 * 1024 * 1024}, DLDataType(DataType::Float(32)), device); + } + + /*! + * \brief Add a sequence to the KV cache. Returns the sequence id. + * \note This method adds a new sequence with **initial length zero**. + * Call `ReserveExtraLengthForAppend` and `Append` afterwards to reserve + * the append length and append data to the cache. + * \returns The id of the new sequence. + */ + int64_t AddSequence() { + page_table_.push_back({}); + seq_lengths_.push_back(0); + cur_append_lengths_.push_back(0); + int64_t seq_id = num_total_seqs_++; + return seq_id; + } + + /*! + * \brief Given a sequence id and a required extra length, allocate new + * new pages for the sequence until the total capacity can cover the + * current sequence length plus the required extra length. + * \note By reserving extra length for the sequence, the subsequent appended + * K/V data for the sequence must have the same length as the input length here. + * \param seq_id The id of the sequence to process. + * \param extra_length The extra length to reserve for the sequence. + */ + void ReserveExtraLengthForAppend(int64_t seq_id, int64_t extra_length) { + CHECK_GE(seq_id, 0) << "Input sequence id should be positive"; + CHECK_LT(seq_id, num_total_seqs_) + << "Invalid input sequence id " << seq_id << ", which is out of the range of [0, " + << num_total_seqs_ << ")."; + CHECK_GT(extra_length, 0) << "The input length should be positive."; + + // The reservation is based on the current sequence length. + // If "current sequence + input extra length" does not exceed the + // current capacity (number of pages * page size), no action is taken. + int64_t cur_npage = page_table_[seq_id].size(); + int64_t tgt_npage = (seq_lengths_[seq_id] + extra_length + page_size_ - 1) / page_size_; + for (int64_t page_idx = cur_npage; page_idx < tgt_npage; ++page_idx) { + AllocatePageForSequence(seq_id); + } + seq_lengths_[seq_id] += extra_length; + cur_append_lengths_[seq_id] += extra_length; + dirty_aux_data_device_ = true; + } + + /*! + * \brief The entrance of attention. It takes an attention compute + * function together with the query data, invokes the attention + * function to complete the computation. + * \param f_attention The input attention compute function. + * \param q_data The query data. We support the following two layout settings: + * - in batch decode settings, q_data has layout (num_total_seqs, 1, num_heads, head_dim) + * where num_total_seqs should exactly equal to the number of sequences in the cache. + * - in other settings (single-sequence prefill, batch prefill, or speculation + * verification), q_data has the **flattened layout** to handle raggedness: + * (1, total query length, num_heads, head_dim). + * \param layer_id The model layer index of the current attention. + * \param output The attention output array. + * \param apply_rotary A boolean flag indicating if to apply RoPE to Q and K + * in attention computation. + * \param rotary_scale The RoPE scale if applicable. + * \param rotary_theta The RoPE theta if applicable. + * \note As a TODO item, we could move the rope related parameters to the + * f_attention closure in the future. + */ + void Attention(PackedFunc f_attention, NDArray q_data, int64_t layer_id, NDArray output, + bool apply_rotary = false, double rotary_scale = 1.0f, double rotary_theta = 1e4) { + // Check q_data shape validity. + CHECK_EQ(q_data->ndim, 4); + CHECK_GT(q_data->shape[1], 0); + CHECK_EQ(q_data->shape[2], num_heads_); + CHECK_EQ(q_data->shape[3], head_dim_); + CHECK(q_data.DataType() == pages_.DataType()); + + if (q_data->shape[0] > 1) { + CHECK_EQ(q_data->shape[0], num_total_seqs_); + CHECK_EQ(q_data->shape[1], 1); + } + int64_t ntoken = 0; + for (int64_t seq_id = 0; seq_id < num_total_seqs_; ++seq_id) { + ntoken += cur_append_lengths_[seq_id]; + CHECK_LE(cur_append_lengths_[seq_id], seq_lengths_[seq_id]); + if (q_data->shape[0] > 1) { + CHECK_EQ(cur_append_lengths_[seq_id], 1); + } + } + CHECK_EQ(ntoken, q_data->shape[0] * q_data->shape[1]); + + // The auxiliary data structure on device must have been synchronized. + CHECK(!dirty_aux_data_device_) + << "The auxiliary arrays are not synchronized to device. Please call " + "`SyncAuxArrayToDevice` to synchronize before calling `Attention`."; + + f_attention(q_data, pages_, // + page_table_indptr_view_, page_table_values_view_, // + last_page_offset_view_, cur_append_length_indptr_view_, // + layer_id, attn_tmp_buffer_, output, apply_rotary, rotary_scale, rotary_theta); + } + + /*! + * \brief Append the k/v data to the cache on the given layer. + * Prior to Append, ReserveExtraLengthForAppend should be called to specify + * the length of append for each request. + * \param f_transpose_append The function that copies the input data to the + * cache data. It does a transpose-copy due to the layout difference between + * K/V and the cache data. + * \param k_data The input k data. + * \param v_data The input v data. + * \param layer_id The model layer index of the current append. + * \note We support the following two layout settings for K/V data: + * - in batch decode settings, k/v_data has layout (num_total_seqs, 1, num_heads, head_dim) + * where num_total_seqs should exactly equal to the number of sequences in the cache. + * - in other settings (single-sequence prefill, batch prefill, or speculation + * verification), k/v_data has the **flattened layout** to handle raggedness: + * (1, total query length, num_heads, head_dim). + */ + void Append(PackedFunc f_transpose_append, NDArray k_data, NDArray v_data, int64_t layer_id) { + // Check k/v_data shape validity + CHECK_EQ(k_data->ndim, 4); + CHECK_GT(k_data->shape[1], 0); + CHECK_EQ(k_data->shape[2], num_heads_); + CHECK_EQ(k_data->shape[3], head_dim_); + for (int i = 0; i < 4; ++i) { + CHECK_EQ(k_data->shape[i], v_data->shape[i]); + } + CHECK(k_data.DataType() == pages_.DataType()); + CHECK(v_data.DataType() == pages_.DataType()); + + if (k_data->shape[0] > 1) { + CHECK_EQ(k_data->shape[0], num_total_seqs_); + CHECK_EQ(k_data->shape[1], 1); + } + int64_t ntoken = 0; + for (int64_t seq_id = 0; seq_id < num_total_seqs_; ++seq_id) { + ntoken += cur_append_lengths_[seq_id]; + CHECK_LE(cur_append_lengths_[seq_id], seq_lengths_[seq_id]); + if (k_data->shape[0] > 1) { + CHECK_EQ(cur_append_lengths_[seq_id], 1); + } + } + CHECK_EQ(ntoken, k_data->shape[0] * k_data->shape[1]); + ICHECK_EQ(ntoken, cur_pos2seqid_view_->shape[0]); + + // The auxiliary data structure on device must have been synchronized. + CHECK(!dirty_aux_data_device_) + << "The auxiliary arrays are not synchronized to device. Please call " + "`SyncAuxArrayToDevice` to synchronize before calling `Append`."; + + // Copy data + f_transpose_append(pages_, // + k_data.CreateView({ntoken, num_heads_, head_dim_}, k_data->dtype), + v_data.CreateView({ntoken, num_heads_, head_dim_}, v_data->dtype), + page_table_indptr_view_, page_table_values_view_, // + last_page_offset_view_, cur_append_length_indptr_view_, cur_pos2seqid_view_, + layer_id); + } + + /*! + * \brief Remove the given sequence from the cache. + * This includes erasing the sequence from data structures like page table, + * seq_lengths, etc. + * The id of all sequences on behind of it will be decreased by 1. + * \param seq_id The sequence to remove. + */ + void Remove(int64_t seq_id) { + CHECK_LT(seq_id, num_total_seqs_); + for (int32_t page_id : page_table_[seq_id]) { + FreePage(page_id); + } + page_table_.erase(page_table_.begin() + seq_id); + seq_lengths_.erase(seq_lengths_.begin() + seq_id); + cur_append_lengths_.erase(cur_append_lengths_.begin() + seq_id); + --num_total_seqs_; + dirty_aux_data_device_ = true; + } + + /*! + * \brief Pop the last `n` slots of K/V values for the given sequence. + * \param seq_id The sequence to be processed. + * \param n The length to pop. + */ + void PopN(int64_t seq_id, int64_t n) { + CHECK_LT(seq_id, num_total_seqs_); + CHECK_GE(n, 0); + CHECK_LE(n, seq_lengths_[seq_id]); + + // NOTE: this method does not free pages. + seq_lengths_[seq_id] -= n; + dirty_aux_data_device_ = true; + } + + /*! + * \brief Returning the cached K/V values in the form of "an array of NDArray". + * Each returned NDArray has layout (num_layers, 2, seqlen, num_heads, head_dim), + * where along on the "2" dimension, index 0 stands for K and 1 stands for V. + * \param f_copy_data The function used for copying data out from cache. + * \return The cached K/V values, one NDArray per sequence. + * \note This method is majorly for debug and testing purpose. + */ + Array DebugGetKV(PackedFunc f_copy_data) { + // The auxiliary data structure on device must have been synchronized. + CHECK(!dirty_aux_data_device_) + << "The auxiliary arrays are not synchronized to device. Please call " + "`SyncAuxArrayToDevice` to synchronize before calling `View`."; + + Array kv_values; + kv_values.reserve(num_total_seqs_); + + for (int64_t seq_id = 0; seq_id < num_total_seqs_; ++seq_id) { + NDArray values = NDArray::Empty({num_layers_, 2, num_heads_, seq_lengths_[seq_id], head_dim_}, + pages_->dtype, pages_->device); + f_copy_data(pages_, page_table_indptr_view_, page_table_values_view_, values, seq_id); + kv_values.push_back(values); + } + return kv_values; + } + + /*! \brief Reset the values in cur_append_lengths to zeros */ + void ResetAppendLengths() { + ICHECK_EQ(cur_append_lengths_.size(), num_total_seqs_); + for (int64_t seq_id = 0; seq_id < num_total_seqs_; ++seq_id) { + cur_append_lengths_[seq_id] = 0; + } + dirty_aux_data_device_ = true; + } + + /*! + * \brief Synchronize auxiliary arrays to device. The arrays include + * - page table (represented in CSR format on device), + * - the number of used slots in the last page of each sequence, + * - the lengths of append for each sequence, + * - the position-to-seqid array of each position in the appended data. + * \note This method resets the dirty flag to false, and needs to be + * invoked before running any computation on device (attention/append/...). + */ + void SyncAuxArrayToDevice() { + int64_t nbyte_aux = (dtype_aux_.bits * dtype_aux_.lanes + 7) / 8; + + // - Invariant checks + ICHECK_EQ(page_table_.size(), num_total_seqs_); + ICHECK_EQ(seq_lengths_.size(), num_total_seqs_); + for (int64_t seq_id = 0; seq_id < num_total_seqs_; ++seq_id) { + ICHECK(!page_table_[seq_id].empty()); + } + // - Grow NDArrays when needed. + DeviceAuxNDArrayGrow(); + + // - Copy page table indptr and values and copy to device. + std::vector page_table_indptr_host = {0}; + std::vector page_table_values_host; + int64_t npage_in_use = 0; + for (int64_t seq_id = 0; seq_id < num_total_seqs_; ++seq_id) { + // NOTE: The page table on host may contain the pages that are not + // in use by the sequence. + // Here we only copy the pages in use. + std::vector seq_page_table = page_table_[seq_id]; + int64_t npage = (seq_lengths_[seq_id] + page_size_ - 1) / page_size_; + ICHECK_LE(npage, static_cast(seq_page_table.size())); + page_table_values_host.insert(page_table_values_host.end(), seq_page_table.begin(), + seq_page_table.begin() + npage); + page_table_indptr_host.push_back(page_table_values_host.size()); + npage_in_use += seq_page_table.size(); + } + + ICHECK_EQ(npage_in_use, num_pages_in_use_); + ICHECK_EQ(page_table_indptr_host.size(), num_total_seqs_ + 1); + page_table_indptr_view_ = + page_table_indptr_device_.CreateView({num_total_seqs_ + 1}, dtype_aux_); + page_table_indptr_view_.CopyFromBytes(page_table_indptr_host.data(), + (num_total_seqs_ + 1) * nbyte_aux); + page_table_values_view_ = page_table_values_device_.CreateView( + {static_cast(page_table_values_host.size())}, dtype_aux_); + page_table_values_view_.CopyFromBytes(page_table_values_host.data(), + page_table_values_host.size() * nbyte_aux); + + // - Compute `last_page_offset` from seq_lengths and copy to device. + std::vector last_page_offset_host; + last_page_offset_host.reserve(num_total_seqs_); + for (int32_t len : seq_lengths_) { + ICHECK_GT(len, 0); + last_page_offset_host.push_back((len - 1) % page_size_ + 1); + } + last_page_offset_view_ = last_page_offset_device_.CreateView({num_total_seqs_}, dtype_aux_); + last_page_offset_view_.CopyFromBytes(last_page_offset_host.data(), num_total_seqs_ * nbyte_aux); + + // - Compute append_length_indptr, pos2seqid and copy to device. + std::vector append_length_indptr = {0}; + std::vector pos2seqid; + append_length_indptr.reserve(num_total_seqs_ + 1); + + for (int64_t seq_id = 0; seq_id < num_total_seqs_; ++seq_id) { + append_length_indptr.push_back(append_length_indptr.back() + cur_append_lengths_[seq_id]); + for (int64_t pos = 0; pos < cur_append_lengths_[seq_id]; ++pos) { + pos2seqid.push_back(seq_id); + } + } + CHECK_EQ(append_length_indptr.back(), pos2seqid.size()); + ICHECK_EQ(append_length_indptr.size(), num_total_seqs_ + 1); + cur_append_length_indptr_view_ = + cur_append_length_indptr_device_.CreateView({num_total_seqs_ + 1}, dtype_aux_); + cur_append_length_indptr_view_.CopyFromBytes(append_length_indptr.data(), + (num_total_seqs_ + 1) * nbyte_aux); + cur_pos2seqid_view_ = + cur_pos2seqid_device_.CreateView({static_cast(pos2seqid.size())}, dtype_aux_); + cur_pos2seqid_view_.CopyFromBytes(pos2seqid.data(), pos2seqid.size() * nbyte_aux); + + // - Reset the dirty flag to false. + dirty_aux_data_device_ = false; + } + + /*! \brief Reset the KV cache. */ + void Clear() { + num_total_seqs_ = 0; + num_pages_in_use_ = 0; + num_pages_allocated_ = 0; + + free_page_ids_.clear(); + page_table_.clear(); + seq_lengths_.clear(); + cur_append_lengths_.clear(); + + dirty_aux_data_device_ = false; + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.vm.PagedAttentionKVCache"; + TVM_DECLARE_FINAL_OBJECT_INFO(PagedAttentionKVCacheObj, Object); + + private: + /*! + * \brief Allocate a new page for the given sequence. + * This function updates the page table for the given sequence. + */ + void AllocatePageForSequence(int64_t seq_id) { + ICHECK_LT(seq_id, num_total_seqs_); + int32_t page_id = GetFreePage(); + page_table_[seq_id].push_back(page_id); + ++num_pages_in_use_; + } + + /*! \brief Get a new free page and return its id. */ + int32_t GetFreePage() { + // Find a page from the free page pools. + if (!free_page_ids_.empty()) { + int32_t page_id = free_page_ids_.back(); + free_page_ids_.pop_back(); + return page_id; + } + + // Allocate a new page. + int64_t reserved_num_pages = pages_->shape[0]; + if (num_pages_allocated_ < reserved_num_pages) { + return num_pages_allocated_++; + } + ICHECK_EQ(num_pages_allocated_, reserved_num_pages); + + // Grow the `pages` array by doubling its size. + ICHECK_EQ(pages_->ndim, 6); + std::vector new_shape(pages_->shape, pages_->shape + 6); + new_shape[0] = reserved_num_pages * 2; + DLDataType dtype = pages_->dtype; + NDArray new_pages = NDArray::Empty(new_shape, dtype, pages_->device); + new_pages.CreateView(pages_.Shape(), dtype).CopyFrom(pages_); + this->pages_ = new_pages; + // Also create a larger pos2seqid + this->cur_pos2seqid_device_ = NDArray::Empty({reserved_num_pages * 2 * page_size_}, dtype_aux_, + cur_pos2seqid_device_->device); + + return num_pages_allocated_++; + } + + /*! \brief Free the given page, putting it to the available free page pool. */ + void FreePage(int32_t page_id) { + free_page_ids_.push_back(page_id); + --num_pages_in_use_; + } + + /*! \brief Resize the auxiliary arrays on device as they grow. */ + void DeviceAuxNDArrayGrow() { + int64_t reserved_nseq = page_table_indptr_device_->shape[0] - 1; + ICHECK_EQ(last_page_offset_device_->shape[0], reserved_nseq); + while (num_total_seqs_ + 1 > reserved_nseq) { + reserved_nseq *= 2; + } + + DLDevice device = page_table_indptr_device_->device; + if (reserved_nseq != page_table_indptr_device_->shape[0] - 1) { + page_table_indptr_device_ = NDArray::Empty({reserved_nseq + 1}, dtype_aux_, device); + last_page_offset_device_ = NDArray::Empty({reserved_nseq}, dtype_aux_, device); + cur_append_length_indptr_device_ = NDArray::Empty({reserved_nseq + 1}, dtype_aux_, device); + } + + if (pages_->shape[0] > page_table_values_device_->shape[0]) { + page_table_values_device_ = NDArray::Empty({pages_->shape[0]}, dtype_aux_, device); + } + } +}; + +class PagedAttentionKVCache : public ObjectRef { + public: + static PagedAttentionKVCache Create(int64_t reserved_num_seqs, int64_t total_sequence_length, + int64_t page_size, int64_t num_layers, int64_t num_heads, + int64_t head_dim, NDArray init) { + int64_t reserved_num_pages = (total_sequence_length + page_size - 1) / page_size; + auto n = make_object(page_size, num_layers, num_heads, head_dim, + reserved_num_seqs, reserved_num_pages, + init->dtype, init->device); + return PagedAttentionKVCache(n); + } + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PagedAttentionKVCache, ObjectRef, PagedAttentionKVCacheObj); +}; + +TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj); + +//------------------------------------------------- +// Register runtime functions +//------------------------------------------------- + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") + .set_body_typed([](ShapeTuple cache_config, int64_t num_layers_, int64_t num_heads_, + int64_t head_dim_, NDArray init) { + CHECK_EQ(cache_config.size(), 3); + return PagedAttentionKVCache::Create(cache_config[0], cache_config[1], cache_config[2], + num_layers_, num_heads_, head_dim_, init); + }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_add_sequence") + .set_body_typed([](PagedAttentionKVCache cache) { return cache->AddSequence(); }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_reserve_extra_length_for_append") + .set_body_typed([](PagedAttentionKVCache cache, int seq_id, int extra_length) { + cache->ReserveExtraLengthForAppend(seq_id, extra_length); + }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_append") + .set_body_typed([](PagedAttentionKVCache cache, PackedFunc f_transpose_append, NDArray k_data, + NDArray v_data, int64_t layer_id) { + cache->Append(f_transpose_append, k_data, v_data, layer_id); + return cache; + }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_remove") + .set_body_typed([](PagedAttentionKVCache cache, int64_t seq_id) { cache->Remove(seq_id); }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_debug_get_kv") + .set_body_typed([](PagedAttentionKVCache cache, PackedFunc f_view) { + return cache->DebugGetKV(f_view); + }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_popn") + .set_body_typed([](PagedAttentionKVCache cache, int64_t seq_id, int64_t n) { + cache->PopN(seq_id, n); + }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_reset_append_lengths") + .set_body_typed([](PagedAttentionKVCache cache) { cache->ResetAppendLengths(); }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_sync_aux_array_to_device") + .set_body_typed([](PagedAttentionKVCache cache) { cache->SyncAuxArrayToDevice(); }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_clear") + .set_body_typed([](PagedAttentionKVCache cache) { cache->Clear(); }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention") + .set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK(args.size() == 5 || args.size() == 8); + bool apply_rotary = false; + double rotary_scale = 1.0; + double rotary_theta = 1e4; + if (args.size() == 8) { + apply_rotary = args[4]; + rotary_scale = args[5]; + rotary_theta = args[6]; + } + PagedAttentionKVCache cache = args[0]; + cache->Attention(/*f_attention=*/args[1], /*q_data=*/args[2], /*layer_id=*/args[3], + /*output=*/args[args.size() - 1], apply_rotary, rotary_scale, rotary_theta); + }); + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py new file mode 100644 index 0000000000000..a57a905550032 --- /dev/null +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py @@ -0,0 +1,421 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List + +import numpy as np +import tvm +import tvm.testing +from tvm.script import tir as T + + +reserved_nseq = 2 +total_seq_len = 128 +page_size = 8 +nlayer = 4 +nhead = 16 +nfeat = 32 +dtype = "float16" + + +fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create") +fadd_sequence = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_add_sequence") +freserve = tvm.get_global_func( + "vm.builtin.paged_attention_kv_cache_reserve_extra_length_for_append" +) +fappend = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_append") +freset_append_length = tvm.get_global_func( + "vm.builtin.paged_attention_kv_cache_reset_append_lengths" +) +fsync = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_sync_aux_array_to_device") +fremove = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_remove") +fpopn = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_popn") +fclear = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_clear") + +# fmt: off +@T.prim_func +def transpose_append( + var_pages: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + var_page_table_indptr: T.handle, + var_page_table_values: T.handle, + var_last_page_offset: T.handle, + var_append_length_indptr: T.handle, + var_pos2seqidx: T.handle, + layer_id: T.int32, +): + nseq = T.int32() + ntoken = T.int32() + nhead = T.int32() + nfeat = T.int32() + nlayer = T.int32() + npage = T.int32() + page_size = T.int32() + num_pages = T.int32() + + pages = T.match_buffer(var_pages, (num_pages, nlayer, 2, nhead, page_size, nfeat), "float16") + k_data = T.match_buffer(var_k_data, (ntoken, nhead, nfeat), "float16") + v_data = T.match_buffer(var_v_data, (ntoken, nhead, nfeat), "float16") + last_page_offset = T.match_buffer(var_last_page_offset, (nseq,), "int32") + page_table_indptr = T.match_buffer(var_page_table_indptr, (nseq + 1,), "int32") + page_table_values = T.match_buffer(var_page_table_values, (npage,), "int32") + append_length_indptr = T.match_buffer(var_append_length_indptr, (nseq + 1,), "int32") + pos2seqidx = T.match_buffer(var_pos2seqidx, (ntoken,), "int32") + + for global_pos, h, f in T.grid(ntoken, nhead, nfeat): + with T.block("k_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + seq_idx = pos2seqidx[vgpos] + seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx] + pages[ + page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)], + layer_id, + 0, + vh, + T.floormod(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size), + vf, + ] = k_data[vgpos, vh, vf] + with T.block("v_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + seq_idx = pos2seqidx[vgpos] + seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx] + pages[ + page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)], + layer_id, + 1, + vh, + T.floormod(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size), + vf, + ] = v_data[vgpos, vh, vf] + + +@T.prim_func +def copy_cache( + var_pages: T.handle, + var_page_table_indptr: T.handle, + var_page_table_values: T.handle, + var_values: T.handle, + seq_id: T.int32, +): + nhead = T.int32() + nfeat = T.int32() + nlayer = T.int32() + seqlen = T.int32() + npage = T.int32() + page_size = T.int32() + num_pages = T.int32() + num_total_seqs_plus_1 = T.int32() + + pages = T.match_buffer(var_pages, (num_pages, nlayer, 2, nhead, page_size, nfeat), "float16") + page_table_indptr = T.match_buffer(var_page_table_indptr, (num_total_seqs_plus_1,), "int32") + page_table_values = T.match_buffer(var_page_table_values, (npage,), "int32") + values = T.match_buffer(var_values, (nlayer, 2, nhead, seqlen, nfeat), "float16") + + for l, kv_idx, h, pos, f in T.grid(nlayer, 2, nhead, seqlen, nfeat): + with T.block("view"): + vl, vi, vh, vp, vf = T.axis.remap("SSSSS", [l, kv_idx, h, pos, f]) + values[vl, vi, vh, vp, vf] = pages[ + page_table_values[page_table_indptr[seq_id] + T.floordiv(vp, page_size)], + vl, + vi, + vh, + T.floormod(vp, page_size), + vf, + ] +# fmt: on + + +def verify_cached_values(cache, expected, f_copy_cache): + fview = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_debug_get_kv") + + actual = fview(cache, f_copy_cache) + assert len(actual) == len(expected) + for seq_actual, seq_expected in zip(actual, expected): + tvm.testing.assert_allclose(np.transpose(seq_actual.numpy(), [0, 1, 3, 2, 4]), seq_expected) + + +def build_tir_func(tir_funcs: List[tvm.tir.PrimFunc], target="llvm"): + return [tvm.build(tir_func, target=target).entry_func for tir_func in tir_funcs] + + +def test_paged_attention_kv_cache_append_prefill(): + f_transpose_append, f_copy_cache = build_tir_func([transpose_append, copy_cache]) + cache = fcreate( + tvm.runtime.ShapeTuple([reserved_nseq, total_seq_len, page_size]), + nlayer, + nhead, + nfeat, + tvm.nd.empty((), dtype), + ) + + operation_seq = [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4, 19), (5, 20)]] + operation_seq += [[(6, 21), (7, 24)], [(2, 5), (4, 7), (8, 24)]] + operation_seq += [[(6, 13)], [(8, 19)], [(0, 1)], [(1, 3), (3, 8), (5, 12), (7, 11)]] + + current_nseq = 0 + cached_values = [] + for batch in operation_seq: + for seq_id, _ in batch: + if seq_id >= current_nseq: + seq_id_in_cache = fadd_sequence(cache) + assert seq_id_in_cache == seq_id + assert seq_id == current_nseq + current_nseq += 1 + + freset_append_length(cache) + for seq_id, append_length in batch: + freserve(cache, seq_id, append_length) + fsync(cache) + + global_new_kv = np.zeros((nlayer, 2, 0, nhead, nfeat), dtype) + for seq_id, new_len in batch: + if seq_id >= len(cached_values): + assert seq_id == len(cached_values) + cached_values.append(np.zeros((nlayer, 2, 0, nhead, nfeat), dtype)) + + new_kv = np.random.rand(nlayer, 2, new_len, nhead, nfeat).astype(dtype) + cached_values[seq_id] = np.concatenate([cached_values[seq_id], new_kv], axis=2) + global_new_kv = np.concatenate([global_new_kv, new_kv], axis=2) + for layer_id in range(nlayer): + keys = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 0], axis=0)) + values = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 1], axis=0)) + fappend(cache, f_transpose_append, keys, values, layer_id) + + # Verify + verify_cached_values(cache, cached_values, f_copy_cache) + + +def test_paged_attention_kv_cache_append_decode(): + f_transpose_append, f_copy_cache = build_tir_func([transpose_append, copy_cache]) + cache = fcreate( + tvm.runtime.ShapeTuple([reserved_nseq, total_seq_len, page_size]), + nlayer, + nhead, + nfeat, + tvm.nd.empty((), dtype), + ) + + cached_values = [] + initial_lengths = [31, 21, 16, 3, 8, 7, 3] + nseq = len(initial_lengths) + + # Initial prefill + freset_append_length(cache) + for seq_id, append_length in enumerate(initial_lengths): + seq_id_in_cache = fadd_sequence(cache) + assert seq_id_in_cache == seq_id + freserve(cache, seq_id, append_length) + fsync(cache) + + global_new_kv = np.zeros((nlayer, 2, 0, nhead, nfeat), dtype) + for length in initial_lengths: + new_kv = np.random.rand(nlayer, 2, length, nhead, nfeat).astype(dtype) + cached_values.append(new_kv) + global_new_kv = np.concatenate([global_new_kv, new_kv], axis=2) + for layer_id in range(nlayer): + keys = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 0], axis=0)) + values = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 1], axis=0)) + fappend(cache, f_transpose_append, keys, values, layer_id) + + verify_cached_values(cache, cached_values, f_copy_cache) + + # Decode + for _ in range(16): + decode_new_kv = np.random.rand(nlayer, 2, nseq, 1, nhead, nfeat).astype(dtype) + freset_append_length(cache) + for seq_id in range(nseq): + freserve(cache, seq_id, 1) + fsync(cache) + for seq_id in range(nseq): + cached_values[seq_id] = np.concatenate( + [cached_values[seq_id], decode_new_kv[:, :, seq_id, ...]], axis=2 + ) + for layer_id in range(nlayer): + keys = tvm.nd.array(decode_new_kv[layer_id, 0]) + values = tvm.nd.array(decode_new_kv[layer_id, 1]) + fappend(cache, f_transpose_append, keys, values, layer_id) + + verify_cached_values(cache, cached_values, f_copy_cache) + + +def test_paged_attention_kv_cache_remove(): + f_transpose_append, f_copy_cache = build_tir_func([transpose_append, copy_cache]) + cache = fcreate( + tvm.runtime.ShapeTuple([reserved_nseq, total_seq_len, page_size]), + nlayer, + nhead, + nfeat, + tvm.nd.empty((), dtype), + ) + + cached_values = [] + initial_lengths = [31, 21, 16, 3, 8, 7, 3] + + # Initial prefill + freset_append_length(cache) + for seq_id, append_length in enumerate(initial_lengths): + seq_id_in_cache = fadd_sequence(cache) + assert seq_id_in_cache == seq_id + freserve(cache, seq_id, append_length) + fsync(cache) + + global_new_kv = np.zeros((nlayer, 2, 0, nhead, nfeat), dtype) + for length in initial_lengths: + new_kv = np.random.rand(nlayer, 2, length, nhead, nfeat).astype(dtype) + cached_values.append(new_kv) + global_new_kv = np.concatenate([global_new_kv, new_kv], axis=2) + for layer_id in range(nlayer): + keys = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 0], axis=0)) + values = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 1], axis=0)) + fappend(cache, f_transpose_append, keys, values, layer_id) + + verify_cached_values(cache, cached_values, f_copy_cache) + + # Remove + while len(cached_values) > 2: + seq_id = np.random.randint(0, len(cached_values)) + fremove(cache, seq_id) + cached_values.pop(seq_id) + fsync(cache) + verify_cached_values(cache, cached_values, f_copy_cache) + + # Append after removal + seq_id = 2 + new_len = 29 + seq_id_in_cache = fadd_sequence(cache) + assert seq_id_in_cache == seq_id + freset_append_length(cache) + freserve(cache, seq_id, new_len) + fsync(cache) + new_kv = np.random.rand(nlayer, 2, new_len, nhead, nfeat).astype(dtype) + cached_values.append(new_kv) + for layer_id in range(nlayer): + keys = tvm.nd.array(np.expand_dims(new_kv[layer_id, 0], axis=0)) + values = tvm.nd.array(np.expand_dims(new_kv[layer_id, 1], axis=0)) + fappend(cache, f_transpose_append, keys, values, layer_id) + + verify_cached_values(cache, cached_values, f_copy_cache) + + +def test_paged_attention_kv_cache_popn(): + f_transpose_append, f_copy_cache = build_tir_func([transpose_append, copy_cache]) + cache = fcreate( + tvm.runtime.ShapeTuple([reserved_nseq, total_seq_len, page_size]), + nlayer, + nhead, + nfeat, + tvm.nd.empty((), dtype), + ) + + cached_values = [] + initial_lengths = [20, 24, 26, 27] + nseq = len(initial_lengths) + + # Initial prefill + freset_append_length(cache) + for seq_id, append_length in enumerate(initial_lengths): + seq_id_in_cache = fadd_sequence(cache) + assert seq_id_in_cache == seq_id + freserve(cache, seq_id, append_length) + fsync(cache) + + global_new_kv = np.zeros((nlayer, 2, 0, nhead, nfeat), dtype) + for length in initial_lengths: + new_kv = np.random.rand(nlayer, 2, length, nhead, nfeat).astype(dtype) + cached_values.append(new_kv) + global_new_kv = np.concatenate([global_new_kv, new_kv], axis=2) + for layer_id in range(nlayer): + keys = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 0], axis=0)) + values = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 1], axis=0)) + fappend(cache, f_transpose_append, keys, values, layer_id) + + verify_cached_values(cache, cached_values, f_copy_cache) + + # Pop n + for pop_length in [3, 13]: + for seq_id in range(nseq): + fpopn(cache, seq_id, pop_length) + cached_values[seq_id] = cached_values[seq_id][:, :, :-pop_length, ...] + fsync(cache) + verify_cached_values(cache, cached_values, f_copy_cache) + + # Decode after pop n + for _ in range(5): + decode_new_kv = np.random.rand(nlayer, 2, nseq, 1, nhead, nfeat).astype(dtype) + freset_append_length(cache) + for seq_id in range(nseq): + freserve(cache, seq_id, 1) + fsync(cache) + + for seq_id in range(nseq): + cached_values[seq_id] = np.concatenate( + [cached_values[seq_id], decode_new_kv[:, :, seq_id, ...]], axis=2 + ) + for layer_id in range(nlayer): + keys = tvm.nd.array(decode_new_kv[layer_id, 0]) + values = tvm.nd.array(decode_new_kv[layer_id, 1]) + fappend(cache, f_transpose_append, keys, values, layer_id) + + verify_cached_values(cache, cached_values, f_copy_cache) + + +def test_paged_attention_kv_cache_clear(): + f_transpose_append, f_copy_cache = build_tir_func([transpose_append, copy_cache]) + cache = fcreate( + tvm.runtime.ShapeTuple([reserved_nseq, total_seq_len, page_size]), + nlayer, + nhead, + nfeat, + tvm.nd.empty((), dtype), + ) + + cached_values = [] + initial_lengths = [20, 24, 26, 27] + + # Initial prefill + freset_append_length(cache) + for seq_id, append_length in enumerate(initial_lengths): + seq_id_in_cache = fadd_sequence(cache) + assert seq_id_in_cache == seq_id + freserve(cache, seq_id, append_length) + fsync(cache) + + global_new_kv = np.zeros((nlayer, 2, 0, nhead, nfeat), dtype) + for length in initial_lengths: + new_kv = np.random.rand(nlayer, 2, length, nhead, nfeat).astype(dtype) + cached_values.append(new_kv) + global_new_kv = np.concatenate([global_new_kv, new_kv], axis=2) + for layer_id in range(nlayer): + keys = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 0], axis=0)) + values = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 1], axis=0)) + fappend(cache, f_transpose_append, keys, values, layer_id) + + verify_cached_values(cache, cached_values, f_copy_cache) + + # Clear + fclear(cache) + verify_cached_values(cache, [], f_copy_cache) + + +if __name__ == "__main__": + test_paged_attention_kv_cache_append_prefill() + test_paged_attention_kv_cache_append_decode() + test_paged_attention_kv_cache_remove() + test_paged_attention_kv_cache_popn() + test_paged_attention_kv_cache_clear() + # Test for attention is not included at this moment + # since we do not have TIR attention functions yet.