Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Refactor prompt processing #4028

Merged
merged 139 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 131 commits
Commits
Show all changes
139 commits
Select commit Hold shift + click to select a range
ce770f4
Use discriminated union in prompt parsing
DarkLight1337 Apr 12, 2024
6b016bc
Fix some type errors along the way
DarkLight1337 Apr 12, 2024
7620354
Some more fixes
DarkLight1337 Apr 12, 2024
7c3e6d9
Apply formatter
DarkLight1337 Apr 12, 2024
7bdc84e
Refactor prompt parsing so that it can be shared between Chat Complet…
DarkLight1337 Apr 12, 2024
a7d1098
Make code more readable
DarkLight1337 Apr 12, 2024
8b9d636
Move assertion to a more appropriate place
DarkLight1337 Apr 12, 2024
c48c13a
Add code documentation
DarkLight1337 Apr 12, 2024
3530362
Decompose `_validate_prompt_and_tokenize`
DarkLight1337 Apr 12, 2024
b8feec9
Fix missing import due to renaming
DarkLight1337 Apr 12, 2024
89d9086
Merge branch 'upstream' into openai-typing
DarkLight1337 Apr 13, 2024
cc1a5b3
Fix bug when parsing array of tokens
DarkLight1337 Apr 13, 2024
f9c1135
Add token array to batch completions testing
DarkLight1337 Apr 13, 2024
f2e8180
Replace legacy `conint` with `Annotated` field
DarkLight1337 Apr 14, 2024
797326b
Merge branch 'upstream' into openai-typing
DarkLight1337 Apr 19, 2024
1d00087
Merge branch 'upstream' into openai-typing
DarkLight1337 Apr 23, 2024
a1db4e0
Fix `mypy` error
DarkLight1337 Apr 23, 2024
7b9a3ff
Merge branch 'upstream' into openai-typing
DarkLight1337 Apr 24, 2024
5d42800
Combine prompt inputs
DarkLight1337 Apr 24, 2024
5db2c5e
Fix a bunch of tests
DarkLight1337 Apr 25, 2024
74c5905
Fix LLaVA test
DarkLight1337 Apr 25, 2024
cd8917b
Merge branch 'upstream' into llm-inputs
DarkLight1337 Apr 25, 2024
b49aba7
Fix `benchmark_latency` test
DarkLight1337 Apr 25, 2024
bfd7295
Merge branch 'upstream' into llm-inputs
DarkLight1337 Apr 25, 2024
928b9b9
Merge branch 'upstream' into openai-typing
DarkLight1337 Apr 27, 2024
45c7f23
Merge branch 'upstream' into llm-inputs
DarkLight1337 Apr 27, 2024
493e6ed
Merge branch 'upstream' into llm-inputs
DarkLight1337 Apr 28, 2024
aa5f32a
Merge branch 'upstream' into openai-typing
DarkLight1337 May 1, 2024
20aeceb
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 1, 2024
0f46653
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 3, 2024
c4f3540
Clarify tokenizer usage
DarkLight1337 May 3, 2024
ab8182c
Rename `encode_request -> process_model_inputs`
DarkLight1337 May 3, 2024
eac33e1
Support old API in `LLM.generate`
DarkLight1337 May 3, 2024
703d318
Add tests to ensure old API still works
DarkLight1337 May 3, 2024
19d85f9
Let all entrypoints tests be run at the same time
DarkLight1337 May 3, 2024
034ef30
Merge branch 'upstream' into openai-typing
DarkLight1337 May 7, 2024
baebd99
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 7, 2024
dc9816f
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 8, 2024
29144e9
Merge branch 'upstream' into openai-typing
DarkLight1337 May 8, 2024
80729ad
Merge branch 'upstream' into openai-typing
DarkLight1337 May 9, 2024
1c50600
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 14, 2024
5759dfa
Add tests for LLM.encode and fix corresponding bugs
DarkLight1337 May 14, 2024
cc4bfb5
Apply formatter
DarkLight1337 May 14, 2024
6085b08
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 14, 2024
d5c9731
Rename `_add_requests` to `_validate_and_add_requests` to be more sim…
DarkLight1337 May 14, 2024
4f218a5
Separate `entrypoints` tests into two groups
DarkLight1337 May 14, 2024
a9201d0
Fix memory profiling error
DarkLight1337 May 14, 2024
ceebfa6
Fix memory usage for embedding server
DarkLight1337 May 15, 2024
7d991cd
Update embeddings API to use new imputs
DarkLight1337 May 15, 2024
0e79dfb
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 15, 2024
f23583c
Merge branch 'upstream' into openai-typing
DarkLight1337 May 15, 2024
07c0a2e
Remove unnecessary commas
DarkLight1337 May 15, 2024
2c0d58f
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 15, 2024
48e7a4a
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 16, 2024
0cc0b74
Merge branch 'upstream' into openai-typing
DarkLight1337 May 17, 2024
b6c0e29
Merge branch 'upstream' into llm-inputs
DarkLight1337 May 20, 2024
3097582
Merge `llm` groups back into one by enabling gc
DarkLight1337 May 20, 2024
7bbd123
Improve documentation for LLM/engine
DarkLight1337 May 20, 2024
056eb61
Direct readers to the `PromptInputs` class
DarkLight1337 May 22, 2024
b3b990a
Separate `_run_engine` from `_validate_and_add_requests`
DarkLight1337 May 22, 2024
2169def
Add flag for deprecating legacy API
DarkLight1337 May 22, 2024
3dbded1
Add tests for `deprecate_kwargs`
DarkLight1337 May 22, 2024
8e20317
Apply formatter
DarkLight1337 May 22, 2024
fdccaa2
Rename attribute to be less misleading
DarkLight1337 May 22, 2024
77ee1c8
Renable using `'fork'` start method and improve speed by using `torch…
DarkLight1337 May 23, 2024
b1bcdd1
Simplify logic of casting request output
DarkLight1337 May 23, 2024
44b4681
Improve code readability
DarkLight1337 May 23, 2024
50343cb
Fix `multi_modal_data` being a required key
DarkLight1337 May 23, 2024
45aa420
Fix index out of range error
DarkLight1337 May 23, 2024
d4e2589
Use a flag to control whether to check output types
DarkLight1337 May 23, 2024
c07b579
Simplify flags
DarkLight1337 May 23, 2024
9d56eb0
Move output validation to a more appropriate location
DarkLight1337 May 23, 2024
bc05031
Add message to deprecation notice
DarkLight1337 May 23, 2024
95d4130
Apply formatter
DarkLight1337 May 23, 2024
cc84f65
Remove unused parameter in `_validate_and_add_requests` and fix test
DarkLight1337 May 24, 2024
19de59f
Merge branch 'upstream' into openai-typing
DarkLight1337 May 24, 2024
6c5d4a6
Simplify code
DarkLight1337 May 25, 2024
fd2da12
Move attribute assignment outside `_init_tokenizer`
DarkLight1337 May 25, 2024
d78de94
Only emit warning once
DarkLight1337 May 25, 2024
8a86829
Simplify assignment expression
DarkLight1337 May 25, 2024
731ac0e
Place special case at the start
DarkLight1337 May 25, 2024
1be75d2
Merge branch 'llm-inputs' into openai-typing
DarkLight1337 May 25, 2024
2d1a0bc
move API reference to under developer doc
ywang96 May 25, 2024
7b8ce2c
Fix links in docs
DarkLight1337 May 26, 2024
fff21a1
Remove unnecessary code to avoid repeated warning
DarkLight1337 May 26, 2024
18d5bcd
Merge branch 'llm-inputs' into openai-typing
DarkLight1337 May 26, 2024
fab7f92
Parse and batch the prompt using #4328
DarkLight1337 May 26, 2024
cb057eb
Move logging from async engine to server
DarkLight1337 May 26, 2024
b682dfa
Fix missing args in test
DarkLight1337 May 26, 2024
c982b88
Merge branch 'upstream' into openai-typing
DarkLight1337 May 26, 2024
d87ae34
Improve control flow
DarkLight1337 May 26, 2024
c73e972
Fix missing input text when processing Completions API output
DarkLight1337 May 26, 2024
7d2c08b
Remove extra attribute
DarkLight1337 May 26, 2024
3eb1eaa
Merge branch 'upstream' into openai-typing
DarkLight1337 May 29, 2024
9862802
Merge branch 'upstream' into openai-typing
DarkLight1337 May 29, 2024
62418d6
Merge branch 'upstream' into openai-typing
DarkLight1337 Jun 3, 2024
ce48c94
Merge branch 'upstream' into openai-typing
DarkLight1337 Jun 8, 2024
2936ee0
Update docs
DarkLight1337 Jun 8, 2024
c932b89
Merge branch 'upstream' into openai-typing
DarkLight1337 Jun 10, 2024
156c017
Merge branch 'upstream' into openai-typing
DarkLight1337 Jun 11, 2024
5d57ce2
Merge branch 'upstream' into openai-typing
DarkLight1337 Jun 14, 2024
55044ba
Merge branch 'upstream' into openai-typing
DarkLight1337 Jun 15, 2024
87dc5d3
Merge branch 'upstream' into openai-typing
DarkLight1337 Jun 15, 2024
118f935
Merge branch 'upstream' into openai-typing
DarkLight1337 Jun 15, 2024
7f5668e
Merge branch 'upstream' into openai-typing
DarkLight1337 Jun 25, 2024
b0fa8ff
Fix type errors
DarkLight1337 Jun 25, 2024
4aed3e4
Merge branch 'upstream' into openai-typing
DarkLight1337 Jun 27, 2024
187fb24
Merge branch 'upstream' into openai-typing
DarkLight1337 Jul 2, 2024
4fc018c
Fix bad merge
DarkLight1337 Jul 2, 2024
e9ea1ac
Merge branch 'upstream' into openai-typing
DarkLight1337 Jul 5, 2024
81e8a55
Merge branch 'upstream' into openai-typing
DarkLight1337 Jul 18, 2024
59870cf
Fix linter errors
DarkLight1337 Jul 18, 2024
e07c2be
Fix docs
DarkLight1337 Jul 18, 2024
1b41a29
Remove logging params from tokenization endpoint
DarkLight1337 Jul 18, 2024
3e88444
Remove duplicated function
DarkLight1337 Jul 18, 2024
0b2eac0
Fix type errors
DarkLight1337 Jul 18, 2024
165c0b1
Handle lora and prompt adapters for embeddings
DarkLight1337 Jul 18, 2024
1dc3d21
Remove redundant code
DarkLight1337 Jul 18, 2024
6264c82
Use type union for tokenization requests
DarkLight1337 Jul 18, 2024
0492d79
Fix `request_id` and add logging
DarkLight1337 Jul 18, 2024
4aa552a
yapf
DarkLight1337 Jul 18, 2024
994f3ee
Update request id
DarkLight1337 Jul 18, 2024
fe0629d
Enable logging
DarkLight1337 Jul 18, 2024
d85ba52
Fix invalid attribute access
DarkLight1337 Jul 18, 2024
d0fd1f4
Add `add_special_tokens` to Completions API to simplify the logic
DarkLight1337 Jul 19, 2024
2c38ccf
Factor out logging args
DarkLight1337 Jul 19, 2024
b0f9595
Make optional args keyword-only; cleanup
DarkLight1337 Jul 19, 2024
4fddfa0
Remove extra line
DarkLight1337 Jul 19, 2024
18a1fac
Move definition back
DarkLight1337 Jul 19, 2024
61cf999
Avoid creating new list
DarkLight1337 Jul 19, 2024
76124a9
Update args in test_serving_chat.py
DarkLight1337 Jul 19, 2024
68d2c96
Silently ignore prompt adapter for tokenization
DarkLight1337 Jul 19, 2024
7a92e6e
Clean
DarkLight1337 Jul 19, 2024
e78dd28
Use HTTP boolean format
DarkLight1337 Jul 19, 2024
d56c9cc
Fix incorrectly allowing some sampling params to be `None`
DarkLight1337 Jul 20, 2024
f62edef
Fix inconsistent arg availability and ordering
DarkLight1337 Jul 20, 2024
032eeec
Fix wrong model class
DarkLight1337 Jul 20, 2024
2adb0ad
Merge branch 'upstream' into openai-typing
DarkLight1337 Jul 21, 2024
0a9b0d8
isort
DarkLight1337 Jul 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptStrictInputs
from vllm.inputs import PromptInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser

Expand Down Expand Up @@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_inputs: List[PromptStrictInputs] = [{
dummy_inputs: List[PromptInputs] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]

Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.

Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`.
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.

Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/offline_inference/llm_inputs.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
LLM Inputs
==========

