Skip to content

Commit

Permalink
[Bugfix][TPU] Fix KV cache size calculation (vllm-project#5860)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored and prashantgupta24 committed Jul 1, 2024
1 parent 6150a0b commit 66b592f
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
xm.wait_device_ops()

m = xm.get_memory_info(self.device)
program_size = 1024 * 1024 * 1024 # 1GB
free_bytes = max(m["bytes_limit"] - m["bytes_used"] - program_size, 0)
kv_cache_bytes = int(free_bytes *
self.cache_config.gpu_memory_utilization)
kv_cache_dtype_btyes = get_dtype_size(self.cache_dtype)
total_memory_size = m["bytes_limit"]
usable_memory_size = int(total_memory_size *
self.cache_config.gpu_memory_utilization)
profiled = m["bytes_used"] # Weights + intermediate activations.
kv_cache_bytes = max(usable_memory_size - profiled, 0)
dtype_btyes = get_dtype_size(self.cache_dtype)
block_size = self.cache_config.block_size
num_tpu_blocks = (kv_cache_bytes //
(kv_cache_dtype_btyes * block_size * num_layers * 2 *
(dtype_btyes * block_size * num_layers * 2 *
head_size * num_kv_heads))
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
return num_tpu_blocks, 0
Expand Down

0 comments on commit 66b592f

Please sign in to comment.