Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for guided decoding for offline LLM #6878

Merged
merged 13 commits into from
Aug 4, 2024
69 changes: 69 additions & 0 deletions tests/entrypoints/llm/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest


@pytest.fixture
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")


@pytest.fixture
def sample_json_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}


@pytest.fixture
def sample_guided_choice():
return [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
"Ruby", "Swift", "Kotlin"
]


@pytest.fixture
def sample_sql_statements():
return ("""
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
""")
kevinbu233 marked this conversation as resolved.
Show resolved Hide resolved
154 changes: 154 additions & 0 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import json
import re
import weakref
from typing import List

import jsonschema
import pytest

from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams

from ...conftest import cleanup


MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"

PROMPTS = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

TOKEN_IDS = [
[0],
[0, 1],
[0, 2, 1],
[0, 3, 1, 2],
]
kevinbu233 marked this conversation as resolved.
Show resolved Hide resolved

@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME, max_model_len=1024)

with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()

@pytest.mark.skip_global_cleanup
def test_guided_regex(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options=dict(guided_regex=sample_regex))

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert re.fullmatch(sample_regex, generated_text) is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_json_completion(sample_json_schema, llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
)
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options=dict(guided_json=sample_json_schema))

assert outputs is not None

for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)


@pytest.mark.skip_global_cleanup
def test_guided_choice_completion(sample_guided_choice, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True,
guided_options=dict(guided_choice=sample_guided_choice))

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_grammar(sample_sql_statements, llm):

sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
sampling_params=sampling_params,
use_tqdm=True,
guided_options=dict(guided_grammar=sample_sql_statements))

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None

# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_statements)
parser.parse(generated_text)

# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")

assert generated_text.strip() == ground_truth

print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
47 changes: 45 additions & 2 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
from typing import (ClassVar, Dict, List, Optional, Sequence, Union, cast,
overload)

from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
Expand All @@ -10,6 +11,8 @@
parse_and_batch_prompt)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
GuidedDecodingFields, get_local_guided_decoding_logits_processor)
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand Down Expand Up @@ -262,6 +265,7 @@ def generate(
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options: Optional[Union[Dict, GuidedDecodingFields]] = None
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.

Expand Down Expand Up @@ -303,6 +307,13 @@ def generate(
else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)

if isinstance(guided_options, Dict):
kevinbu233 marked this conversation as resolved.
Show resolved Hide resolved
if len(guided_options) > 1:
raise ValueError(
"You can only use one guided decoding but multiple is "
f"specified: {self.__dict__}")
kevinbu233 marked this conversation as resolved.
Show resolved Hide resolved
guided_options = GuidedDecodingFields(**guided_options)

if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
Expand All @@ -311,7 +322,8 @@ def generate(
inputs=inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
guided_options=guided_options)

outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)
Expand Down Expand Up @@ -508,6 +520,7 @@ def _validate_and_add_requests(
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
guided_options: Optional[GuidedDecodingFields] = None,
) -> None:
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
Expand All @@ -523,6 +536,18 @@ def _validate_and_add_requests(
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")

if isinstance(params, list):
if len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
kevinbu233 marked this conversation as resolved.
Show resolved Hide resolved

params = [
self._add_guided_processor(param, guided_options)
for param in params if isinstance(param, SamplingParams)
]
elif isinstance(params, SamplingParams):
params = self._add_guided_processor(params, guided_options)

# Add requests to the engine.
for i, request_inputs in enumerate(inputs):
self._add_request(
Expand All @@ -548,6 +573,24 @@ def _add_request(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)

def _add_guided_processor(
self,
params: SamplingParams,
guided_options: Optional[GuidedDecodingFields] = None):
if guided_options:
if guided_options.guided_decoding_backend is None:
decoding_config = self.llm_engine.get_decoding_config()
guided_options.guided_decoding_backend = (
decoding_config.guided_decoding_backend)
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
guided_options.guided_decoding_backend, guided_options,
self.get_tokenizer())
if guided_logits_processor:
if params.logits_processors is None:
params.logits_processors = []
params.logits_processors.append(guided_logits_processor)
return params
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved

def _run_engine(
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Expand Down
26 changes: 24 additions & 2 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
get_lm_format_enforcer_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingFields)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor

Expand All @@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor(
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_lm_format_enforcer_guided_decoding_logits_processor)
return await get_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer)

Expand All @@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor(
"Must be one of 'outlines, 'lm-format-enforcer'")


def get_local_guided_decoding_logits_processor(
guided_decoding_backend: str, guided_options: GuidedDecodingFields,
tokenizer) -> Optional[LogitsProcessor]:
# request = _adapt_request_for_tool_use(request)

if guided_decoding_backend == 'outlines':
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options, tokenizer)

raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")


def _adapt_request_for_tool_use(request: Union[CompletionRequest,
ChatCompletionRequest]):
# the legacy completion API does not support tool use
Expand Down
Loading
Loading