Skip to content

Commit

Permalink
Allow user to define whitespace pattern for outlines (vllm-project#4305)
Browse files Browse the repository at this point in the history
  • Loading branch information
robcaulk authored May 1, 2024
1 parent a822eb3 commit c3845d8
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 8 deletions.
4 changes: 3 additions & 1 deletion tests/entrypoints/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def test_guided_logits_processors():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer)
json_LP = JSONLogitsProcessor(TEST_SCHEMA,
tokenizer,
whitespace_pattern=None)

regex_LP.init_state()
token_ids = tokenizer.encode(
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))

# doc: end-chat-completion-extra-params

Expand Down Expand Up @@ -285,6 +290,11 @@ class CompletionRequest(OpenAIBaseModel):
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))

# doc: end-completion-extra-params

Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/guided_decoding/outlines_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor(

result = await loop.run_in_executor(global_thread_pool,
_get_cached_logits_processor, guide,
tokenizer, mode)
tokenizer, mode,
request.guided_whitespace_pattern)

logits_processor = copy(result)
# reset logits processor's internal state
Expand Down Expand Up @@ -117,9 +118,10 @@ def _get_guide_and_mode(
@lru_cache(maxsize=32)
def _get_cached_logits_processor(guide: str,
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode):
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None]):
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer)
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.GRAMMAR:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import math
from collections import defaultdict
from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Optional, Union
from typing import Callable, DefaultDict, Dict, List, Union

import torch
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
Expand Down Expand Up @@ -80,10 +80,9 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):

class JSONLogitsProcessor(RegexLogitsProcessor):

def __init__(self,
schema: Union[str, Dict, BaseModel],
def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Optional[str] = None):
whitespace_pattern: Union[str, None]):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
Expand Down

0 comments on commit c3845d8

Please sign in to comment.