Skip to content

Commit

Permalink
[bugfix] fix broken tests of mlp speculator (vllm-project#10177)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
  • Loading branch information
youkaichao authored and sumitd2 committed Nov 14, 2024
1 parent a9ffaa2 commit 0779cd0
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions vllm/model_executor/models/mlp_speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs import MLPSpeculatorConfig

SQRT2 = 2**0.5

Expand Down Expand Up @@ -65,8 +65,9 @@ class MLPSpeculator(nn.Module):
https://huggingface.co/ibm-fms and https://huggingface.co/ibm-granite
"""

def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
Expand Down

0 comments on commit 0779cd0

Please sign in to comment.