diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 84e22f7..702bcf6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -154,7 +154,7 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], quantize=config.quantize, - dim=0 + dim=0, ) if config.quantize != "gptq": @@ -168,7 +168,9 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize)) + return TensorParallelColumnLinear( + get_linear(weight, bias=None, quantize=config.quantize) + ) class FlashLlamaAttention(torch.nn.Module): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 517fba6..c0592cb 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -39,6 +39,7 @@ def __init__( device: torch.device, ): self.block_size = BLOCK_SIZE + self.num_blocks = num_blocks element_size = torch.tensor([], dtype=dtype).element_size() x = self.block_size // element_size @@ -714,7 +715,6 @@ def warmup(self, batch: FlashCausalLMBatch): global CACHE_MANAGER torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats(self.device) try: CACHE_MANAGER = CacheManager( batch.blocks, @@ -731,23 +731,20 @@ def warmup(self, batch: FlashCausalLMBatch): f"You need to decrease `--max-batch-prefill-tokens`" ) from e - # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. torch.cuda.synchronize(self.device) - peak_memory = torch.cuda.max_memory_reserved(self.device) + # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) + # Calculate the number of blocks that can be allocated with the free memory dtype_size = torch.tensor([], dtype=self.dtype).element_size() cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size - total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory + free_memory, _ = torch.cuda.mem_get_info(self.device) - # 0.98 to add some wiggle room num_blocks = ( - int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size) + int(free_memory // total_cache_size) # Add batch.blocks as we allocated it above, so it is included in the peak memory. - + batch.blocks + + CACHE_MANAGER.num_blocks ) del CACHE_MANAGER diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index 160c9c9..45b01ae 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -867,8 +867,9 @@ def quantize( ) with init_empty_weights(): - model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16, - trust_remote_code=trust_remote_code) + model = AutoModelForCausalLM.from_config( + config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code + ) model = model.eval() print("LOADED model")