From 8eed4adedaa53fe92519678ef923856438fc7f37 Mon Sep 17 00:00:00 2001 From: sohamparikh Date: Wed, 25 Sep 2024 02:16:30 -0400 Subject: [PATCH] [Bugfix] load fc bias from config for eagle (#8790) --- vllm/model_executor/models/eagle.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index ad1ab0231d861..13811d33768a6 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -44,7 +44,7 @@ def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: self.model = model_cls(self.config.model, *args, **kwargs) self.fc = nn.Linear(config.model.hidden_size * 2, config.model.hidden_size, - bias=False) + bias=getattr(self.config, "bias", False)) self.orig_vocab_size = config.vocab_size self.truncated_vocab_size = config.truncated_vocab_size @@ -136,10 +136,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if self.config.truncated_vocab_size < self.config.vocab_size: self.token_map = nn.Parameter(loaded_weight, requires_grad=False) - elif name.startswith("fc."): + elif name.startswith("fc.weight"): weight_loader = getattr(self.fc.weight, "weight_loader", default_weight_loader) weight_loader(self.fc.weight, loaded_weight) + elif name.startswith("fc.bias"): + if self.fc.bias is not None: + weight_loader = getattr(self.fc.bias, "weight_loader", + default_weight_loader) + weight_loader(self.fc.bias, loaded_weight) + else: + raise ValueError("Found bias in the loaded weights " + "but the model config doesn't have bias") elif name.startswith("model.lm_head.") or name.startswith( "model.model."): model_weights[name.split("model.", 1)[-1]] = loaded_weight