Skip to content

Commit

Permalink
[Model] Add FP8 kv cache for Qwen2 (vllm-project#5656)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored and jimpang committed Jul 24, 2024
1 parent 69ab029 commit 64332bc
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.utils import print_warning_once


class Qwen2MLP(nn.Module):
Expand Down Expand Up @@ -375,6 +376,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
f"Found kv scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
"not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down

0 comments on commit 64332bc

Please sign in to comment.