From c3845d82dc3d1831714898114f87d9c103e2dd41 Mon Sep 17 00:00:00 2001 From: Robert Caulk Date: Wed, 1 May 2024 05:48:39 +0200 Subject: [PATCH] Allow user to define whitespace pattern for outlines (#4305) --- tests/entrypoints/test_guided_processors.py | 4 +++- vllm/entrypoints/openai/protocol.py | 10 ++++++++++ .../guided_decoding/outlines_decoding.py | 8 +++++--- .../guided_decoding/outlines_logits_processors.py | 7 +++---- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 30f0ad5d8272f..41c871ca40bc8 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -57,7 +57,9 @@ def test_guided_logits_processors(): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer) - json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer) + json_LP = JSONLogitsProcessor(TEST_SCHEMA, + tokenizer, + whitespace_pattern=None) regex_LP.init_state() token_ids = tokenizer.encode( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0a949f9867754..731596e80bd71 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -146,6 +146,11 @@ class ChatCompletionRequest(OpenAIBaseModel): "If specified, will override the default guided decoding backend " "of the server for this specific request. If set, must be either " "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) # doc: end-chat-completion-extra-params @@ -285,6 +290,11 @@ class CompletionRequest(OpenAIBaseModel): "If specified, will override the default guided decoding backend " "of the server for this specific request. If set, must be one of " "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) # doc: end-completion-extra-params diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 53efebb604048..8403604286903 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor( result = await loop.run_in_executor(global_thread_pool, _get_cached_logits_processor, guide, - tokenizer, mode) + tokenizer, mode, + request.guided_whitespace_pattern) logits_processor = copy(result) # reset logits processor's internal state @@ -117,9 +118,10 @@ def _get_guide_and_mode( @lru_cache(maxsize=32) def _get_cached_logits_processor(guide: str, tokenizer: PreTrainedTokenizerBase, - mode: GuidedDecodingMode): + mode: GuidedDecodingMode, + whitespace_pattern: Union[str, None]): if mode == GuidedDecodingMode.JSON: - return JSONLogitsProcessor(guide, tokenizer) + return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: return RegexLogitsProcessor(guide, tokenizer) elif mode == GuidedDecodingMode.GRAMMAR: diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 25ab5bf8b6a9c..a131c6a1b92b4 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -18,7 +18,7 @@ import math from collections import defaultdict from functools import lru_cache -from typing import Callable, DefaultDict, Dict, List, Optional, Union +from typing import Callable, DefaultDict, Dict, List, Union import torch from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM @@ -80,10 +80,9 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__(self, - schema: Union[str, Dict, BaseModel], + def __init__(self, schema: Union[str, Dict, BaseModel], tokenizer: PreTrainedTokenizerBase, - whitespace_pattern: Optional[str] = None): + whitespace_pattern: Union[str, None]): """Compile the FSM that drives the JSON-guided generation. Parameters