From 5277dc35f4782537eea8ebddb3c133de4916120e Mon Sep 17 00:00:00 2001 From: Paul Grundmann Date: Wed, 7 Aug 2024 11:31:05 +0200 Subject: [PATCH] Revert "Support for guided decoding for offline LLM (#6878)" This reverts commit 654bc5ca49bde0969bc95e4b1dbe7fabbb8f631c. --- docs/source/conf.py | 1 - tests/entrypoints/llm/test_guided_generate.py | 142 ------------------ tests/entrypoints/{ => openai}/conftest.py | 22 +-- vllm/entrypoints/llm.py | 44 +----- vllm/entrypoints/openai/protocol.py | 26 +--- .../guided_decoding/__init__.py | 26 +--- .../guided_decoding/guided_fields.py | 38 ----- .../lm_format_enforcer_decoding.py | 39 ----- .../guided_decoding/outlines_decoding.py | 26 +--- 9 files changed, 12 insertions(+), 352 deletions(-) delete mode 100644 tests/entrypoints/llm/test_guided_generate.py rename tests/entrypoints/{ => openai}/conftest.py (83%) delete mode 100644 vllm/model_executor/guided_decoding/guided_fields.py diff --git a/docs/source/conf.py b/docs/source/conf.py index f1eb8524d4e9c..1093b30bca11d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -111,7 +111,6 @@ def setup(app): "tqdm", "tensorizer", "pynvml", - "outlines", ] for mock_target in autodoc_mock_imports: diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py deleted file mode 100644 index 873e115421257..0000000000000 --- a/tests/entrypoints/llm/test_guided_generate.py +++ /dev/null @@ -1,142 +0,0 @@ -import json -import re -import weakref - -import jsonschema -import pytest - -from vllm.entrypoints.llm import LLM -from vllm.outputs import RequestOutput -from vllm.sampling_params import SamplingParams - -from ...conftest import cleanup - -MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" - - -@pytest.fixture(scope="module") -def llm(): - # pytest caches the fixture so we use weakref.proxy to - # enable garbage collection - llm = LLM(model=MODEL_NAME, max_model_len=1024) - - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) - del llm - cleanup() - - -@pytest.mark.skip_global_cleanup -def test_guided_regex(sample_regex, llm): - sampling_params = SamplingParams( - temperature=0.8, - top_p=0.95, - ) - outputs = llm.generate( - prompts=[ - f"Give an example IPv4 address with this regex: {sample_regex}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_regex=sample_regex)) - - assert outputs is not None - for output in outputs: - assert output is not None - assert isinstance(output, RequestOutput) - prompt = output.prompt - generated_text = output.outputs[0].text - print(generated_text) - assert generated_text is not None - assert re.fullmatch(sample_regex, generated_text) is not None - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -@pytest.mark.skip_global_cleanup -def test_guided_json_completion(sample_json_schema, llm): - sampling_params = SamplingParams( - temperature=1.0, - max_tokens=1000, - ) - outputs = llm.generate( - prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {sample_json_schema}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_json=sample_json_schema)) - - assert outputs is not None - - for output in outputs: - assert output is not None - assert isinstance(output, RequestOutput) - prompt = output.prompt - - generated_text = output.outputs[0].text - assert generated_text is not None - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - output_json = json.loads(generated_text) - jsonschema.validate(instance=output_json, schema=sample_json_schema) - - -@pytest.mark.skip_global_cleanup -def test_guided_choice_completion(sample_guided_choice, llm): - sampling_params = SamplingParams( - temperature=0.8, - top_p=0.95, - ) - outputs = llm.generate( - prompts="The best language for type-safe systems programming is ", - sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_choice=sample_guided_choice)) - - assert outputs is not None - for output in outputs: - assert output is not None - assert isinstance(output, RequestOutput) - prompt = output.prompt - generated_text = output.outputs[0].text - print(generated_text) - assert generated_text is not None - assert generated_text in sample_guided_choice - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -@pytest.mark.skip_global_cleanup -def test_guided_grammar(sample_sql_statements, llm): - - sampling_params = SamplingParams( - temperature=0.8, - top_p=0.95, - max_tokens=1000, - ) - outputs = llm.generate( - prompts=("Generate a sql state that select col_1 from " - "table_1 where it is equals to 1"), - sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_grammar=sample_sql_statements)) - - assert outputs is not None - for output in outputs: - assert output is not None - assert isinstance(output, RequestOutput) - prompt = output.prompt - - generated_text = output.outputs[0].text - assert generated_text is not None - # use Lark to parse the output, and make sure it's a valid parse tree - from lark import Lark - parser = Lark(sample_sql_statements) - parser.parse(generated_text) - - # remove spaces for comparison b/c we removed them in the grammar - ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( - " ", "") - - assert generated_text.strip() == ground_truth - - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/openai/conftest.py similarity index 83% rename from tests/entrypoints/conftest.py rename to tests/entrypoints/openai/conftest.py index e7ef5637c8ccb..0837644f26bde 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/openai/conftest.py @@ -1,26 +1,6 @@ import pytest -@pytest.fixture -def sample_prompts(): - return [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - -@pytest.fixture -def sample_token_ids(): - return [ - [0], - [0, 1], - [0, 2, 1], - [0, 3, 1, 2], - ] - - @pytest.fixture def sample_regex(): return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" @@ -86,4 +66,4 @@ def sample_sql_statements(): table: "table_1" | "table_2" condition: column "=" number number: "1" | "2" -""") +""") \ No newline at end of file diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 262cba79e5712..62309ed345b1d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -10,9 +10,6 @@ parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.guided_decoding import ( - GuidedDecodingRequest, get_local_guided_decoding_logits_processor) -from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -265,8 +262,6 @@ def generate( use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - guided_options_request: Optional[Union[LLMGuidedOptions, - GuidedDecodingRequest]] = None ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -308,14 +303,6 @@ def generate( else: inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) - if isinstance(guided_options_request, dict): - if len(guided_options_request) > 1: - raise ValueError( - "You can only use one guided decoding but multiple is " - f"specified: {guided_options_request}") - guided_options_request = GuidedDecodingRequest( - **guided_options_request) - if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() @@ -324,8 +311,7 @@ def generate( inputs=inputs, params=sampling_params, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - guided_options=guided_options_request) + prompt_adapter_request=prompt_adapter_request) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) @@ -522,7 +508,6 @@ def _validate_and_add_requests( Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], - guided_options: Optional[GuidedDecodingRequest] = None, ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. @@ -538,15 +523,6 @@ def _validate_and_add_requests( raise ValueError("The lengths of prompts and lora_request " "must be the same.") - if isinstance(params, list): - params = [ - self._add_guided_processor(param, guided_options) - if isinstance(param, SamplingParams) else param - for param in params - ] - elif isinstance(params, SamplingParams): - params = self._add_guided_processor(params, guided_options) - # Add requests to the engine. for i, request_inputs in enumerate(inputs): self._add_request( @@ -572,24 +548,6 @@ def _add_request( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - def _add_guided_processor( - self, - params: SamplingParams, - guided_options: Optional[GuidedDecodingRequest] = None): - if guided_options: - if guided_options.guided_decoding_backend is None: - decoding_config = self.llm_engine.get_decoding_config() - guided_options.guided_decoding_backend = ( - decoding_config.guided_decoding_backend) - guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa - guided_options.guided_decoding_backend, guided_options, - self.get_tokenizer()) - if guided_logits_processor: - if params.logits_processors is None: - params.logits_processors = [] - params.logits_processors.append(guided_logits_processor) - return params - def _run_engine( self, *, use_tqdm: bool ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 76318a1271229..3b35ae1ebd705 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,7 +1,6 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time -from argparse import Namespace from typing import Any, Dict, List, Literal, Optional, Union import torch @@ -15,23 +14,6 @@ from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.utils import random_uuid -# torch is mocked during docs generation, -# so we have to provide the values as literals -_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) - -try: - from sphinx.ext.autodoc.mock import _MockModule - - if isinstance(torch, _MockModule): - _LONG_INFO = _MOCK_LONG_INFO - else: - _LONG_INFO = torch.iinfo(torch.long) -except ModuleNotFoundError: - _LONG_INFO = torch.iinfo(torch.long) - -assert _LONG_INFO.min == _MOCK_LONG_INFO.min -assert _LONG_INFO.max == _MOCK_LONG_INFO.max - class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields @@ -126,7 +108,9 @@ class ChatCompletionRequest(OpenAIBaseModel): n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 response_format: Optional[ResponseFormat] = None - seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None @@ -343,7 +327,9 @@ class CompletionRequest(OpenAIBaseModel): max_tokens: Optional[int] = 16 n: int = 1 presence_penalty: Optional[float] = 0.0 - seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 4a2476dd6314d..50aa3ec379f4a 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -3,10 +3,9 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, CompletionRequest) -from vllm.model_executor.guided_decoding.guided_fields import ( - GuidedDecodingRequest) +from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( + get_lm_format_enforcer_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_decoding import ( - get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) from vllm.sampling_params import LogitsProcessor @@ -21,8 +20,6 @@ async def get_guided_decoding_logits_processor( return await get_outlines_guided_decoding_logits_processor( request, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': - from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa - get_lm_format_enforcer_guided_decoding_logits_processor) return await get_lm_format_enforcer_guided_decoding_logits_processor( request, tokenizer) @@ -31,25 +28,6 @@ async def get_guided_decoding_logits_processor( "Must be one of 'outlines, 'lm-format-enforcer'") -def get_local_guided_decoding_logits_processor( - guided_decoding_backend: str, guided_options: GuidedDecodingRequest, - tokenizer) -> Optional[LogitsProcessor]: - # request = _adapt_request_for_tool_use(request) - - if guided_decoding_backend == 'outlines': - return get_local_outlines_guided_decoding_logits_processor( - guided_options, tokenizer) - if guided_decoding_backend == 'lm-format-enforcer': - from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa - get_local_lm_format_enforcer_guided_decoding_logits_processor) - return get_local_lm_format_enforcer_guided_decoding_logits_processor( - guided_options, tokenizer) - - raise ValueError( - f"Unknown guided decoding backend '{guided_decoding_backend}'. " - "Must be one of 'outlines, 'lm-format-enforcer'") - - def _adapt_request_for_tool_use(request: Union[CompletionRequest, ChatCompletionRequest]): # the legacy completion API does not support tool use diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py deleted file mode 100644 index 3082ac1510ccc..0000000000000 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ /dev/null @@ -1,38 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, List, Optional, TypedDict, Union - -from pydantic import BaseModel - - -class LLMGuidedOptions(TypedDict, total=False): - guided_json: Union[Dict, BaseModel, str] - guided_regex: str - guided_choice: List[str] - guided_grammar: str - guided_decoding_backend: str - guided_whitespace_pattern: str - guided_json_object: bool - - -@dataclass -class GuidedDecodingRequest: - """One of the fields will be used to retrieve the logit processor.""" - guided_json: Optional[Union[Dict, BaseModel, str]] = None - guided_regex: Optional[str] = None - guided_choice: Optional[List[str]] = None - guided_grammar: Optional[str] = None - guided_decoding_backend: Optional[str] = None - guided_whitespace_pattern: Optional[str] = None - guided_json_object: Optional[bool] = None - - def __post_init__(self): - """Validate that some fields are mutually exclusive.""" - guide_count = sum([ - self.guided_json is not None, self.guided_regex is not None, - self.guided_choice is not None, self.guided_grammar is not None, - self.guided_json_object is not None - ]) - if guide_count > 1: - raise ValueError( - "You can only use one kind of guided decoding but multiple are " - f"specified: {self.__dict__}") diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index b2188c9cbc2bb..d0a5ca5592f9d 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -12,10 +12,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) -from vllm.model_executor.guided_decoding.guided_fields import ( - GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_decoding import ( - get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) from vllm.sampling_params import LogitsProcessor @@ -57,42 +54,6 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( return logits_processor -def get_local_lm_format_enforcer_guided_decoding_logits_processor( - guided_options: GuidedDecodingRequest, - tokenizer) -> Optional[LogitsProcessor]: - """ - Given an OpenAI-compatible request, check for guided decoding parameters - and get the necessary logits processor for the given guide. - We cache logit processors by (guide, tokenizer), and on cache hit - we make a shallow copy to reuse the same underlying FSM. - """ - - tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( - tokenizer) - character_level_parser: CharacterLevelParser - if guided_options.guided_json: - schema = _normalize_json_schema_object(guided_options.guided_json) - character_level_parser = JsonSchemaParser(schema) - elif guided_options.guided_choice: - character_level_parser = UnionParser( - [StringParser(choice) for choice in guided_options.guided_choice]) - elif guided_options.guided_regex: - character_level_parser = RegexParser(guided_options.guided_regex) - elif guided_options.guided_grammar: - # CFG grammar not supported by LMFE, revert to outlines - return get_local_outlines_guided_decoding_logits_processor( - guided_options, tokenizer) - elif guided_options.guided_json_object: - # None means any json object - character_level_parser = JsonSchemaParser(None) - else: - return None - - logits_processor = build_vllm_logits_processor(tokenizer_data, - character_level_parser) - return logits_processor - - def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: if isinstance(schema, str): return json_loads(schema) diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 2ffa1d38a7a8c..f7216e52e925b 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -10,8 +10,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) -from vllm.model_executor.guided_decoding.guided_fields import ( - GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) @@ -79,27 +77,8 @@ async def get_outlines_guided_decoding_logits_processor( mode, request.guided_whitespace_pattern) -def get_local_outlines_guided_decoding_logits_processor( - guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase -) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, - None]: - """ - Given an OpenAI-compatible request, check for guided decoding parameters - and get the necessary logits processor for the given guide. - We cache logit processors by (guide, tokenizer), and on cache hit - we make a shallow copy to reuse the same underlying FSM. - """ - guide, mode = _get_guide_and_mode(guided_options) - if not guide or not mode: - return None - - return _get_logits_processor(guide, tokenizer, mode, - guided_options.guided_whitespace_pattern) - - def _get_guide_and_mode( - request: Union[CompletionRequest, ChatCompletionRequest, - GuidedDecodingRequest] + request: Union[CompletionRequest, ChatCompletionRequest] ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: if request.guided_json: @@ -123,8 +102,7 @@ def _get_guide_and_mode( return choices_regex, GuidedDecodingMode.CHOICE elif request.guided_grammar: return request.guided_grammar, GuidedDecodingMode.GRAMMAR - elif (not isinstance(request, GuidedDecodingRequest) - and request.response_format is not None + elif (request.response_format is not None and request.response_format.type == "json_object"): return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR else: