Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
LM Format Enforcer Guided Decoding Support (vllm-project#3868)
Browse files Browse the repository at this point in the history
Co-authored-by: Simon Mo <[email protected]>
  • Loading branch information
2 people authored and robertgshaw2-neuralmagic committed Apr 21, 2024
1 parent ead1e24 commit 723e328
Show file tree
Hide file tree
Showing 13 changed files with 304 additions and 87 deletions.
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.9.3
outlines == 0.0.34 # Requires torch >= 2.1.0
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
42 changes: 39 additions & 3 deletions tests/entrypoints/test_guided_processors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# This unit test should be moved to a new
# tests/test_guided_decoding directory.

import pytest
import torch
from transformers import AutoTokenizer

from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
RegexLogitsProcessor)
from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)

TEST_SCHEMA = {
"type": "object",
Expand Down Expand Up @@ -73,3 +76,36 @@ def test_guided_logits_processors():
json_LP(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)


@pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
async def test_guided_logits_processor_black_box(backend: str):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
regex_request = CompletionRequest(model='test',
prompt=token_ids,
guided_regex=TEST_REGEX)
regex_lp = await get_guided_decoding_logits_processor(
backend, regex_request, tokenizer)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)

token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
json_request = CompletionRequest(model='test',
prompt=token_ids,
guided_json=TEST_SCHEMA)
json_lp = await get_guided_decoding_logits_processor(
backend, json_request, tokenizer)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
69 changes: 50 additions & 19 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,15 +506,19 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text


async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example JSON for an employee profile "
f"that fits this schema: {TEST_SCHEMA}",
n=3,
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
extra_body=dict(guided_json=TEST_SCHEMA,
guided_decoding_backend=guided_decoding_backend))

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3
Expand All @@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)


async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
max_tokens=1000,
extra_body=dict(guided_json=TEST_SCHEMA,
guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message
assert message.content is not None
json1 = json.loads(message.content)
Expand All @@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
max_tokens=1000,
extra_body=dict(guided_json=TEST_SCHEMA,
guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message
assert message.content is not None
json2 = json.loads(message.content)
Expand All @@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
assert json1["age"] != json2["age"]


async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
n=3,
temperature=1.0,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
extra_body=dict(guided_regex=TEST_REGEX,
guided_decoding_backend=guided_decoding_backend))

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3
Expand All @@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None


async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
extra_body=dict(guided_regex=TEST_REGEX,
guided_decoding_backend=guided_decoding_backend))
ip1 = chat_completion.choices[0].message.content
assert ip1 is not None
assert re.fullmatch(TEST_REGEX, ip1) is not None
Expand All @@ -606,29 +623,37 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
extra_body=dict(guided_regex=TEST_REGEX,
guided_decoding_backend=guided_decoding_backend))
ip2 = chat_completion.choices[0].message.content
assert ip2 is not None
assert re.fullmatch(TEST_REGEX, ip2) is not None
assert ip1 != ip2


async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
completion = await client.completions.create(
model=MODEL_NAME,
prompt="The best language for type-safe systems programming is ",
n=2,
temperature=1.0,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 2
for i in range(2):
assert completion.choices[i].text in TEST_CHOICE


async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
choice1 = chat_completion.choices[0].message.content
assert choice1 in TEST_CHOICE

Expand All @@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
choice2 = chat_completion.choices[0].message.content
assert choice2 in TEST_CHOICE
assert choice1 != choice2


async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example JSON that fits this schema: 42",
extra_body=dict(guided_json=42))
extra_body=dict(guided_json=42,
guided_decoding_backend=guided_decoding_backend))

messages = [{
"role": "system",
Expand Down
26 changes: 21 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class ModelConfig:
weights. If None, we assume the model weights are not quantized.
quantization_param_path: Path to JSON file containing scaling factors.
Used to load KV cache scaling factors into the model when KV cache
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
be used to load activation and weight scaling factors when the
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
be used to load activation and weight scaling factors when the
model dtype is FP8_E4M3 on ROCm.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
Expand Down Expand Up @@ -454,7 +454,7 @@ def verify_with_parallel_config(
@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool.
Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
Expand All @@ -478,9 +478,9 @@ def create_config(
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
If tokenizer_pool_size is 0, return None.
Args:
tokenizer_pool_size: Number of tokenizer workers in the pool.
tokenizer_pool_type: Type of the pool.
Expand Down Expand Up @@ -1119,6 +1119,21 @@ def _get_and_verify_max_len(
return int(max_model_len)


@dataclass
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine"""

# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
guided_decoding_backend: str = 'outlines'

def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer']
backend = self.guided_decoding_backend
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
f"must be one of {valid_guided_backends}")


@dataclass(frozen=True)
class EngineConfig:
"""Dataclass which contains all engine-related configuration. This
Expand All @@ -1133,6 +1148,7 @@ class EngineConfig:
lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig]
speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig]
tensorizer_config: Optional[TensorizerConfig]

def __post_init__(self):
Expand Down
18 changes: 15 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from dataclasses import dataclass
from typing import BinaryIO, Optional, Union

from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, TensorizerConfig,
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig, TensorizerConfig,
TokenizerPoolConfig, VisionLanguageConfig)
from vllm.model_executor.tensorizer_loader import TensorizerArgs
from vllm.utils import str_to_int_tuple
Expand Down Expand Up @@ -84,6 +84,7 @@ class EngineArgs:
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False

guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
Expand Down Expand Up @@ -204,6 +205,13 @@ def add_cli_args(
default=EngineArgs.max_model_len,
help='model context length. If unspecified, '
'will be automatically derived from the model.')
parser.add_argument(
'--guided-decoding-backend',
type=str,
default='outlines',
choices=['outlines', 'lm-format-enforcer'],
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc)')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
Expand Down Expand Up @@ -539,6 +547,9 @@ def create_engine_config(self, ) -> EngineConfig:
else:
vision_language_config = None

decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)

return EngineConfig(model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
Expand All @@ -547,6 +558,7 @@ def create_engine_config(self, ) -> EngineConfig:
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
decoding_config=decoding_config,
tensorizer_config=tensorizer_config)


Expand Down
10 changes: 7 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from transformers import PreTrainedTokenizer

import vllm
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
TensorizerConfig, VisionLanguageConfig)
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, TensorizerConfig,
VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
tensorizer_config: Optional[TensorizerConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
Expand Down Expand Up @@ -102,6 +104,7 @@ def __init__(
f"kv_cache_dtype={cache_config.cache_dtype}, "
f"quantization_param_path={model_config.quantization_param_path}, "
f"device_config={device_config.device}, "
f"decoding_config={decoding_config!r}, "
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.

Expand All @@ -113,6 +116,7 @@ def __init__(
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.decoding_config = decoding_config or DecodingConfig()
self.tensorizer_config = tensorizer_config
self.log_stats = log_stats

Expand Down
Loading

0 comments on commit 723e328

Please sign in to comment.