Skip to content

Commit

Permalink
[Bugfix] load fc bias from config for eagle (vllm-project#8790)
Browse files Browse the repository at this point in the history
  • Loading branch information
sohamparikh authored and siddharth9820 committed Sep 30, 2024
1 parent 5ae6cea commit 9db1b59
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions vllm/model_executor/models/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None:
self.model = model_cls(self.config.model, *args, **kwargs)
self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size,
bias=False)
bias=getattr(self.config, "bias", False))

self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
Expand Down Expand Up @@ -136,10 +136,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if self.config.truncated_vocab_size < self.config.vocab_size:
self.token_map = nn.Parameter(loaded_weight,
requires_grad=False)
elif name.startswith("fc."):
elif name.startswith("fc.weight"):
weight_loader = getattr(self.fc.weight, "weight_loader",
default_weight_loader)
weight_loader(self.fc.weight, loaded_weight)
elif name.startswith("fc.bias"):
if self.fc.bias is not None:
weight_loader = getattr(self.fc.bias, "weight_loader",
default_weight_loader)
weight_loader(self.fc.bias, loaded_weight)
else:
raise ValueError("Found bias in the loaded weights "
"but the model config doesn't have bias")
elif name.startswith("model.lm_head.") or name.startswith(
"model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight
Expand Down

0 comments on commit 9db1b59

Please sign in to comment.