From 2efce05dc3c7c1e367617465f8f661a058499e37 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 5 Mar 2024 16:17:20 -0800 Subject: [PATCH] [Fix] Avoid pickling entire LLMEngine for Ray workers (#3207) Co-authored-by: Antoni Baum --- vllm/engine/llm_engine.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 52dc96e2b82e1..8484014c9a13f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -158,6 +158,11 @@ def __init__( if USE_RAY_COMPILED_DAG: self.forward_dag = self._compiled_ray_dag() + def __reduce__(self): + # This is to ensure that the LLMEngine is not referenced in + # the closure used to initialize Ray worker actors + raise RuntimeError("LLMEngine should not be pickled!") + def get_tokenizer_for_seq(self, sequence: Sequence): return self.tokenizer.get_lora_tokenizer(sequence.lora_request) @@ -280,6 +285,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", parallel_config = copy.deepcopy(self.parallel_config) scheduler_config = copy.deepcopy(self.scheduler_config) device_config = copy.deepcopy(self.device_config) + lora_config = copy.deepcopy(self.lora_config) + kv_cache_dtype = self.cache_config.cache_dtype for rank, (worker, (node_id, _)) in enumerate(zip(self.workers, @@ -295,22 +302,22 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", local_rank, rank, distributed_init_method, - lora_config=self.lora_config, - kv_cache_dtype=self.cache_config.cache_dtype, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, )) driver_rank = 0 driver_local_rank = node_workers[driver_node_id].index(driver_rank) self.driver_worker = Worker( - model_config, - parallel_config, - scheduler_config, - device_config, + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, driver_local_rank, driver_rank, distributed_init_method, lora_config=self.lora_config, - kv_cache_dtype=self.cache_config.cache_dtype, + kv_cache_dtype=kv_cache_dtype, is_driver_worker=True, )