Skip to content

Commit

Permalink
fix typo bug and add test for vllm reconstruct_from_cache kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 11, 2024
1 parent b8230f6 commit de3f53e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/runtime/contrib/vllm/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ __global__ void reconstruct_from_cache_kernel(
block_offset;

key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
value[src_value_idx] = __ldg(&value_cache[tgt_value_idx]);
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
}
}

Expand Down
34 changes: 34 additions & 0 deletions tests/python/relax/test_contrib_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,5 +742,39 @@ def main(
assert np.max(np.abs(out_value_cache - ref_value_cache)) == 0


def test_reconstruct_from_cache():
num_heads = 1
head_dim = 8
vec_size = 8
block_size = 16
num_tokens = 8
num_blocks = 1

dev = tvm.device("cuda", 0)

key = tvm.nd.array(np.random.randn(num_tokens, num_heads, head_dim).astype("float16"), dev)
value = tvm.nd.array(np.random.randn(num_tokens, num_heads, head_dim).astype("float16"), dev)
slot_mapping = tvm.nd.array(np.arange(num_tokens).astype("int32"), dev)

k_cache = tvm.nd.array(
np.random.randn(num_blocks, num_heads, head_dim // vec_size, block_size, vec_size).astype(
"float16"
),
dev,
)
v_cache = tvm.nd.array(
np.random.randn(num_blocks, num_heads, head_dim, block_size).astype("float16"), dev
)

reshape_and_cache_func = tvm.get_global_func("tvm.contrib.vllm.reshape_and_cache")
reconstruct_from_cache_func = tvm.get_global_func("tvm.contrib.vllm.reconstruct_from_cache")

reshape_and_cache_func(key, value, k_cache, v_cache, slot_mapping)
out = reconstruct_from_cache_func(k_cache, v_cache, slot_mapping)

np.testing.assert_equal(key.numpy(), out[0].numpy())
np.testing.assert_equal(value.numpy(), out[1].numpy())


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit de3f53e

Please sign in to comment.