From 257afc37c5b3e4c6d491d105337387989b013aee Mon Sep 17 00:00:00 2001 From: Harsha vardhan manoj Bikki <39381063+hbikki@users.noreply.github.com> Date: Thu, 29 Aug 2024 13:58:14 -0700 Subject: [PATCH] [Neuron] Adding support for context-lenght, token-gen buckets. (#7885) Co-authored-by: Harsha Bikki --- examples/offline_inference_neuron.py | 11 ++++++-- vllm/model_executor/model_loader/neuron.py | 33 ++++++++++++++++------ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/examples/offline_inference_neuron.py b/examples/offline_inference_neuron.py index 5ecbbf020ab8b..2856be7c864ea 100644 --- a/examples/offline_inference_neuron.py +++ b/examples/offline_inference_neuron.py @@ -1,5 +1,12 @@ +import os + from vllm import LLM, SamplingParams +# creates XLA hlo graphs for all the context length buckets. +os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" +# creates XLA hlo graphs for all the token gen buckets. +os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" + # Sample prompts. prompts = [ "Hello, my name is", @@ -19,8 +26,8 @@ # Currently, this is a known limitation in continuous batching support # in transformers-neuronx. # TODO(liangfu): Support paged-attention in transformers-neuronx. - max_model_len=128, - block_size=128, + max_model_len=2048, + block_size=2048, # The device can be automatically detected when AWS Neuron SDK is installed. # The device argument can be either unspecified for automated detection, # or explicitly assigned. diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 07e23aca6cc5f..24fa13d7e5fe5 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -1,7 +1,7 @@ """Utilities for selecting and loading neuron models.""" import importlib import os -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -109,6 +109,17 @@ def _get_model_architecture(config: PretrainedConfig) -> str: f"{list(_NEURON_SUPPORTED_MODELS.keys())}") +def _get_buckets(env: str, default_value: List[int]) -> List[int]: + env_value = os.getenv(env) + if env_value is None: + return default_value + buckets_remove_empty = filter( + lambda x: x is not None and len(x.strip()) > 0, env_value.split(",")) + buckets_int = map(int, buckets_remove_empty) + buckets_list = list(buckets_int) + return buckets_list + + def get_neuron_model(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: @@ -123,14 +134,18 @@ def get_neuron_model(model_config: ModelConfig, neuron_config = NeuronConfig( continuous_batching=continuous_batching_config) + context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", + [scheduler_config.max_model_len]) + n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", + [scheduler_config.max_model_len]) + # Load the weights from the cached or downloaded files. - model.load_weights( - model_config.model, - tp_degree=parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - neuron_config=neuron_config, - context_length_estimate=[scheduler_config.max_model_len], - n_positions=[scheduler_config.max_model_len], - batch_size=scheduler_config.max_num_seqs) + model.load_weights(model_config.model, + tp_degree=parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) return model.eval()