Skip to content

Commit

Permalink
Support Mistral Model Inference with transformers-neuronx (vllm-proje…
Browse files Browse the repository at this point in the history
  • Loading branch information
DAIZHENWEI authored Mar 11, 2024
1 parent dde4eb4 commit 72425a5
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 6 deletions.
10 changes: 6 additions & 4 deletions examples/offline_inference_neuron.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
llm = LLM(
model="openlm-research/open_llama_3b",
max_num_seqs=8,
# The max_model_len and block_size arguments are required to be same as max sequence length,
# when targeting neuron device. Currently, this is a known limitation in continuous batching
# support in transformers-neuronx.
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
# Currently, this is a known limitation in continuous batching support
# in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len=128,
block_size=128,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection, or explicitly assigned.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device="neuron")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/models/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@
"Sliding window attention is not yet supported in ROCm's flash attention",
}

# Models not supported by Neuron.
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"}
# Models supported by Neuron.
_NEURON_SUPPORTED_MODELS = {
"LlamaForCausalLM": "neuron.llama",
"MistralForCausalLM": "neuron.mistral"
}


class ModelRegistry:
Expand Down
82 changes: 82 additions & 0 deletions vllm/model_executor/models/neuron/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Inference-only Mistral model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple

import torch
from torch import nn
from transformers import MistralConfig

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
import os

KVCache = Tuple[torch.Tensor, torch.Tensor]


class MistralForCausalLM(nn.Module):

def __init__(
self,
config: MistralConfig,
linear_method=None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = None
self.lm_head = None
self.sampler = Sampler(config.vocab_size)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> SamplerOutput:
with torch.inference_mode():
seq_ids = []
block_size = self.model.context_buckets[-1]
if input_metadata.is_prompt:
seq_ids = input_metadata.slot_mapping[:, 0] // block_size
else:
seq_ids = input_metadata.block_tables

logits = self.model(input_ids,
cache_ids=positions,
start_ids=seq_ids)
return logits

def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return next_tokens

def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
**kwargs):
from transformers_neuronx.mistral.model import MistralForSampling

split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
from transformers import MistralForCausalLM
from transformers_neuronx.module import save_pretrained_split

hf_model = MistralForCausalLM.from_pretrained(
model_name_or_path, low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")

self.model = MistralForSampling.from_pretrained(
split_model_dir, **kwargs)
self.model.to_neuron()

0 comments on commit 72425a5

Please sign in to comment.