Skip to content

Commit

Permalink
[ Bugfix ] Enabling Loading Models With Fused QKV/MLP on Disk with FP8 (
Browse files Browse the repository at this point in the history
vllm-project#5921)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>
  • Loading branch information
2 people authored and prashantgupta24 committed Jul 1, 2024
1 parent 06ad77d commit eddfbb6
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 23 deletions.
14 changes: 12 additions & 2 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,13 @@ def weight_loader(self,
None)

if loaded_shard_id is None:
# Loaded weight is already packed.
# Loaded weight is already fused on disk (qkv/mlp).
if output_dim is None:
# If fp8 + scale, need to send to each shard.
if fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
Expand Down Expand Up @@ -567,8 +572,13 @@ def weight_loader(self,
None)

if loaded_shard_id is None:
# Loaded weight is already packed.
# Loaded weight is already fused on disk (qkv/mlp).
if output_dim is None:
# If fp8 + scale, need to send to each shard.
if fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
Expand Down
41 changes: 20 additions & 21 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class Fp8LinearMethod(LinearMethodBase):
"""

def __init__(self, quant_config: Fp8Config):
self.fused_module_in_checkpoint = False
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()

Expand All @@ -111,6 +112,7 @@ def _create_scale_param(
scale = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float8_e4m3fn).min
layer.register_parameter(scale_name, scale)
set_weight_attrs(
scale, {
Expand Down Expand Up @@ -169,11 +171,15 @@ def create_weights(
**extra_weight_attrs)

def scales_shard_indexer(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Optional[Union[str,
int]]) -> Tuple[torch.Tensor, torch.Tensor]:
qkv_idxs = {"q": 0, "k": 1, "v": 2}

if isinstance(shard_id, int):
if shard_id is None:
shard_id = 0
self.fused_module_in_checkpoint = True
elif isinstance(shard_id, int):
pass
elif isinstance(shard_id, str):
if shard_id not in qkv_idxs:
Expand Down Expand Up @@ -205,15 +211,17 @@ def process_weights_after_loading(self, layer: Module) -> None:
# WEIGHT_SCALE / WEIGHT
# Loop over logical weights, requantizing with single scale.
max_w_scale = layer.weight_scale.max()
start = 0
for idx, logical_width in enumerate(layer.logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
layer.weight_scale[idx])

layer.weight[start:end, :] = per_tensor_quantize(
weight_dq, layer.weight_scale.max())
start = end

if not self.fused_module_in_checkpoint:
start = 0
for idx, logical_width in enumerate(layer.logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(
layer.weight[start:end, :], layer.weight_scale[idx])

layer.weight[start:end, :] = per_tensor_quantize(
weight_dq, layer.weight_scale.max())
start = end
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

# WEIGHT
Expand All @@ -227,10 +235,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
if self.quant_config.activation_scheme == "dynamic":
layer.input_scale = None
elif self.quant_config.activation_scheme == "static":
if not all_close_1d(layer.input_scale):
raise ValueError(
"All the input_scales for the logical weights of a "
f"layer must be equal. But got {layer.input_scale}")
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
Expand Down Expand Up @@ -317,11 +321,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
del layer.kv_scale


def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))


def per_tensor_quantize(tensor: torch.Tensor,
inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
Expand Down

0 comments on commit eddfbb6

Please sign in to comment.