Skip to content

Commit

Permalink
fix(server): use mem_get_info to get kv cache size (#664)
Browse files Browse the repository at this point in the history
  • Loading branch information
VerdantCap committed Jul 20, 2023
1 parent a056bf0 commit e352bb1
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0
dim=0,
)

if config.quantize != "gptq":
Expand All @@ -168,7 +168,9 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)


class FlashLlamaAttention(torch.nn.Module):
Expand Down
15 changes: 6 additions & 9 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
device: torch.device,
):
self.block_size = BLOCK_SIZE
self.num_blocks = num_blocks

element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size
Expand Down Expand Up @@ -714,7 +715,6 @@ def warmup(self, batch: FlashCausalLMBatch):
global CACHE_MANAGER

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(self.device)
try:
CACHE_MANAGER = CacheManager(
batch.blocks,
Expand All @@ -731,23 +731,20 @@ def warmup(self, batch: FlashCausalLMBatch):
f"You need to decrease `--max-batch-prefill-tokens`"
) from e

# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize(self.device)
peak_memory = torch.cuda.max_memory_reserved(self.device)

# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
free_memory, _ = torch.cuda.mem_get_info(self.device)

# 0.98 to add some wiggle room
num_blocks = (
int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size)
int(free_memory // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ batch.blocks
+ CACHE_MANAGER.num_blocks
)

del CACHE_MANAGER
Expand Down
5 changes: 3 additions & 2 deletions server/text_generation_server/utils/gptq/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,8 +867,9 @@ def quantize(
)

with init_empty_weights():
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16,
trust_remote_code=trust_remote_code)
model = AutoModelForCausalLM.from_config(
config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
)
model = model.eval()

print("LOADED model")
Expand Down

0 comments on commit e352bb1

Please sign in to comment.