Skip to content

Commit

Permalink
[1/N] pass the complete config from engine to executor (#9933)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 1, 2024
1 parent 598b6d7 commit 18bd758
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 137 deletions.
2 changes: 1 addition & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def from_engine_args(

# Create the async LLM engine.
engine = cls(
**engine_config.to_dict(),
vllm_config=engine_config,
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
Expand Down
50 changes: 21 additions & 29 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@
from typing_extensions import TypeIs, TypeVar

import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
Expand Down Expand Up @@ -222,24 +219,30 @@ def validate_outputs(

def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
vllm_config: EngineConfig,
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
) -> None:

# TODO: remove the local variables and use self.* throughout the class.
model_config = self.model_config = vllm_config.model_config
cache_config = self.cache_config = vllm_config.cache_config
lora_config = self.lora_config = vllm_config.lora_config
parallel_config = self.parallel_config = vllm_config.parallel_config
scheduler_config = self.scheduler_config = vllm_config.scheduler_config
device_config = self.device_config = vllm_config.device_config
speculative_config = self.speculative_config = vllm_config.speculative_config # noqa
load_config = self.load_config = vllm_config.load_config
decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)

logger.info(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
Expand Down Expand Up @@ -340,18 +343,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
self.input_processor = input_registry.create_input_processor(
model_config)

self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)
self.model_executor = executor_class(vllm_config=vllm_config, )

if self.model_config.task != "embedding":
self._initialize_kv_caches()
Expand Down Expand Up @@ -582,7 +574,7 @@ def from_engine_args(
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
**engine_config.to_dict(),
vllm_config=engine_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
Expand Down
7 changes: 1 addition & 6 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import zmq

from vllm import AsyncEngineArgs, SamplingParams
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
Expand All @@ -30,9 +28,6 @@
else:
from vllm.engine.llm_engine import LLMEngine

CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig]

logger = init_logger(__name__)

POLLING_TIMEOUT_MS = 10000
Expand Down Expand Up @@ -130,7 +125,7 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs,

return cls(ipc_path=ipc_path,
use_async_sockets=use_async_sockets,
**engine_config.to_dict(),
vllm_config=engine_config,
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
Expand Down
37 changes: 13 additions & 24 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import EngineConfig
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand All @@ -23,27 +20,19 @@ class ExecutorBase(ABC):

def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
observability_config: Optional[ObservabilityConfig],
vllm_config: EngineConfig,
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self._init_executor()

@abstractmethod
Expand Down
44 changes: 8 additions & 36 deletions vllm/executor/xpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@

import torch

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import ModelConfig, ParallelConfig
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
Expand All @@ -21,38 +18,13 @@ class XPUExecutor(GPUExecutor):

uses_ray: bool = False

def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
assert device_config.device_type == "xpu"
assert (not speculative_config
), "Speculative decoding not yet supported for XPU backend"

model_config = _verify_and_get_model_config(model_config)

self.model_config = model_config
self.cache_config = cache_config
self.load_config = load_config
self.lora_config = lora_config
self.parallel_config = _verify_and_get_parallel_config(parallel_config)
self.scheduler_config = scheduler_config
self.device_config = device_config
self.prompt_adapter_config = prompt_adapter_config
self.speculative_config = None
self.observability_config = observability_config

# Instantiate the worker and load the model to GPU.
self._init_executor()
def _init_executor(self) -> None:
assert self.device_config.device_type == "xpu"
assert self.speculative_config is None, (
"Speculative decoding not yet supported for XPU backend")

self.model_config = _verify_and_get_model_config(self.model_config)
GPUExecutor._init_executor(self)

def _get_worker_module_and_class(
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
Expand Down
62 changes: 21 additions & 41 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type,
Union)

from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
Expand Down Expand Up @@ -35,24 +32,30 @@ class LLMEngine:

def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
vllm_config: EngineConfig,
executor_class: Type[GPUExecutor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
) -> None:

# TODO: remove the local variables and use self.* throughout the class.
model_config = self.model_config = vllm_config.model_config
cache_config = self.cache_config = vllm_config.cache_config
lora_config = self.lora_config = vllm_config.lora_config
parallel_config = self.parallel_config = vllm_config.parallel_config
scheduler_config = self.scheduler_config = vllm_config.scheduler_config
device_config = self.device_config = vllm_config.device_config
speculative_config = self.speculative_config = vllm_config.speculative_config # noqa
load_config = self.load_config = vllm_config.load_config
decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)

# Override the configs for V1.
# FIXME
if usage_context == UsageContext.LLM_CLASS:
Expand Down Expand Up @@ -112,18 +115,6 @@ def __init__(
model_config.mm_processor_kwargs,
)

self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats

assert not self.model_config.skip_tokenizer_init
Expand Down Expand Up @@ -154,18 +145,7 @@ def __init__(
# Request id -> RequestOutput
self.request_outputs: Dict[str, RequestOutput] = {}

self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)
self.model_executor = executor_class(vllm_config=vllm_config)
assert self.model_config.task != "embedding"
self._initialize_kv_caches()

Expand Down Expand Up @@ -203,7 +183,7 @@ def from_engine_args(
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
**engine_config.to_dict(),
vllm_config=engine_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
Expand Down

0 comments on commit 18bd758

Please sign in to comment.