Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][Neuron] Refactor neuron support #3471

Merged
merged 23 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/offline_inference_neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

# Create an LLM.
llm = LLM(
model="openlm-research/open_llama_3b",
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
max_num_seqs=8,
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
Expand All @@ -24,7 +24,8 @@
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device="neuron")
device="neuron",
tensor_parallel_size=2)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
Expand Down
2 changes: 1 addition & 1 deletion tests/lora/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_worker_apply_lora(sql_lora_files):
max_loras=32),
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
)
worker.init_model()
worker.init_device()
worker.load_model()

worker.model_runner.set_active_loras([], LoRAMapping([], []))
Expand Down
18 changes: 9 additions & 9 deletions tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):

worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()

vocab_size = 32_000

Expand Down Expand Up @@ -151,7 +151,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):

worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()

proposal_token_ids = torch.randint(low=0,
high=vocab_size,
Expand Down Expand Up @@ -230,7 +230,7 @@ def test_correctly_formats_output(k: int, batch_size: int):

worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()

proposal_token_ids = torch.randint(low=0,
high=vocab_size,
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):

worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
worker.init_model()
worker.init_device()

proposal_token_ids = torch.randint(low=0,
high=vocab_size,
Expand Down Expand Up @@ -486,8 +486,8 @@ def test_empty_input_batch(k: int, batch_size: int):


@torch.inference_mode()
def test_init_model():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_model, as
def test_init_device():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
Expand All @@ -499,11 +499,11 @@ def test_init_model():
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)

worker.init_model()
worker.init_device()

draft_worker.init_model.assert_called_once()
draft_worker.init_device.assert_called_once()

target_worker.init_model.assert_called_once()
target_worker.init_device.assert_called_once()

metrics_collector.init_gpu_tensors.assert_called_once()
rejection_sampler.init_gpu_tensors.assert_called_once()
Expand Down
2 changes: 1 addition & 1 deletion tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def create_worker(cls: type,
is_driver_worker=is_driver_worker,
)

worker.init_model()
worker.init_device()
worker.load_model()

cache_config.num_gpu_blocks = num_gpu_blocks
Expand Down
2 changes: 1 addition & 1 deletion tests/worker/test_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_swap() -> None:
)

