From bf19419af9cb3ca7493e92f37c3a6acfc8012845 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 16 Apr 2024 10:50:40 -0700 Subject: [PATCH 01/16] add fp8 --- vllm/entrypoints/llm.py | 2 +- vllm/model_executor/layers/linear.py | 7 ++ .../layers/quantization/__init__.py | 2 + .../model_executor/layers/quantization/fp8.py | 113 ++++++++++++++++++ .../model_loader/weight_utils.py | 9 +- 5 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/fp8.py diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9e08c253dc539..90914cb5b906e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -42,7 +42,7 @@ class LLM: However, if the `torch_dtype` in the config is `float32`, we will use `float16` instead. quantization: The method used to quantize the model weights. Currently, - we support "awq", "gptq" and "squeezellm". If None, we first check + we support "awq", "gptq", "fp8 and "squeezellm". If None, we first check the `quantization_config` attribute in the model config file. If that is None, we assume the model weights are not quantized and use `dtype` to determine the data type of the weights. diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3ca870742efc5..b3e3ac15b0d9d 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -38,6 +38,9 @@ def create_weights(self, layer: torch.nn.Module, The weights will be set as attributes of the layer.""" raise NotImplementedError + def postproc_weights(self, layer: torch.nn.Module): + """Postprocess the weights after loading (optional).""" + @abstractmethod def apply_weights(self, layer: torch.nn.Module, @@ -209,6 +212,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) assert param_data.shape == loaded_weight.shape + loaded_weight = self.linear_method.postproc_weights(loaded_weight) param_data.copy_(loaded_weight) def forward(self, input_): @@ -328,6 +332,7 @@ def weight_loader(self, "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") assert param_data.shape == loaded_weight.shape + loaded_weight = self.linear_method.postproc_weights(loaded_weight) param_data.copy_(loaded_weight) @@ -469,6 +474,7 @@ def weight_loader(self, "QKVParallelLinear, assume the weight is the same " "for all partitions.") assert param_data.shape == loaded_weight.shape + loaded_weight = self.linear_method.postproc_weights(loaded_weight) param_data.copy_(loaded_weight) @@ -558,6 +564,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) assert param_data.shape == loaded_weight.shape + loaded_weight = self.linear_method.postproc_weights(loaded_weight) param_data.copy_(loaded_weight) def forward(self, input_): diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index a3b89a66469eb..0344d6e4e3e45 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -3,12 +3,14 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.fp8 import FP8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig QUANTIZATION_METHODS = { "awq": AWQConfig, + "fp8": FP8Config, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, "marlin": MarlinConfig, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py new file mode 100644 index 0000000000000..29ab2b8cee20e --- /dev/null +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -0,0 +1,113 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + + +class FP8Config(QuantizationConfig): + """Config class for FP8.""" + + @classmethod + def get_name(cls) -> str: + return "fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [ + torch.bfloat16, torch.half, torch.float8_e4m3fn, torch.float8_e5m2 + ] + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "FP8Config": + return cls() + + def get_linear_method(self) -> "Fp8LinearMethod": + return Fp8LinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class Fp8LinearMethod(LinearMethodBase): + """Linear method for FP8. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: FP8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_size_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def postproc_weights(self, weight: torch.Tensor): + qweight, scale = per_tensor_quantize(weight) + self.scale = scale + return qweight + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qinput, scale = per_tensor_quantize(x) + output, _ = torch._scaled_mm( + qinput, + layer.weight.t(), + out_dtype=x.dtype, + scale_a=scale, + scale_b=self.scale, + bias=bias, + ) + return output + + +def per_tensor_quantize( + tensor: torch.Tensor, + qdtype=torch.float8_e4m3fn) -> tuple[torch.Tensor, float]: + """Quantize a tensor using per-tensor static scaling factor. + + Args: + tensor: The input tensor. + qdtype: The quantized data type. + """ + finfo = torch.finfo(qdtype) + # Calculate the scale as dtype max divided by absmax + scale = finfo.max / tensor.abs().max().clamp(min=1e-12) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + qweight = qweight.to(qdtype) + scale = scale.float().reciprocal() + return qweight, scale diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 1798db0136868..9995f2afe3cf7 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -134,11 +134,18 @@ def get_quant_config(model_config: ModelConfig, tqdm_class=DisabledTqdm) else: hf_folder = model_name_or_path + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + config_files = glob.glob(os.path.join(hf_folder, "*.json")) quant_config_files = [ f for f in config_files if any( - f.endswith(x) for x in quant_cls.get_config_filenames()) + f.endswith(x) for x in possible_config_filenames) ] if len(quant_config_files) == 0: raise ValueError( From 65a999973e491c1c56b03c3b930b87f2f6c7daf0 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 16 Apr 2024 12:38:58 -0700 Subject: [PATCH 02/16] work --- vllm/model_executor/layers/linear.py | 7 ----- .../model_executor/layers/quantization/fp8.py | 31 +++++++++++++++---- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b3e3ac15b0d9d..3ca870742efc5 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -38,9 +38,6 @@ def create_weights(self, layer: torch.nn.Module, The weights will be set as attributes of the layer.""" raise NotImplementedError - def postproc_weights(self, layer: torch.nn.Module): - """Postprocess the weights after loading (optional).""" - @abstractmethod def apply_weights(self, layer: torch.nn.Module, @@ -212,7 +209,6 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) assert param_data.shape == loaded_weight.shape - loaded_weight = self.linear_method.postproc_weights(loaded_weight) param_data.copy_(loaded_weight) def forward(self, input_): @@ -332,7 +328,6 @@ def weight_loader(self, "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") assert param_data.shape == loaded_weight.shape - loaded_weight = self.linear_method.postproc_weights(loaded_weight) param_data.copy_(loaded_weight) @@ -474,7 +469,6 @@ def weight_loader(self, "QKVParallelLinear, assume the weight is the same " "for all partitions.") assert param_data.shape == loaded_weight.shape - loaded_weight = self.linear_method.postproc_weights(loaded_weight) param_data.copy_(loaded_weight) @@ -564,7 +558,6 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) assert param_data.shape == loaded_weight.shape - loaded_weight = self.linear_method.postproc_weights(loaded_weight) param_data.copy_(loaded_weight) def forward(self, input_): diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 29ab2b8cee20e..9a14445d6bf08 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,3 +1,5 @@ +import enum +from enum import Enum from typing import Any, Dict, List, Optional import torch @@ -41,6 +43,12 @@ def get_scaled_act_names(self) -> List[str]: return [] +class Fp8LinearState(Enum): + + UNINITIALIZED = enum.auto() + READY = enum.auto() + + class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. @@ -69,27 +77,38 @@ def create_weights( layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) - def postproc_weights(self, weight: torch.Tensor): - qweight, scale = per_tensor_quantize(weight) - self.scale = scale - return qweight + scale = Parameter( + torch.empty(1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("scale", scale) + set_weight_attrs(scale, extra_weight_attrs) + layer.fp8_linear_state = Fp8LinearState.UNINITIALIZED def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: qinput, scale = per_tensor_quantize(x) + + if layer.fp8_linear_state == Fp8LinearState.UNINITIALIZED: + qweight, weight_scale = per_tensor_quantize(layer.weight) + layer.weight.data = qweight.t() + layer.scale.data = weight_scale + layer.fp8_linear_state = Fp8LinearState.READY + output, _ = torch._scaled_mm( qinput, - layer.weight.t(), + layer.weight, out_dtype=x.dtype, scale_a=scale, - scale_b=self.scale, + scale_b=layer.scale, bias=bias, ) return output +@torch.compile def per_tensor_quantize( tensor: torch.Tensor, qdtype=torch.float8_e4m3fn) -> tuple[torch.Tensor, float]: From ad07afe9baacff5d4ff7d1f61b87a713ec303421 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 16 Apr 2024 12:51:35 -0700 Subject: [PATCH 03/16] lint --- vllm/entrypoints/llm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 90914cb5b906e..8f64d7ea1f69b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -42,10 +42,10 @@ class LLM: However, if the `torch_dtype` in the config is `float32`, we will use `float16` instead. quantization: The method used to quantize the model weights. Currently, - we support "awq", "gptq", "fp8 and "squeezellm". If None, we first check - the `quantization_config` attribute in the model config file. If - that is None, we assume the model weights are not quantized and use - `dtype` to determine the data type of the weights. + we support "awq", "gptq", "fp8 and "squeezellm". If None, we first + check the `quantization_config` attribute in the model config file. + If that is None, we assume the model weights are not quantized and + use `dtype` to determine the data type of the weights. revision: The specific model version to use. It can be a branch name, a tag name, or a commit id. tokenizer_revision: The specific tokenizer version to use. It can be a From f2ff3e54b56c3a2ec5c6d5a15f8d88b5743005c5 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 16 Apr 2024 16:48:31 -0700 Subject: [PATCH 04/16] comments --- requirements-cuda.txt | 2 +- .../model_executor/layers/quantization/fp8.py | 28 +++++++++++++------ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index c6d2cd46aee54..da53d35aa4ae0 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -1,5 +1,5 @@ # Common dependencies --r requirements-common.txt +# -r requirements-common.txt # Dependencies for NVIDIA GPUs ray >= 2.9 diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9a14445d6bf08..1068d931cea9a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -20,13 +20,14 @@ def get_name(cls) -> str: @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [ - torch.bfloat16, torch.half, torch.float8_e4m3fn, torch.float8_e5m2 - ] + return [torch.bfloat16, torch.half] @classmethod def get_min_capability(cls) -> int: - return 90 + # PyTorch 2.3.0+ is required to run FP8 on SM 89 (e.g. Ada) GPUs. + # Specifially, this PR has to be included: + # https://github.com/pytorch/pytorch/pull/118881/files + return 89 @classmethod def get_config_filenames(cls) -> List[str]: @@ -77,6 +78,14 @@ def create_weights( layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) + qweight = Parameter(torch.empty(input_size_per_partition, + output_size_per_partition, + dtype=torch.float8_e4m3fn), + requires_grad=False) + set_weight_attrs(qweight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + scale = Parameter( torch.empty(1, dtype=torch.float32), requires_grad=False, @@ -92,14 +101,15 @@ def apply_weights(self, qinput, scale = per_tensor_quantize(x) if layer.fp8_linear_state == Fp8LinearState.UNINITIALIZED: - qweight, weight_scale = per_tensor_quantize(layer.weight) - layer.weight.data = qweight.t() + qweight, weight_scale = per_tensor_quantize(layer.weight.data) + layer.weight.data = torch.empty_like(layer.weight.data) + layer.qweight.data = qweight.t() layer.scale.data = weight_scale layer.fp8_linear_state = Fp8LinearState.READY output, _ = torch._scaled_mm( qinput, - layer.weight, + layer.qweight, out_dtype=x.dtype, scale_a=scale, scale_b=layer.scale, @@ -120,7 +130,9 @@ def per_tensor_quantize( """ finfo = torch.finfo(qdtype) # Calculate the scale as dtype max divided by absmax - scale = finfo.max / tensor.abs().max().clamp(min=1e-12) + min_val, max_val = tensor.aminmax() + amax = max(-min_val, max_val) + scale = finfo.max / amax.clamp(min=1e-12) # scale and clamp the tensor to bring it to # the representative range of float8 data type # (as default cast is unsaturated) From dbe46a6ca5ee097bc233c80f5312fd3f564afc85 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 17 Apr 2024 16:35:13 -0700 Subject: [PATCH 05/16] done --- .../model_executor/layers/quantization/fp8.py | 46 ++++++++++--------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1068d931cea9a..0d414752ca21c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -25,7 +25,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: # PyTorch 2.3.0+ is required to run FP8 on SM 89 (e.g. Ada) GPUs. - # Specifially, this PR has to be included: + # Specifically, this PR has to be included: # https://github.com/pytorch/pytorch/pull/118881/files return 89 @@ -52,7 +52,15 @@ class Fp8LinearState(Enum): class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. - + We now support two types of model checkpoints: + 1. Common FP16/BF16 model checkpoints. In this case, the weight + scaling factor will be initialized during the first forward pass. + 2. FP8 model checkpoints. In this case, the weight scaling factor + will be loaded from the checkpoint with parameter name "w_scale", + and the weight should already in FP8 and has been transposed. + + Note that we currently only support per-tensor quantization. + Args: quant_config: The quantization config. """ @@ -78,41 +86,37 @@ def create_weights( layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) - qweight = Parameter(torch.empty(input_size_per_partition, - output_size_per_partition, - dtype=torch.float8_e4m3fn), - requires_grad=False) - set_weight_attrs(qweight, {"input_dim": 1, "output_dim": 0}) - layer.register_parameter("qweight", qweight) - set_weight_attrs(qweight, extra_weight_attrs) - - scale = Parameter( + # Will be loaded from FP8 checkpoints, or initialized for + # FP16/BF16 checkpoints during the first forward pass. + w_scale = Parameter( torch.empty(1, dtype=torch.float32), requires_grad=False, ) - layer.register_parameter("scale", scale) - set_weight_attrs(scale, extra_weight_attrs) + layer.register_parameter("w_scale", w_scale) + set_weight_attrs(w_scale, extra_weight_attrs) layer.fp8_linear_state = Fp8LinearState.UNINITIALIZED def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qinput, scale = per_tensor_quantize(x) if layer.fp8_linear_state == Fp8LinearState.UNINITIALIZED: - qweight, weight_scale = per_tensor_quantize(layer.weight.data) - layer.weight.data = torch.empty_like(layer.weight.data) - layer.qweight.data = qweight.t() - layer.scale.data = weight_scale + # Per-tensor scaling on the fly for FP16/BF16 weights + # during the first forward pass. + if layer.weight.dtype != torch.float8_e4m3fn: + qweight, weight_scale = per_tensor_quantize(layer.weight.data) + layer.weight.data = qweight.t() + layer.w_scale.data = weight_scale layer.fp8_linear_state = Fp8LinearState.READY + qinput, x_scale = per_tensor_quantize(x) output, _ = torch._scaled_mm( qinput, - layer.qweight, + layer.weight, out_dtype=x.dtype, - scale_a=scale, - scale_b=layer.scale, + scale_a=x_scale, + scale_b=layer.w_scale, bias=bias, ) return output From 4a8d9232e7a8e63d906d7cab8d4d64c94f2f96d5 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 17 Apr 2024 16:39:38 -0700 Subject: [PATCH 06/16] revert --- requirements-cuda.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index da53d35aa4ae0..c6d2cd46aee54 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -1,5 +1,5 @@ # Common dependencies -# -r requirements-common.txt +-r requirements-common.txt # Dependencies for NVIDIA GPUs ray >= 2.9 From 4bd55a52c987026e4b50a6be4e20cfcb35b42284 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 18 Apr 2024 09:44:16 -0700 Subject: [PATCH 07/16] comment --- vllm/entrypoints/llm.py | 9 +++++---- vllm/model_executor/layers/quantization/fp8.py | 10 +++++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 8f64d7ea1f69b..961de5d5063fa 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -42,10 +42,11 @@ class LLM: However, if the `torch_dtype` in the config is `float32`, we will use `float16` instead. quantization: The method used to quantize the model weights. Currently, - we support "awq", "gptq", "fp8 and "squeezellm". If None, we first - check the `quantization_config` attribute in the model config file. - If that is None, we assume the model weights are not quantized and - use `dtype` to determine the data type of the weights. + we support "awq", "gptq", "squeezellm", and "fp8" (experimental). + If None, we first check the `quantization_config` attribute in the + model config file. If that is None, we assume the model weights are + not quantized and use `dtype` to determine the data type of + the weights. revision: The specific model version to use. It can be a branch name, a tag name, or a commit id. tokenizer_revision: The specific tokenizer version to use. It can be a diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0d414752ca21c..db084769f486a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -94,6 +94,11 @@ def create_weights( ) layer.register_parameter("w_scale", w_scale) set_weight_attrs(w_scale, extra_weight_attrs) + + # We always initialize the state to UNINITIALIZED because + # we cannot know whether weights going to be loaded are in FP8 + # or not. If they are in FP8, the first forward pass simply + # sets the state to READY without re-quantization. layer.fp8_linear_state = Fp8LinearState.UNINITIALIZED def apply_weights(self, @@ -103,9 +108,12 @@ def apply_weights(self, if layer.fp8_linear_state == Fp8LinearState.UNINITIALIZED: # Per-tensor scaling on the fly for FP16/BF16 weights - # during the first forward pass. + # during the first forward pass if the loaded weights + # are not in FP8. if layer.weight.dtype != torch.float8_e4m3fn: qweight, weight_scale = per_tensor_quantize(layer.weight.data) + # Note that torch._scaled_mm requires column-major in + # the second input, so we transpose the quantized weight here. layer.weight.data = qweight.t() layer.w_scale.data = weight_scale layer.fp8_linear_state = Fp8LinearState.READY From 2dee644a91cd2d3cd220031256cce9ed9bd2e60d Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 18 Apr 2024 15:03:00 -0700 Subject: [PATCH 08/16] wip --- tests/quantization/test_fp8.py | 28 +++++++ vllm/model_executor/layers/linear.py | 30 +++++++ .../model_executor/layers/quantization/fp8.py | 81 +++++++++++-------- vllm/model_executor/model_loader/loader.py | 3 + 4 files changed, 109 insertions(+), 33 deletions(-) create mode 100644 tests/quantization/test_fp8.py diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py new file mode 100644 index 0000000000000..9074a0fd4f29f --- /dev/null +++ b/tests/quantization/test_fp8.py @@ -0,0 +1,28 @@ +"""Tests whether FP8 computation is enabled correctly. + +Run `pytest tests/quantization/test_fp8.py --forked`. +""" +import pytest + +import torch + +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod, Fp8LinearState + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] + + +@pytest.mark.skipif(capability + < QUANTIZATION_METHODS["fp8"].get_min_capability(), + reason="FP8 is not supported on this GPU type.") +def test_load_fp16_model(vllm_runner) -> None: + llm = vllm_runner("facebook/opt-125m", quantization="fp8") + + model = llm.llm_engine.model_executor.driver_worker.model_runner.model + fc1 = model.model.decoder.layers[0].fc1 + assert isinstance(fc1.linaer_method, Fp8LinearMethod) + + # The engine has been warmed up and the state should be ready. + assert fc1.fp8_linear_state == Fp8LinearState.READY + assert fc1.weight.dtype == torch.float8_e4m3fn diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3ca870742efc5..f57d7a84898fc 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F +from torch import nn from torch.nn.parameter import Parameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -48,6 +49,23 @@ def apply_weights(self, Expects create_weights to have been called before on the layer.""" raise NotImplementedError + def proc_before_loading(self, layer: nn.Module, param: Parameter, + loaded_weight: torch.Tensor) -> torch.Tensor: + """Process the weight before loading. + + This can be used for exmaple, quantizing the weight before + loading it to the model. + """ + return loaded_weight + + def proc_after_loading(self, layer: nn.Module, param: Parameter, + loaded_weight: torch.Tensor) -> None: + """Process the weight after loading. + + This can be used for example, to transpose weights for computation. + """ + pass + class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization. @@ -200,6 +218,9 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + loaded_weight = self.linear_method.proc_before_loading( + self, param, loaded_weight) + tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) param_data = param.data @@ -266,6 +287,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None): + loaded_weight = self.linear_method.proc_before_loading( + self, param, loaded_weight) + param_data = param.data output_dim = getattr(param, "output_dim", None) if loaded_shard_id is None: @@ -392,6 +416,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[str] = None): + loaded_weight = self.linear_method.proc_before_loading( + self, param, loaded_weight) + param_data = param.data output_dim = getattr(param, "output_dim", None) @@ -549,6 +576,9 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + loaded_weight = self.linear_method.proc_before_loading( + self, param, loaded_weight) + tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) param_data = param.data diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index db084769f486a..7c8893ec1ec5a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -3,6 +3,8 @@ from typing import Any, Dict, List, Optional import torch +import torch._dynamo +from torch.nn import Module from torch.nn.parameter import Parameter from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -10,6 +12,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +torch._dynamo.config.suppress_errors = True class FP8Config(QuantizationConfig): """Config class for FP8.""" @@ -24,10 +27,10 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - # PyTorch 2.3.0+ is required to run FP8 on SM 89 (e.g. Ada) GPUs. - # Specifically, this PR has to be included: - # https://github.com/pytorch/pytorch/pull/118881/files - return 89 + # TODO: PyTorch 2.3.0+ is required to run FP8 on + # SM 89 (e.g. Ada) GPUs. Specifically, this PR has to + # be included: https://github.com/pytorch/pytorch/pull/118881 + return 90 @classmethod def get_config_filenames(cls) -> List[str]: @@ -52,14 +55,11 @@ class Fp8LinearState(Enum): class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. - We now support two types of model checkpoints: - 1. Common FP16/BF16 model checkpoints. In this case, the weight - scaling factor will be initialized during the first forward pass. - 2. FP8 model checkpoints. In this case, the weight scaling factor - will be loaded from the checkpoint with parameter name "w_scale", - and the weight should already in FP8 and has been transposed. - - Note that we currently only support per-tensor quantization. + We now support common FP16/BF16 model checkpoints ONLY. The weight + scaling factor will be initialized during the first forward pass. + + Note that we currently only support per-tensor quantization due to + torch._scaled_mm support. Args: quant_config: The quantization config. @@ -80,20 +80,19 @@ def create_weights( ): weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, - dtype=params_dtype), + dtype=torch.float8_e4m3fn), requires_grad=False) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.register_parameter("weight", weight) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, extra_weight_attrs) - # Will be loaded from FP8 checkpoints, or initialized for - # FP16/BF16 checkpoints during the first forward pass. + # Will be initialized for FP16/BF16 checkpoints + # during the first forward pass. w_scale = Parameter( torch.empty(1, dtype=torch.float32), requires_grad=False, ) - layer.register_parameter("w_scale", w_scale) - set_weight_attrs(w_scale, extra_weight_attrs) + layer.register_parameter("weight_scaling_factor", w_scale) # We always initialize the state to UNINITIALIZED because # we cannot know whether weights going to be loaded are in FP8 @@ -101,22 +100,37 @@ def create_weights( # sets the state to READY without re-quantization. layer.fp8_linear_state = Fp8LinearState.UNINITIALIZED + def proc_before_loading(self, layer: Module, param: Parameter, + loaded_weight: torch.Tensor) -> torch.Tensor: + if loaded_weight.dtype != torch.float8_e4m3fn: + loaded_weight, weight_scale = per_tensor_quantize(loaded_weight) + # If the loaded weight is not in FP8, we override the + # weight and scaling factor with the quantized weight. + layer.weight_scaling_factor.data.copy_(weight_scale) + return loaded_weight + + def proc_after_loading(self, layer: Module) -> None: + # Note that torch._scaled_mm requires column-major in + # the second input (weight), so we transpose the quantized + # weight here. + layer.weight.data = layer.weight.data.t() + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - if layer.fp8_linear_state == Fp8LinearState.UNINITIALIZED: - # Per-tensor scaling on the fly for FP16/BF16 weights - # during the first forward pass if the loaded weights - # are not in FP8. - if layer.weight.dtype != torch.float8_e4m3fn: - qweight, weight_scale = per_tensor_quantize(layer.weight.data) - # Note that torch._scaled_mm requires column-major in - # the second input, so we transpose the quantized weight here. - layer.weight.data = qweight.t() - layer.w_scale.data = weight_scale - layer.fp8_linear_state = Fp8LinearState.READY + # if layer.fp8_linear_state == Fp8LinearState.UNINITIALIZED: + # # Per-tensor scaling on the fly for FP16/BF16 weights + # # during the first forward pass if the loaded weights + # # are not in FP8. + # if layer.weight.dtype != torch.float8_e4m3fn: + # qweight, weight_scale = per_tensor_quantize(layer.weight.data) + # # Note that torch._scaled_mm requires column-major in + # # the second input, so we transpose the quantized weight here. + # layer.weight.data = qweight.t() + # layer.weight_scaling_factor.data = weight_scale + # layer.fp8_linear_state = Fp8LinearState.READY qinput, x_scale = per_tensor_quantize(x) output, _ = torch._scaled_mm( @@ -124,13 +138,13 @@ def apply_weights(self, layer.weight, out_dtype=x.dtype, scale_a=x_scale, - scale_b=layer.w_scale, + scale_b=layer.weight_scaling_factor, bias=bias, ) return output -@torch.compile +#@torch.compile def per_tensor_quantize( tensor: torch.Tensor, qdtype=torch.float8_e4m3fn) -> tuple[torch.Tensor, float]: @@ -142,8 +156,9 @@ def per_tensor_quantize( """ finfo = torch.finfo(qdtype) # Calculate the scale as dtype max divided by absmax - min_val, max_val = tensor.aminmax() - amax = max(-min_val, max_val) + # min_val, max_val = tensor.aminmax() + # amax = max(-min_val, max_val) + amax = tensor.abs().max() scale = finfo.max / amax.clamp(min=1e-12) # scale and clamp the tensor to bring it to # the representative range of float8 data type diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 3b1d125ef8a67..18589af75ce27 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -228,6 +228,9 @@ def load_model(self, *, model_config: ModelConfig, model, "fall_back_to_pt_during_load", True)), ) + for _, module in model.named_modules(): + if hasattr(module, "linear_method") and hasattr(module, "weight"): + module.linear_method.proc_after_loading(module) return model.eval() From 6a0a8bae491bd3a7fcc5bce7137449ec2cec43b5 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 18 Apr 2024 16:07:03 -0700 Subject: [PATCH 09/16] work --- vllm/model_executor/layers/linear.py | 24 +-------- .../model_executor/layers/quantization/fp8.py | 53 ++++--------------- 2 files changed, 12 insertions(+), 65 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f57d7a84898fc..73388a213e0b3 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -49,17 +49,7 @@ def apply_weights(self, Expects create_weights to have been called before on the layer.""" raise NotImplementedError - def proc_before_loading(self, layer: nn.Module, param: Parameter, - loaded_weight: torch.Tensor) -> torch.Tensor: - """Process the weight before loading. - - This can be used for exmaple, quantizing the weight before - loading it to the model. - """ - return loaded_weight - - def proc_after_loading(self, layer: nn.Module, param: Parameter, - loaded_weight: torch.Tensor) -> None: + def proc_after_loading(self, layer: nn.Module) -> None: """Process the weight after loading. This can be used for example, to transpose weights for computation. @@ -218,9 +208,6 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - loaded_weight = self.linear_method.proc_before_loading( - self, param, loaded_weight) - tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) param_data = param.data @@ -287,9 +274,6 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None): - loaded_weight = self.linear_method.proc_before_loading( - self, param, loaded_weight) - param_data = param.data output_dim = getattr(param, "output_dim", None) if loaded_shard_id is None: @@ -416,9 +400,6 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[str] = None): - loaded_weight = self.linear_method.proc_before_loading( - self, param, loaded_weight) - param_data = param.data output_dim = getattr(param, "output_dim", None) @@ -576,9 +557,6 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - loaded_weight = self.linear_method.proc_before_loading( - self, param, loaded_weight) - tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) param_data = param.data diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 7c8893ec1ec5a..cd210cb4d4bc6 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional import torch -import torch._dynamo from torch.nn import Module from torch.nn.parameter import Parameter @@ -12,7 +11,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -torch._dynamo.config.suppress_errors = True class FP8Config(QuantizationConfig): """Config class for FP8.""" @@ -47,12 +45,6 @@ def get_scaled_act_names(self) -> List[str]: return [] -class Fp8LinearState(Enum): - - UNINITIALIZED = enum.auto() - READY = enum.auto() - - class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. We now support common FP16/BF16 model checkpoints ONLY. The weight @@ -80,58 +72,35 @@ def create_weights( ): weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, - dtype=torch.float8_e4m3fn), + dtype=params_dtype), requires_grad=False) layer.register_parameter("weight", weight) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, extra_weight_attrs) - # Will be initialized for FP16/BF16 checkpoints - # during the first forward pass. w_scale = Parameter( torch.empty(1, dtype=torch.float32), requires_grad=False, ) layer.register_parameter("weight_scaling_factor", w_scale) - # We always initialize the state to UNINITIALIZED because - # we cannot know whether weights going to be loaded are in FP8 - # or not. If they are in FP8, the first forward pass simply - # sets the state to READY without re-quantization. - layer.fp8_linear_state = Fp8LinearState.UNINITIALIZED - - def proc_before_loading(self, layer: Module, param: Parameter, - loaded_weight: torch.Tensor) -> torch.Tensor: - if loaded_weight.dtype != torch.float8_e4m3fn: - loaded_weight, weight_scale = per_tensor_quantize(loaded_weight) - # If the loaded weight is not in FP8, we override the - # weight and scaling factor with the quantized weight. - layer.weight_scaling_factor.data.copy_(weight_scale) - return loaded_weight + # def proc_before_loading(self, layer: Module, param: Parameter, + # loaded_weight: torch.Tensor) -> torch.Tensor: + # loaded_weight, weight_scale = per_tensor_quantize(loaded_weight) + # layer.weight_scaling_factor.data.copy_(weight_scale) + # return loaded_weight def proc_after_loading(self, layer: Module) -> None: - # Note that torch._scaled_mm requires column-major in - # the second input (weight), so we transpose the quantized - # weight here. - layer.weight.data = layer.weight.data.t() + # torch._scaled_mm requires column-major in the second + # input (weight), so we transpose the quantized weight here. + qweight, weight_scale = per_tensor_quantize(layer.weight) + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scaling_factor.data.copy_(weight_scale) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - - # if layer.fp8_linear_state == Fp8LinearState.UNINITIALIZED: - # # Per-tensor scaling on the fly for FP16/BF16 weights - # # during the first forward pass if the loaded weights - # # are not in FP8. - # if layer.weight.dtype != torch.float8_e4m3fn: - # qweight, weight_scale = per_tensor_quantize(layer.weight.data) - # # Note that torch._scaled_mm requires column-major in - # # the second input, so we transpose the quantized weight here. - # layer.weight.data = qweight.t() - # layer.weight_scaling_factor.data = weight_scale - # layer.fp8_linear_state = Fp8LinearState.READY - qinput, x_scale = per_tensor_quantize(x) output, _ = torch._scaled_mm( qinput, From 7974ccfa64b9f313b59f7ae86a21dec5846a1c5e Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 18 Apr 2024 16:10:34 -0700 Subject: [PATCH 10/16] lint --- tests/quantization/test_fp8.py | 12 ++++-------- vllm/model_executor/layers/linear.py | 2 +- vllm/model_executor/layers/quantization/fp8.py | 2 -- vllm/model_executor/model_loader/loader.py | 3 ++- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 9074a0fd4f29f..2d0bac460941d 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -3,26 +3,22 @@ Run `pytest tests/quantization/test_fp8.py --forked`. """ import pytest - import torch from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod, Fp8LinearState +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] -@pytest.mark.skipif(capability - < QUANTIZATION_METHODS["fp8"].get_min_capability(), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + capability < QUANTIZATION_METHODS["fp8"].get_min_capability(), + reason="FP8 is not supported on this GPU type.") def test_load_fp16_model(vllm_runner) -> None: llm = vllm_runner("facebook/opt-125m", quantization="fp8") model = llm.llm_engine.model_executor.driver_worker.model_runner.model fc1 = model.model.decoder.layers[0].fc1 assert isinstance(fc1.linaer_method, Fp8LinearMethod) - - # The engine has been warmed up and the state should be ready. - assert fc1.fp8_linear_state == Fp8LinearState.READY assert fc1.weight.dtype == torch.float8_e4m3fn diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 73388a213e0b3..c84002302ebe8 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -54,7 +54,7 @@ def proc_after_loading(self, layer: nn.Module) -> None: This can be used for example, to transpose weights for computation. """ - pass + return class UnquantizedLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index cd210cb4d4bc6..da6b3bf426b4d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,5 +1,3 @@ -import enum -from enum import Enum from typing import Any, Dict, List, Optional import torch diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 18589af75ce27..bfd76b3e77e4d 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -229,7 +229,8 @@ def load_model(self, *, model_config: ModelConfig, "fall_back_to_pt_during_load", True)), ) for _, module in model.named_modules(): - if hasattr(module, "linear_method") and hasattr(module, "weight"): + if hasattr(module, "linear_method") and hasattr( + module, "weight"): module.linear_method.proc_after_loading(module) return model.eval() From ca416f390b51ee80be762ce10a5e525cb9360add Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 18 Apr 2024 16:23:25 -0700 Subject: [PATCH 11/16] done --- .../model_executor/layers/quantization/fp8.py | 36 +++++++++---------- vllm/model_executor/model_loader/loader.py | 3 +- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index da6b3bf426b4d..a329505eadf49 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -46,10 +46,12 @@ def get_scaled_act_names(self) -> List[str]: class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. We now support common FP16/BF16 model checkpoints ONLY. The weight - scaling factor will be initialized during the first forward pass. + scaling factor will be initialized after the model weights are loaded. - Note that we currently only support per-tensor quantization due to - torch._scaled_mm support. + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) Args: quant_config: The quantization config. @@ -82,16 +84,13 @@ def create_weights( ) layer.register_parameter("weight_scaling_factor", w_scale) - # def proc_before_loading(self, layer: Module, param: Parameter, - # loaded_weight: torch.Tensor) -> torch.Tensor: - # loaded_weight, weight_scale = per_tensor_quantize(loaded_weight) - # layer.weight_scaling_factor.data.copy_(weight_scale) - # return loaded_weight - def proc_after_loading(self, layer: Module) -> None: - # torch._scaled_mm requires column-major in the second - # input (weight), so we transpose the quantized weight here. + if not hasattr(layer, "weight_scaling_factor"): + return + qweight, weight_scale = per_tensor_quantize(layer.weight) + # torch._scaled_mm requires column-major in the second + # input (weight), so we transpose the quantized weight. layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scaling_factor.data.copy_(weight_scale) @@ -111,21 +110,18 @@ def apply_weights(self, return output -#@torch.compile -def per_tensor_quantize( - tensor: torch.Tensor, - qdtype=torch.float8_e4m3fn) -> tuple[torch.Tensor, float]: +@torch.compile +def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]: """Quantize a tensor using per-tensor static scaling factor. Args: tensor: The input tensor. qdtype: The quantized data type. """ - finfo = torch.finfo(qdtype) + finfo = torch.finfo(torch.float8_e4m3fn) # Calculate the scale as dtype max divided by absmax - # min_val, max_val = tensor.aminmax() - # amax = max(-min_val, max_val) - amax = tensor.abs().max() + min_val, max_val = tensor.aminmax() + amax = max(-min_val, max_val) scale = finfo.max / amax.clamp(min=1e-12) # scale and clamp the tensor to bring it to # the representative range of float8 data type @@ -133,6 +129,6 @@ def per_tensor_quantize( qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max) # Return both float8 data and the inverse scale (as float), # as both required as inputs to torch._scaled_mm - qweight = qweight.to(qdtype) + qweight = qweight.to(torch.float8_e4m3fn) scale = scale.float().reciprocal() return qweight, scale diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bfd76b3e77e4d..bf9575c72fc02 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -229,8 +229,7 @@ def load_model(self, *, model_config: ModelConfig, "fall_back_to_pt_during_load", True)), ) for _, module in model.named_modules(): - if hasattr(module, "linear_method") and hasattr( - module, "weight"): + if hasattr(module, "linear_method"): module.linear_method.proc_after_loading(module) return model.eval() From ee69c1b52d3923a603cc76daa77af725abc60ac6 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 19 Apr 2024 09:48:51 -0700 Subject: [PATCH 12/16] fix ci --- vllm/model_executor/layers/quantization/fp8.py | 4 ++++ vllm/model_executor/model_loader/loader.py | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index a329505eadf49..13570182e0da9 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -85,6 +85,10 @@ def create_weights( layer.register_parameter("weight_scaling_factor", w_scale) def proc_after_loading(self, layer: Module) -> None: + # Although the linear_method is propagated to all layers, + # only linear layers invoke "create_weights". So we check + # whether "weight_scaling_facor" is registered to determine + # whether the layer is a linear layer that requires quantization. if not hasattr(layer, "weight_scaling_factor"): return diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bf9575c72fc02..d299151490a28 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -229,8 +229,9 @@ def load_model(self, *, model_config: ModelConfig, "fall_back_to_pt_during_load", True)), ) for _, module in model.named_modules(): - if hasattr(module, "linear_method"): - module.linear_method.proc_after_loading(module) + linear_method = getattr(module, "linear_method", None) + if linear_method is not None: + linear_method.proc_after_loading(module) return model.eval() From 2088110c0b7e25264b68a07de2f2b774345ab81a Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 19 Apr 2024 10:30:53 -0700 Subject: [PATCH 13/16] fix --- vllm/model_executor/layers/quantization/fp8.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 13570182e0da9..c05a843390279 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -114,7 +114,6 @@ def apply_weights(self, return output -@torch.compile def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]: """Quantize a tensor using per-tensor static scaling factor. @@ -123,9 +122,11 @@ def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]: qdtype: The quantized data type. """ finfo = torch.finfo(torch.float8_e4m3fn) - # Calculate the scale as dtype max divided by absmax + # Calculate the scale as dtype max divided by absmax. + # Since .abs() creates a new tensor, we use aminmax to get + # the min and max first and then calculate the absmax. min_val, max_val = tensor.aminmax() - amax = max(-min_val, max_val) + amax = min_val.abs().max(max_val.abs()) scale = finfo.max / amax.clamp(min=1e-12) # scale and clamp the tensor to bring it to # the representative range of float8 data type From e7b2fc8d09024c52b2e7923591a36dadc5c4b852 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 19 Apr 2024 13:10:04 -0700 Subject: [PATCH 14/16] rename --- vllm/model_executor/layers/linear.py | 2 +- vllm/model_executor/layers/quantization/fp8.py | 2 +- vllm/model_executor/model_loader/loader.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index c84002302ebe8..d466d8807fc64 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -49,7 +49,7 @@ def apply_weights(self, Expects create_weights to have been called before on the layer.""" raise NotImplementedError - def proc_after_loading(self, layer: nn.Module) -> None: + def process_weights_after_loading(self, layer: nn.Module) -> None: """Process the weight after loading. This can be used for example, to transpose weights for computation. diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c05a843390279..b044b5c1a2622 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -84,7 +84,7 @@ def create_weights( ) layer.register_parameter("weight_scaling_factor", w_scale) - def proc_after_loading(self, layer: Module) -> None: + def process_weights_after_loading(self, layer: Module) -> None: # Although the linear_method is propagated to all layers, # only linear layers invoke "create_weights". So we check # whether "weight_scaling_facor" is registered to determine diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index d299151490a28..6c8cb2935f37e 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -231,7 +231,7 @@ def load_model(self, *, model_config: ModelConfig, for _, module in model.named_modules(): linear_method = getattr(module, "linear_method", None) if linear_method is not None: - linear_method.proc_after_loading(module) + linear_method.process_weights_after_loading(module) return model.eval() From 1f957393a042c3fa0983c1c5d94400c43e961468 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 19 Apr 2024 16:26:55 -0700 Subject: [PATCH 15/16] comment --- vllm/model_executor/layers/quantization/fp8.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b044b5c1a2622..9dc0e86e1243d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -119,7 +119,6 @@ def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]: Args: tensor: The input tensor. - qdtype: The quantized data type. """ finfo = torch.finfo(torch.float8_e4m3fn) # Calculate the scale as dtype max divided by absmax. From 5b70cd1c7c9cf7cc2e60fb5454d460b2c3d3de92 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 19 Apr 2024 17:00:29 -0700 Subject: [PATCH 16/16] fix test --- tests/quantization/test_fp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 2d0bac460941d..fa10e60de10a7 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -18,7 +18,7 @@ def test_load_fp16_model(vllm_runner) -> None: llm = vllm_runner("facebook/opt-125m", quantization="fp8") - model = llm.llm_engine.model_executor.driver_worker.model_runner.model + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model fc1 = model.model.decoder.layers[0].fc1 - assert isinstance(fc1.linaer_method, Fp8LinearMethod) + assert isinstance(fc1.linear_method, Fp8LinearMethod) assert fc1.weight.dtype == torch.float8_e4m3fn