diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py index c38bd8955f457..edce6d19b6c49 100644 --- a/vllm/model_executor/layers/quantization/utils/layer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -21,13 +21,17 @@ def replace_parameter(mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter]) -> None: old = getattr(mod, name) - if old.dtype == new.dtype and \ + if type(old) is type(new) and old.dtype == new.dtype and \ old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): # If we can just update in-place to avoid re-registering # can be faster if the underlying storage is the same update_tensor_inplace(old, new) else: - # Fallback re-register parameter + # Fallback re-register parameter, convert to Parameter if necessary + # this not only ensures we don't register a tensor as a parameter, but + # also ensures that all parameter subclasses get re-registered as + # parameters for `torch.compile` compatibility if not isinstance(new, torch.nn.Parameter): - new = torch.nn.Parameter(new) - mod.register_parameter(name, torch.nn.Parameter(new)) + new = torch.nn.Parameter(new, requires_grad=False) + mod.register_parameter(name, + torch.nn.Parameter(new, requires_grad=False))