# Initialize the worker.
worker.init_model()
worker.init_device()
worker.load_model()
worker.init_cache_engine(cache_config)
worker.warm_up_model()
Expand Down
17 changes: 2 additions & 15 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,15 +474,7 @@ def __init__(
placement_group: Optional["PlacementGroup"] = None,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
if is_neuron():
# For Neuron device support, here we assign TP=1 to avoid sharding
# within vLLM directly. Transformer-neuronx would take
# neuron_tp_degree attribute, and distribute the workload
# to multiple NeuronCores.
self.tensor_parallel_size = 1
self.neuron_tp_degree = tensor_parallel_size
else:
self.tensor_parallel_size = tensor_parallel_size
self.tensor_parallel_size = tensor_parallel_size
zhuohan123 marked this conversation as resolved.
Show resolved Hide resolved
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
Expand All @@ -491,8 +483,7 @@ def __init__(
self.placement_group = placement_group

self.world_size = pipeline_parallel_size * self.tensor_parallel_size
# Ray worker is not supported for Neuron backend.
if self.world_size > 1 and not is_neuron():
if self.world_size > 1:
self.worker_use_ray = True
self._verify_args()

Expand Down Expand Up @@ -591,10 +582,6 @@ def __init__(self, device: str = "auto") -> None:
# Set device with device type
self.device = torch.device(self.device_type)

@property
def is_neuron(self):
return self.device_type == "neuron"


@dataclass
class LoRAConfig:
Expand Down
7 changes: 6 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,12 @@ def from_engine_args(cls,
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
device_config = engine_configs[4]

if device_config.device_type == "neuron":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of string comparison, set a property in device_config?

FYI, we might consider some device-specific config with device="neuron config=(a=1, b=2, c=3)"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device_type is different from device. I think we can safely assume device_type can only be cuda and neuron for now? We can put the config string in other fields of device_config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, i think we are fine to move on with this.

raise NotImplementedError("Neuron is not supported for "
"async engine yet.")
elif parallel_config.worker_use_ray or engine_args.engine_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
Expand Down
6 changes: 5 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,13 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
device_config = engine_configs[4]

# Initialize the cluster and specify the executor class.
if parallel_config.worker_use_ray:
if device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif parallel_config.worker_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
Expand Down
18 changes: 2 additions & 16 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
from typing import Dict, List, Optional

from vllm.lora.request import LoRARequest
Expand All @@ -13,12 +12,6 @@

logger = init_logger(__name__)

# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP = {
"cuda": "vllm.worker.worker",
"neuron": "vllm.worker.neuron_worker",
}


class GPUExecutor(ExecutorBase):

Expand All @@ -44,17 +37,10 @@ def __init__(
# Profile the memory usage and initialize the cache.
self._init_cache()

def _dispatch_worker(self):
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
self.device_config.device_type]
imported_worker = importlib.import_module(worker_module)
Worker = imported_worker.Worker
return Worker

def _init_worker(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Worker = self._dispatch_worker()
from vllm.worker.worker import Worker

assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
Expand All @@ -73,7 +59,7 @@ def _init_worker(self):
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
self.driver_worker.init_model()
self.driver_worker.init_device()
self.driver_worker.load_model()

def _init_cache(self) -> None:
Expand Down
80 changes: 80 additions & 0 deletions vllm/executor/neuron_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Dict, List, Optional

from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.sequence import SamplerOutput, SequenceGroupMetadata

logger = init_logger(__name__)


class NeuronExecutor(ExecutorBase):

def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
assert lora_config is None, "LoRA is not supported for Neuron backend."
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config

# Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent
# to schedule without PagedAttention.
self.cache_config.num_gpu_blocks = self.scheduler_config.max_num_seqs
self.cache_config.num_cpu_blocks = 0

# Instantiate the worker and load the model to the device.
self._init_worker()

def _init_worker(self):
from vllm.worker.neuron_worker import NeuronWorker

self.driver_worker = NeuronWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
)
self.driver_worker.init_device()
self.driver_worker.load_model()

def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
and blocks_to_copy == {}), (
"Cache operations are not supported for Neuron backend.")

output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list)
return output

def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError(
"LoRA is not implemented for neuron backend.")

def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"LoRA is not implemented for neuron backend.")

def list_loras(self) -> List[int]:
raise NotImplementedError(
"LoRA is not implemented for neuron backend.")

def check_health(self) -> None:
# NeuronExecutor will always be healthy as long as
# it's running.
return
18 changes: 2 additions & 16 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections import defaultdict
import os
import pickle
import importlib
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
Expand All @@ -25,12 +24,6 @@

logger = init_logger(__name__)

# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP = {
"cuda": "vllm.worker.worker",
"neuron": "vllm.worker.neuron_worker",
}

# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
Expand Down Expand Up @@ -73,13 +66,6 @@ def __init__(
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()

def _dispatch_worker(self):
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
self.device_config.device_type]
imported_worker = importlib.import_module(worker_module)
Worker = imported_worker.Worker
return Worker

def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
Expand Down Expand Up @@ -155,7 +141,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",

# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Worker = self._dispatch_worker()
from vllm.worker.worker import Worker

model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
Expand Down Expand Up @@ -201,7 +187,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",

# FIXME(woosuk): We are not properly initializing cupy NCCL when
# we have multiple nodes.
self._run_workers("init_model",
self._run_workers("init_device",
cupy_port=get_open_port()
if not model_config.enforce_eager else None)
self._run_workers(
Expand Down
4 changes: 2 additions & 2 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,8 @@ def __init__(
self.device = device

@property
def logits_as_hidden_states(self):
return self.base_layer.logits_as_hidden_states
def logits_as_input(self):
return self.base_layer.logits_as_input

@property
def vocab_size(self):
Expand Down
4 changes: 2 additions & 2 deletions vllm/lora/lora.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional

import torch
from vllm.utils import in_wsl
from vllm.utils import is_pin_memory_available


class LoRALayerWeights:
Expand Down Expand Up @@ -64,7 +64,7 @@ def create_dummy_lora_weights(
dtype: torch.dtype,
device: torch.device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and not in_wsl()
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank],
dtype=dtype,
device=device,
Expand Down
Loading
Loading