Skip to content

Commit

Permalink
Support Batch Completion in Server (vllm-project#2529)
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-mo authored Jan 25, 2024
1 parent 223c192 commit 3a7dd7e
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 105 deletions.
55 changes: 53 additions & 2 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
import os
import subprocess
import time

import sys
import pytest
Expand All @@ -17,8 +18,11 @@
class ServerRunner:

def __init__(self, args):
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
Expand Down Expand Up @@ -58,7 +62,8 @@ def server():
"--dtype",
"bfloat16", # use half precision for speed and memory savings in CI environment
"--max-model-len",
"8192"
"8192",
"--enforce-eager",
])
ray.get(server_runner.ready.remote())
yield server_runner
Expand Down Expand Up @@ -199,5 +204,51 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
assert "".join(chunks) == output


async def test_batch_completions(server, client: openai.AsyncOpenAI):
# test simple list
batch = await client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
)
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text

# test n = 2
batch = await client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"

# test streaming
batch = await client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]


if __name__ == "__main__":
pytest.main([__file__])
Loading

0 comments on commit 3a7dd7e

Please sign in to comment.