Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HotFix] Fix final output truncation with stop string + streaming #8468

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def should_do_global_cleanup_after_test(request) -> bool:


@pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine):
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_asyncio_run(async_engine, stop):

scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps
Expand All @@ -169,6 +170,7 @@ async def run(prompt: str):
temperature=0,
max_tokens=32,
min_tokens=32,
stop=stop,
)

output_count = 0
Expand Down Expand Up @@ -203,7 +205,8 @@ async def run(prompt: str):


@pytest.mark.asyncio(scope="module")
async def test_output_kinds(async_engine):
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_output_kinds(async_engine, stop):
"""Test that output_kind works as expected and that
results are equivalent across different kinds."""

Expand All @@ -214,6 +217,7 @@ async def test_output_kinds(async_engine):
temperature=0,
max_tokens=32,
min_tokens=32,
stop=stop,
)

async def run(prompt: str, kind: RequestOutputKind):
Expand All @@ -229,6 +233,8 @@ async def run(prompt: str, kind: RequestOutputKind):
final_output = output

assert final_output is not None
assert final_output.finished

return (final_output.prompt_token_ids,
final_output.outputs[0].token_ids,
final_output.outputs[0].text, output_count)
Expand All @@ -241,16 +247,18 @@ async def run_deltas(prompt: str):
output_tokens: List[int] = []
output_text = ""
output_count = 0
final_output = None
async for output in async_engine.generate(prompt,
params,
request_id=uid()):
token_ids = output.outputs[0].token_ids
text = output.outputs[0].text
final_output = output

# Ensure we get prompt ids iff we haven't yet received output tokens
if output_tokens:
assert 1 <= len(token_ids) <= num_scheduler_steps
assert text
assert stop or text
assert not output.prompt_token_ids
else:
assert output.prompt_token_ids
Expand All @@ -260,6 +268,10 @@ async def run_deltas(prompt: str):
output_text += text

output_count += 1

assert final_output is not None
assert final_output.finished

return prompt_tokens, output_tokens, output_text, output_count

results = await asyncio.gather(
Expand Down Expand Up @@ -291,14 +303,16 @@ async def run_deltas(prompt: str):


@pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine):
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_cancellation(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps

sampling_params = SamplingParams(
temperature=0,
min_tokens=13,
max_tokens=13,
stop=stop,
)

stop_at = 5 if num_scheduler_steps == 1 else 1
Expand All @@ -319,7 +333,8 @@ async def test_cancellation(async_engine):


@pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine):
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_delayed_generator(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config()

if scheduler_config.num_scheduler_steps != 1:
Expand All @@ -329,6 +344,7 @@ async def test_delayed_generator(async_engine):
temperature=0,
min_tokens=10,
max_tokens=10,
stop=stop,
)

stream = async_engine.generate("test3", sampling_params, request_id=uid())
Expand Down
4 changes: 3 additions & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,9 @@ def get_output_text_to_return(self, buffer_length: int,
if not delta:
return self.output_text[:-buffer_length] if truncate else (
self.output_text)
length = len(self.output_text) - buffer_length
length = len(self.output_text)
if truncate:
length -= buffer_length
last_offset = self._last_output_text_offset
if last_offset < length:
self._last_output_text_offset = length
Expand Down
Loading