From 7e248067aed034536b72070649a150f70873fda1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 7 Nov 2024 17:08:24 -0800 Subject: [PATCH] [V1] Add all_token_ids attribute to Request (#10135) Signed-off-by: Woosuk Kwon Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/v1/core/scheduler.py | 2 +- vllm/v1/engine/llm_engine.py | 2 +- vllm/v1/request.py | 29 ++++++++++++++-- vllm/v1/utils.py | 64 ++++++++++++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 5 deletions(-) create mode 100644 vllm/v1/utils.py diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 41659ff62747d..6017905642172 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -246,7 +246,7 @@ def update_from_output( # NOTE(woosuk): Currently, we assume that each request # generates at most one token at each step. token_id = sampled_token_ids[req_index] - request.output_token_ids.append(token_id) + request.append_output_token_ids(token_id) sampled.append((request, 1)) # TODO: Update the KV cache manager for prefix caching. diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 5f5720480abdc..b538c2c7d63bc 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -324,7 +324,7 @@ def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None: ) for req, num_tokens in sampled: inputs.req_ids.append(req.request_id) - if len(req.output_token_ids) == num_tokens: + if req.num_output_tokens == num_tokens: # The request is first detokenized. inputs.prompt_token_ids.append(req.prompt_token_ids) else: diff --git a/vllm/v1/request.py b/vllm/v1/request.py index be7d4d165d280..087067cdac56f 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -4,6 +4,7 @@ from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import RequestMetrics +from vllm.v1.utils import ConstantList if TYPE_CHECKING: from vllm.inputs import DecoderOnlyInputs @@ -40,17 +41,39 @@ def __init__( self.prompt = inputs.get("prompt") self.prompt_token_ids = inputs["prompt_token_ids"] self.num_prompt_tokens = len(self.prompt_token_ids) - self.output_token_ids: List[int] = [] + self._output_token_ids: List[int] = [] + self._all_token_ids: List[int] = self.prompt_token_ids.copy() self.output_text = "" self.num_computed_tokens = 0 + @property + def output_token_ids(self) -> ConstantList[int]: + # Prevent directly appending to the output_token_ids since + # all_token_ids should also be updated simultaneously. + return ConstantList(self._output_token_ids) + + @property + def all_token_ids(self) -> ConstantList[int]: + # Prevent directly appending to the all_token_ids since + # output_token_ids should also be updated simultaneously + return ConstantList(self._all_token_ids) + + def append_output_token_ids( + self, + token_ids: Union[int, List[int]], + ) -> None: + if isinstance(token_ids, int): + token_ids = [token_ids] + self._output_token_ids.extend(token_ids) + self._all_token_ids.extend(token_ids) + @property def num_tokens(self) -> int: - return self.num_prompt_tokens + len(self.output_token_ids) + return len(self._all_token_ids) @property def num_output_tokens(self) -> int: - return len(self.output_token_ids) + return len(self._output_token_ids) def is_finished(self) -> bool: return RequestStatus.is_finished(self.status) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py new file mode 100644 index 0000000000000..4b26749712e32 --- /dev/null +++ b/vllm/v1/utils.py @@ -0,0 +1,64 @@ +from typing import Generic, List, TypeVar, overload + +T = TypeVar("T") + + +class ConstantList(Generic[T]): + + def __init__(self, x: List[T]) -> None: + self._x = x + + def append(self, item): + raise Exception("Cannot append to a constant list") + + def extend(self, item): + raise Exception("Cannot extend a constant list") + + def insert(self, item): + raise Exception("Cannot insert into a constant list") + + def pop(self, item): + raise Exception("Cannot pop from a constant list") + + def remove(self, item): + raise Exception("Cannot remove from a constant list") + + def clear(self): + raise Exception("Cannot clear a constant list") + + def index(self, item): + return self._x.index(item) + + @overload + def __getitem__(self, item) -> T: + ... + + @overload + def __getitem__(self, s: slice, /) -> List[T]: + ... + + def __getitem__(self, item): + return self._x[item] + + @overload + def __setitem__(self, item, value): + ... + + @overload + def __setitem__(self, s: slice, value, /): + ... + + def __setitem__(self, item, value): + raise Exception("Cannot set item in a constant list") + + def __delitem__(self, item): + raise Exception("Cannot delete item from a constant list") + + def __iter__(self): + return iter(self._x) + + def __contains__(self, item): + return item in self._x + + def __len__(self): + return len(self._x)