Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Misc ] fp8-marlin channelwise via compressed-tensors #6524

Merged
merged 12 commits into from
Jul 25, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1
model_name: "nm-testing/Qwen2-1.5B-Instruct-FP8W8"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.578
- name: "exact_match,flexible-extract"
value: 0.585
limit: 1000
num_fewshot: 5
1 change: 1 addition & 0 deletions .buildkite/lm-eval-harness/configs/models-small.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsUnquantized,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsWNA16)
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_matched_target, is_activation_quantization_format,
Expand Down Expand Up @@ -100,14 +101,18 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
def get_config_filenames(cls) -> List[str]:
return []

def _check_scheme_supported(self, min_capability: int):
def _check_scheme_supported(self,
min_capability: int,
error: bool = True) -> bool:
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < min_capability:
supported = capability >= min_capability
if error and not supported:
raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")
return supported

def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand Down Expand Up @@ -170,6 +175,29 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
# All conditions satisfied.
return True

def _is_fp8_w8a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
# Confirm weights quantized.
if weight_quant is None:
return False

# Confirm we have floating points.
if weight_quant.type != QuantizationType.FLOAT:
return False

# Confirm weight scheme is supported.
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
if not (is_symmetric_weight and is_static_weight
and is_per_tensor_or_channel_weight):
return False

# All conditions satisfied.
return True

def _is_wNa16_group_channel(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None
Expand Down Expand Up @@ -204,9 +232,23 @@ def _get_scheme_from_parts(
# Detect If Activation Quantization.
if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8(
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic))
else:
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(input_quant
and not input_quant.dynamic))

if self._is_fp8_w8a16(weight_quant, input_quant):
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic))
is_static_input_scheme=(input_quant
and not input_quant.dynamic))

if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
Expand Down Expand Up @@ -257,11 +299,10 @@ def get_scheme(
targets=self.target_scheme_map.keys())

# Find the quant_scheme
scheme = self.target_scheme_map[matched_target]

return self._get_scheme_from_parts(
weight_quant=scheme["weights"],
input_quant=scheme["input_activations"])
scheme_dict = self.target_scheme_map[matched_target]
scheme = self._get_scheme_from_parts(
weight_quant=scheme_dict["weights"],
input_quant=scheme_dict["input_activations"])

# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16)

__all__ = [
"CompressedTensorsScheme",
"CompressedTensorsUnquantized",
"CompressedTensorsWNA16",
"CompressedTensorsW8A16Fp8",
"CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8",
"CompressedTensorsW8A8Fp8",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ class CompressedTensorsScheme(ABC):
of different quantization schemes supported by CompressedTensors.
"""

@classmethod
@abstractmethod
def get_min_capability(self) -> int:
def get_min_capability(cls) -> int:
"""
Get minimum device capability.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
in a linear transformation.
"""

def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
# volta and up
return 70

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __init__(self,
raise ValueError(
"group_size must be given when using strategy group")

def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
# ampere + up
return 80

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Callable, List, Optional

import torch

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, create_per_channel_scale_param,
create_per_tensor_scale_param)
from vllm.model_executor.utils import set_weight_attrs

__all__ = ["CompressedTensorsW8A16Fp8"]

SUPPORTED_STRATEGIES = [
QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR
]


class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):

def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme

@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80

# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
def process_weights_after_loading(self, layer) -> None:
if self.strategy == QuantizationStrategy.TENSOR:
ws_channelwise = convert_to_channelwise(layer.weight_scale,
layer.logical_widths)
layer.weight_scale = torch.nn.Parameter(ws_channelwise,
requires_grad=False)

# Weights must be transposed for marlin
layer.weight = torch.nn.Parameter(layer.weight.t(),
requires_grad=False)

prepare_fp8_layer_for_marlin(layer, strategy="channel")

def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype

# WEIGHT
weight = torch.nn.Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
})

# WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param(
output_partition_sizes, **layer_kwargs)
elif self.strategy == QuantizationStrategy.TENSOR:
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
else:
raise ValueError(
f"Unsupported weight strategy={self.strategy}, "
f"supported strategies are {SUPPORTED_STRATEGIES}")
layer.register_parameter("weight_scale", weight_scale)

# INPUT SCALE (to deal with converted checkpoints)
if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
layer.register_parameter("input_scale", input_scale)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

return apply_fp8_marlin_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def __init__(self, strategy: str, is_static_input_scheme: bool):
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()

def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89

Expand Down Expand Up @@ -77,19 +78,20 @@ def create_weights(self, layer: torch.nn.Module,
})

# WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param(
output_partition_sizes, weight_loader=weight_loader)
output_partition_sizes, **layer_kwargs)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
output_partition_sizes, **layer_kwargs)
layer.register_parameter("weight_scale", weight_scale)

# INPUT SCALE
if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
output_partition_sizes, **layer_kwargs)
layer.register_parameter("input_scale", input_scale)

def apply_weights(self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme

def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
# turing and up
return 75

Expand Down Expand Up @@ -68,19 +69,19 @@ def create_weights(self, layer: torch.nn.Module,
# WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL:
scale = create_per_channel_scale_param(output_partition_sizes,
**layer_kwargs)
weight_scale = create_per_channel_scale_param(
output_partition_sizes, **layer_kwargs)
else:
assert self.strategy == QuantizationStrategy.TENSOR
scale = create_per_tensor_scale_param(output_partition_sizes,
**layer_kwargs)
layer.register_parameter("weight_scale", scale)
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
layer.register_parameter("weight_scale", weight_scale)

# INPUT SCALE
if self.is_static_input_scheme:
scale = create_per_tensor_scale_param(output_partition_sizes,
**layer_kwargs)
layer.register_parameter("input_scale", scale)
input_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
layer.register_parameter("input_scale", input_scale)

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(self,
group_size=self.group_size,
is_sym=True)

def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80

Expand Down
Loading
Loading