From 966e31697bdeb47b33b3e26b4aab5999c85f3e90 Mon Sep 17 00:00:00 2001 From: Wallas Henrique Date: Tue, 5 Nov 2024 21:39:26 -0300 Subject: [PATCH] [Bugfix] Fix pickle of input when async output processing is on (#9931) Signed-off-by: Wallas Santos --- .../test_basic_correctness.py | 26 +++++++++++++++++++ vllm/worker/model_runner.py | 12 +++++++++ 2 files changed, 38 insertions(+) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 79647589d5204..7f16baa65a644 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -156,3 +156,29 @@ def test_model_with_failure(vllm_runner) -> None: ModelInputForGPUWithSamplingMetadata) finally: os.remove(filename) + + +def test_failure_with_async_out_proc(vllm_runner) -> None: + + filename = None + try: + with vllm_runner("facebook/opt-125m", + dtype="half", + enforce_eager=False, + gpu_memory_utilization=0.7) as vllm_model,\ + patch("vllm.model_executor.models.opt.OPTForCausalLM.forward", + side_effect=ValueError()): + model_config = vllm_model.model.llm_engine.model_config + assert model_config.use_async_output_proc + with pytest.raises(ValueError) as exc_info: + vllm_model.generate_greedy('how to make pizza?', 250) + matches = re.search(r"input dumped to (.+).pkl", + str(exc_info.value)) + assert matches is not None + + filename = f"{matches.group(1)}.pkl" + finally: + # Clean up + if filename is not None: + os.remove(filename) + pass diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2447eecf7957d..1e8ea4e8e79cf 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -136,6 +136,18 @@ def from_broadcasted_tensor_dict( attn_backend, tensor_dict) return cls(**tensor_dict) + # Exclude `async_callback` to be able to pickle this object + def __getstate__(self): + state = self.__dict__.copy() + del state["async_callback"] + return state + + # TODO: What happens when we depickle this object? + # How can we update this callback to properly pass it to the engine? + def __setstate__(self, state): + self.__dict__.update(state) + self.__dict__.update({'async_callback': None}) + @dataclass(frozen=True) class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):