From 16a49f5b11807d7f1d1467809fdd7680b51da7e7 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 27 Feb 2024 13:43:24 -0800 Subject: [PATCH 1/7] Push logprob generation to LLMEngine Co-authored-by: Avnish Narayan --- tests/samplers/test_logprobs.py | 44 ++++++++++++++++--- tests/worker/spec_decode/utils.py | 12 ++--- vllm/config.py | 2 + vllm/engine/arg_utils.py | 10 ++++- vllm/engine/llm_engine.py | 44 ++++++++++++++++++- vllm/entrypoints/openai/serving_completion.py | 2 +- vllm/entrypoints/openai/serving_engine.py | 9 ++-- vllm/model_executor/layers/sampler.py | 15 +++++-- vllm/sequence.py | 26 ++++++++--- vllm/worker/spec_decode/multi_step_worker.py | 2 +- 10 files changed, 135 insertions(+), 31 deletions(-) diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 0ea3704462fcb..911676b05cce3 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -1,5 +1,6 @@ import pytest import torch +from tests.conftest import VllmRunner from vllm import SamplingParams @@ -16,6 +17,7 @@ def test_get_prompt_logprobs( example_prompts, ): max_tokens = 5 + num_top_logprobs = 6 hf_model = hf_runner(model, dtype=dtype) hf_logprobs = hf_model.generate_greedy_logprobs( example_prompts, @@ -23,19 +25,34 @@ def test_get_prompt_logprobs( ) del hf_model - vllm_model = vllm_runner(model, dtype=dtype) + vllm_model = vllm_runner(model, + dtype=dtype, + max_log_probs=num_top_logprobs) vllm_sampling_params = SamplingParams(max_tokens=max_tokens, - logprobs=5, + logprobs=num_top_logprobs, prompt_logprobs=5, temperature=0.0) vllm_results = vllm_model.model.generate( example_prompts, sampling_params=vllm_sampling_params) - del vllm_model # Test whether logprobs are included in the results. for result in vllm_results: assert result.prompt_logprobs is not None assert result.outputs[0].logprobs is not None + assert len(result.outputs[0].logprobs) == max_tokens + for logprobs in result.outputs[0].logprobs: + assert len(logprobs) == num_top_logprobs + output_text = result.outputs[0].text + output_string_from_most_likely_tokens = [] + for top_logprobs in result.outputs[0].logprobs: + top_logprob = next(iter(top_logprobs.values())) + output_string_from_most_likely_tokens.append( + top_logprob.decoded_token) + output_string_from_most_likely_tokens = "".join( + output_string_from_most_likely_tokens) + assert output_text == output_string_from_most_likely_tokens, ( + "The output text from the top logprob for each token position " + "should be the same as the output text in the result.") # Test whether prompt logprobs are consistent with HF for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs): @@ -43,14 +60,29 @@ def test_get_prompt_logprobs( vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:] for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs): for token_id, logprob in vllm_prompt_logprob_dict.items(): - torch.testing.assert_close(logprob, + torch.testing.assert_close(logprob.logprob, hf_logprob[0][i][token_id].item(), atol=1e-2, rtol=1e-2) vllm_sample_logprobs = vllm_result.outputs[0].logprobs - for i, vllm_sample_logprob_dict in enumerate(vllm_sample_logprobs): - for token_id, logprob in vllm_sample_logprob_dict.items(): + for i, top_logprobs in enumerate(vllm_sample_logprobs): + for token_id, sample_logprob in top_logprobs.items(): + logprob = sample_logprob.logprob torch.testing.assert_close(logprob, hf_logprob[i][-1][token_id].item(), atol=1e-2, rtol=1e-2) + assert isinstance(sample_logprob.decoded_token, str), \ + ("The token should be decoded by the time it is returned " + " to the user.") + + +def test_max_log_probs(): + runner = VllmRunner("facebook/opt-125m", max_log_probs=1) + vllm_sampling_params = SamplingParams(logprobs=1) + # should pass + runner.generate(["Hello world"], sampling_params=vllm_sampling_params) + + bad_sampling_params = SamplingParams(logprobs=2) + with pytest.raises(ValueError): + runner.generate(["Hello world"], sampling_params=bad_sampling_params) diff --git a/tests/worker/spec_decode/utils.py b/tests/worker/spec_decode/utils.py index 8d74509fea488..fa8767cf898aa 100644 --- a/tests/worker/spec_decode/utils.py +++ b/tests/worker/spec_decode/utils.py @@ -4,7 +4,7 @@ from vllm.worker.worker import Worker from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.engine.arg_utils import EngineArgs -from vllm.sequence import SequenceGroupMetadata, SequenceData +from vllm.sequence import Logprob, SequenceGroupMetadata, SequenceData from vllm.sampling_params import SamplingParams from vllm.worker.cache_engine import CacheEngine from vllm.model_executor.utils import set_random_seed @@ -166,13 +166,15 @@ def create_seq_group_metadata_from_prompts( def assert_logprobs_dict_allclose( - actual_logprobs: List[Dict[int, float]], - expected_logprobs: List[Dict[int, float]]) -> None: + actual_logprobs: List[Dict[int, Logprob]], + expected_logprobs: List[Dict[int, Logprob]]) -> None: for single_step_actual_logprobs, single_step_expected_logprobs in zip( actual_logprobs, expected_logprobs): assert set(single_step_actual_logprobs.keys()) == set( single_step_expected_logprobs.keys()) for token_id in single_step_actual_logprobs: - actual = torch.tensor(single_step_actual_logprobs[token_id]) - expected = torch.tensor(single_step_expected_logprobs[token_id]) + actual = torch.tensor( + single_step_actual_logprobs[token_id].logprob) + expected = torch.tensor( + single_step_expected_logprobs[token_id].logprob) assert torch.allclose(actual, expected) diff --git a/vllm/config.py b/vllm/config.py index bd0dc89b585f7..82dd5cefe93b8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -79,6 +79,7 @@ def __init__( quantization: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, + max_log_probs: int = 5, ) -> None: self.model = model self.tokenizer = tokenizer @@ -93,6 +94,7 @@ def __init__( self.quantization = quantization self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture + self.max_log_probs = max_log_probs if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": # download model from ModelScope hub, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a4efd171b871d..cf1db55b6e16f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -30,6 +30,7 @@ class EngineArgs: max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 max_paddings: int = 256 + max_log_probs: int = 5 disable_log_stats: bool = False revision: Optional[str] = None code_revision: Optional[str] = None @@ -201,6 +202,12 @@ def add_cli_args( type=int, default=EngineArgs.max_paddings, help='maximum number of paddings in a batch') + parser.add_argument( + '--max-log-probs', + type=int, + default=EngineArgs.max_log_probs, + help=('max number of log probs to return logprobs is specified in' + ' SamplingParams')) parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') @@ -291,7 +298,8 @@ def create_engine_configs( self.trust_remote_code, self.download_dir, self.load_format, self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, - self.enforce_eager, self.max_context_len_to_capture) + self.enforce_eager, self.max_context_len_to_capture, + self.max_log_probs) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f5b2145c22d6f..f690935461ea1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -16,7 +16,7 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, +from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, TokenizerGroup) @@ -449,6 +449,13 @@ def add_request( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + max_log_probs = self.get_model_config().max_log_probs + if (sampling_params.logprobs + and sampling_params.logprobs > max_log_probs) or ( + sampling_params.prompt_logprobs + and sampling_params.prompt_logprobs > max_log_probs): + raise ValueError(f"Cannot request more than " + f"{max_log_probs} logprobs.") if arrival_time is None: arrival_time = time.monotonic() prompt_token_ids = self.encode_request( @@ -460,6 +467,8 @@ def add_request( # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) + assert prompt + assert prompt_token_ids seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, lora_request) @@ -563,6 +572,13 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Process prompt logprobs prompt_logprobs = outputs.prompt_logprobs if prompt_logprobs is not None: + # We can pick any sequence for the prompt. + seq = next(iter(seq_group.seqs_dict.values())) + all_token_ids = seq.get_token_ids() + for i, prompt_logprobs_for_token in enumerate(prompt_logprobs): + self._decode_logprobs(seq, seq_group.sampling_params, + prompt_logprobs_for_token, + all_token_ids[:i]) seq_group.prompt_logprobs = prompt_logprobs # Process samples @@ -909,12 +925,36 @@ def _get_stats(self, time_e2e_requests=time_e2e_requests, ) + def _decode_logprobs(self, seq: Sequence, prms: SamplingParams, + logprobs: Dict[int, Logprob], + all_input_ids: List[int]) -> None: + if not logprobs: + return + for token_id, sample_logprob in logprobs.items(): + if (sample_logprob.decoded_token is None and token_id != -1): + all_input_ids_with_logprob = all_input_ids[:-1] + [token_id] + _, new_text, prefix_offset, read_offset = detokenize_incrementally( + self.get_tokenizer_for_seq(seq), + all_input_ids=all_input_ids_with_logprob, + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms. + spaces_between_special_tokens, + ) + sample_logprob.decoded_token = new_text + def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" + all_input_ids = seq.get_token_ids() + self._decode_logprobs(seq, prms, seq.output_logprobs[-1], + all_input_ids) + (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( self.get_tokenizer_for_seq(seq), - all_input_ids=seq.get_token_ids(), + all_input_ids=all_input_ids, prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, read_offset=seq.read_offset, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 610f53549da48..164b8cce0cd70 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -297,7 +297,7 @@ async def create_completion(self, request: CompletionRequest, request, prompt=prompt) generators.append( - self.engine.generate(None, + self.engine.generate(prompt, sampling_params, f"{request_id}-{i}", prompt_token_ids=input_ids, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 09945471e9af0..3af4838195f08 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -11,6 +11,7 @@ ModelCard, ModelList, ModelPermission) from vllm.lora.request import LoRARequest +from vllm.sequence import Logprob logger = init_logger(__name__) @@ -83,7 +84,7 @@ async def show_available_models(self) -> ModelList: def _create_logprobs( self, token_ids: List[int], - top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None, + top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None, num_output_top_logprobs: Optional[int] = None, initial_text_offset: int = 0, ) -> LogProbs: @@ -95,10 +96,10 @@ def _create_logprobs( for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] if step_top_logprobs is not None: - token_logprob = step_top_logprobs[token_id] + token_logprob = step_top_logprobs[token_id].logprob else: token_logprob = None - token = self.tokenizer.convert_ids_to_tokens(token_id) + token = step_top_logprobs[token_id].decoded_token logprobs.tokens.append(token) logprobs.token_logprobs.append(token_logprob) if len(logprobs.text_offset) == 0: @@ -110,7 +111,7 @@ def _create_logprobs( if num_output_top_logprobs: logprobs.top_logprobs.append({ - self.tokenizer.convert_ids_to_tokens(i): p + p.decoded_token: p.logprob for i, p in step_top_logprobs.items() } if step_top_logprobs else None) return logprobs diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 884d84387e505..65975bf2f63ff 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -8,8 +8,9 @@ tensor_model_parallel_gather) from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, - SequenceData, SequenceGroupOutput, SequenceOutput) +from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, + SamplerOutput, SequenceData, SequenceGroupOutput, + SequenceOutput) class Sampler(nn.Module): @@ -520,7 +521,10 @@ def _get_logprobs( prompt_logprobs_dict.update( zip(top_token_ids[sample_idx, :num_logprobs].tolist(), top_logprobs[sample_idx, :num_logprobs].tolist())) - group_prompt_logprobs.append(prompt_logprobs_dict) + group_prompt_logprobs.append({ + token_id: Logprob(logprob) + for token_id, logprob in prompt_logprobs_dict.items() + }) sample_idx += 1 query_result_idx += 1 result_prompt_logprobs.append(group_prompt_logprobs) @@ -545,7 +549,10 @@ def _get_logprobs( parent_id, :num_logprobs].tolist(), top_logprobs[sample_idx + parent_id, :num_logprobs].tolist())) - group_sample_logprobs.append(sample_logprobs_dict) + group_sample_logprobs.append({ + token_id: Logprob(logprob) + for token_id, logprob in sample_logprobs_dict.items() + }) result_sample_logprobs.append(group_sample_logprobs) sample_idx += len(seq_ids) diff --git a/vllm/sequence.py b/vllm/sequence.py index 040e9756e15c6..ee0490b11dfa4 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -9,8 +9,16 @@ from vllm.sampling_params import SamplingParams from vllm.lora.request import LoRARequest -PromptLogprobs = List[Optional[Dict[int, float]]] -SampleLogprobs = List[Dict[int, float]] + +@dataclass +class Logprob: + """Infos for supporting OpenAI compatible logprobs.""" + logprob: float + decoded_token: Optional[str] = None + + +PromptLogprobs = List[Optional[Dict[int, Logprob]]] +SampleLogprobs = List[Dict[int, Logprob]] class SequenceStatus(enum.Enum): @@ -187,12 +195,12 @@ def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: def append_token_id( self, token_id: int, - logprobs: Dict[int, float], + logprobs: Dict[int, Logprob], ) -> None: assert token_id in logprobs self._append_tokens_to_blocks([token_id]) self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id]) + self.data.append_token_id(token_id, logprobs[token_id].logprob) def get_len(self) -> int: return self.data.get_len() @@ -465,9 +473,13 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: if not isinstance(other, SequenceOutput): raise NotImplementedError() - return (self.parent_seq_id == other.parent_seq_id - and self.output_token == other.output_token - and self.logprobs == other.logprobs) + equal = (self.parent_seq_id == other.parent_seq_id + and self.output_token == other.output_token) + log_probs_equal = ((len(other.logprobs) == len(self.logprobs)) + and all(other_logprob == self_logprob + for other_logprob, self_logprob in zip( + other.logprobs, self.logprobs))) + return equal and log_probs_equal class SequenceGroupOutput: diff --git a/vllm/worker/spec_decode/multi_step_worker.py b/vllm/worker/spec_decode/multi_step_worker.py index 591d1b1300c88..ab3e28389a04c 100644 --- a/vllm/worker/spec_decode/multi_step_worker.py +++ b/vllm/worker/spec_decode/multi_step_worker.py @@ -77,7 +77,7 @@ def _append_new_tokens( token_id = seq_output.output_token token_logprob = seq_output.logprobs[token_id] - seq.append_token_id(token_id, token_logprob) + seq.append_token_id(token_id, token_logprob.logprob) def _shallow_copy_inputs( self, seq_group_metadata_list: List[SequenceGroupMetadata] From 3b59c014a1594b8801d9371208fe443a0db54baa Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 28 Feb 2024 17:51:57 -0800 Subject: [PATCH 2/7] Fix error propagation --- tests/entrypoints/test_openai_server.py | 61 ++- vllm/engine/async_llm_engine.py | 29 +- vllm/entrypoints/openai/serving_chat.py | 236 ++++++----- vllm/entrypoints/openai/serving_completion.py | 391 +++++++++--------- vllm/entrypoints/openai/serving_engine.py | 14 + vllm/sequence.py | 7 +- 6 files changed, 429 insertions(+), 309 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 72e2374899793..652abcefdcea8 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -161,14 +161,14 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, messages=messages, max_tokens=10, logprobs=True, - top_logprobs=10) + top_logprobs=5) assert chat_completion.id is not None assert chat_completion.choices is not None and len( chat_completion.choices) == 1 assert chat_completion.choices[0].message is not None assert chat_completion.choices[0].logprobs is not None assert chat_completion.choices[0].logprobs.top_logprobs is not None - assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 10 + assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5 message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 10 assert message.role == "assistant" @@ -177,7 +177,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, # test multi-turn dialogue messages.append({"role": "user", "content": "express your result in json"}) chat_completion = await client.chat.completions.create( - model=MODEL_NAME, + model=model_name, messages=messages, max_tokens=10, ) @@ -185,6 +185,61 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + # Default max_logprobs is 5, so this should raise an error + with pytest.raises((openai.BadRequestError, openai.APIError)): + stream = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=10, + stream=True) + async for chunk in stream: + ... + + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=10, + stream=False) + + with pytest.raises((openai.BadRequestError, openai.APIError)): + stream = await client.completions.create(model=model_name, + prompt="Test", + max_tokens=10, + logprobs=10, + stream=True) + async for chunk in stream: + ... + + with pytest.raises(openai.BadRequestError): + await client.completions.create(model=model_name, + prompt="Test", + max_tokens=10, + logprobs=10, + stream=False) + + # the server should still work afterwards + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + stream=False) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 + + @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7cba654602779..5a960ef60d704 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -47,7 +47,7 @@ def __init__(self, request_id: str) -> None: self._queue = asyncio.Queue() self._finished = False - def put(self, item: RequestOutput) -> None: + def put(self, item: Union[RequestOutput, Exception]) -> None: if self._finished: return self._queue.put_nowait(item) @@ -110,6 +110,17 @@ def process_request_output(self, logger.info(f"Finished request {request_id}.") self.abort_request(request_id) + def process_exception(self, + request_id: str, + exception: Exception, + *, + verbose: bool = False) -> None: + """Propagate an exception from the engine.""" + self._request_streams[request_id].put(exception) + if verbose: + logger.info(f"Finished request {request_id}.") + self.abort_request(request_id) + def add_request(self, request_id: str, **engine_add_request_kwargs) -> AsyncStream: """Add a request to be sent to the engine on the next background @@ -376,10 +387,18 @@ async def engine_step(self) -> bool: for new_request in new_requests: # Add the request into the vLLM engine's waiting queue. # TODO: Maybe add add_request_batch to reduce Ray overhead - if self.engine_use_ray: - await self.engine.add_request.remote(**new_request) - else: - await self.engine.add_request_async(**new_request) + try: + if self.engine_use_ray: + await self.engine.add_request.remote(**new_request) + else: + await self.engine.add_request_async(**new_request) + except ValueError as e: + # TODO: use a vLLM specific error for failed validation + self._request_tracker.process_exception( + new_request["request_id"], + e, + verbose=self.log_requests, + ) if finished_requests: await self._engine_abort(finished_requests) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5635ac6c9e106..7b64f333d6fb8 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -73,8 +73,12 @@ async def create_chat_completion( return self.chat_completion_stream_generator( request, result_generator, request_id) else: - return await self.chat_completion_full_generator( - request, raw_request, result_generator, request_id) + try: + return await self.chat_completion_full_generator( + request, raw_request, result_generator, request_id) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) def get_chat_request_role(self, request: ChatCompletionRequest) -> str: if request.add_generation_prompt: @@ -90,117 +94,133 @@ async def chat_completion_stream_generator( model_name = request.model created_time = int(time.monotonic()) chunk_object_type = "chat.completion.chunk" - - # Send first response for each request.n (index) with the role - role = self.get_chat_request_role(request) - for i in range(request.n): - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(role=role), - logprobs=None, - finish_reason=None) - chunk = ChatCompletionStreamResponse(id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - - # Send response to echo the input portion of the last message - if request.echo: - last_msg_content = "" - if request.messages and isinstance( - request.messages, list) and request.messages[-1].get( - "content") and request.messages[-1].get( - "role") == role: - last_msg_content = request.messages[-1]["content"] - - if last_msg_content: - for i in range(request.n): - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(content=last_msg_content), - finish_reason=None) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - logprobs=None, - model=model_name) - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" + first_iteration = True # Send response for each token for each request.n (index) previous_texts = [""] * request.n previous_num_tokens = [0] * request.n finish_reason_sent = [False] * request.n - async for res in result_generator: - res: RequestOutput - for output in res.outputs: - i = output.index - - if finish_reason_sent[i]: - continue - - delta_token_ids = output.token_ids[previous_num_tokens[i]:] - top_logprobs = output.logprobs[ - previous_num_tokens[i]:] if output.logprobs else None - - if request.logprobs: - logprobs = self._create_logprobs( - token_ids=delta_token_ids, - top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - initial_text_offset=len(previous_texts[i]), - ) - else: - logprobs = None - - delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) - if output.finish_reason is None: - # Send token-by-token response for each request.n - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(content=delta_text), - logprobs=logprobs, - finish_reason=None) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - else: - # Send the finish response for each request.n only once - prompt_tokens = len(res.prompt_token_ids) - final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=previous_num_tokens[i], - total_tokens=prompt_tokens + previous_num_tokens[i], - ) - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(content=delta_text), - logprobs=logprobs, - finish_reason=output.finish_reason) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - if final_usage is not None: - chunk.usage = final_usage - data = chunk.model_dump_json(exclude_unset=True, - exclude_none=True) - yield f"data: {data}\n\n" - finish_reason_sent[i] = True + try: + async for res in result_generator: + res: RequestOutput + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + if first_iteration: + # Send first response for each request.n (index) with the role + role = self.get_chat_request_role(request) + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role=role), + logprobs=None, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the last message + if request.echo: + last_msg_content = "" + if request.messages and isinstance( + request.messages, + list) and request.messages[-1].get( + "content") and request.messages[-1].get( + "role") == role: + last_msg_content = request.messages[-1]["content"] + + if last_msg_content: + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + content=last_msg_content), + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + logprobs=None, + model=model_name) + data = chunk.model_dump_json( + exclude_unset=True) + yield f"data: {data}\n\n" + first_iteration = False + + for output in res.outputs: + i = output.index + + if finish_reason_sent[i]: + continue + + delta_token_ids = output.token_ids[previous_num_tokens[i]:] + top_logprobs = output.logprobs[ + previous_num_tokens[i]:] if output.logprobs else None + + if request.logprobs: + logprobs = self._create_logprobs( + token_ids=delta_token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=len(previous_texts[i]), + ) + else: + logprobs = None + + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + if output.finish_reason is None: + # Send token-by-token response for each request.n + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + logprobs=logprobs, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + else: + # Send the finish response for each request.n only once + prompt_tokens = len(res.prompt_token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=previous_num_tokens[i], + total_tokens=prompt_tokens + + previous_num_tokens[i], + ) + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + logprobs=logprobs, + finish_reason=output.finish_reason) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if final_usage is not None: + chunk.usage = final_usage + data = chunk.model_dump_json(exclude_unset=True, + exclude_none=True) + yield f"data: {data}\n\n" + finish_reason_sent[i] = True + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" # Send the final done message after all response.n are finished yield "data: [DONE]\n\n" diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 164b8cce0cd70..597de327eef06 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -25,107 +25,6 @@ [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] -async def completion_stream_generator( - request: CompletionRequest, - raw_request: Request, - on_abort, - result_generator: AsyncIterator[Tuple[int, RequestOutput]], - create_logprobs_fn: TypeCreateLogProbsFn, - request_id: str, - created_time: int, - model_name: str, - num_prompts: int, -) -> AsyncGenerator[str, None]: - previous_texts = [""] * request.n * num_prompts - previous_num_tokens = [0] * request.n * num_prompts - has_echoed = [False] * request.n * num_prompts - - async for prompt_idx, res in result_generator: - - # Abort the request if the client disconnects. - if await raw_request.is_disconnected(): - await on_abort(f"{request_id}-{prompt_idx}") - raise StopAsyncIteration() - - for output in res.outputs: - i = output.index + prompt_idx * request.n - # TODO(simon): optimize the performance by avoiding full text O(n^2) sending. - - if request.echo and request.max_tokens == 0: - # only return the prompt - delta_text = res.prompt - delta_token_ids = res.prompt_token_ids - top_logprobs = res.prompt_logprobs - has_echoed[i] = True - elif request.echo and request.max_tokens > 0 and not has_echoed[i]: - # echo the prompt and first token - delta_text = res.prompt + output.text - delta_token_ids = res.prompt_token_ids + output.token_ids - top_logprobs = res.prompt_logprobs + (output.logprobs or []) - has_echoed[i] = True - else: - # return just the delta - delta_text = output.text[len(previous_texts[i]):] - delta_token_ids = output.token_ids[previous_num_tokens[i]:] - top_logprobs = output.logprobs[ - previous_num_tokens[i]:] if output.logprobs else None - - if request.logprobs is not None: - assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested" - logprobs = create_logprobs_fn( - token_ids=delta_token_ids, - top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - initial_text_offset=len(previous_texts[i]), - ) - else: - logprobs = None - - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) - finish_reason = output.finish_reason - response_json = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[ - CompletionResponseStreamChoice( - index=i, - text=delta_text, - logprobs=logprobs, - finish_reason=finish_reason, - ) - ]).model_dump_json(exclude_unset=True) - yield f"data: {response_json}\n\n" - - if output.finish_reason is not None: # return final usage - logprobs = LogProbs() if request.logprobs is not None else None - prompt_tokens = len(res.prompt_token_ids) - completion_tokens = len(output.token_ids) - final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - response_json = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[ - CompletionResponseStreamChoice( - index=i, - text="", - logprobs=logprobs, - finish_reason=output.finish_reason, - ) - ], - usage=final_usage, - ).model_dump_json(exclude_unset=True) - yield f"data: {response_json}\n\n" - - yield "data: [DONE]\n\n" - - def parse_prompt_format(prompt) -> Tuple[bool, list]: # get the prompt, openai supports the following # "a string, array of strings, array of tokens, or array of token arrays." @@ -150,73 +49,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]: return prompt_is_tokens, prompts -def request_output_to_completion_response( - final_res_batch: List[RequestOutput], - request: CompletionRequest, - create_logprobs_fn: TypeCreateLogProbsFn, - request_id: str, - created_time: int, - model_name: str, -) -> CompletionResponse: - choices = [] - num_prompt_tokens = 0 - num_generated_tokens = 0 - for final_res in final_res_batch: - assert final_res is not None - prompt_token_ids = final_res.prompt_token_ids - prompt_logprobs = final_res.prompt_logprobs - prompt_text = final_res.prompt - - for output in final_res.outputs: - if request.echo and request.max_tokens == 0: - token_ids = prompt_token_ids - top_logprobs = prompt_logprobs - output_text = prompt_text - elif request.echo and request.max_tokens > 0: - token_ids = prompt_token_ids + output.token_ids - top_logprobs = prompt_logprobs + output.logprobs - output_text = prompt_text + output.text - else: - token_ids = output.token_ids - top_logprobs = output.logprobs - output_text = output.text - - if request.logprobs is not None: - logprobs = create_logprobs_fn( - token_ids=token_ids, - top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - ) - else: - logprobs = None - - choice_data = CompletionResponseChoice( - index=len(choices), - text=output_text, - logprobs=logprobs, - finish_reason=output.finish_reason, - ) - choices.append(choice_data) - - num_prompt_tokens += len(prompt_token_ids) - num_generated_tokens += sum( - len(output.token_ids) for output in final_res.outputs) - - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) - - return CompletionResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices, - usage=usage, - ) - - def merge_async_iterators(*iterators): """Merge multiple asynchronous iterators into a single iterator. @@ -229,8 +61,11 @@ def merge_async_iterators(*iterators): finished = [False] * len(iterators) async def producer(i, iterator): - async for item in iterator: - await queue.put((i, item)) + try: + async for item in iterator: + await queue.put((i, item)) + except Exception as e: + await queue.put(e) finished[i] = True _tasks = [ @@ -241,6 +76,8 @@ async def producer(i, iterator): async def consumer(): while not all(finished) or not queue.empty(): item = await queue.get() + if isinstance(item, Exception): + raise item yield item await asyncio.gather(*_tasks) @@ -303,6 +140,7 @@ async def create_completion(self, request: CompletionRequest, prompt_token_ids=input_ids, lora_request=lora_request)) except ValueError as e: + # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) result_generator: AsyncIterator[Tuple[ @@ -316,27 +154,28 @@ async def create_completion(self, request: CompletionRequest, # Streaming response if stream: - return completion_stream_generator(request, - raw_request, - self.engine.abort, - result_generator, - self._create_logprobs, - request_id, - created_time, - model_name, - num_prompts=len(prompts)) + return self.completion_stream_generator(request, + raw_request, + result_generator, + request_id, + created_time, + model_name, + num_prompts=len(prompts)) # Non-streaming response final_res_batch: RequestOutput = [None] * len(prompts) - async for i, res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") - return self.create_error_response("Client disconnected") - final_res_batch[i] = res - response = request_output_to_completion_response( - final_res_batch, request, self._create_logprobs, request_id, - created_time, model_name) + try: + async for i, res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(f"{request_id}-{i}") + return self.create_error_response("Client disconnected") + final_res_batch[i] = res + response = self.request_output_to_completion_response( + final_res_batch, request, request_id, created_time, model_name) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) # When user requests streaming but we don't stream, we still need to # return a streaming response with a single event. @@ -350,3 +189,179 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: return fake_stream_generator() return response + + async def completion_stream_generator( + self, + request: CompletionRequest, + raw_request: Request, + result_generator: AsyncIterator[Tuple[int, RequestOutput]], + request_id: str, + created_time: int, + model_name: str, + num_prompts: int, + ) -> AsyncGenerator[str, None]: + previous_texts = [""] * request.n * num_prompts + previous_num_tokens = [0] * request.n * num_prompts + has_echoed = [False] * request.n * num_prompts + + try: + async for prompt_idx, res in result_generator: + + # Abort the request if the client disconnects. + if await raw_request.is_disconnected(): + await self.engine.abort(f"{request_id}-{prompt_idx}") + raise StopAsyncIteration() + + for output in res.outputs: + i = output.index + prompt_idx * request.n + # TODO(simon): optimize the performance by avoiding full text O(n^2) sending. + + if request.echo and request.max_tokens == 0: + # only return the prompt + delta_text = res.prompt + delta_token_ids = res.prompt_token_ids + top_logprobs = res.prompt_logprobs + has_echoed[i] = True + elif request.echo and request.max_tokens > 0 and not has_echoed[ + i]: + # echo the prompt and first token + delta_text = res.prompt + output.text + delta_token_ids = res.prompt_token_ids + output.token_ids + top_logprobs = res.prompt_logprobs + (output.logprobs + or []) + has_echoed[i] = True + else: + # return just the delta + delta_text = output.text[len(previous_texts[i]):] + delta_token_ids = output.token_ids[ + previous_num_tokens[i]:] + top_logprobs = output.logprobs[previous_num_tokens[ + i]:] if output.logprobs else None + + if request.logprobs is not None: + assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested" + logprobs = self._create_logprobs( + token_ids=delta_token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=len(previous_texts[i]), + ) + else: + logprobs = None + + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + finish_reason = output.finish_reason + response_json = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=logprobs, + finish_reason=finish_reason, + ) + ]).model_dump_json(exclude_unset=True) + yield f"data: {response_json}\n\n" + + if output.finish_reason is not None: # return final usage + logprobs = LogProbs( + ) if request.logprobs is not None else None + prompt_tokens = len(res.prompt_token_ids) + completion_tokens = len(output.token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + response_json = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice( + index=i, + text="", + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + ], + usage=final_usage, + ).model_dump_json(exclude_unset=True) + yield f"data: {response_json}\n\n" + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + data = self.create_streaming_error_response(str(e)) + print("yield", f"data: {data}\n\n") + yield f"data: {data}\n\n" + + print("yield", "data: [DONE]\n\n") + yield "data: [DONE]\n\n" + + def request_output_to_completion_response( + self, + final_res_batch: List[RequestOutput], + request: CompletionRequest, + request_id: str, + created_time: int, + model_name: str, + ) -> CompletionResponse: + choices = [] + num_prompt_tokens = 0 + num_generated_tokens = 0 + for final_res in final_res_batch: + assert final_res is not None + prompt_token_ids = final_res.prompt_token_ids + prompt_logprobs = final_res.prompt_logprobs + prompt_text = final_res.prompt + + for output in final_res.outputs: + if request.echo and request.max_tokens == 0: + token_ids = prompt_token_ids + top_logprobs = prompt_logprobs + output_text = prompt_text + elif request.echo and request.max_tokens > 0: + token_ids = prompt_token_ids + output.token_ids + top_logprobs = prompt_logprobs + output.logprobs + output_text = prompt_text + output.text + else: + token_ids = output.token_ids + top_logprobs = output.logprobs + output_text = output.text + + if request.logprobs is not None: + logprobs = self._create_logprobs( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + ) + else: + logprobs = None + + choice_data = CompletionResponseChoice( + index=len(choices), + text=output_text, + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + num_prompt_tokens += len(prompt_token_ids) + num_generated_tokens += sum( + len(output.token_ids) for output in final_res.outputs) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + + return CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 3af4838195f08..230d13d97dbba 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,4 +1,5 @@ import asyncio +import json from dataclasses import dataclass from http import HTTPStatus from typing import Dict, List, Optional, Union @@ -125,6 +126,19 @@ def create_error_response( type=err_type, code=status_code.value) + def create_streaming_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: + json_str = json.dumps({ + "error": + self.create_error_response(message=message, + err_type=err_type, + status_code=status_code).model_dump() + }) + return json_str + async def _check_model(self, request) -> Optional[ErrorResponse]: if request.model == self.served_model: return diff --git a/vllm/sequence.py b/vllm/sequence.py index ee0490b11dfa4..b5a24876d905a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -459,7 +459,7 @@ def __init__( self, parent_seq_id: int, output_token: int, - logprobs: Dict[int, float], + logprobs: Dict[int, Logprob], ) -> None: self.parent_seq_id = parent_seq_id self.output_token = output_token @@ -475,10 +475,7 @@ def __eq__(self, other: object) -> bool: raise NotImplementedError() equal = (self.parent_seq_id == other.parent_seq_id and self.output_token == other.output_token) - log_probs_equal = ((len(other.logprobs) == len(self.logprobs)) - and all(other_logprob == self_logprob - for other_logprob, self_logprob in zip( - other.logprobs, self.logprobs))) + log_probs_equal = other.logprobs == self.logprobs return equal and log_probs_equal From 306d3dd4fef5680596b856e231074baa0ba190ba Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 28 Feb 2024 18:19:05 -0800 Subject: [PATCH 3/7] Trigger CI From cafccae1f79b82d9e96a4640ff752c69c9cb386f Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 28 Feb 2024 18:37:53 -0800 Subject: [PATCH 4/7] Revert --- vllm/engine/llm_engine.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a0ad7e6c9b46d..737434aa58d50 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -482,8 +482,6 @@ def add_request( # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - assert prompt - assert prompt_token_ids seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, lora_request) From 05fcdccfc903a3500e5435a94d37b89b6729ea66 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 4 Mar 2024 10:38:25 -0800 Subject: [PATCH 5/7] max_log_probs -> max_logprobs --- tests/samplers/test_logprobs.py | 6 +++--- vllm/config.py | 4 ++-- vllm/engine/arg_utils.py | 8 ++++---- vllm/engine/llm_engine.py | 8 ++++---- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 911676b05cce3..59c16112a1fed 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -27,7 +27,7 @@ def test_get_prompt_logprobs( vllm_model = vllm_runner(model, dtype=dtype, - max_log_probs=num_top_logprobs) + max_logprobs=num_top_logprobs) vllm_sampling_params = SamplingParams(max_tokens=max_tokens, logprobs=num_top_logprobs, prompt_logprobs=5, @@ -77,8 +77,8 @@ def test_get_prompt_logprobs( " to the user.") -def test_max_log_probs(): - runner = VllmRunner("facebook/opt-125m", max_log_probs=1) +def test_max_logprobs(): + runner = VllmRunner("facebook/opt-125m", max_logprobs=1) vllm_sampling_params = SamplingParams(logprobs=1) # should pass runner.generate(["Hello world"], sampling_params=vllm_sampling_params) diff --git a/vllm/config.py b/vllm/config.py index 723b12b802c2e..ef9a920f29c2a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -79,7 +79,7 @@ def __init__( quantization: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, - max_log_probs: int = 5, + max_logprobs: int = 5, ) -> None: self.model = model self.tokenizer = tokenizer @@ -94,7 +94,7 @@ def __init__( self.quantization = quantization self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture - self.max_log_probs = max_log_probs + self.max_logprobs = max_logprobs if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": # download model from ModelScope hub, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7effe0eb50c4f..d7e62cc08aa1f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -31,7 +31,7 @@ class EngineArgs: max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 max_paddings: int = 256 - max_log_probs: int = 5 + max_logprobs: int = 5 disable_log_stats: bool = False revision: Optional[str] = None code_revision: Optional[str] = None @@ -214,9 +214,9 @@ def add_cli_args( default=EngineArgs.max_paddings, help='maximum number of paddings in a batch') parser.add_argument( - '--max-log-probs', + '--max-logprobs', type=int, - default=EngineArgs.max_log_probs, + default=EngineArgs.max_logprobs, help=('max number of log probs to return logprobs is specified in' ' SamplingParams')) parser.add_argument('--disable-log-stats', @@ -308,7 +308,7 @@ def create_engine_configs( self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.enforce_eager, self.max_context_len_to_capture, - self.max_log_probs) + self.max_logprobs) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 887ee5ae2ef86..703756996b7f7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -473,13 +473,13 @@ def add_request( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") - max_log_probs = self.get_model_config().max_log_probs + max_logprobs = self.get_model_config().max_logprobs if (sampling_params.logprobs - and sampling_params.logprobs > max_log_probs) or ( + and sampling_params.logprobs > max_logprobs) or ( sampling_params.prompt_logprobs - and sampling_params.prompt_logprobs > max_log_probs): + and sampling_params.prompt_logprobs > max_logprobs): raise ValueError(f"Cannot request more than " - f"{max_log_probs} logprobs.") + f"{max_logprobs} logprobs.") if arrival_time is None: arrival_time = time.monotonic() prompt_token_ids = self.encode_request( From 9da7def1abde05bb1a7524526f996089c26748c0 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 4 Mar 2024 10:48:29 -0800 Subject: [PATCH 6/7] Add comment --- vllm/engine/arg_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d7e62cc08aa1f..c3dccdd5bb50b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -31,7 +31,7 @@ class EngineArgs: max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 max_paddings: int = 256 - max_logprobs: int = 5 + max_logprobs: int = 5 # OpenAI default value disable_log_stats: bool = False revision: Optional[str] = None code_revision: Optional[str] = None From 2c3e8dafde8bc610206d8e9674ed32814172b76f Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 4 Mar 2024 10:49:02 -0800 Subject: [PATCH 7/7] Lint --- tests/samplers/test_logprobs.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 59c16112a1fed..1abb55f021214 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -25,9 +25,7 @@ def test_get_prompt_logprobs( ) del hf_model - vllm_model = vllm_runner(model, - dtype=dtype, - max_logprobs=num_top_logprobs) + vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs) vllm_sampling_params = SamplingParams(max_tokens=max_tokens, logprobs=num_top_logprobs, prompt_logprobs=5,