Skip to content

Commit

Permalink
fix vllm qwen2 models (#11879)
Browse files Browse the repository at this point in the history
  • Loading branch information
gc-fu authored Aug 21, 2024
1 parent bd1e490 commit 537c0d2
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions python/llm/src/ipex_llm/vllm/xpu/model_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather

from typing import Tuple, Optional
from typing import Tuple, Optional, Union
from ipex_llm.utils.common import invalidInputError
from vllm.sequence import SamplerOutput

Expand All @@ -51,8 +51,10 @@ def _Qwen2_sample(
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens
# Embedding layer is not optimized to LowBitLinear
lm_head_weight = self.model.embed_tokens.weight
else:
# This layer is optimized to LowBitLinear
lm_head_weight = self.lm_head
next_tokens = self.sampler(lm_head_weight, hidden_states,
sampling_metadata)
Expand All @@ -70,9 +72,15 @@ def _Chatglm_sample(
return next_tokens


def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module,
def _sample_get_logits(self, hidden_states: torch.Tensor,
embedding: Union[torch.nn.Module, torch.Tensor],
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
logits = embedding(hidden_states)
# For tie_word_embedding models, the embedding is not optimized as
# the low_bit_linear layer...
if isinstance(embedding, torch.Tensor):
logits = torch.matmul(hidden_states, embedding.t())
else:
logits = embedding(hidden_states)
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
Expand Down

0 comments on commit 537c0d2

Please sign in to comment.