Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][lm_support] window kvcache sink #16240

Merged
merged 2 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions src/runtime/relax_vm/lm_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ class AttentionKVCacheObj : public Object {
/*!
* \brief Append value to the cache, overrides if full.
* \param value The value to override previous elements.
* \param max_cache_size max size of the cache.
* \param num_attention_sinks number of sinks to store (https://arxiv.org/abs/2309.17453).
*/
void WindowOverride(NDArray value, int64_t max_cache_size) {
void WindowOverride(NDArray value, int64_t max_cache_size, int64_t num_attention_sinks = 0) {
CHECK(data.DataType() == value.DataType()) << "dtype mismatch";
CHECK_LE(value->shape[0], max_cache_size) << "dim 0 of value too large";
CHECK_LE(value->shape[0], max_cache_size - num_attention_sinks) << "dim 0 of value too large";
// reallocate cache
if (fill_count + value->shape[0] <= max_cache_size) {
int64_t reserved_slots = data->shape[0];
Expand Down Expand Up @@ -148,20 +150,22 @@ class AttentionKVCacheObj : public Object {
shape.push_back(data->shape[i]);
}
int64_t num_filled_elements = window_attention_current_pos * num_elements_p_entry;

DLTensor copy_dst = *(data.operator->());
copy_dst.byte_offset = num_filled_elements * ((data->dtype.bits * data->dtype.lanes + 7) / 8);
copy_dst.shape = &shape[0];

DLTensor copy_src = *(value.operator->());
copy_src.byte_offset = 0;
copy_src.shape = &shape[0];

NDArray::CopyFromTo(&copy_src, &copy_dst);
this->fill_count = std::min(this->fill_count + value->shape[0], max_cache_size);
this->window_attention_current_pos =
std::min(this->window_attention_current_pos + value->shape[0], max_cache_size);

if (num_elements_to_copy > 0) {
DLTensor copy_dst = *(data.operator->());
copy_dst.byte_offset = num_filled_elements * ((data->dtype.bits * data->dtype.lanes + 7) / 8);
copy_dst.shape = &shape[0];

DLTensor copy_src = *(value.operator->());
copy_src.byte_offset = 0;
copy_src.shape = &shape[0];

NDArray::CopyFromTo(&copy_src, &copy_dst);
}

// copy the remainder to the beginning of the cache
if (num_elements_to_copy < value->shape[0]) {
ICHECK_EQ(this->fill_count, max_cache_size);
Expand All @@ -171,7 +175,8 @@ class AttentionKVCacheObj : public Object {
num_filled_elements = num_elements_to_copy * num_elements_p_entry;

DLTensor copy_dst = *(data.operator->());
copy_dst.byte_offset = 0;
copy_dst.byte_offset = (num_attention_sinks * num_elements_p_entry) *
((data->dtype.bits * data->dtype.lanes + 7) / 8);
copy_dst.shape = &shape[0];

DLTensor copy_src = *(value.operator->());
Expand All @@ -180,7 +185,8 @@ class AttentionKVCacheObj : public Object {
copy_src.shape = &shape[0];

NDArray::CopyFromTo(&copy_src, &copy_dst);
this->window_attention_current_pos = value->shape[0] - num_elements_to_copy;
this->window_attention_current_pos =
value->shape[0] - num_elements_to_copy + num_attention_sinks;
}
}

Expand Down Expand Up @@ -277,6 +283,16 @@ AttentionKVCache AttentionKVCacheWindowOverride(AttentionKVCache cache, NDArray
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override")
.set_body_typed(AttentionKVCacheWindowOverride);

AttentionKVCache AttentionKVCacheWindowOverrideWithSinks(AttentionKVCache cache, NDArray value,
int64_t max_cache_size,
int64_t num_attention_sinks) {
cache->WindowOverride(value, max_cache_size, num_attention_sinks);
return cache;
}

TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override_with_sinks")
.set_body_typed(AttentionKVCacheWindowOverrideWithSinks);

NDArray AttentionKVCacheView(AttentionKVCache cache, ShapeTuple shape) {
return cache->View(shape);
}
Expand Down
41 changes: 41 additions & 0 deletions tests/python/relax/test_runtime_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,5 +217,46 @@ def test_attention_kv_cache_window_override():
).all()


def test_attention_kv_cache_window_override_with_sinks():
fcreate = tvm.get_global_func("vm.builtin.attention_kv_cache_create")
foverride = tvm.get_global_func("vm.builtin.attention_kv_cache_window_override_with_sinks")
fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view")

num_attention_sinks = 2
has_sink = False
current_pos = 0

cache = fcreate(
tvm.nd.array(np.full((16, 2), -1).astype("int32")),
tvm.runtime.ShapeTuple([16, 2]),
current_pos,
)
np_all_arrays = np.zeros((0, 2)).astype("int32")

num_steps = 40
for i in range(num_steps):
np_array = i * np.ones((1, 2)).astype("int32")
np_all_arrays = np.concatenate((np_all_arrays, np_array), axis=0)
cache = foverride(cache, tvm.nd.array(np_array), 16, num_attention_sinks)

if has_sink:
current_pos = max((current_pos + 1) % 16, num_attention_sinks)
else:
current_pos += 1
has_sink = current_pos >= num_attention_sinks

res = fview(cache, tvm.runtime.ShapeTuple((16, 2))).numpy()

# unrotate cache and assert cache matches last 16 elements
assert (
np.concatenate(
(np_all_arrays[:num_attention_sinks, :], np_all_arrays[-16 + num_attention_sinks :, :])
)
== np.concatenate(
(res[:num_attention_sinks], res[current_pos:], res[num_attention_sinks:current_pos])
)
).all()


if __name__ == "__main__":
tvm.testing.main()
Loading