diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc index 39948ae52255..9cd6fb186de9 100644 --- a/src/runtime/relax_vm/lm_support.cc +++ b/src/runtime/relax_vm/lm_support.cc @@ -472,6 +472,430 @@ void ApplySoftmaxWithTemperature(NDArray logits, double temperature) { TVM_REGISTER_GLOBAL("vm.builtin.apply_softmax_with_temperature") .set_body_typed(ApplySoftmaxWithTemperature); +//////////////////////////////////////////////////////////////////////// + +class PagedAttentionKVCacheObj : public Object { + public: + int64_t num_total_seqs; + int64_t num_pages_in_use; + int64_t num_pages_allocated; + + int64_t page_size; + int64_t nlayer; + int64_t nhead; + int64_t nfeat; + + const DLDataType dtype_aux = DataType::Int(32, 1).operator DLDataType(); + + /********************* Page Structures *********************/ + + const int64_t page_chunk_size = 1; + + NDArray pages; + std::vector free_page_ids; + + std::vector> page_table; + std::vector seqlen; + + NDArray page_table_indptr_device; + NDArray page_table_values_device; + NDArray last_page_offset_device; + + /********************* Current Batch Info *********************/ + + ShapeTuple cur_append_lengths; + NDArray cur_append_length_indptr_device; + NDArray cur_pos2seqidx_device; + + public: + void Attention(PackedFunc f_attention, NDArray q_data, int64_t layer_id, NDArray output, + bool apply_rotary = false, double rotary_scale = 1.0f, double rotary_theta = 1e4) { + // q_data: (b, len, nhead, nfeat) + CHECK_EQ(q_data->ndim, 4); + CHECK_GT(q_data->shape[1], 0); + CHECK_EQ(q_data->shape[2], nhead); + CHECK_EQ(q_data->shape[3], nfeat); + CHECK(q_data.DataType() == pages.DataType()); + + int64_t ntoken = 0; + for (int64_t i = 0; i < num_total_seqs; ++i) { + ntoken += cur_append_lengths[i]; + if (q_data->shape[0] > 1) { + CHECK_EQ(cur_append_lengths[i], q_data->shape[1]); + } + } + if (q_data->shape[0] > 1) { + CHECK_EQ(q_data->shape[0], num_total_seqs); + } + CHECK_EQ(ntoken, q_data->shape[0] * q_data->shape[1]); + + // Provide a 8MB size float temp buffer for the attention kernel. + static NDArray tmp_buffer = + NDArray::Empty({8 * 1024 * 1024}, DLDataType(DataType::Float(32)), pages->device); + + f_attention(q_data, pages, // + page_table_indptr_device.CreateView({num_total_seqs + 1}, dtype_aux), // + page_table_values_device.CreateView({num_pages_in_use}, dtype_aux), // + last_page_offset_device.CreateView({num_total_seqs}, dtype_aux), // + cur_append_length_indptr_device.CreateView({num_total_seqs + 1}, dtype_aux), // + layer_id, tmp_buffer, output, apply_rotary, rotary_scale, rotary_theta); + } + + void Prepare(Optional opt_append_lengths) { + ShapeTuple append_lengths; + int64_t nseq = + opt_append_lengths.defined() ? opt_append_lengths.value().size() : num_total_seqs; + append_lengths = + opt_append_lengths.value_or(ShapeTuple(std::vector(/*count=*/nseq, /*value=*/1))); + + CHECK_GE(nseq, num_total_seqs); + cur_append_lengths = append_lengths; + + for (int64_t i = 0; i < nseq; ++i) { + ICHECK_LE(i, num_total_seqs); + if (i == num_total_seqs) { + // Initialize if it is a new sequence. + CHECK_GT(append_lengths[i], 0); + page_table.push_back({}); + seqlen.push_back(0); + ++num_total_seqs; + } + + if (append_lengths[i] == 0) { + continue; + } + + int64_t cur_npage = (seqlen[i] + page_size - 1) / page_size; + int64_t tgt_npage = (seqlen[i] + append_lengths[i] + page_size - 1) / page_size; + for (int64_t page_idx = cur_npage; page_idx < tgt_npage; ++page_idx) { + AllocatePage(i); + } + seqlen[i] += append_lengths[i]; + } + + // Sync auxiliary data structures to device + SyncDevice(/*sync_current_batch=*/true); + } + + void Append(PackedFunc f_transpose_append, NDArray k_data, NDArray v_data, int64_t layer_id) { + // k/v_data: (b, len, nhead, nfeat) + CHECK_EQ(k_data->ndim, 4); + CHECK_GT(k_data->shape[1], 0); + CHECK_EQ(k_data->shape[2], nhead); + CHECK_EQ(k_data->shape[3], nfeat); + + for (int i = 0; i < 4; ++i) { + CHECK_EQ(k_data->shape[i], v_data->shape[i]); + } + CHECK(k_data.DataType() == pages.DataType()); + CHECK(k_data.DataType() == v_data.DataType()); + + int64_t ntoken = 0; + for (int64_t i = 0; i < num_total_seqs; ++i) { + ntoken += cur_append_lengths[i]; + if (k_data->shape[0] > 1) { + CHECK_EQ(cur_append_lengths[i], k_data->shape[1]); + } + } + if (k_data->shape[0] > 1) { + CHECK_EQ(k_data->shape[0], num_total_seqs); + } + CHECK_EQ(ntoken, k_data->shape[0] * k_data->shape[1]); + + // Copy data + f_transpose_append(pages, // + k_data.CreateView({ntoken, nhead, nfeat}, k_data->dtype), + v_data.CreateView({ntoken, nhead, nfeat}, v_data->dtype), + page_table_indptr_device.CreateView({num_total_seqs + 1}, dtype_aux), + page_table_values_device.CreateView({num_pages_in_use}, dtype_aux), + last_page_offset_device.CreateView({num_total_seqs}, dtype_aux), + cur_append_length_indptr_device.CreateView({num_total_seqs + 1}, dtype_aux), + cur_pos2seqidx_device.CreateView({ntoken}, dtype_aux), // + layer_id); + } + + void Remove(int64_t seq_id) { + CHECK_LT(seq_id, num_total_seqs); + for (int32_t page_id : page_table[seq_id]) { + FreePage(page_id); + } + page_table.erase(page_table.begin() + seq_id); + seqlen.erase(seqlen.begin() + seq_id); + num_total_seqs -= 1; + SyncDevice(); + } + + Array View(PackedFunc f_view) { + Array kv_values; + kv_values.reserve(num_total_seqs); + + for (int64_t seq_id = 0; seq_id < num_total_seqs; ++seq_id) { + NDArray values = + NDArray::Empty({nlayer, 2, nhead, seqlen[seq_id], nfeat}, pages->dtype, pages->device); + f_view(pages, // + page_table_indptr_device.CreateView({num_total_seqs + 1}, dtype_aux), // + page_table_values_device.CreateView({num_pages_in_use}, dtype_aux), // + values, seq_id); + kv_values.push_back(values); + } + return kv_values; + } + + void PopN(int64_t seq_id, int64_t n) { + CHECK_LT(seq_id, num_total_seqs); + CHECK_GE(n, 0); + CHECK_LE(n, seqlen[seq_id]); + + int64_t cur_npage = (seqlen[seq_id] + page_size - 1) / page_size; + int64_t tgt_npage = (seqlen[seq_id] - n + page_size - 1) / page_size; + for (int64_t page_idx = cur_npage - 1; page_idx >= tgt_npage; --page_idx) { + ICHECK_EQ(page_idx, page_table[seq_id].size() - 1); + int64_t page_id = page_table[seq_id].back(); + page_table[seq_id].pop_back(); + FreePage(page_id); + } + seqlen[seq_id] -= n; + } + + void Clear() { + num_total_seqs = 0; + num_pages_in_use = 0; + num_pages_allocated = 0; + + free_page_ids.clear(); + page_table.clear(); + seqlen.clear(); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.vm.PagedAttentionKVCache"; + TVM_DECLARE_FINAL_OBJECT_INFO(PagedAttentionKVCacheObj, Object); + + private: + void AllocatePage(int64_t seq_id) { + ICHECK_LT(seq_id, num_total_seqs); + int32_t page_id = GetFreePage(); + page_table[seq_id].push_back(page_id); + ++num_pages_in_use; + } + + int32_t GetFreePage() { + if (!free_page_ids.empty()) { + int32_t page_id = free_page_ids.back(); + free_page_ids.pop_back(); + return page_id; + } + + int64_t reserved_num_page_chunks = pages->shape[0]; + if (num_pages_allocated < reserved_num_page_chunks * page_chunk_size) { + return num_pages_allocated++; + } + ICHECK_EQ(num_pages_allocated, reserved_num_page_chunks * page_chunk_size); + + // page chunk grow + ICHECK_EQ(pages->ndim, 7); + std::vector new_shape(pages->shape, pages->shape + 7); + new_shape[0] = reserved_num_page_chunks * 2; + DLDataType dtype = pages->dtype; + NDArray new_pages = NDArray::Empty(new_shape, dtype, pages->device); + new_pages.CreateView(pages.Shape(), dtype).CopyFrom(pages); + this->pages = new_pages; + // Also create a larger pos2seqidx + this->cur_pos2seqidx_device = + NDArray::Empty({reserved_num_page_chunks * 2 * page_chunk_size * page_size}, dtype_aux, + cur_pos2seqidx_device->device); + + return num_pages_allocated++; + } + + void FreePage(int32_t page_id) { + free_page_ids.push_back(page_id); + --num_pages_in_use; + } + + void DeviceAuxNDArrayGrow() { + int64_t reserved_nseq = page_table_indptr_device->shape[0] - 1; + ICHECK_EQ(last_page_offset_device->shape[0], reserved_nseq); + while (num_total_seqs + 1 > reserved_nseq) { + reserved_nseq *= 2; + } + + DLDevice device = page_table_indptr_device->device; + if (reserved_nseq != page_table_indptr_device->shape[0] - 1) { + page_table_indptr_device = NDArray::Empty({reserved_nseq + 1}, dtype_aux, device); + last_page_offset_device = NDArray::Empty({reserved_nseq}, dtype_aux, device); + cur_append_length_indptr_device = NDArray::Empty({reserved_nseq + 1}, dtype_aux, device); + } + + if (pages->shape[0] * page_chunk_size > page_table_values_device->shape[0]) { + page_table_values_device = + NDArray::Empty({pages->shape[0] * page_chunk_size}, dtype_aux, device); + } + } + + void SyncDevice(bool sync_current_batch = false) { + // Invariant checks + ICHECK_EQ(page_table.size(), num_total_seqs); + ICHECK_EQ(seqlen.size(), num_total_seqs); + for (int64_t seq_id = 0; seq_id < num_total_seqs; ++seq_id) { + ICHECK(!page_table[seq_id].empty()); + } + + // Grow NDArrays when needed. + DeviceAuxNDArrayGrow(); + + int64_t nbyte_aux = (dtype_aux.bits * dtype_aux.lanes + 7) / 8; + // Copy page table indptr and values + std::vector page_table_indptr_host = {0}; + std::vector page_table_values_host; + for (auto& seq_page_table : page_table) { + page_table_values_host.insert(page_table_values_host.end(), seq_page_table.begin(), + seq_page_table.end()); + page_table_indptr_host.push_back(page_table_values_host.size()); + } + std::vector last_page_offset_host; + last_page_offset_host.reserve(num_total_seqs); + for (int32_t len : seqlen) { + ICHECK_GT(len, 0); + last_page_offset_host.push_back((len - 1) % page_size + 1); + } + + ICHECK_EQ(page_table_values_host.size(), num_pages_in_use); + page_table_indptr_device + .CreateView({static_cast(page_table_indptr_host.size())}, dtype_aux) + .CopyFromBytes(page_table_indptr_host.data(), page_table_indptr_host.size() * nbyte_aux); + page_table_values_device + .CreateView({static_cast(page_table_values_host.size())}, dtype_aux) + .CopyFromBytes(page_table_values_host.data(), page_table_values_host.size() * nbyte_aux); + // Copy seqlen + last_page_offset_device.CreateView({num_total_seqs}, dtype_aux) + .CopyFromBytes(last_page_offset_host.data(), num_total_seqs * nbyte_aux); + + if (sync_current_batch) { + std::vector append_length_indptr; + std::vector pos2seqidx; + append_length_indptr.reserve(num_total_seqs + 1); + append_length_indptr.push_back(0); + + for (int64_t i = 0; i < num_total_seqs; ++i) { + append_length_indptr.push_back(append_length_indptr.back() + cur_append_lengths[i]); + for (int64_t pos = 0; pos < cur_append_lengths[i]; ++pos) { + pos2seqidx.push_back(i); + } + } + CHECK_EQ(append_length_indptr.back(), pos2seqidx.size()); + + cur_append_length_indptr_device + .CreateView({static_cast(append_length_indptr.size())}, dtype_aux) + .CopyFromBytes(append_length_indptr.data(), append_length_indptr.size() * nbyte_aux); + cur_pos2seqidx_device.CreateView({static_cast(pos2seqidx.size())}, dtype_aux) + .CopyFromBytes(pos2seqidx.data(), pos2seqidx.size() * nbyte_aux); + } + } +}; + +class PagedAttentionKVCache : public ObjectRef { + public: + static PagedAttentionKVCache Create(int64_t reserved_nseq, int64_t total_sequence_length, + int64_t page_size, int64_t nlayer, int64_t nhead, + int64_t nfeat, NDArray init) { + auto n = make_object(); + n->num_total_seqs = 0; + n->num_pages_in_use = 0; + n->num_pages_allocated = 0; + n->page_size = page_size; + n->nlayer = nlayer; + n->nhead = nhead; + n->nfeat = nfeat; + + DLDevice device = init->device; + int64_t reserved_num_page_chunks = + (((total_sequence_length + page_size - 1) / page_size * nlayer * 2) + + (n->page_chunk_size - 1)) / + n->page_chunk_size; + n->pages = NDArray::Empty( + {reserved_num_page_chunks, nlayer, n->page_chunk_size, 2, nhead, page_size, nfeat}, + init->dtype, device); + + n->page_table_indptr_device = NDArray::Empty({reserved_nseq + 1}, n->dtype_aux, device); + n->page_table_values_device = + NDArray::Empty({reserved_num_page_chunks * n->page_chunk_size}, n->dtype_aux, device); + n->last_page_offset_device = NDArray::Empty({reserved_nseq}, n->dtype_aux, device); + n->cur_append_length_indptr_device = NDArray::Empty({reserved_nseq + 1}, n->dtype_aux, device); + n->cur_pos2seqidx_device = NDArray::Empty( + {reserved_num_page_chunks * n->page_chunk_size * page_size}, n->dtype_aux, device); + + return PagedAttentionKVCache(n); + } + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PagedAttentionKVCache, ObjectRef, PagedAttentionKVCacheObj); +}; + +TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj); + +//------------------------------------------------- +// Register runtime functions +//------------------------------------------------- + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") + .set_body_typed([](ShapeTuple cache_config, int64_t nlayer, int64_t nhead, int64_t nfeat, + NDArray init) { + CHECK_EQ(cache_config.size(), 3); + return PagedAttentionKVCache::Create(cache_config[0], cache_config[1], cache_config[2], + nlayer, nhead, nfeat, init); + }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_prepare") + .set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK(args.size() == 1 || args.size() == 2); + PagedAttentionKVCache cache = args[0]; + Optional append_lengths; + if (args.size() == 2) { + append_lengths = ShapeTuple(args[1]); + } + cache->Prepare(append_lengths); + *rv = cache; + }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_append") + .set_body_typed([](PagedAttentionKVCache cache, PackedFunc f_transpose_append, NDArray k_data, + NDArray v_data, int64_t layer_id) { + cache->Append(f_transpose_append, k_data, v_data, layer_id); + return cache; + }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_remove") + .set_body_typed([](PagedAttentionKVCache cache, int64_t seq_id) { cache->Remove(seq_id); }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_view_testing") + .set_body_typed([](PagedAttentionKVCache cache, PackedFunc f_view) { + return cache->View(f_view); + }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_popn") + .set_body_typed([](PagedAttentionKVCache cache, int64_t seq_id, int64_t n) { + return cache->PopN(seq_id, n); + }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_clear") + .set_body_typed([](PagedAttentionKVCache cache) { return cache->Clear(); }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention") + .set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK(args.size() == 5 || args.size() == 8); + bool apply_rotary = false; + double rotary_scale = 1.0; + double rotary_theta = 1e4; + if (args.size() == 8) { + apply_rotary = args[4]; + rotary_scale = args[5]; + rotary_theta = args[6]; + } + PagedAttentionKVCache cache = args[0]; + cache->Attention(/*f_attention=*/args[1], /*q_data=*/args[2], /*layer_id=*/args[3], + /*output=*/args[args.size() - 1], apply_rotary, rotary_scale, rotary_theta); + }); + } // namespace relax_vm } // namespace runtime } // namespace tvm diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py new file mode 100644 index 000000000000..dfbc7aaa736c --- /dev/null +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py @@ -0,0 +1,407 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List + +import numpy as np +import tvm +import tvm.testing +from tvm.script import tir as T + + +reserved_nseq = 2 +total_seq_len = 128 +page_size = 8 +nlayer = 4 +nhead = 16 +nfeat = 32 +dtype = "float16" + + +# fmt: off +@T.prim_func +def transpose_append( + var_pages: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + var_page_table_indptr: T.handle, + var_page_table_values: T.handle, + var_last_page_offset: T.handle, + var_append_length_indptr: T.handle, + var_pos2seqidx: T.handle, + layer_id: T.int32, +): + nseq = T.int32() + ntoken = T.int32() + nhead = T.int32() + nfeat = T.int32() + nlayer = T.int32() + npage = T.int32() + page_size = T.int32() + num_page_chunks = T.int32() + page_chunk_size = T.int32() + + pages = T.match_buffer(var_pages, (num_page_chunks, nlayer, page_chunk_size, 2, nhead, page_size, nfeat), "float16") + k_data = T.match_buffer(var_k_data, (ntoken, nhead, nfeat), "float16") + v_data = T.match_buffer(var_v_data, (ntoken, nhead, nfeat), "float16") + last_page_offset = T.match_buffer(var_last_page_offset, (nseq,), "int32") + page_table_indptr = T.match_buffer(var_page_table_indptr, (nseq + 1,), "int32") + page_table_values = T.match_buffer(var_page_table_values, (npage,), "int32") + append_length_indptr = T.match_buffer(var_append_length_indptr, (nseq + 1,), "int32") + pos2seqidx = T.match_buffer(var_pos2seqidx, (ntoken,), "int32") + + for global_pos, h, f in T.grid(ntoken, nhead, nfeat): + with T.block("k_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + seq_idx = pos2seqidx[vgpos] + seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx] + pages[ + T.floordiv(page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)], page_chunk_size), + layer_id, + T.floormod(page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)], page_chunk_size), + 0, + vh, + T.floormod(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size), + vf, + ] = k_data[vgpos, vh, vf] + with T.block("v_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + seq_idx = pos2seqidx[vgpos] + seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx] + pages[ + T.floordiv(page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)], page_chunk_size), + layer_id, + T.floormod(page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)], page_chunk_size), + 1, + vh, + T.floormod(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size), + vf, + ] = v_data[vgpos, vh, vf] + + +@T.prim_func +def view_cache( + var_pages: T.handle, + var_page_table_indptr: T.handle, + var_page_table_values: T.handle, + var_values: T.handle, + seq_id: T.int32, +): + nhead = T.int32() + nfeat = T.int32() + nlayer = T.int32() + seqlen = T.int32() + npage = T.int32() + page_size = T.int32() + num_page_chunks = T.int32() + page_chunk_size = T.int32() + num_total_seqs_plus_1 = T.int32() + + pages = T.match_buffer(var_pages, (num_page_chunks, nlayer, page_chunk_size, 2, nhead, page_size, nfeat), "float16") + page_table_indptr = T.match_buffer(var_page_table_indptr, (num_total_seqs_plus_1,), "int32") + page_table_values = T.match_buffer(var_page_table_values, (npage,), "int32") + values = T.match_buffer(var_values, (nlayer, 2, nhead, seqlen, nfeat), "float16") + + for l, kv_idx, h, pos, f in T.grid(nlayer, 2, nhead, seqlen, nfeat): + with T.block("view"): + vl, vi, vh, vp, vf = T.axis.remap("SSSSS", [l, kv_idx, h, pos, f]) + values[vl, vi, vh, vp, vf] = pages[ + T.floordiv(page_table_values[page_table_indptr[seq_id] + T.floordiv(vp, page_size)], page_chunk_size), + vl, + T.floormod(page_table_values[page_table_indptr[seq_id] + T.floordiv(vp, page_size)], page_chunk_size), + vi, + vh, + T.floormod(vp, page_size), + vf, + ] +# fmt: on + + +def verify_cached_values(cache, expected, f_view_cache): + fview = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_view_testing") + + actual = fview(cache, f_view_cache) + assert len(actual) == len(expected) + for seq_actual, seq_expected in zip(actual, expected): + tvm.testing.assert_allclose(np.transpose(seq_actual.numpy(), [0, 1, 3, 2, 4]), seq_expected) + + +def build_tir_func(tir_funcs: List[tvm.tir.PrimFunc], target="llvm"): + return [tvm.build(tir_func, target=target).entry_func for tir_func in tir_funcs] + + +def test_paged_attention_kv_cache_append_prefill(): + fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create") + fprepare = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_prepare") + fappend = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_append") + f_transpose_append, f_view_cache = build_tir_func([transpose_append, view_cache]) + + cache = fcreate( + tvm.runtime.ShapeTuple([reserved_nseq, total_seq_len, page_size]), + nlayer, + nhead, + nfeat, + tvm.nd.empty((), dtype), + ) + + operation_seq = [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4, 19), (5, 20)]] + operation_seq += [[(6, 21), (7, 24)], [(2, 5), (4, 7), (8, 24)]] + operation_seq += [[(6, 13)], [(8, 19)], [(0, 1)], [(1, 3), (3, 8), (5, 12), (7, 11)]] + + current_nseq = 0 + append_lengths_list = [] + + cached_values = [] + for batch in operation_seq: + for seq_id, _ in batch: + if seq_id >= current_nseq: + assert seq_id == current_nseq + current_nseq += 1 + + append_lengths_list = [0] * current_nseq + for seq_id, append_length in batch: + append_lengths_list[seq_id] = append_length + + append_lengths = tvm.runtime.ShapeTuple(append_lengths_list) + fprepare(cache, append_lengths) + + global_new_kv = np.zeros((nlayer, 2, 0, nhead, nfeat), dtype) + for seq_id, new_len in batch: + if seq_id >= len(cached_values): + assert seq_id == len(cached_values) + cached_values.append(np.zeros((nlayer, 2, 0, nhead, nfeat), dtype)) + + new_kv = np.random.rand(nlayer, 2, new_len, nhead, nfeat).astype(dtype) + cached_values[seq_id] = np.concatenate([cached_values[seq_id], new_kv], axis=2) + global_new_kv = np.concatenate([global_new_kv, new_kv], axis=2) + for layer_id in range(nlayer): + keys = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 0], axis=0)) + values = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 1], axis=0)) + fappend(cache, f_transpose_append, keys, values, layer_id) + + # Verify + verify_cached_values(cache, cached_values, f_view_cache) + + +def test_paged_attention_kv_cache_append_decode(): + fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create") + fprepare = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_prepare") + fappend = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_append") + f_transpose_append, f_view_cache = build_tir_func([transpose_append, view_cache]) + + cache = fcreate( + tvm.runtime.ShapeTuple([reserved_nseq, total_seq_len, page_size]), + nlayer, + nhead, + nfeat, + tvm.nd.empty((), dtype), + ) + + cached_values = [] + initial_lengths = [31, 21, 16, 3, 8, 7, 3] + nseq = len(initial_lengths) + + # Initial prefill + append_lengths = tvm.runtime.ShapeTuple(tuple(length for length in initial_lengths)) + fprepare(cache, append_lengths) + + global_new_kv = np.zeros((nlayer, 2, 0, nhead, nfeat), dtype) + for length in initial_lengths: + new_kv = np.random.rand(nlayer, 2, length, nhead, nfeat).astype(dtype) + cached_values.append(new_kv) + global_new_kv = np.concatenate([global_new_kv, new_kv], axis=2) + for layer_id in range(nlayer): + keys = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 0], axis=0)) + values = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 1], axis=0)) + fappend(cache, f_transpose_append, keys, values, layer_id) + + verify_cached_values(cache, cached_values, f_view_cache) + + # Decode + for _ in range(16): + decode_new_kv = np.random.rand(nlayer, 2, nseq, 1, nhead, nfeat).astype(dtype) + fprepare(cache) + for seq_id in range(nseq): + cached_values[seq_id] = np.concatenate( + [cached_values[seq_id], decode_new_kv[:, :, seq_id, ...]], axis=2 + ) + for layer_id in range(nlayer): + keys = tvm.nd.array(decode_new_kv[layer_id, 0]) + values = tvm.nd.array(decode_new_kv[layer_id, 1]) + fappend(cache, f_transpose_append, keys, values, layer_id) + + verify_cached_values(cache, cached_values, f_view_cache) + + +def test_paged_attention_kv_cache_remove(): + fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create") + fprepare = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_prepare") + fappend = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_append") + fremove = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_remove") + f_transpose_append, f_view_cache = build_tir_func([transpose_append, view_cache]) + + cache = fcreate( + tvm.runtime.ShapeTuple([reserved_nseq, total_seq_len, page_size]), + nlayer, + nhead, + nfeat, + tvm.nd.empty((), dtype), + ) + + cached_values = [] + initial_lengths = [31, 21, 16, 3, 8, 7, 3] + + # Initial prefill + append_lengths = tvm.runtime.ShapeTuple(tuple(length for length in initial_lengths)) + fprepare(cache, append_lengths) + + global_new_kv = np.zeros((nlayer, 2, 0, nhead, nfeat), dtype) + for length in initial_lengths: + new_kv = np.random.rand(nlayer, 2, length, nhead, nfeat).astype(dtype) + cached_values.append(new_kv) + global_new_kv = np.concatenate([global_new_kv, new_kv], axis=2) + for layer_id in range(nlayer): + keys = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 0], axis=0)) + values = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 1], axis=0)) + fappend(cache, f_transpose_append, keys, values, layer_id) + + verify_cached_values(cache, cached_values, f_view_cache) + + # Remove + while len(cached_values) > 2: + seq_id = np.random.randint(0, len(cached_values)) + fremove(cache, seq_id) + cached_values.pop(seq_id) + verify_cached_values(cache, cached_values, f_view_cache) + + # Append after removal + seq_id = 2 + new_len = 29 + fprepare(cache, tvm.runtime.ShapeTuple((0, 0, new_len))) + new_kv = np.random.rand(nlayer, 2, new_len, nhead, nfeat).astype(dtype) + cached_values.append(new_kv) + for layer_id in range(nlayer): + keys = tvm.nd.array(np.expand_dims(new_kv[layer_id, 0], axis=0)) + values = tvm.nd.array(np.expand_dims(new_kv[layer_id, 1], axis=0)) + fappend(cache, f_transpose_append, keys, values, layer_id) + + verify_cached_values(cache, cached_values, f_view_cache) + + +def test_paged_attention_kv_cache_popn(): + fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create") + fprepare = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_prepare") + fappend = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_append") + fpopn = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_popn") + f_transpose_append, f_view_cache = build_tir_func([transpose_append, view_cache]) + + cache = fcreate( + tvm.runtime.ShapeTuple([reserved_nseq, total_seq_len, page_size]), + nlayer, + nhead, + nfeat, + tvm.nd.empty((), dtype), + ) + + cached_values = [] + initial_lengths = [20, 24, 26, 27] + nseq = len(initial_lengths) + + # Initial prefill + append_lengths = tvm.runtime.ShapeTuple(tuple(length for length in initial_lengths)) + fprepare(cache, append_lengths) + + global_new_kv = np.zeros((nlayer, 2, 0, nhead, nfeat), dtype) + for length in initial_lengths: + new_kv = np.random.rand(nlayer, 2, length, nhead, nfeat).astype(dtype) + cached_values.append(new_kv) + global_new_kv = np.concatenate([global_new_kv, new_kv], axis=2) + for layer_id in range(nlayer): + keys = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 0], axis=0)) + values = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 1], axis=0)) + fappend(cache, f_transpose_append, keys, values, layer_id) + + verify_cached_values(cache, cached_values, f_view_cache) + + # Pop n + for pop_length in [3, 13]: + for seq_id in range(nseq): + fpopn(cache, seq_id, pop_length) + cached_values[seq_id] = cached_values[seq_id][:, :, :-pop_length, ...] + verify_cached_values(cache, cached_values, f_view_cache) + + # Decode after pop n + for _ in range(5): + decode_new_kv = np.random.rand(nlayer, 2, nseq, 1, nhead, nfeat).astype(dtype) + fprepare(cache) + for seq_id in range(nseq): + cached_values[seq_id] = np.concatenate( + [cached_values[seq_id], decode_new_kv[:, :, seq_id, ...]], axis=2 + ) + for layer_id in range(nlayer): + keys = tvm.nd.array(decode_new_kv[layer_id, 0]) + values = tvm.nd.array(decode_new_kv[layer_id, 1]) + fappend(cache, f_transpose_append, keys, values, layer_id) + + verify_cached_values(cache, cached_values, f_view_cache) + + +def test_paged_attention_kv_cache_clear(): + fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create") + fprepare = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_prepare") + fappend = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_append") + fclear = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_clear") + f_transpose_append, f_view_cache = build_tir_func([transpose_append, view_cache]) + + cache = fcreate( + tvm.runtime.ShapeTuple([reserved_nseq, total_seq_len, page_size]), + nlayer, + nhead, + nfeat, + tvm.nd.empty((), dtype), + ) + + cached_values = [] + initial_lengths = [20, 24, 26, 27] + + # Initial prefill + append_lengths = tvm.runtime.ShapeTuple(tuple(length for length in initial_lengths)) + fprepare(cache, append_lengths) + + global_new_kv = np.zeros((nlayer, 2, 0, nhead, nfeat), dtype) + for length in initial_lengths: + new_kv = np.random.rand(nlayer, 2, length, nhead, nfeat).astype(dtype) + cached_values.append(new_kv) + global_new_kv = np.concatenate([global_new_kv, new_kv], axis=2) + for layer_id in range(nlayer): + keys = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 0], axis=0)) + values = tvm.nd.array(np.expand_dims(global_new_kv[layer_id, 1], axis=0)) + fappend(cache, f_transpose_append, keys, values, layer_id) + + verify_cached_values(cache, cached_values, f_view_cache) + + # Clear + fclear(cache) + verify_cached_values(cache, [], f_view_cache) + + +if __name__ == "__main__": + test_paged_attention_kv_cache_append_prefill() + test_paged_attention_kv_cache_append_decode() + test_paged_attention_kv_cache_remove() + test_paged_attention_kv_cache_popn() + test_paged_attention_kv_cache_clear() + # tvm.testing.main()