diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 987c1be3d5ad9..d759ce04d75e7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -317,9 +317,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32, 128, 256, 512, 1024, 2048], + choices=[8, 16, 32], help='Token block size for contiguous chunks of ' - 'tokens.') + 'tokens. This is ignored on neuron devices and ' + 'set to max-model-len') parser.add_argument('--enable-prefix-caching', action='store_true', @@ -793,7 +794,8 @@ def create_engine_config(self) -> EngineConfig: limit_mm_per_prompt=self.limit_mm_per_prompt, ) cache_config = CacheConfig( - block_size=self.block_size, + block_size=self.block_size if self.device != "neuron" else + self.max_model_len, # neuron needs block_size = max_model_len gpu_memory_utilization=self.gpu_memory_utilization, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index b45d5d86b54fa..02627de3e0be7 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -4,7 +4,8 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.utils import make_async +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) logger = init_logger(__name__) @@ -24,14 +25,17 @@ def _init_executor(self) -> None: def _init_worker(self): from vllm.worker.neuron_worker import NeuronWorker - + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) self.driver_worker = NeuronWorker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, - self.cache_config, - ) + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method) self.driver_worker.init_device() self.driver_worker.load_model() diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 3b0ded36ca1b6..fff14d6402b44 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -6,6 +6,8 @@ from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.worker.neuron_model_runner import NeuronModelRunner @@ -24,12 +26,18 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, + local_rank: int, + rank: int, + distributed_init_method: str, ) -> None: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules @@ -40,6 +48,8 @@ def __init__( self.is_driver_worker = True def init_device(self) -> None: + self.init_distributed_environment() + # Set random seed. set_random_seed(self.model_config.seed) @@ -98,3 +108,20 @@ def get_cache_block_size_bytes(self) -> int: This is required for speculative decoding; it is not yet implemented. """ raise NotImplementedError + + def init_distributed_environment(self): + """Neuron uses transformers-neuronx for tensor parallelism. + + vLLM still needs the environment inited when TP/PP > 1 + """ + init_distributed_environment( + world_size=1, + rank=self.rank, + local_rank=self.local_rank, + distributed_init_method=self.distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + 1, + 1, + )