Skip to content

Commit

Permalink
Add guided decoding for OpenAI API server (vllm-project#2819)
Browse files Browse the repository at this point in the history
Co-authored-by: br3no <[email protected]>
Co-authored-by: simon-mo <[email protected]>
  • Loading branch information
3 people authored Feb 29, 2024
1 parent 29a8d6a commit 703e42e
Show file tree
Hide file tree
Showing 9 changed files with 597 additions and 1 deletion.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
pynvml == 11.5.0
triton >= 2.1.0
outlines >= 0.0.27
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
75 changes: 75 additions & 0 deletions tests/entrypoints/test_guided_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# This unit test should be moved to a new
# tests/test_guided_decoding directory.

from transformers import AutoTokenizer
import torch

from vllm.model_executor.guided_logits_processors import (RegexLogitsProcessor,
JSONLogitsProcessor)

TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "string"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}

TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"


def test_guided_logits_processors():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer)

regex_LP.init_state()
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
regex_LP(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)

json_LP.init_state()
token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
json_LP(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
237 changes: 237 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,64 @@
import openai # use the official client for correctness check
from huggingface_hub import snapshot_download # downloading lora to test lora requests

# imports for guided decoding tests
import json
import jsonschema
import re

from vllm.transformers_utils.tokenizer import get_tokenizer

MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here

TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "string"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}

TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"

TEST_CHOICE = [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
"Swift", "Kotlin"
]

pytestmark = pytest.mark.asyncio


Expand Down Expand Up @@ -325,6 +377,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token_id): 100},
seed=42,
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
Expand Down Expand Up @@ -358,5 +411,189 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text


async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=
f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}",
n=3,
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3
for i in range(3):
assert completion.choices[i].text is not None
output_json = json.loads(completion.choices[i].text)
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)


async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "Give an example JSON for an employee profile that " + \
f"fits this schema: {TEST_SCHEMA}"
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
message = chat_completion.choices[0].message
assert message.content is not None
json1 = json.loads(message.content)
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)

messages.append({"role": "assistant", "content": message.content})
messages.append({
"role":
"user",
"content":
"Give me another one with a different name and age"
})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
message = chat_completion.choices[0].message
assert message.content is not None
json2 = json.loads(message.content)
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
assert json1["name"] != json2["name"]
assert json1["age"] != json2["age"]


async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
n=3,
temperature=1.0,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3
for i in range(3):
assert completion.choices[i].text is not None
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None


async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example IP address with this regex: {TEST_REGEX}"
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
ip1 = chat_completion.choices[0].message.content
assert ip1 is not None
assert re.fullmatch(TEST_REGEX, ip1) is not None

messages.append({"role": "assistant", "content": ip1})
messages.append({"role": "user", "content": "Give me a different one"})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
ip2 = chat_completion.choices[0].message.content
assert ip2 is not None
assert re.fullmatch(TEST_REGEX, ip2) is not None
assert ip1 != ip2


async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=MODEL_NAME,
prompt="The best language for type-safe systems programming is ",
n=2,
temperature=1.0,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 2
for i in range(2):
assert completion.choices[i].text in TEST_CHOICE


async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
"The best language for type-safe systems programming is "
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
choice1 = chat_completion.choices[0].message.content
assert choice1 in TEST_CHOICE

messages.append({"role": "assistant", "content": choice1})
messages.append({
"role": "user",
"content": "I disagree, pick another one"
})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
choice2 = chat_completion.choices[0].message.content
assert choice2 in TEST_CHOICE
assert choice1 != choice2


async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example JSON that fits this schema: 42",
extra_body=dict(guided_json=42))

messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
"The best language for type-safe systems programming is "
}]
with pytest.raises(openai.BadRequestError):
_ = await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
extra_body=dict(guided_regex={
1: "Python",
2: "C++"
}))

with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example string that fits this regex",
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))


if __name__ == "__main__":
pytest.main([__file__])
3 changes: 3 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ def is_running(self) -> bool:
return (self.background_loop is not None
and not self.background_loop.done())

def get_tokenizer(self):
return self.engine.tokenizer.tokenizer

def start_background_loop(self) -> None:
"""Start the background loop."""
if self.is_running:
Expand Down
Loading

0 comments on commit 703e42e

Please sign in to comment.