From 22de45235c6dd14e901e089971635ec655d5fbe0 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 4 Mar 2024 11:54:06 -0800 Subject: [PATCH] Push logprob generation to LLMEngine (#3065) Co-authored-by: Avnish Narayan --- tests/entrypoints/test_openai_server.py | 61 ++- tests/samplers/test_logprobs.py | 42 +- tests/worker/spec_decode/utils.py | 12 +- vllm/config.py | 2 + vllm/engine/arg_utils.py | 10 +- vllm/engine/async_llm_engine.py | 29 +- vllm/engine/llm_engine.py | 42 +- vllm/entrypoints/openai/serving_chat.py | 236 ++++++----- vllm/entrypoints/openai/serving_completion.py | 391 +++++++++--------- vllm/entrypoints/openai/serving_engine.py | 23 +- vllm/model_executor/layers/sampler.py | 15 +- vllm/sequence.py | 25 +- vllm/worker/spec_decode/multi_step_worker.py | 2 +- 13 files changed, 555 insertions(+), 335 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index e426cf7eed72b..f4a6e44d88a87 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -213,14 +213,14 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, messages=messages, max_tokens=10, logprobs=True, - top_logprobs=10) + top_logprobs=5) assert chat_completion.id is not None assert chat_completion.choices is not None and len( chat_completion.choices) == 1 assert chat_completion.choices[0].message is not None assert chat_completion.choices[0].logprobs is not None assert chat_completion.choices[0].logprobs.top_logprobs is not None - assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 10 + assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5 message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 10 assert message.role == "assistant" @@ -229,7 +229,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, # test multi-turn dialogue messages.append({"role": "user", "content": "express your result in json"}) chat_completion = await client.chat.completions.create( - model=MODEL_NAME, + model=model_name, messages=messages, max_tokens=10, ) @@ -237,6 +237,61 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + # Default max_logprobs is 5, so this should raise an error + with pytest.raises((openai.BadRequestError, openai.APIError)): + stream = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=10, + stream=True) + async for chunk in stream: + ... + + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=10, + stream=False) + + with pytest.raises((openai.BadRequestError, openai.APIError)): + stream = await client.completions.create(model=model_name, + prompt="Test", + max_tokens=10, + logprobs=10, + stream=True) + async for chunk in stream: + ... + + with pytest.raises(openai.BadRequestError): + await client.completions.create(model=model_name, + prompt="Test", + max_tokens=10, + logprobs=10, + stream=False) + + # the server should still work afterwards + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + stream=False) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 + + @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 0ea3704462fcb..1abb55f021214 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -1,5 +1,6 @@ import pytest import torch +from tests.conftest import VllmRunner from vllm import SamplingParams @@ -16,6 +17,7 @@ def test_get_prompt_logprobs( example_prompts, ): max_tokens = 5 + num_top_logprobs = 6 hf_model = hf_runner(model, dtype=dtype) hf_logprobs = hf_model.generate_greedy_logprobs( example_prompts, @@ -23,19 +25,32 @@ def test_get_prompt_logprobs( ) del hf_model - vllm_model = vllm_runner(model, dtype=dtype) + vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs) vllm_sampling_params = SamplingParams(max_tokens=max_tokens, - logprobs=5, + logprobs=num_top_logprobs, prompt_logprobs=5, temperature=0.0) vllm_results = vllm_model.model.generate( example_prompts, sampling_params=vllm_sampling_params) - del vllm_model # Test whether logprobs are included in the results. for result in vllm_results: assert result.prompt_logprobs is not None assert result.outputs[0].logprobs is not None + assert len(result.outputs[0].logprobs) == max_tokens + for logprobs in result.outputs[0].logprobs: + assert len(logprobs) == num_top_logprobs + output_text = result.outputs[0].text + output_string_from_most_likely_tokens = [] + for top_logprobs in result.outputs[0].logprobs: + top_logprob = next(iter(top_logprobs.values())) + output_string_from_most_likely_tokens.append( + top_logprob.decoded_token) + output_string_from_most_likely_tokens = "".join( + output_string_from_most_likely_tokens) + assert output_text == output_string_from_most_likely_tokens, ( + "The output text from the top logprob for each token position " + "should be the same as the output text in the result.") # Test whether prompt logprobs are consistent with HF for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs): @@ -43,14 +58,29 @@ def test_get_prompt_logprobs( vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:] for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs): for token_id, logprob in vllm_prompt_logprob_dict.items(): - torch.testing.assert_close(logprob, + torch.testing.assert_close(logprob.logprob, hf_logprob[0][i][token_id].item(), atol=1e-2, rtol=1e-2) vllm_sample_logprobs = vllm_result.outputs[0].logprobs - for i, vllm_sample_logprob_dict in enumerate(vllm_sample_logprobs): - for token_id, logprob in vllm_sample_logprob_dict.items(): + for i, top_logprobs in enumerate(vllm_sample_logprobs): + for token_id, sample_logprob in top_logprobs.items(): + logprob = sample_logprob.logprob torch.testing.assert_close(logprob, hf_logprob[i][-1][token_id].item(), atol=1e-2, rtol=1e-2) + assert isinstance(sample_logprob.decoded_token, str), \ + ("The token should be decoded by the time it is returned " + " to the user.") + + +def test_max_logprobs(): + runner = VllmRunner("facebook/opt-125m", max_logprobs=1) + vllm_sampling_params = SamplingParams(logprobs=1) + # should pass + runner.generate(["Hello world"], sampling_params=vllm_sampling_params) + + bad_sampling_params = SamplingParams(logprobs=2) + with pytest.raises(ValueError): + runner.generate(["Hello world"], sampling_params=bad_sampling_params) diff --git a/tests/worker/spec_decode/utils.py b/tests/worker/spec_decode/utils.py index 8d74509fea488..fa8767cf898aa 100644 --- a/tests/worker/spec_decode/utils.py +++ b/tests/worker/spec_decode/utils.py @@ -4,7 +4,7 @@ from vllm.worker.worker import Worker from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.engine.arg_utils import EngineArgs -from vllm.sequence import SequenceGroupMetadata, SequenceData +from vllm.sequence import Logprob, SequenceGroupMetadata, SequenceData from vllm.sampling_params import SamplingParams from vllm.worker.cache_engine import CacheEngine from vllm.model_executor.utils import set_random_seed @@ -166,13 +166,15 @@ def create_seq_group_metadata_from_prompts( def assert_logprobs_dict_allclose( - actual_logprobs: List[Dict[int, float]], - expected_logprobs: List[Dict[int, float]]) -> None: + actual_logprobs: List[Dict[int, Logprob]], + expected_logprobs: List[Dict[int, Logprob]]) -> None: for single_step_actual_logprobs, single_step_expected_logprobs in zip( actual_logprobs, expected_logprobs): assert set(single_step_actual_logprobs.keys()) == set( single_step_expected_logprobs.keys()) for token_id in single_step_actual_logprobs: - actual = torch.tensor(single_step_actual_logprobs[token_id]) - expected = torch.tensor(single_step_expected_logprobs[token_id]) + actual = torch.tensor( + single_step_actual_logprobs[token_id].logprob) + expected = torch.tensor( + single_step_expected_logprobs[token_id].logprob) assert torch.allclose(actual, expected) diff --git a/vllm/config.py b/vllm/config.py index e39fd7265689f..ef9a920f29c2a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -79,6 +79,7 @@ def __init__( quantization: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, + max_logprobs: int = 5, ) -> None: self.model = model self.tokenizer = tokenizer @@ -93,6 +94,7 @@ def __init__( self.quantization = quantization self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture + self.max_logprobs = max_logprobs if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": # download model from ModelScope hub, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6882e8be34d11..c3dccdd5bb50b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -31,6 +31,7 @@ class EngineArgs: max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 max_paddings: int = 256 + max_logprobs: int = 5 # OpenAI default value disable_log_stats: bool = False revision: Optional[str] = None code_revision: Optional[str] = None @@ -212,6 +213,12 @@ def add_cli_args( type=int, default=EngineArgs.max_paddings, help='maximum number of paddings in a batch') + parser.add_argument( + '--max-logprobs', + type=int, + default=EngineArgs.max_logprobs, + help=('max number of log probs to return logprobs is specified in' + ' SamplingParams')) parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') @@ -300,7 +307,8 @@ def create_engine_configs( self.trust_remote_code, self.download_dir, self.load_format, self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, - self.enforce_eager, self.max_context_len_to_capture) + self.enforce_eager, self.max_context_len_to_capture, + self.max_logprobs) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 9e52d20ca4980..df66139fddcd1 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -47,7 +47,7 @@ def __init__(self, request_id: str) -> None: self._queue = asyncio.Queue() self._finished = False - def put(self, item: RequestOutput) -> None: + def put(self, item: Union[RequestOutput, Exception]) -> None: if self._finished: return self._queue.put_nowait(item) @@ -110,6 +110,17 @@ def process_request_output(self, logger.info(f"Finished request {request_id}.") self.abort_request(request_id) + def process_exception(self, + request_id: str, + exception: Exception, + *, + verbose: bool = False) -> None: + """Propagate an exception from the engine.""" + self._request_streams[request_id].put(exception) + if verbose: + logger.info(f"Finished request {request_id}.") + self.abort_request(request_id) + def add_request(self, request_id: str, **engine_add_request_kwargs) -> AsyncStream: """Add a request to be sent to the engine on the next background @@ -377,10 +388,18 @@ async def engine_step(self) -> bool: for new_request in new_requests: # Add the request into the vLLM engine's waiting queue. # TODO: Maybe add add_request_batch to reduce Ray overhead - if self.engine_use_ray: - await self.engine.add_request.remote(**new_request) - else: - await self.engine.add_request_async(**new_request) + try: + if self.engine_use_ray: + await self.engine.add_request.remote(**new_request) + else: + await self.engine.add_request_async(**new_request) + except ValueError as e: + # TODO: use a vLLM specific error for failed validation + self._request_tracker.process_exception( + new_request["request_id"], + e, + verbose=self.log_requests, + ) if finished_requests: await self._engine_abort(finished_requests) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8a2573034c940..703756996b7f7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -18,7 +18,7 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, +from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, TokenizerGroup) @@ -473,6 +473,13 @@ def add_request( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + max_logprobs = self.get_model_config().max_logprobs + if (sampling_params.logprobs + and sampling_params.logprobs > max_logprobs) or ( + sampling_params.prompt_logprobs + and sampling_params.prompt_logprobs > max_logprobs): + raise ValueError(f"Cannot request more than " + f"{max_logprobs} logprobs.") if arrival_time is None: arrival_time = time.monotonic() prompt_token_ids = self.encode_request( @@ -583,6 +590,13 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Process prompt logprobs prompt_logprobs = outputs.prompt_logprobs if prompt_logprobs is not None: + # We can pick any sequence for the prompt. + seq = next(iter(seq_group.seqs_dict.values())) + all_token_ids = seq.get_token_ids() + for i, prompt_logprobs_for_token in enumerate(prompt_logprobs): + self._decode_logprobs(seq, seq_group.sampling_params, + prompt_logprobs_for_token, + all_token_ids[:i]) seq_group.prompt_logprobs = prompt_logprobs # Process samples @@ -930,12 +944,36 @@ def _get_stats(self, time_e2e_requests=time_e2e_requests, ) + def _decode_logprobs(self, seq: Sequence, prms: SamplingParams, + logprobs: Dict[int, Logprob], + all_input_ids: List[int]) -> None: + if not logprobs: + return + for token_id, sample_logprob in logprobs.items(): + if (sample_logprob.decoded_token is None and token_id != -1): + all_input_ids_with_logprob = all_input_ids[:-1] + [token_id] + _, new_text, prefix_offset, read_offset = detokenize_incrementally( + self.get_tokenizer_for_seq(seq), + all_input_ids=all_input_ids_with_logprob, + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms. + spaces_between_special_tokens, + ) + sample_logprob.decoded_token = new_text + def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" + all_input_ids = seq.get_token_ids() + self._decode_logprobs(seq, prms, seq.output_logprobs[-1], + all_input_ids) + (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( self.get_tokenizer_for_seq(seq), - all_input_ids=seq.get_token_ids(), + all_input_ids=all_input_ids, prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, read_offset=seq.read_offset, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index f4ad0aa5a0184..ba352f18f6454 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -82,8 +82,12 @@ async def create_chat_completion( return self.chat_completion_stream_generator( request, result_generator, request_id) else: - return await self.chat_completion_full_generator( - request, raw_request, result_generator, request_id) + try: + return await self.chat_completion_full_generator( + request, raw_request, result_generator, request_id) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) def get_chat_request_role(self, request: ChatCompletionRequest) -> str: if request.add_generation_prompt: @@ -99,117 +103,133 @@ async def chat_completion_stream_generator( model_name = request.model created_time = int(time.monotonic()) chunk_object_type = "chat.completion.chunk" - - # Send first response for each request.n (index) with the role - role = self.get_chat_request_role(request) - for i in range(request.n): - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(role=role), - logprobs=None, - finish_reason=None) - chunk = ChatCompletionStreamResponse(id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - - # Send response to echo the input portion of the last message - if request.echo: - last_msg_content = "" - if request.messages and isinstance( - request.messages, list) and request.messages[-1].get( - "content") and request.messages[-1].get( - "role") == role: - last_msg_content = request.messages[-1]["content"] - - if last_msg_content: - for i in range(request.n): - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(content=last_msg_content), - finish_reason=None) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - logprobs=None, - model=model_name) - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" + first_iteration = True # Send response for each token for each request.n (index) previous_texts = [""] * request.n previous_num_tokens = [0] * request.n finish_reason_sent = [False] * request.n - async for res in result_generator: - res: RequestOutput - for output in res.outputs: - i = output.index - - if finish_reason_sent[i]: - continue - - delta_token_ids = output.token_ids[previous_num_tokens[i]:] - top_logprobs = output.logprobs[ - previous_num_tokens[i]:] if output.logprobs else None - - if request.logprobs: - logprobs = self._create_logprobs( - token_ids=delta_token_ids, - top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - initial_text_offset=len(previous_texts[i]), - ) - else: - logprobs = None - - delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) - if output.finish_reason is None: - # Send token-by-token response for each request.n - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(content=delta_text), - logprobs=logprobs, - finish_reason=None) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - else: - # Send the finish response for each request.n only once - prompt_tokens = len(res.prompt_token_ids) - final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=previous_num_tokens[i], - total_tokens=prompt_tokens + previous_num_tokens[i], - ) - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(content=delta_text), - logprobs=logprobs, - finish_reason=output.finish_reason) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - if final_usage is not None: - chunk.usage = final_usage - data = chunk.model_dump_json(exclude_unset=True, - exclude_none=True) - yield f"data: {data}\n\n" - finish_reason_sent[i] = True + try: + async for res in result_generator: + res: RequestOutput + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + if first_iteration: + # Send first response for each request.n (index) with the role + role = self.get_chat_request_role(request) + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role=role), + logprobs=None, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the last message + if request.echo: + last_msg_content = "" + if request.messages and isinstance( + request.messages, + list) and request.messages[-1].get( + "content") and request.messages[-1].get( + "role") == role: + last_msg_content = request.messages[-1]["content"] + + if last_msg_content: + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + content=last_msg_content), + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + logprobs=None, + model=model_name) + data = chunk.model_dump_json( + exclude_unset=True) + yield f"data: {data}\n\n" + first_iteration = False + + for output in res.outputs: + i = output.index + + if finish_reason_sent[i]: + continue + + delta_token_ids = output.token_ids[previous_num_tokens[i]:] + top_logprobs = output.logprobs[ + previous_num_tokens[i]:] if output.logprobs else None + + if request.logprobs: + logprobs = self._create_logprobs( + token_ids=delta_token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=len(previous_texts[i]), + ) + else: + logprobs = None + + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + if output.finish_reason is None: + # Send token-by-token response for each request.n + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + logprobs=logprobs, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + else: + # Send the finish response for each request.n only once + prompt_tokens = len(res.prompt_token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=previous_num_tokens[i], + total_tokens=prompt_tokens + + previous_num_tokens[i], + ) + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + logprobs=logprobs, + finish_reason=output.finish_reason) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if final_usage is not None: + chunk.usage = final_usage + data = chunk.model_dump_json(exclude_unset=True, + exclude_none=True) + yield f"data: {data}\n\n" + finish_reason_sent[i] = True + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" # Send the final done message after all response.n are finished yield "data: [DONE]\n\n" diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 99a10196b5f73..a8244fd150753 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -26,107 +26,6 @@ [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] -async def completion_stream_generator( - request: CompletionRequest, - raw_request: Request, - on_abort, - result_generator: AsyncIterator[Tuple[int, RequestOutput]], - create_logprobs_fn: TypeCreateLogProbsFn, - request_id: str, - created_time: int, - model_name: str, - num_prompts: int, -) -> AsyncGenerator[str, None]: - previous_texts = [""] * request.n * num_prompts - previous_num_tokens = [0] * request.n * num_prompts - has_echoed = [False] * request.n * num_prompts - - async for prompt_idx, res in result_generator: - - # Abort the request if the client disconnects. - if await raw_request.is_disconnected(): - await on_abort(f"{request_id}-{prompt_idx}") - raise StopAsyncIteration() - - for output in res.outputs: - i = output.index + prompt_idx * request.n - # TODO(simon): optimize the performance by avoiding full text O(n^2) sending. - - if request.echo and request.max_tokens == 0: - # only return the prompt - delta_text = res.prompt - delta_token_ids = res.prompt_token_ids - top_logprobs = res.prompt_logprobs - has_echoed[i] = True - elif request.echo and request.max_tokens > 0 and not has_echoed[i]: - # echo the prompt and first token - delta_text = res.prompt + output.text - delta_token_ids = res.prompt_token_ids + output.token_ids - top_logprobs = res.prompt_logprobs + (output.logprobs or []) - has_echoed[i] = True - else: - # return just the delta - delta_text = output.text[len(previous_texts[i]):] - delta_token_ids = output.token_ids[previous_num_tokens[i]:] - top_logprobs = output.logprobs[ - previous_num_tokens[i]:] if output.logprobs else None - - if request.logprobs is not None: - assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested" - logprobs = create_logprobs_fn( - token_ids=delta_token_ids, - top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - initial_text_offset=len(previous_texts[i]), - ) - else: - logprobs = None - - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) - finish_reason = output.finish_reason - response_json = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[ - CompletionResponseStreamChoice( - index=i, - text=delta_text, - logprobs=logprobs, - finish_reason=finish_reason, - ) - ]).model_dump_json() - yield f"data: {response_json}\n\n" - - if output.finish_reason is not None: # return final usage - logprobs = LogProbs() if request.logprobs is not None else None - prompt_tokens = len(res.prompt_token_ids) - completion_tokens = len(output.token_ids) - final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - response_json = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[ - CompletionResponseStreamChoice( - index=i, - text="", - logprobs=logprobs, - finish_reason=output.finish_reason, - ) - ], - usage=final_usage, - ).model_dump_json() - yield f"data: {response_json}\n\n" - - yield "data: [DONE]\n\n" - - def parse_prompt_format(prompt) -> Tuple[bool, list]: # get the prompt, openai supports the following # "a string, array of strings, array of tokens, or array of token arrays." @@ -151,73 +50,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]: return prompt_is_tokens, prompts -def request_output_to_completion_response( - final_res_batch: List[RequestOutput], - request: CompletionRequest, - create_logprobs_fn: TypeCreateLogProbsFn, - request_id: str, - created_time: int, - model_name: str, -) -> CompletionResponse: - choices = [] - num_prompt_tokens = 0 - num_generated_tokens = 0 - for final_res in final_res_batch: - assert final_res is not None - prompt_token_ids = final_res.prompt_token_ids - prompt_logprobs = final_res.prompt_logprobs - prompt_text = final_res.prompt - - for output in final_res.outputs: - if request.echo and request.max_tokens == 0: - token_ids = prompt_token_ids - top_logprobs = prompt_logprobs - output_text = prompt_text - elif request.echo and request.max_tokens > 0: - token_ids = prompt_token_ids + output.token_ids - top_logprobs = prompt_logprobs + output.logprobs - output_text = prompt_text + output.text - else: - token_ids = output.token_ids - top_logprobs = output.logprobs - output_text = output.text - - if request.logprobs is not None: - logprobs = create_logprobs_fn( - token_ids=token_ids, - top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - ) - else: - logprobs = None - - choice_data = CompletionResponseChoice( - index=len(choices), - text=output_text, - logprobs=logprobs, - finish_reason=output.finish_reason, - ) - choices.append(choice_data) - - num_prompt_tokens += len(prompt_token_ids) - num_generated_tokens += sum( - len(output.token_ids) for output in final_res.outputs) - - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) - - return CompletionResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices, - usage=usage, - ) - - def merge_async_iterators(*iterators): """Merge multiple asynchronous iterators into a single iterator. @@ -230,8 +62,11 @@ def merge_async_iterators(*iterators): finished = [False] * len(iterators) async def producer(i, iterator): - async for item in iterator: - await queue.put((i, item)) + try: + async for item in iterator: + await queue.put((i, item)) + except Exception as e: + await queue.put(e) finished[i] = True _tasks = [ @@ -242,6 +77,8 @@ async def producer(i, iterator): async def consumer(): while not all(finished) or not queue.empty(): item = await queue.get() + if isinstance(item, Exception): + raise item yield item await asyncio.gather(*_tasks) @@ -312,6 +149,7 @@ async def create_completion(self, request: CompletionRequest, prompt_token_ids=input_ids, lora_request=lora_request)) except ValueError as e: + # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) result_generator: AsyncIterator[Tuple[ @@ -325,27 +163,28 @@ async def create_completion(self, request: CompletionRequest, # Streaming response if stream: - return completion_stream_generator(request, - raw_request, - self.engine.abort, - result_generator, - self._create_logprobs, - request_id, - created_time, - model_name, - num_prompts=len(prompts)) + return self.completion_stream_generator(request, + raw_request, + result_generator, + request_id, + created_time, + model_name, + num_prompts=len(prompts)) # Non-streaming response final_res_batch: RequestOutput = [None] * len(prompts) - async for i, res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") - return self.create_error_response("Client disconnected") - final_res_batch[i] = res - response = request_output_to_completion_response( - final_res_batch, request, self._create_logprobs, request_id, - created_time, model_name) + try: + async for i, res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(f"{request_id}-{i}") + return self.create_error_response("Client disconnected") + final_res_batch[i] = res + response = self.request_output_to_completion_response( + final_res_batch, request, request_id, created_time, model_name) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) # When user requests streaming but we don't stream, we still need to # return a streaming response with a single event. @@ -359,3 +198,179 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: return fake_stream_generator() return response + + async def completion_stream_generator( + self, + request: CompletionRequest, + raw_request: Request, + result_generator: AsyncIterator[Tuple[int, RequestOutput]], + request_id: str, + created_time: int, + model_name: str, + num_prompts: int, + ) -> AsyncGenerator[str, None]: + previous_texts = [""] * request.n * num_prompts + previous_num_tokens = [0] * request.n * num_prompts + has_echoed = [False] * request.n * num_prompts + + try: + async for prompt_idx, res in result_generator: + + # Abort the request if the client disconnects. + if await raw_request.is_disconnected(): + await self.engine.abort(f"{request_id}-{prompt_idx}") + raise StopAsyncIteration() + + for output in res.outputs: + i = output.index + prompt_idx * request.n + # TODO(simon): optimize the performance by avoiding full text O(n^2) sending. + + if request.echo and request.max_tokens == 0: + # only return the prompt + delta_text = res.prompt + delta_token_ids = res.prompt_token_ids + top_logprobs = res.prompt_logprobs + has_echoed[i] = True + elif request.echo and request.max_tokens > 0 and not has_echoed[ + i]: + # echo the prompt and first token + delta_text = res.prompt + output.text + delta_token_ids = res.prompt_token_ids + output.token_ids + top_logprobs = res.prompt_logprobs + (output.logprobs + or []) + has_echoed[i] = True + else: + # return just the delta + delta_text = output.text[len(previous_texts[i]):] + delta_token_ids = output.token_ids[ + previous_num_tokens[i]:] + top_logprobs = output.logprobs[previous_num_tokens[ + i]:] if output.logprobs else None + + if request.logprobs is not None: + assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested" + logprobs = self._create_logprobs( + token_ids=delta_token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=len(previous_texts[i]), + ) + else: + logprobs = None + + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + finish_reason = output.finish_reason + response_json = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=logprobs, + finish_reason=finish_reason, + ) + ]).model_dump_json() + yield f"data: {response_json}\n\n" + + if output.finish_reason is not None: # return final usage + logprobs = LogProbs( + ) if request.logprobs is not None else None + prompt_tokens = len(res.prompt_token_ids) + completion_tokens = len(output.token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + response_json = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice( + index=i, + text="", + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + ], + usage=final_usage, + ).model_dump_json() + yield f"data: {response_json}\n\n" + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + data = self.create_streaming_error_response(str(e)) + print("yield", f"data: {data}\n\n") + yield f"data: {data}\n\n" + + print("yield", "data: [DONE]\n\n") + yield "data: [DONE]\n\n" + + def request_output_to_completion_response( + self, + final_res_batch: List[RequestOutput], + request: CompletionRequest, + request_id: str, + created_time: int, + model_name: str, + ) -> CompletionResponse: + choices = [] + num_prompt_tokens = 0 + num_generated_tokens = 0 + for final_res in final_res_batch: + assert final_res is not None + prompt_token_ids = final_res.prompt_token_ids + prompt_logprobs = final_res.prompt_logprobs + prompt_text = final_res.prompt + + for output in final_res.outputs: + if request.echo and request.max_tokens == 0: + token_ids = prompt_token_ids + top_logprobs = prompt_logprobs + output_text = prompt_text + elif request.echo and request.max_tokens > 0: + token_ids = prompt_token_ids + output.token_ids + top_logprobs = prompt_logprobs + output.logprobs + output_text = prompt_text + output.text + else: + token_ids = output.token_ids + top_logprobs = output.logprobs + output_text = output.text + + if request.logprobs is not None: + logprobs = self._create_logprobs( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + ) + else: + logprobs = None + + choice_data = CompletionResponseChoice( + index=len(choices), + text=output_text, + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + num_prompt_tokens += len(prompt_token_ids) + num_generated_tokens += sum( + len(output.token_ids) for output in final_res.outputs) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + + return CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 09945471e9af0..230d13d97dbba 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,4 +1,5 @@ import asyncio +import json from dataclasses import dataclass from http import HTTPStatus from typing import Dict, List, Optional, Union @@ -11,6 +12,7 @@ ModelCard, ModelList, ModelPermission) from vllm.lora.request import LoRARequest +from vllm.sequence import Logprob logger = init_logger(__name__) @@ -83,7 +85,7 @@ async def show_available_models(self) -> ModelList: def _create_logprobs( self, token_ids: List[int], - top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None, + top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None, num_output_top_logprobs: Optional[int] = None, initial_text_offset: int = 0, ) -> LogProbs: @@ -95,10 +97,10 @@ def _create_logprobs( for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] if step_top_logprobs is not None: - token_logprob = step_top_logprobs[token_id] + token_logprob = step_top_logprobs[token_id].logprob else: token_logprob = None - token = self.tokenizer.convert_ids_to_tokens(token_id) + token = step_top_logprobs[token_id].decoded_token logprobs.tokens.append(token) logprobs.token_logprobs.append(token_logprob) if len(logprobs.text_offset) == 0: @@ -110,7 +112,7 @@ def _create_logprobs( if num_output_top_logprobs: logprobs.top_logprobs.append({ - self.tokenizer.convert_ids_to_tokens(i): p + p.decoded_token: p.logprob for i, p in step_top_logprobs.items() } if step_top_logprobs else None) return logprobs @@ -124,6 +126,19 @@ def create_error_response( type=err_type, code=status_code.value) + def create_streaming_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: + json_str = json.dumps({ + "error": + self.create_error_response(message=message, + err_type=err_type, + status_code=status_code).model_dump() + }) + return json_str + async def _check_model(self, request) -> Optional[ErrorResponse]: if request.model == self.served_model: return diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 71655b216fb3d..b48dde0318d09 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -8,8 +8,9 @@ tensor_model_parallel_gather) from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, - SequenceData, SequenceGroupOutput, SequenceOutput) +from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, + SamplerOutput, SequenceData, SequenceGroupOutput, + SequenceOutput) from vllm.utils import is_neuron @@ -528,7 +529,10 @@ def _get_logprobs( prompt_logprobs_dict.update( zip(top_token_ids[sample_idx, :num_logprobs].tolist(), top_logprobs[sample_idx, :num_logprobs].tolist())) - group_prompt_logprobs.append(prompt_logprobs_dict) + group_prompt_logprobs.append({ + token_id: Logprob(logprob) + for token_id, logprob in prompt_logprobs_dict.items() + }) sample_idx += 1 query_result_idx += 1 result_prompt_logprobs.append(group_prompt_logprobs) @@ -553,7 +557,10 @@ def _get_logprobs( parent_id, :num_logprobs].tolist(), top_logprobs[sample_idx + parent_id, :num_logprobs].tolist())) - group_sample_logprobs.append(sample_logprobs_dict) + group_sample_logprobs.append({ + token_id: Logprob(logprob) + for token_id, logprob in sample_logprobs_dict.items() + }) result_sample_logprobs.append(group_sample_logprobs) sample_idx += len(seq_ids) diff --git a/vllm/sequence.py b/vllm/sequence.py index 04a9a90a68bcc..a110ab6b748f8 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -8,8 +8,16 @@ from vllm.sampling_params import SamplingParams from vllm.lora.request import LoRARequest -PromptLogprobs = List[Optional[Dict[int, float]]] -SampleLogprobs = List[Dict[int, float]] + +@dataclass +class Logprob: + """Infos for supporting OpenAI compatible logprobs.""" + logprob: float + decoded_token: Optional[str] = None + + +PromptLogprobs = List[Optional[Dict[int, Logprob]]] +SampleLogprobs = List[Dict[int, Logprob]] class SequenceStatus(enum.Enum): @@ -196,12 +204,12 @@ def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: def append_token_id( self, token_id: int, - logprobs: Dict[int, float], + logprobs: Dict[int, Logprob], ) -> None: assert token_id in logprobs self._append_tokens_to_blocks([token_id]) self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id]) + self.data.append_token_id(token_id, logprobs[token_id].logprob) def get_len(self) -> int: return self.data.get_len() @@ -456,7 +464,7 @@ def __init__( self, parent_seq_id: int, output_token: int, - logprobs: Dict[int, float], + logprobs: Dict[int, Logprob], ) -> None: self.parent_seq_id = parent_seq_id self.output_token = output_token @@ -470,9 +478,10 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: if not isinstance(other, SequenceOutput): raise NotImplementedError() - return (self.parent_seq_id == other.parent_seq_id - and self.output_token == other.output_token - and self.logprobs == other.logprobs) + equal = (self.parent_seq_id == other.parent_seq_id + and self.output_token == other.output_token) + log_probs_equal = other.logprobs == self.logprobs + return equal and log_probs_equal class SequenceGroupOutput: diff --git a/vllm/worker/spec_decode/multi_step_worker.py b/vllm/worker/spec_decode/multi_step_worker.py index 591d1b1300c88..ab3e28389a04c 100644 --- a/vllm/worker/spec_decode/multi_step_worker.py +++ b/vllm/worker/spec_decode/multi_step_worker.py @@ -77,7 +77,7 @@ def _append_new_tokens( token_id = seq_output.output_token token_logprob = seq_output.logprobs[token_id] - seq.append_token_id(token_id, token_logprob) + seq.append_token_id(token_id, token_logprob.logprob) def _shallow_copy_inputs( self, seq_group_metadata_list: List[SequenceGroupMetadata]