forked from mesolitica/vllm-whisper
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Mistral Model Inference with transformers-neuronx (vllm-proje…
- Loading branch information
1 parent
dde4eb4
commit 72425a5
Showing
3 changed files
with
93 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""Inference-only Mistral model compatible with HuggingFace weights.""" | ||
from typing import List, Optional, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
from transformers import MistralConfig | ||
|
||
from vllm.model_executor.input_metadata import InputMetadata | ||
from vllm.model_executor.layers.sampler import Sampler | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
from vllm.sequence import SamplerOutput | ||
import os | ||
|
||
KVCache = Tuple[torch.Tensor, torch.Tensor] | ||
|
||
|
||
class MistralForCausalLM(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
config: MistralConfig, | ||
linear_method=None, | ||
) -> None: | ||
super().__init__() | ||
self.config = config | ||
self.linear_method = linear_method | ||
self.model = None | ||
self.lm_head = None | ||
self.sampler = Sampler(config.vocab_size) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: List[KVCache], | ||
input_metadata: InputMetadata, | ||
) -> SamplerOutput: | ||
with torch.inference_mode(): | ||
seq_ids = [] | ||
block_size = self.model.context_buckets[-1] | ||
if input_metadata.is_prompt: | ||
seq_ids = input_metadata.slot_mapping[:, 0] // block_size | ||
else: | ||
seq_ids = input_metadata.block_tables | ||
|
||
logits = self.model(input_ids, | ||
cache_ids=positions, | ||
start_ids=seq_ids) | ||
return logits | ||
|
||
def sample( | ||
self, | ||
hidden_states: torch.Tensor, | ||
sampling_metadata: SamplingMetadata, | ||
) -> Optional[SamplerOutput]: | ||
next_tokens = self.sampler(self.model.chkpt_model.lm_head, | ||
hidden_states, sampling_metadata) | ||
return next_tokens | ||
|
||
def load_weights(self, | ||
model_name_or_path: str, | ||
cache_dir: Optional[str] = None, | ||
load_format: str = "auto", | ||
revision: Optional[str] = None, | ||
**kwargs): | ||
from transformers_neuronx.mistral.model import MistralForSampling | ||
|
||
split_model_dir = f"{model_name_or_path}-split" | ||
if os.path.isdir(os.path.join(model_name_or_path, | ||
"pytorch_model.bin")): | ||
split_model_dir = model_name_or_path | ||
elif not os.path.exists(f"{model_name_or_path}-split"): | ||
from transformers import MistralForCausalLM | ||
from transformers_neuronx.module import save_pretrained_split | ||
|
||
hf_model = MistralForCausalLM.from_pretrained( | ||
model_name_or_path, low_cpu_mem_usage=True) | ||
save_pretrained_split(hf_model, f"{model_name_or_path}-split") | ||
|
||
self.model = MistralForSampling.from_pretrained( | ||
split_model_dir, **kwargs) | ||
self.model.to_neuron() |