Skip to content

Commit

Permalink
update copy blocks test
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao committed May 5, 2024
1 parent 6847846 commit ae7348e
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping = {}
block_mapping = []
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1]
block_mapping[src] = [dst1, dst2]
block_mapping.append((src, dst1))
block_mapping.append((src, dst2))

# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
Expand All @@ -81,15 +82,17 @@ def test_copy_blocks(
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]

# Call the copy blocks kernel.
ops.copy_blocks(key_caches, value_caches, block_mapping)
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device=device).view(-1, 2)
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)

# Run the reference implementation.
for src, dsts in block_mapping.items():
for dst in dsts:
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])
for src, dst in block_mapping:
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])

# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
Expand Down

0 comments on commit ae7348e

Please sign in to comment.