diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json deleted file mode 100644 index 0f99d2597e5..00000000000 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin.json +++ /dev/null @@ -1,84 +0,0 @@ -{ - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6230469, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.046875, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1425781, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.9238281, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.076660156, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10821533, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" -} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json deleted file mode 100644 index 4152b5b308b..00000000000 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_all_params.json +++ /dev/null @@ -1,84 +0,0 @@ -{ - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": 0, - "tokens": [ - { - "id": 13, - "logprob": -2.2539062, - "special": false, - "text": "." - }, - { - "id": 578, - "logprob": -0.15563965, - "special": false, - "text": " The" - }, - { - "id": 3622, - "logprob": -0.8203125, - "special": false, - "text": " server" - }, - { - "id": 706, - "logprob": 0.0, - "special": false, - "text": " has" - }, - { - "id": 539, - "logprob": 0.0, - "special": false, - "text": " not" - }, - { - "id": 3686, - "logprob": 0.0, - "special": false, - "text": " yet" - }, - { - "id": 3288, - "logprob": 0.0, - "special": false, - "text": " sent" - }, - { - "id": 904, - "logprob": 0.0, - "special": false, - "text": " any" - }, - { - "id": 828, - "logprob": 0.0, - "special": false, - "text": " data" - }, - { - "id": 382, - "logprob": -1.5517578, - "special": false, - "text": ".\n\n" - } - ], - "top_tokens": null - }, - "generated_text": "Test request. The server has not yet sent any data.\n\n" -} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json deleted file mode 100644 index 75e903033c4..00000000000 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq_marlin/test_flash_llama_gptq_marlin_load.json +++ /dev/null @@ -1,338 +0,0 @@ -[ - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6220703, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.0410156, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1445312, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.92333984, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.07672119, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10638428, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" - }, - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6220703, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.0410156, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1445312, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.92333984, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.07672119, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10638428, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" - }, - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6220703, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.0410156, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1445312, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.92333984, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.07672119, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10638428, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" - }, - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 2323, - "logprob": null, - "text": "Test" - }, - { - "id": 1715, - "logprob": -11.34375, - "text": " request" - } - ], - "seed": null, - "tokens": [ - { - "id": 198, - "logprob": -2.5742188, - "special": false, - "text": "\n" - }, - { - "id": 262, - "logprob": -1.6220703, - "special": false, - "text": " " - }, - { - "id": 3270, - "logprob": -2.0410156, - "special": false, - "text": " \"\"\"\n" - }, - { - "id": 262, - "logprob": -0.015281677, - "special": false, - "text": " " - }, - { - "id": 422, - "logprob": -2.1445312, - "special": false, - "text": " if" - }, - { - "id": 1715, - "logprob": -0.92333984, - "special": false, - "text": " request" - }, - { - "id": 13204, - "logprob": -0.07672119, - "special": false, - "text": ".method" - }, - { - "id": 624, - "logprob": -0.021987915, - "special": false, - "text": " ==" - }, - { - "id": 364, - "logprob": -0.39208984, - "special": false, - "text": " '" - }, - { - "id": 3019, - "logprob": -0.10638428, - "special": false, - "text": "POST" - } - ], - "top_tokens": null - }, - "generated_text": "\n \"\"\"\n if request.method == 'POST" - } -] diff --git a/integration-tests/models/test_flash_llama_gptq_marlin.py b/integration-tests/models/test_flash_llama_gptq_marlin.py deleted file mode 100644 index 2274abce9ab..00000000000 --- a/integration-tests/models/test_flash_llama_gptq_marlin.py +++ /dev/null @@ -1,68 +0,0 @@ -import pytest - - -@pytest.fixture(scope="module") -def flash_llama_gptq_marlin_handle(launcher): - with launcher( - "astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin" - ) as handle: - yield handle - - -@pytest.fixture(scope="module") -async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle): - await flash_llama_gptq_marlin_handle.health(300) - return flash_llama_gptq_marlin_handle.client - - -@pytest.mark.release -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot): - response = await flash_llama_gptq_marlin.generate( - "Test request", max_new_tokens=10, decoder_input_details=True - ) - - assert response.details.generated_tokens == 10 - assert response == response_snapshot - - -@pytest.mark.release -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_llama_gptq_marlin_all_params( - flash_llama_gptq_marlin, response_snapshot -): - response = await flash_llama_gptq_marlin.generate( - "Test request", - max_new_tokens=10, - repetition_penalty=1.2, - return_full_text=True, - temperature=0.5, - top_p=0.9, - top_k=10, - truncate=5, - typical_p=0.9, - watermark=True, - decoder_input_details=True, - seed=0, - ) - - assert response.details.generated_tokens == 10 - assert response == response_snapshot - - -@pytest.mark.release -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_llama_gptq_marlin_load( - flash_llama_gptq_marlin, generate_load, response_snapshot -): - responses = await generate_load( - flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4 - ) - - assert len(responses) == 4 - assert all([r.generated_text == responses[0].generated_text for r in responses]) - - assert responses == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 816fa5f37f3..d2ca38e5a0e 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -898,13 +898,20 @@ enum LauncherError { WebserverCannotStart, } -fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { +fn download_convert_model( + model_id: &str, + revision: Option<&str>, + trust_remote_code: bool, + huggingface_hub_cache: Option<&str>, + weights_cache_override: Option<&str>, + running: Arc, +) -> Result<(), LauncherError> { // Enter download tracing span let _span = tracing::span!(tracing::Level::INFO, "download").entered(); let mut download_args = vec![ "download-weights".to_string(), - args.model_id.to_string(), + model_id.to_string(), "--extension".to_string(), ".safetensors".to_string(), "--logger-level".to_string(), @@ -913,13 +920,13 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L ]; // Model optional revision - if let Some(revision) = &args.revision { + if let Some(revision) = &revision { download_args.push("--revision".to_string()); download_args.push(revision.to_string()) } // Trust remote code for automatic peft fusion - if args.trust_remote_code { + if trust_remote_code { download_args.push("--trust-remote-code".to_string()); } @@ -934,7 +941,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // If huggingface_hub_cache is set, pass it to the download process // Useful when running inside a docker container - if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { + if let Some(ref huggingface_hub_cache) = huggingface_hub_cache { envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); }; @@ -952,7 +959,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // If args.weights_cache_override is some, pass it to the download process // Useful when running inside a HuggingFace Inference Endpoint - if let Some(weights_cache_override) = &args.weights_cache_override { + if let Some(weights_cache_override) = &weights_cache_override { envs.push(( "WEIGHTS_CACHE_OVERRIDE".into(), weights_cache_override.into(), @@ -960,7 +967,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L }; // Start process - tracing::info!("Starting download process."); + tracing::info!("Starting check and download process for {model_id}"); let mut download_process = match Command::new("text-generation-server") .args(download_args) .env_clear() @@ -1002,7 +1009,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L loop { if let Some(status) = download_process.try_wait().unwrap() { if status.success() { - tracing::info!("Successfully downloaded weights."); + tracing::info!("Successfully downloaded weights for {model_id}"); break; } @@ -1557,7 +1564,28 @@ fn main() -> Result<(), LauncherError> { .expect("Error setting Ctrl-C handler"); // Download and convert model weights - download_convert_model(&args, running.clone())?; + download_convert_model( + &args.model_id, + args.revision.as_deref(), + args.trust_remote_code, + args.huggingface_hub_cache.as_deref(), + args.weights_cache_override.as_deref(), + running.clone(), + )?; + + // Download and convert lora adapters if any + if let Some(lora_adapters) = &args.lora_adapters { + for adapter in lora_adapters.split(',') { + download_convert_model( + adapter, + None, + args.trust_remote_code, + args.huggingface_hub_cache.as_deref(), + args.weights_cache_override.as_deref(), + running.clone(), + )?; + } + } if !running.load(Ordering::SeqCst) { // Launcher was asked to stop diff --git a/router/src/main.rs b/router/src/main.rs index 9a281556ab5..e5c4e607121 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -309,7 +309,7 @@ async fn main() -> Result<(), RouterError> { let mut tokenizer = Tokenizer::from_file(filename).ok(); if let Some(tokenizer) = &mut tokenizer { if let Some(class) = &tokenizer_config.tokenizer_class { - if (class == "LlamaTokenizer" || class == "LlamaTokenizerFast") && tokenizer.get_post_processor().is_none() { + if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{ if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); tokenizer.with_post_processor(post_processor); @@ -577,7 +577,7 @@ pub fn create_post_processor( if add_bos_token { if let Some(bos) = bos_token { - single.push(format!("{}:1", bos.as_str())); + pair.push(format!("{}:1", bos.as_str())); } } diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 1172775f096..56080145028 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -7,6 +7,16 @@ ) +@dataclass +class GPTQParams: + bits: int + checkpoint_format: Optional[str] + groupsize: int + desc_act: bool + quant_method: str + sym: bool + + @dataclass class GPTQWeight: qweight: torch.Tensor diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index dd48465feee..e94e5465cf2 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -166,35 +166,45 @@ def get_linear(weight, bias, quantize): elif quantize == "gptq": from text_generation_server.layers.gptq import GPTQWeight + from text_generation_server.layers.marlin import ( + GPTQMarlinLinear, + GPTQMarlinWeight, + ) - if not isinstance(weight, GPTQWeight): - raise NotImplementedError( - f"The passed weight is not `gptq` compatible, loader needs to be updated." + if isinstance(weight, GPTQMarlinWeight): + linear = GPTQMarlinLinear( + weight=weight, + bias=bias, ) + elif isinstance(weight, GPTQWeight): + if weight.use_exllama: + try: + from text_generation_server.layers.gptq import ( + ExllamaQuantLinear, + ) + except ImportError: + raise NotImplementedError( + f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" + ) - if weight.use_exllama: - try: - from text_generation_server.layers.gptq import ( - ExllamaQuantLinear, - ) - except ImportError: - raise NotImplementedError( - f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" - ) + linear = ExllamaQuantLinear(weight, bias) + else: + from text_generation_server.layers.gptq.quant_linear import QuantLinear - linear = ExllamaQuantLinear(weight, bias) + linear = QuantLinear( + weight.qweight, + weight.qzeros, + weight.scales, + weight.g_idx, + bias, + weight.bits, + weight.groupsize, + ) else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear - - linear = QuantLinear( - weight.qweight, - weight.qzeros, - weight.scales, - weight.g_idx, - bias, - weight.bits, - weight.groupsize, + raise NotImplementedError( + f"The passed weight is not `gptq` compatible, loader needs to be updated." ) + elif quantize == "awq": from text_generation_server.layers.gptq import GPTQWeight @@ -226,18 +236,11 @@ def get_linear(weight, bias, quantize): from text_generation_server.layers.marlin import ( GPTQMarlin24Linear, GPTQMarlin24Weight, - GPTQMarlinLinear, - GPTQMarlinWeight, MarlinLinear, MarlinWeight, ) - if isinstance(weight, GPTQMarlinWeight): - linear = GPTQMarlinLinear( - weight=weight, - bias=bias, - ) - elif isinstance(weight, GPTQMarlin24Weight): + if isinstance(weight, GPTQMarlin24Weight): linear = GPTQMarlin24Linear( weight=weight, bias=bias, diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index 2207b2e41c0..a1af67a3f5f 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -3,6 +3,8 @@ import torch import torch.nn as nn + +from text_generation_server.layers.gptq import GPTQParams from text_generation_server.utils.import_utils import SYSTEM try: @@ -22,6 +24,19 @@ MARLIN_TILE_SIZE = 16 +def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool: + return ( + SYSTEM == "cuda" + and marlin_kernels is not None + and has_sm_8_0 + and quantize == "gptq" + and gptq_params.quant_method == "gptq" + and gptq_params.bits in GPTQ_MARLIN_BITS + and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES + and gptq_params.sym + ) + + def _check_marlin_kernels(): if not (SYSTEM == "cuda" and has_sm_8_0): raise NotImplementedError( diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index c48ed26883f..6b82aeca798 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -309,7 +309,9 @@ def forward(self, hidden_states, adapter_data): dtype=hidden_states.dtype, device="cuda", ) - _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) + _custom_C.LLMM_Silu( + self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 + ) return self.down_proj(out, adapter_data) else: gate_up_states = self.gate_up_proj(hidden_states, adapter_data) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 348d215cdbc..3731fd249f7 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,25 +1,15 @@ import os -from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union from safetensors import safe_open, SafetensorError import torch from loguru import logger from huggingface_hub import hf_hub_download import json +from text_generation_server.layers.gptq import GPTQParams from text_generation_server.utils.log import log_once -@dataclass -class _GPTQParams: - bits: int - checkpoint_format: Optional[str] - groupsize: int - desc_act: bool - quant_method: str - sym: bool - - class Weights: def __init__( self, @@ -212,6 +202,10 @@ def get_weights_col_packed( """ if quantize in ["gptq", "awq"]: from text_generation_server.layers.gptq import GPTQWeight + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) try: qweight = self.get_packed_sharded( @@ -221,17 +215,28 @@ def get_weights_col_packed( raise RuntimeError( f"Cannot load `{quantize}` weight, make sure the model is already quantized." ) + scales = self.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) + scales = scales.to(dtype=self.dtype) gptq_params = self._get_gptq_params() + if can_use_gptq_marlin(gptq_params, quantize): + g_idx = self.get_tensor(f"{prefix}.g_idx") + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=gptq_params.bits, + desc_act=gptq_params.desc_act, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, + sharded_infeatures=False, + ) qzeros = self.get_packed_sharded( f"{prefix}.qzeros", dim=1, block_sizes=block_sizes ) - scales = self.get_packed_sharded( - f"{prefix}.scales", dim=1, block_sizes=block_sizes - ) - scales = scales.to(dtype=self.dtype) - if quantize == "gptq" and gptq_params.quant_method == "gptq": g_idx = self.get_tensor(f"{prefix}.g_idx") elif quantize == "gptq" and gptq_params.quant_method == "awq": @@ -269,7 +274,6 @@ def get_weights_col_packed( repack_gptq_for_marlin, ) - quant_method = getattr(self, "quant_method", "marlin") is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" if is_marlin_24: B = self.get_packed_sharded( @@ -286,31 +290,6 @@ def get_weights_col_packed( weight = GPTQMarlin24Weight( B=B, B_meta=B_meta, s=s, bits=gptq_params.bits ) - elif quant_method == "gptq": - gptq_params = self._get_gptq_params() - try: - qweight = self.get_packed_sharded( - f"{prefix}.qweight", dim=1, block_sizes=block_sizes - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - scales = self.get_packed_sharded( - f"{prefix}.scales", dim=1, block_sizes=block_sizes - ) - g_idx = self.get_tensor(f"{prefix}.g_idx") - weight = repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) else: B = self.get_packed_sharded( f"{prefix}.B", dim=1, block_sizes=block_sizes @@ -356,6 +335,10 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): raise ValueError("get_multi_weights_col is not supported for exl2") elif quantize in ["gptq", "awq"]: from text_generation_server.layers.gptq import GPTQWeight + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) try: qweight = torch.cat( @@ -366,14 +349,31 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): f"Cannot load `{quantize}` weight, make sure the model is already quantized" ) - qzeros = torch.cat( - [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) scales = torch.cat( [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 ) gptq_params = self._get_gptq_params() + if can_use_gptq_marlin(gptq_params, quantize): + w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=gptq_params.bits, + desc_act=gptq_params.desc_act, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, + sharded_infeatures=False, + ) + + qzeros = torch.cat( + [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) from text_generation_server.layers.gptq import HAS_EXLLAMA @@ -425,10 +425,8 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): from text_generation_server.layers.marlin import ( GPTQMarlin24Weight, MarlinWeight, - repack_gptq_for_marlin, ) - quant_method = getattr(self, "quant_method", "marlin") is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" if is_marlin_24: try: @@ -452,36 +450,6 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): weight = GPTQMarlin24Weight( B=B, B_meta=B_meta, s=s, bits=gptq_params.bits ) - elif quant_method == "gptq": - gptq_params = self._get_gptq_params() - try: - qweight = torch.cat( - [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], - dim=1, - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - scales = torch.cat( - [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 - ) - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - weight = repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) else: try: B = torch.cat( @@ -544,9 +512,41 @@ def get_multi_weights_row(self, prefix: str, quantize: str): ) elif quantize == "gptq": - use_exllama = True + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + gptq_params = self._get_gptq_params() + if can_use_gptq_marlin(gptq_params, quantize): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + if gptq_params.desc_act or gptq_params.groupsize == -1: + scales = self.get_tensor(f"{prefix}.scales") + else: + scales = self.get_sharded(f"{prefix}.scales", dim=0) + + sharded_in_features = self.process_group.size() > 1 + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=gptq_params.bits, + desc_act=gptq_params.desc_act, + groupsize=gptq_params.groupsize, + sym=gptq_params.sym, + sharded_infeatures=sharded_in_features, + ) + use_exllama = True if gptq_params.bits != 4: use_exllama = False @@ -672,10 +672,8 @@ def get_multi_weights_row(self, prefix: str, quantize: str): from text_generation_server.layers.marlin import ( GPTQMarlin24Weight, MarlinWeight, - repack_gptq_for_marlin, ) - quant_method = getattr(self, "quant_method", "marlin") is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" if is_marlin_24: try: @@ -698,35 +696,6 @@ def get_multi_weights_row(self, prefix: str, quantize: str): weight = GPTQMarlin24Weight( B=B, B_meta=B_meta, s=s, bits=gptq_params.bits ) - elif quant_method == "gptq": - log_once(logger.info, "Converting GPTQ model to Marlin packing format.") - gptq_params = self._get_gptq_params() - - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - if gptq_params.desc_act or gptq_params.groupsize == -1: - scales = self.get_tensor(f"{prefix}.scales") - else: - scales = self.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = self.process_group.size() > 1 - - weight = repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=sharded_in_features, - ) else: try: B = self.get_sharded(f"{prefix}.B", dim=0) @@ -743,18 +712,17 @@ def get_multi_weights_row(self, prefix: str, quantize: str): else: s = self.get_sharded(f"{prefix}.s", dim=0) weight = MarlinWeight(B=B, s=s) - else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight - def _get_gptq_params(self) -> _GPTQParams: + def _get_gptq_params(self) -> GPTQParams: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() checkpoint_format = getattr(self, "gptq_checkpoint_format", None) desc_act = False - sym = True + sym = False quant_method = "gptq" except (SafetensorError, RuntimeError) as e: try: @@ -767,7 +735,7 @@ def _get_gptq_params(self) -> _GPTQParams: except Exception: raise e - return _GPTQParams( + return GPTQParams( bits=bits, checkpoint_format=checkpoint_format, desc_act=desc_act,