From bcb371ef4ae2cf9ac6eff2751d85f6511462ec24 Mon Sep 17 00:00:00 2001 From: mgerstgrasser Date: Sat, 30 Mar 2024 12:32:37 -0700 Subject: [PATCH 1/6] make detokenization optional --- vllm/engine/llm_engine.py | 20 +++++++++++--------- vllm/sampling_params.py | 7 +++++++ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dec42c633b10b..7d28723c8a80f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -415,7 +415,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Process prompt logprobs prompt_logprobs = outputs.prompt_logprobs - if prompt_logprobs is not None: + if prompt_logprobs is not None and seq_group.sampling_params.detokenize: self.detokenizer.decode_prompt_logprobs_inplace( seq_group, prompt_logprobs) seq_group.prompt_logprobs = prompt_logprobs @@ -461,8 +461,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, child_seqs.append((parent, parent)) for seq, _ in child_seqs: - self.detokenizer.decode_sequence_inplace(seq, - seq_group.sampling_params) + if seq_group.sampling_params.detokenize: + self.detokenizer.decode_sequence_inplace( + seq, seq_group.sampling_params) self._check_stop(seq, seq_group.sampling_params) # Non-beam search case @@ -774,12 +775,13 @@ def _check_stop(self, seq: Sequence, if seq.get_output_len() < sampling_params.min_tokens: return - for stop_str in sampling_params.stop: - if seq.output_text.endswith(stop_str): - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return + if sampling_params.detokenize: + for stop_str in sampling_params.stop: + if seq.output_text.endswith(stop_str): + self._finalize_sequence(seq, sampling_params, stop_str) + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str + return last_token_id = seq.get_last_token_id() if last_token_id in sampling_params.stop_token_ids: stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 6f81ee31f84dd..ffedaf596a5c7 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -88,6 +88,7 @@ class SamplingParams: log probability of the sampled token, so there may be up to `logprobs+1` elements in the response. prompt_logprobs: Number of log probabilities to return per prompt token. + detokenize: Whether to detokenize the output. Defaults to True. skip_special_tokens: Whether to skip special tokens in the output. spaces_between_special_tokens: Whether to add spaces between special tokens in the output. Defaults to True. @@ -118,6 +119,7 @@ def __init__( min_tokens: int = 0, logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None, + detokenize: bool = True, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, logits_processors: Optional[List[LogitsProcessor]] = None, @@ -150,6 +152,7 @@ def __init__( self.min_tokens = min_tokens self.logprobs = logprobs self.prompt_logprobs = prompt_logprobs + self.detokenize = detokenize self.skip_special_tokens = skip_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens self.logits_processors = logits_processors @@ -210,6 +213,10 @@ def _verify_args(self) -> None: if self.prompt_logprobs is not None and self.prompt_logprobs < 0: raise ValueError(f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}.") + if len(self.stop) > 0 and self.detokenize is False: + raise ValueError( + "stop strings are only supported when detokenize is True. " + "Set detokenize=True to use stop.") def _verify_beam_search(self) -> None: if self.best_of == 1: From 3dd6255612c5251ae275a0d972110557562d33c9 Mon Sep 17 00:00:00 2001 From: mgerstgrasser Date: Sat, 30 Mar 2024 21:58:28 -0700 Subject: [PATCH 2/6] Allow skipping detokenization in API server --- vllm/entrypoints/api_server.py | 50 ++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 2a47eae112c12..02cdab6cbf004 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -42,21 +42,33 @@ async def generate(request: Request) -> Response: - other fields: the sampling parameters (See `SamplingParams` for details). """ request_dict = await request.json() - prompt = request_dict.pop("prompt") + prompt = request_dict.pop("prompt", None) + prompt_token_ids = request_dict.pop("prompt_token_ids", None) stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - results_generator = engine.generate(prompt, sampling_params, request_id) + results_generator = engine.generate( + prompt=prompt, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + request_id=request_id, + ) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: - prompt = request_output.prompt - text_outputs = [ - prompt + output.text for output in request_output.outputs - ] - ret = {"text": text_outputs} + if sampling_params.detokenize: + prompt = request_output.prompt if final_output.prompt else "" + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + ret = {"text": text_outputs} + else: + ret = { + "token_ids": + [output.token_ids for output in request_output.outputs] + } yield (json.dumps(ret) + "\0").encode("utf-8") if stream: @@ -72,9 +84,18 @@ async def stream_results() -> AsyncGenerator[bytes, None]: final_output = request_output assert final_output is not None - prompt = final_output.prompt - text_outputs = [prompt + output.text for output in final_output.outputs] - ret = {"text": text_outputs} + if sampling_params.detokenize: + prompt = final_output.prompt if final_output.prompt else "" + text_outputs = [ + prompt + output.text for output in final_output.outputs + ] + ret = { + "text": text_outputs, + } + else: + ret = { + "token_ids": [output.token_ids for output in final_output.outputs], + } return JSONResponse(ret) @@ -99,6 +120,13 @@ async def stream_results() -> AsyncGenerator[bytes, None]: type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument( + "--uvicorn-log-level", + type=str, + default="info", + choices=["debug", "info", "warning", "error", "critical", "trace"], + help="log level for uvicorn", + ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() engine_args = AsyncEngineArgs.from_cli_args(args) @@ -109,7 +137,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: uvicorn.run(app, host=args.host, port=args.port, - log_level="debug", + log_level=args.uvicorn_log_level, timeout_keep_alive=TIMEOUT_KEEP_ALIVE, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, From 36fb3be29532a0547de8de9640c3e1aefa21241e Mon Sep 17 00:00:00 2001 From: mgerstgrasser Date: Sun, 31 Mar 2024 17:14:17 -0700 Subject: [PATCH 3/6] Revert "Allow skipping detokenization in API server" This reverts commit 3dd6255612c5251ae275a0d972110557562d33c9. --- vllm/entrypoints/api_server.py | 50 ++++++++-------------------------- 1 file changed, 11 insertions(+), 39 deletions(-) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 02cdab6cbf004..2a47eae112c12 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -42,33 +42,21 @@ async def generate(request: Request) -> Response: - other fields: the sampling parameters (See `SamplingParams` for details). """ request_dict = await request.json() - prompt = request_dict.pop("prompt", None) - prompt_token_ids = request_dict.pop("prompt_token_ids", None) + prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - results_generator = engine.generate( - prompt=prompt, - prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params, - request_id=request_id, - ) + results_generator = engine.generate(prompt, sampling_params, request_id) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: - if sampling_params.detokenize: - prompt = request_output.prompt if final_output.prompt else "" - text_outputs = [ - prompt + output.text for output in request_output.outputs - ] - ret = {"text": text_outputs} - else: - ret = { - "token_ids": - [output.token_ids for output in request_output.outputs] - } + prompt = request_output.prompt + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + ret = {"text": text_outputs} yield (json.dumps(ret) + "\0").encode("utf-8") if stream: @@ -84,18 +72,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]: final_output = request_output assert final_output is not None - if sampling_params.detokenize: - prompt = final_output.prompt if final_output.prompt else "" - text_outputs = [ - prompt + output.text for output in final_output.outputs - ] - ret = { - "text": text_outputs, - } - else: - ret = { - "token_ids": [output.token_ids for output in final_output.outputs], - } + prompt = final_output.prompt + text_outputs = [prompt + output.text for output in final_output.outputs] + ret = {"text": text_outputs} return JSONResponse(ret) @@ -120,13 +99,6 @@ async def stream_results() -> AsyncGenerator[bytes, None]: type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") - parser.add_argument( - "--uvicorn-log-level", - type=str, - default="info", - choices=["debug", "info", "warning", "error", "critical", "trace"], - help="log level for uvicorn", - ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() engine_args = AsyncEngineArgs.from_cli_args(args) @@ -137,7 +109,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: uvicorn.run(app, host=args.host, port=args.port, - log_level=args.uvicorn_log_level, + log_level="debug", timeout_keep_alive=TIMEOUT_KEEP_ALIVE, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, From 9627425aedc31aaa7be6e9f58765e86493c9b49a Mon Sep 17 00:00:00 2001 From: mgerstgrasser Date: Wed, 3 Apr 2024 17:03:00 -0700 Subject: [PATCH 4/6] Add note. --- vllm/sampling_params.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index ffedaf596a5c7..f4dce4eecf41e 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -152,6 +152,9 @@ def __init__( self.min_tokens = min_tokens self.logprobs = logprobs self.prompt_logprobs = prompt_logprobs + # NOTE: This parameter is only exposed at the engine level for now. + # It is not exposed in the OpenAI API server, as the OpenAI API does + # not support returning only a list of token IDs. self.detokenize = detokenize self.skip_special_tokens = skip_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens From fedea30123d6e8b43ed07404c4b147ab92473458 Mon Sep 17 00:00:00 2001 From: mgerstgrasser Date: Wed, 3 Apr 2024 17:03:10 -0700 Subject: [PATCH 5/6] Add detokenization test --- tests/engine/test_detokenization.py | 32 +++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 tests/engine/test_detokenization.py diff --git a/tests/engine/test_detokenization.py b/tests/engine/test_detokenization.py new file mode 100644 index 0000000000000..f77f6d0725b6b --- /dev/null +++ b/tests/engine/test_detokenization.py @@ -0,0 +1,32 @@ +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.sampling_params import SamplingParams + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +def test_computed_prefix_blocks(model: str): + # This test checks if the engine generates completions both with and + # without optional detokenization, that detokenization includes text + # and no-detokenization doesn't, and that both completions have the same + # token_ids. + prompt = ( + "You are a helpful assistant. How do I build a car from cardboard and " + "paper clips? Is there an easy to follow video tutorial available " + "online for free?") + + llm = LLM(model=model) + sampling_params = SamplingParams(max_tokens=10, + temperature=0.0, + detokenize=False) + + outputs_no_detokenization = llm.generate(prompt, + sampling_params)[0].outputs[0] + sampling_params.detokenize = True + outputs_with_detokenization = llm.generate(prompt, + sampling_params)[0].outputs[0] + + assert outputs_no_detokenization.text == '' + assert outputs_with_detokenization.text != '' + assert outputs_no_detokenization.token_ids == \ + outputs_with_detokenization.token_ids From 61e33a2d8bf056b5082a068e0b14d19478db8702 Mon Sep 17 00:00:00 2001 From: Matthias Gerstgrasser Date: Wed, 3 Apr 2024 20:03:43 -0700 Subject: [PATCH 6/6] Simplification Co-authored-by: Nick Hill --- vllm/sampling_params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index f4dce4eecf41e..bbba02a833fc6 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -216,7 +216,7 @@ def _verify_args(self) -> None: if self.prompt_logprobs is not None and self.prompt_logprobs < 0: raise ValueError(f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}.") - if len(self.stop) > 0 and self.detokenize is False: + if self.stop and not self.detokenize: raise ValueError( "stop strings are only supported when detokenize is True. " "Set detokenize=True to use stop.")