From 72425a5b48a8147f72010ffe06edfc23c808df7b Mon Sep 17 00:00:00 2001 From: DAIZHENWEI <32122197+DAIZHENWEI@users.noreply.github.com> Date: Mon, 11 Mar 2024 13:19:51 -0700 Subject: [PATCH] Support Mistral Model Inference with transformers-neuronx (#3153) --- examples/offline_inference_neuron.py | 10 ++- vllm/model_executor/models/__init__.py | 7 +- vllm/model_executor/models/neuron/mistral.py | 82 ++++++++++++++++++++ 3 files changed, 93 insertions(+), 6 deletions(-) mode change 100644 => 100755 examples/offline_inference_neuron.py mode change 100644 => 100755 vllm/model_executor/models/__init__.py create mode 100755 vllm/model_executor/models/neuron/mistral.py diff --git a/examples/offline_inference_neuron.py b/examples/offline_inference_neuron.py old mode 100644 new mode 100755 index 9b9dc4d94892f..da8874abd92a2 --- a/examples/offline_inference_neuron.py +++ b/examples/offline_inference_neuron.py @@ -14,14 +14,16 @@ llm = LLM( model="openlm-research/open_llama_3b", max_num_seqs=8, - # The max_model_len and block_size arguments are required to be same as max sequence length, - # when targeting neuron device. Currently, this is a known limitation in continuous batching - # support in transformers-neuronx. + # The max_model_len and block_size arguments are required to be same as + # max sequence length when targeting neuron device. + # Currently, this is a known limitation in continuous batching support + # in transformers-neuronx. # TODO(liangfu): Support paged-attention in transformers-neuronx. max_model_len=128, block_size=128, # The device can be automatically detected when AWS Neuron SDK is installed. - # The device argument can be either unspecified for automated detection, or explicitly assigned. + # The device argument can be either unspecified for automated detection, + # or explicitly assigned. device="neuron") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py old mode 100644 new mode 100755 index 75c2ae1e9f48e..bc3b6a582d53d --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -62,8 +62,11 @@ "Sliding window attention is not yet supported in ROCm's flash attention", } -# Models not supported by Neuron. -_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"} +# Models supported by Neuron. +_NEURON_SUPPORTED_MODELS = { + "LlamaForCausalLM": "neuron.llama", + "MistralForCausalLM": "neuron.mistral" +} class ModelRegistry: diff --git a/vllm/model_executor/models/neuron/mistral.py b/vllm/model_executor/models/neuron/mistral.py new file mode 100755 index 0000000000000..a302cce30abab --- /dev/null +++ b/vllm/model_executor/models/neuron/mistral.py @@ -0,0 +1,82 @@ +"""Inference-only Mistral model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import MistralConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput +import os + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class MistralForCausalLM(nn.Module): + + def __init__( + self, + config: MistralConfig, + linear_method=None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = None + self.lm_head = None + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> SamplerOutput: + with torch.inference_mode(): + seq_ids = [] + block_size = self.model.context_buckets[-1] + if input_metadata.is_prompt: + seq_ids = input_metadata.slot_mapping[:, 0] // block_size + else: + seq_ids = input_metadata.block_tables + + logits = self.model(input_ids, + cache_ids=positions, + start_ids=seq_ids) + return logits + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.model.chkpt_model.lm_head, + hidden_states, sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + **kwargs): + from transformers_neuronx.mistral.model import MistralForSampling + + split_model_dir = f"{model_name_or_path}-split" + if os.path.isdir(os.path.join(model_name_or_path, + "pytorch_model.bin")): + split_model_dir = model_name_or_path + elif not os.path.exists(f"{model_name_or_path}-split"): + from transformers import MistralForCausalLM + from transformers_neuronx.module import save_pretrained_split + + hf_model = MistralForCausalLM.from_pretrained( + model_name_or_path, low_cpu_mem_usage=True) + save_pretrained_split(hf_model, f"{model_name_or_path}-split") + + self.model = MistralForSampling.from_pretrained( + split_model_dir, **kwargs) + self.model.to_neuron()