Skip to content

Commit

Permalink
Add backwards compatibility for #8673
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Sep 24, 2024
1 parent 6481cf3 commit e729f1d
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 88 deletions.
8 changes: 5 additions & 3 deletions tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine):

@pytest.mark.asyncio
async def test_new_requests_event():
params = SamplingParams()

engine = MockAsyncLLMEngine()
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0

await engine.add_request("1", "", None)
await engine.add_request("1", "", params)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 1
assert engine.engine.step_calls == 1

await engine.add_request("2", "", None)
await engine.add_request("2", "", params)
engine.engine.generate("2")
await asyncio.sleep(0)
await asyncio.sleep(0)
Expand All @@ -111,7 +113,7 @@ async def test_new_requests_event():
await asyncio.sleep(0.001)
assert engine.engine.step_calls == old_step_calls

await engine.add_request("3", "", None)
await engine.add_request("3", "", params)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1
Expand Down
34 changes: 0 additions & 34 deletions tests/entrypoints/llm/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
assert [o.outputs for o in o1] == [o.outputs for o in o2]


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
pooling_params = PoolingParams()

with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params)

v2_output = llm.encode(prompt, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)

v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
Expand All @@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
pooling_params = PoolingParams()

with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params)

v2_output = llm.encode(PROMPTS, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)

v2_output = llm.encode(
[{
"prompt": p
} for p in PROMPTS],
pooling_params=pooling_params,
)
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
pooling_params = PoolingParams()
Expand Down
37 changes: 0 additions & 37 deletions tests/entrypoints/llm/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=prompt,
sampling_params=sampling_params)

v2_output = llm.generate(prompt, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)

v2_output = llm.generate({"prompt": prompt},
sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
Expand All @@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=PROMPTS,
sampling_params=sampling_params)

v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)

v2_output = llm.generate(
[{
"prompt": p
} for p in PROMPTS],
sampling_params=sampling_params,
)
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
Expand Down
86 changes: 81 additions & 5 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import time
import weakref
from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
List, Mapping, Optional, Set, Tuple, Type, Union, overload)
from weakref import ReferenceType

import vllm.envs as envs
Expand All @@ -28,7 +28,7 @@
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import weak_bind
from vllm.utils import deprecate_kwargs, weak_bind

logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
Expand Down Expand Up @@ -402,6 +402,21 @@ async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()

@overload # DEPRECATED
async def add_request_async(
self,
request_id: str,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...

@overload
async def add_request_async(
self,
request_id: str,
Expand All @@ -411,8 +426,30 @@ async def add_request_async(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...

@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
async def add_request_async(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
"""Async version of :meth:`add_request`."""
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None

if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
Expand Down Expand Up @@ -774,16 +811,55 @@ async def run_engine_loop(engine_ref: ReferenceType):

# This method does not need to be async, but kept that way
# for backwards compatibility.
async def add_request(
@overload # DEPRECATED
def add_request(
self,
request_id: str,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
...

@overload
def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
...

@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
async def add_request(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None

if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
Expand Down
41 changes: 39 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union
from typing import Set, Type, Union, overload

import torch
from typing_extensions import TypeVar
Expand Down Expand Up @@ -51,7 +51,7 @@
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device, weak_bind
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -686,6 +686,21 @@ def _add_processed_request(
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()

@overload # DEPRECATED
def add_request(
self,
request_id: str,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...

@overload
def add_request(
self,
request_id: str,
Expand All @@ -695,6 +710,24 @@ def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...

@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def add_request(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
"""Add a request to the engine's request pool.
Expand Down Expand Up @@ -737,6 +770,10 @@ def add_request(
>>> # continue the request processing
>>> ...
"""
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None

if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
Expand Down
Loading

0 comments on commit e729f1d

Please sign in to comment.