From dae571222838fd4a61ea5ce1c45647e6442464b6 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 24 Sep 2024 14:33:21 -0400 Subject: [PATCH] [Bugfix] Fix torch dynamo fixes caused by `replace_parameters` (#8748) --- .../layers/quantization/utils/layer_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py index c38bd8955f457..edce6d19b6c49 100644 --- a/vllm/model_executor/layers/quantization/utils/layer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -21,13 +21,17 @@ def replace_parameter(mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter]) -> None: old = getattr(mod, name) - if old.dtype == new.dtype and \ + if type(old) is type(new) and old.dtype == new.dtype and \ old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): # If we can just update in-place to avoid re-registering # can be faster if the underlying storage is the same update_tensor_inplace(old, new) else: - # Fallback re-register parameter + # Fallback re-register parameter, convert to Parameter if necessary + # this not only ensures we don't register a tensor as a parameter, but + # also ensures that all parameter subclasses get re-registered as + # parameters for `torch.compile` compatibility if not isinstance(new, torch.nn.Parameter): - new = torch.nn.Parameter(new) - mod.register_parameter(name, torch.nn.Parameter(new)) + new = torch.nn.Parameter(new, requires_grad=False) + mod.register_parameter(name, + torch.nn.Parameter(new, requires_grad=False))