diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 828bb89d70baa..cd72c71199090 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -118,14 +118,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: xm.wait_device_ops() m = xm.get_memory_info(self.device) - program_size = 1024 * 1024 * 1024 # 1GB - free_bytes = max(m["bytes_limit"] - m["bytes_used"] - program_size, 0) - kv_cache_bytes = int(free_bytes * - self.cache_config.gpu_memory_utilization) - kv_cache_dtype_btyes = get_dtype_size(self.cache_dtype) + total_memory_size = m["bytes_limit"] + usable_memory_size = int(total_memory_size * + self.cache_config.gpu_memory_utilization) + profiled = m["bytes_used"] # Weights + intermediate activations. + kv_cache_bytes = max(usable_memory_size - profiled, 0) + dtype_btyes = get_dtype_size(self.cache_dtype) block_size = self.cache_config.block_size num_tpu_blocks = (kv_cache_bytes // - (kv_cache_dtype_btyes * block_size * num_layers * 2 * + (dtype_btyes * block_size * num_layers * 2 * head_size * num_kv_heads)) num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. return num_tpu_blocks, 0