Skip to content

Commit

Permalink
[Neuron] Adding support for context-lenght, token-gen buckets for lat…
Browse files Browse the repository at this point in the history
…ency optimization with neuron device.
  • Loading branch information
Harsha Bikki committed Aug 27, 2024
1 parent 9c71c97 commit 3d6e0c7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
11 changes: 9 additions & 2 deletions examples/offline_inference_neuron.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import os

from vllm import LLM, SamplingParams

# creates XLA hlo graphs for all the context length buckets.
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"

# Sample prompts.
prompts = [
"Hello, my name is",
Expand All @@ -19,8 +26,8 @@
# Currently, this is a known limitation in continuous batching support
# in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len=128,
block_size=128,
max_model_len=2048,
block_size=2048,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
Expand Down
33 changes: 24 additions & 9 deletions vllm/model_executor/model_loader/neuron.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Utilities for selecting and loading neuron models."""
import importlib
import os
from typing import Dict, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -109,6 +109,17 @@ def _get_model_architecture(config: PretrainedConfig) -> str:
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")


def _get_buckets(env: str, default_value: List[int]) -> List[int]:
env_value = os.getenv(env)
if env_value is None:
return default_value
buckets_remove_empty = filter(
lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
buckets_int = map(int, buckets_remove_empty)
buckets_list = list(buckets_int)
return buckets_list


def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
Expand All @@ -123,14 +134,18 @@ def get_neuron_model(model_config: ModelConfig,
neuron_config = NeuronConfig(
continuous_batching=continuous_batching_config)

context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])

# Load the weights from the cached or downloaded files.
model.load_weights(
model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=[scheduler_config.max_model_len],
n_positions=[scheduler_config.max_model_len],
batch_size=scheduler_config.max_num_seqs)
model.load_weights(model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)

return model.eval()

0 comments on commit 3d6e0c7

Please sign in to comment.