From de3f53e3b34060eff0855a2619dedcd404e14467 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 9 Jan 2024 23:43:12 +0000 Subject: [PATCH] fix typo bug and add test for vllm reconstruct_from_cache kernel --- src/runtime/contrib/vllm/cache_kernels.cu | 2 +- tests/python/relax/test_contrib_vllm.py | 34 +++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/runtime/contrib/vllm/cache_kernels.cu b/src/runtime/contrib/vllm/cache_kernels.cu index 537ff31fd0a6..b53cd094c1aa 100644 --- a/src/runtime/contrib/vllm/cache_kernels.cu +++ b/src/runtime/contrib/vllm/cache_kernels.cu @@ -94,7 +94,7 @@ __global__ void reconstruct_from_cache_kernel( block_offset; key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); - value[src_value_idx] = __ldg(&value_cache[tgt_value_idx]); + value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); } } diff --git a/tests/python/relax/test_contrib_vllm.py b/tests/python/relax/test_contrib_vllm.py index b674e1c9fb2f..dd2149e572cf 100644 --- a/tests/python/relax/test_contrib_vllm.py +++ b/tests/python/relax/test_contrib_vllm.py @@ -742,5 +742,39 @@ def main( assert np.max(np.abs(out_value_cache - ref_value_cache)) == 0 +def test_reconstruct_from_cache(): + num_heads = 1 + head_dim = 8 + vec_size = 8 + block_size = 16 + num_tokens = 8 + num_blocks = 1 + + dev = tvm.device("cuda", 0) + + key = tvm.nd.array(np.random.randn(num_tokens, num_heads, head_dim).astype("float16"), dev) + value = tvm.nd.array(np.random.randn(num_tokens, num_heads, head_dim).astype("float16"), dev) + slot_mapping = tvm.nd.array(np.arange(num_tokens).astype("int32"), dev) + + k_cache = tvm.nd.array( + np.random.randn(num_blocks, num_heads, head_dim // vec_size, block_size, vec_size).astype( + "float16" + ), + dev, + ) + v_cache = tvm.nd.array( + np.random.randn(num_blocks, num_heads, head_dim, block_size).astype("float16"), dev + ) + + reshape_and_cache_func = tvm.get_global_func("tvm.contrib.vllm.reshape_and_cache") + reconstruct_from_cache_func = tvm.get_global_func("tvm.contrib.vllm.reconstruct_from_cache") + + reshape_and_cache_func(key, value, k_cache, v_cache, slot_mapping) + out = reconstruct_from_cache_func(k_cache, v_cache, slot_mapping) + + np.testing.assert_equal(key.numpy(), out[0].numpy()) + np.testing.assert_equal(value.numpy(), out[1].numpy()) + + if __name__ == "__main__": tvm.testing.main()