From 70a4fe386561f91fd3e7514e723a5557a810953b Mon Sep 17 00:00:00 2001 From: David Pissarra Date: Thu, 14 Dec 2023 01:15:27 +0000 Subject: [PATCH 1/2] attention sinks with correctness test --- src/runtime/relax_vm/lm_support.cc | 22 +++++++++--- tests/python/relax/test_runtime_builtin.py | 41 ++++++++++++++++++++++ 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc index 6301245dac43..27d2e3eae2cd 100644 --- a/src/runtime/relax_vm/lm_support.cc +++ b/src/runtime/relax_vm/lm_support.cc @@ -116,10 +116,12 @@ class AttentionKVCacheObj : public Object { /*! * \brief Append value to the cache, overrides if full. * \param value The value to override previous elements. + * \param max_cache_size max size of the cache. + * \param num_attention_sinks number of sinks to store (https://arxiv.org/abs/2309.17453). */ - void WindowOverride(NDArray value, int64_t max_cache_size) { + void WindowOverride(NDArray value, int64_t max_cache_size, int64_t num_attention_sinks = 0) { CHECK(data.DataType() == value.DataType()) << "dtype mismatch"; - CHECK_LE(value->shape[0], max_cache_size) << "dim 0 of value too large"; + CHECK_LE(value->shape[0], max_cache_size - num_attention_sinks) << "dim 0 of value too large"; // reallocate cache if (fill_count + value->shape[0] <= max_cache_size) { int64_t reserved_slots = data->shape[0]; @@ -171,7 +173,8 @@ class AttentionKVCacheObj : public Object { num_filled_elements = num_elements_to_copy * num_elements_p_entry; DLTensor copy_dst = *(data.operator->()); - copy_dst.byte_offset = 0; + copy_dst.byte_offset = (num_attention_sinks * num_elements_p_entry) * + ((data->dtype.bits * data->dtype.lanes + 7) / 8); copy_dst.shape = &shape[0]; DLTensor copy_src = *(value.operator->()); @@ -180,7 +183,8 @@ class AttentionKVCacheObj : public Object { copy_src.shape = &shape[0]; NDArray::CopyFromTo(©_src, ©_dst); - this->window_attention_current_pos = value->shape[0] - num_elements_to_copy; + this->window_attention_current_pos = + value->shape[0] - num_elements_to_copy + num_attention_sinks; } } @@ -277,6 +281,16 @@ AttentionKVCache AttentionKVCacheWindowOverride(AttentionKVCache cache, NDArray TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override") .set_body_typed(AttentionKVCacheWindowOverride); +AttentionKVCache AttentionKVCacheWindowOverrideWithSinks(AttentionKVCache cache, NDArray value, + int64_t max_cache_size, + int64_t num_attention_sinks) { + cache->WindowOverride(value, max_cache_size, num_attention_sinks); + return cache; +} + +TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override_with_sinks") + .set_body_typed(AttentionKVCacheWindowOverrideWithSinks); + NDArray AttentionKVCacheView(AttentionKVCache cache, ShapeTuple shape) { return cache->View(shape); } diff --git a/tests/python/relax/test_runtime_builtin.py b/tests/python/relax/test_runtime_builtin.py index 0417f992338e..614d32ce0c7d 100644 --- a/tests/python/relax/test_runtime_builtin.py +++ b/tests/python/relax/test_runtime_builtin.py @@ -217,5 +217,46 @@ def test_attention_kv_cache_window_override(): ).all() +def test_attention_kv_cache_window_override_with_sinks(): + fcreate = tvm.get_global_func("vm.builtin.attention_kv_cache_create") + foverride = tvm.get_global_func("vm.builtin.attention_kv_cache_window_override_with_sinks") + fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view") + + num_attention_sinks = 2 + has_sink = False + current_pos = 0 + + cache = fcreate( + tvm.nd.array(np.full((16, 2), -1).astype("int32")), + tvm.runtime.ShapeTuple([16, 2]), + current_pos, + ) + np_all_arrays = np.zeros((0, 2)).astype("int32") + + num_steps = 40 + for i in range(num_steps): + np_array = i * np.ones((1, 2)).astype("int32") + np_all_arrays = np.concatenate((np_all_arrays, np_array), axis=0) + cache = foverride(cache, tvm.nd.array(np_array), 16, num_attention_sinks) + + if has_sink: + current_pos = max((current_pos + 1) % 16, num_attention_sinks) + else: + current_pos += 1 + has_sink = current_pos >= num_attention_sinks + + res = fview(cache, tvm.runtime.ShapeTuple((16, 2))).numpy() + + # unrotate cache and assert cache matches last 16 elements + assert ( + np.concatenate( + (np_all_arrays[:num_attention_sinks, :], np_all_arrays[-16 + num_attention_sinks :, :]) + ) + == np.concatenate( + (res[:num_attention_sinks], res[current_pos:], res[num_attention_sinks:current_pos]) + ) + ).all() + + if __name__ == "__main__": tvm.testing.main() From bd021c1c7a82a2ac02191b337a53e7feb4de988f Mon Sep 17 00:00:00 2001 From: David Pissarra Date: Thu, 14 Dec 2023 13:28:00 +0000 Subject: [PATCH 2/2] fix override sink --- src/runtime/relax_vm/lm_support.cc | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc index 27d2e3eae2cd..706e2c3d5f48 100644 --- a/src/runtime/relax_vm/lm_support.cc +++ b/src/runtime/relax_vm/lm_support.cc @@ -150,20 +150,22 @@ class AttentionKVCacheObj : public Object { shape.push_back(data->shape[i]); } int64_t num_filled_elements = window_attention_current_pos * num_elements_p_entry; - - DLTensor copy_dst = *(data.operator->()); - copy_dst.byte_offset = num_filled_elements * ((data->dtype.bits * data->dtype.lanes + 7) / 8); - copy_dst.shape = &shape[0]; - - DLTensor copy_src = *(value.operator->()); - copy_src.byte_offset = 0; - copy_src.shape = &shape[0]; - - NDArray::CopyFromTo(©_src, ©_dst); this->fill_count = std::min(this->fill_count + value->shape[0], max_cache_size); this->window_attention_current_pos = std::min(this->window_attention_current_pos + value->shape[0], max_cache_size); + if (num_elements_to_copy > 0) { + DLTensor copy_dst = *(data.operator->()); + copy_dst.byte_offset = num_filled_elements * ((data->dtype.bits * data->dtype.lanes + 7) / 8); + copy_dst.shape = &shape[0]; + + DLTensor copy_src = *(value.operator->()); + copy_src.byte_offset = 0; + copy_src.shape = &shape[0]; + + NDArray::CopyFromTo(©_src, ©_dst); + } + // copy the remainder to the beginning of the cache if (num_elements_to_copy < value->shape[0]) { ICHECK_EQ(this->fill_count, max_cache_size);