.. autodata:: vllm.inputs.PromptStrictInputs
.. autodata:: vllm.inputs.PromptInputs

.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
internally for each model.


To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:

* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
Expand Down
5 changes: 4 additions & 1 deletion tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ async def _async_serving_chat_init():
model_config,
served_model_names=[MODEL_NAME],
response_role="assistant",
chat_template=CHAT_TEMPLATE)
chat_template=CHAT_TEMPLATE,
lora_modules=None,
prompt_adapters=None,
request_logger=None)
return serving_completion


Expand Down
4 changes: 2 additions & 2 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
Expand All @@ -19,7 +19,7 @@
"__version__",
"LLM",
"ModelRegistry",
"PromptStrictInputs",
"PromptInputs",
"TextPrompt",
"TokensPrompt",
"SamplingParams",
Expand Down
7 changes: 0 additions & 7 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,6 @@ class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False
disable_log_requests: bool = False
max_log_len: Optional[int] = None

@staticmethod
def add_cli_args(parser: FlexibleArgumentParser,
Expand All @@ -815,12 +814,6 @@ def add_cli_args(parser: FlexibleArgumentParser,
parser.add_argument('--disable-log-requests',
action='store_true',
help='Disable logging requests.')
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
return parser


Expand Down
63 changes: 22 additions & 41 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import time
from functools import partial
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
Set, Tuple, Type, Union)
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)

