Skip to content

Commit

Permalink
[Bugfix] Fix PerTensorScaleParameter weight loading for fused models (
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka authored Aug 9, 2024
1 parent 933790c commit 5c6c54d
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter)
PackedvLLMParameter,
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)
Expand Down Expand Up @@ -573,11 +574,13 @@ def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
param_data = param.data
if loaded_shard_id is None:
if param.output_dim is None:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=0)
return
elif type(param) is BasevLLMParameter:
param.load_merged_column_weight(loaded_weight=loaded_weight)
return
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
Expand Down Expand Up @@ -720,11 +723,13 @@ def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param_data = param.data
if loaded_shard_id is None: # special case for certain models
if param.output_dim is None:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=0)
return
elif type(param) is BasevLLMParameter:
param.load_merged_column_weight(loaded_weight=loaded_weight)
return
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
Expand Down

0 comments on commit 5c6c54d

Please sign in to comment.