from transformers import PreTrainedTokenizer

Expand Down Expand Up @@ -147,7 +147,10 @@ def process_exception(self,
logger.info("Finished request %s.", request_id)
self.abort_request(request_id)

def add_request(self, request_id: str,
def add_request(self,
request_id: str,
*,
verbose: bool = False,
**engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background
loop iteration."""
Expand All @@ -162,6 +165,9 @@ def add_request(self, request_id: str,

self.new_requests_event.set()

if verbose:
logger.info("Added request %s.", request_id)

return stream

def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
Expand Down Expand Up @@ -295,14 +301,14 @@ async def process_model_inputs_async(
return self.input_processor(llm_inputs)

async def add_request_async(
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
Expand Down Expand Up @@ -349,8 +355,6 @@ class AsyncLLMEngine:
async frontend will be executed in a separate process as the
model workers.
log_requests: Whether to log the requests.
max_log_len: Maximum number of prompt characters or prompt ID numbers
being printed in log.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args: Arguments for :class:`LLMEngine`.
Expand All @@ -364,13 +368,11 @@ def __init__(self,
engine_use_ray: bool,
*args,
log_requests: bool = True,
max_log_len: Optional[int] = None,
start_engine_loop: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray
self.log_requests = log_requests
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs)

self.background_loop: Optional[asyncio.Future] = None
Expand Down Expand Up @@ -450,7 +452,6 @@ def from_engine_args(
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
Expand Down Expand Up @@ -649,30 +650,9 @@ async def add_request(
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncStream:
if self.log_requests:
if isinstance(inputs, str):
shortened_prompt = inputs
shortened_token_ids = None
else:
shortened_prompt = inputs.get("prompt")
shortened_token_ids = inputs.get("prompt_token_ids")

max_log_len = self.max_log_len
if max_log_len is not None:
if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:max_log_len]

logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt, params,
shortened_token_ids, lora_request)

if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
Expand All @@ -688,6 +668,7 @@ async def add_request(

stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
inputs=inputs,
params=params,
arrival_time=arrival_time,
Expand All @@ -703,7 +684,7 @@ async def generate(
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
Expand Down Expand Up @@ -786,7 +767,7 @@ async def encode(
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model.

Expand Down Expand Up @@ -864,7 +845,7 @@ async def _process_request(
params: Union[SamplingParams, PoolingParams],
*,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
Expand Down
9 changes: 5 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
Mapping, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union

Expand Down Expand Up @@ -508,7 +509,7 @@ def _add_processed_request(
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> None:
# Create the sequences.
block_size = self.cache_config.block_size
Expand Down Expand Up @@ -589,7 +590,7 @@ def add_request(
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
"""Add a request to the engine's request pool.
Expand Down Expand Up @@ -663,7 +664,7 @@ def _create_sequence_group_with_sampling(
sampling_params: SamplingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
Expand Down
37 changes: 13 additions & 24 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt,
TextTokensPrompt, TokensPrompt,
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
parse_and_batch_prompt)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -238,7 +237,7 @@ def generate(
@overload
def generate(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
sampling_params: Optional[Union[SamplingParams,
Expand All @@ -255,7 +254,7 @@ def generate(
"instead.")
def generate(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
Expand Down Expand Up @@ -302,9 +301,7 @@ def generate(
prompt_token_ids=prompt_token_ids,
)
else:
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)

if sampling_params is None:
# Use default sampling params.
Expand Down Expand Up @@ -383,7 +380,7 @@ def encode(
@overload
def encode(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
pooling_params: Optional[Union[PoolingParams,
Expand All @@ -400,7 +397,7 @@ def encode(
"instead.")
def encode(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
Expand All @@ -417,7 +414,7 @@ def encode(

Args:
inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptStrictInputs`
batch inference. See :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
Expand Down Expand Up @@ -446,9 +443,7 @@ def encode(
prompt_token_ids=prompt_token_ids,
)
else:
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)

if pooling_params is None:
# Use default pooling params.
Expand Down Expand Up @@ -496,25 +491,19 @@ def _convert_v1_inputs(
inputs: List[PromptInputs] = []
for i in range(num_requests):
if prompts is not None:
if prompt_token_ids is not None:
item = TextTokensPrompt(
prompt=prompts[i],
prompt_token_ids=prompt_token_ids[i])
else:
item = TextPrompt(prompt=prompts[i])
item = TextPrompt(prompt=prompts[i])
elif prompt_token_ids is not None:
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else:
if prompt_token_ids is not None:
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else:
raise AssertionError
raise AssertionError

inputs.append(item)

return inputs

def _validate_and_add_requests(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
inputs: Union[PromptInputs, Sequence[PromptInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
Expand Down
Loading
Loading