From a1dfbbd0d67f38982356b0f3f2046cceb11ef475 Mon Sep 17 00:00:00 2001 From: Zach Blank Date: Thu, 20 Jul 2023 16:58:04 +0000 Subject: [PATCH 1/3] add logits processors to enable logit_bias in OpenAI server --- vllm/entrypoints/openai/api_server.py | 26 ++++++++++-------- vllm/logits_processors.py | 38 +++++++++++++++++++++++++++ vllm/model_executor/layers/sampler.py | 16 +++++++++++ vllm/sampling_params.py | 5 ++++ 4 files changed, 74 insertions(+), 11 deletions(-) create mode 100644 vllm/logits_processors.py diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 373c4812264a0..db04b85b1eab2 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -32,6 +32,7 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import random_uuid +from vllm.logits_processors import BiasLogitsProcessor TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -170,7 +171,6 @@ async def create_chat_completion(raw_request: Request): NOTE: Currently we do not support the following features: - function_call (Users should implement this by themselves) - - logit_bias (to be supported by vLLM engine) """ request = ChatCompletionRequest(**await raw_request.json()) logger.info(f"Received chat completion request: {request}") @@ -179,16 +179,17 @@ async def create_chat_completion(raw_request: Request): if error_check_ret is not None: return error_check_ret - if request.logit_bias is not None: - # TODO: support logit_bias in vLLM engine. - return create_error_response(HTTPStatus.BAD_REQUEST, - "logit_bias is not currently supported") - prompt = await get_gen_prompt(request) error_check_ret = await check_length(request, prompt, engine_model_config) if error_check_ret is not None: return error_check_ret + if not request.logit_bias: + logit_processors = [] + else: + biases = dict(map(lambda bias: (int(bias[0]), bias[1]), request.logit_bias.items())) + logit_processors = [BiasLogitsProcessor(biases)] + model_name = request.model request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) @@ -205,6 +206,7 @@ async def create_chat_completion(raw_request: Request): top_k=request.top_k, ignore_eos=request.ignore_eos, use_beam_search=request.use_beam_search, + logit_processors=logit_processors ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) @@ -342,7 +344,6 @@ async def create_completion(raw_request: Request): getting the logprobs of prompt tokens) - suffix (the language models we currently support do not support suffix) - - logit_bias (to be supported by vLLM engine) """ request = CompletionRequest(**await raw_request.json()) logger.info(f"Received completion request: {request}") @@ -362,10 +363,11 @@ async def create_completion(raw_request: Request): return create_error_response(HTTPStatus.BAD_REQUEST, "suffix is not currently supported") - if request.logit_bias is not None: - # TODO: support logit_bias in vLLM engine. - return create_error_response(HTTPStatus.BAD_REQUEST, - "logit_bias is not currently supported") + if not request.logit_bias: + logit_processors = [] + else: + logit_bias = dict(map(lambda logit: (int(logit[0]), logit[1]), request.logit_bias.items())) + logit_processors = [BiasLogitsProcessor(logit_bias)] model_name = request.model request_id = f"cmpl-{random_uuid()}" @@ -381,6 +383,7 @@ async def create_completion(raw_request: Request): else: prompt = request.prompt created_time = int(time.time()) + try: sampling_params = SamplingParams( n=request.n, @@ -395,6 +398,7 @@ async def create_completion(raw_request: Request): max_tokens=request.max_tokens, logprobs=request.logprobs, use_beam_search=request.use_beam_search, + logits_processors=logit_processors ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) diff --git a/vllm/logits_processors.py b/vllm/logits_processors.py new file mode 100644 index 0000000000000..262d4c01814c2 --- /dev/null +++ b/vllm/logits_processors.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod +import torch +from typing import Dict + +class LogitsProcessor(ABC): + @abstractmethod + def __call__(self, logits: torch.tensor) -> torch.tensor: + pass + +class BiasLogitsProcessor(LogitsProcessor): + """This is to enable logit_bias in the OpenAI server. + + biases is a dict where each value is -100 to 100 according to the OpenAI API docs. + + Args: + biases: Dict ov values from -100 to 100 to scale the probability of a token being generated. + Each key of the dict coresponds to the the token id. + """ + def __init__(self, biases: Dict[int, float]): + self.biases = biases + + if not biases: + return + + self.keys = torch.tensor(list(self.biases.keys()), dtype=torch.long) + self.values = torch.tensor(list(self.biases.values()), dtype=torch.long) + + def __call__(self, logits): + if not self.biases: + return logits + + values = self.values.to(logits.device) + keys = self.keys.to(logits.device) + + update_factors = torch.where(values >= 0, 1 + (values / 100), 1 / (1 - (values / 100))) + logits[0, keys] *= update_factors + + return logits \ No newline at end of file diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index b586c98bd13a9..b0009e48e17a1 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -45,6 +45,10 @@ def forward( # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) + + # Apply and user defined logits processors. + logits = _apply_logits_processors(input_metadata, logits) + if embedding_bias is not None: logits += embedding_bias logits = gather_from_tensor_model_parallel_region(logits) @@ -141,6 +145,18 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: output_tokens.append(seq_data.output_token_ids) return output_tokens +def _apply_logits_processors( + input_metadata: InputMetadata, + logits: torch.Tensor, +) -> torch.Tensor: + for _, seq_group in enumerate(input_metadata.seq_groups): + _, sampling_params = seq_group + logits_processors = sampling_params.logits_processors + + for logits_processor in logits_processors: + logits = logits_processor(logits) + + return logits def _apply_penalties( logits: torch.Tensor, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 6e51ad19e4e0e..cdf12c6066a41 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,5 +1,6 @@ """Sampling parameters for text generation.""" from typing import List, Optional, Union +from vllm.logits_processors import LogitsProcessor _SAMPLING_EPS = 1e-5 @@ -40,6 +41,8 @@ class SamplingParams: tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. logprobs: Number of log probabilities to return per output token. + logits_processors: List of LogitsProcessors to change the probability + of token prediction at runtime. """ def __init__( @@ -56,6 +59,7 @@ def __init__( ignore_eos: bool = False, max_tokens: int = 16, logprobs: Optional[int] = None, + logits_processors: List[LogitsProcessor] = [] ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -74,6 +78,7 @@ def __init__( self.ignore_eos = ignore_eos self.max_tokens = max_tokens self.logprobs = logprobs + self.logits_processors = logits_processors self._verify_args() if self.use_beam_search: From 9234de5dfa45f04e28bee11da1def80e1a432acb Mon Sep 17 00:00:00 2001 From: Zach Blank Date: Fri, 21 Jul 2023 15:03:47 +0000 Subject: [PATCH 2/3] forgot to run format.sh --- vllm/entrypoints/openai/api_server.py | 14 ++++--- vllm/logits_processors.py | 54 +++++++++++++++------------ vllm/model_executor/layers/sampler.py | 9 +++-- vllm/sampling_params.py | 32 ++++++++-------- 4 files changed, 60 insertions(+), 49 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index db04b85b1eab2..aa11571c59833 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -187,7 +187,9 @@ async def create_chat_completion(raw_request: Request): if not request.logit_bias: logit_processors = [] else: - biases = dict(map(lambda bias: (int(bias[0]), bias[1]), request.logit_bias.items())) + biases = dict( + map(lambda bias: (int(bias[0]), bias[1]), + request.logit_bias.items())) logit_processors = [BiasLogitsProcessor(biases)] model_name = request.model @@ -206,8 +208,7 @@ async def create_chat_completion(raw_request: Request): top_k=request.top_k, ignore_eos=request.ignore_eos, use_beam_search=request.use_beam_search, - logit_processors=logit_processors - ) + logits_processors=logit_processors) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) @@ -366,7 +367,9 @@ async def create_completion(raw_request: Request): if not request.logit_bias: logit_processors = [] else: - logit_bias = dict(map(lambda logit: (int(logit[0]), logit[1]), request.logit_bias.items())) + logit_bias = dict( + map(lambda logit: (int(logit[0]), logit[1]), + request.logit_bias.items())) logit_processors = [BiasLogitsProcessor(logit_bias)] model_name = request.model @@ -398,8 +401,7 @@ async def create_completion(raw_request: Request): max_tokens=request.max_tokens, logprobs=request.logprobs, use_beam_search=request.use_beam_search, - logits_processors=logit_processors - ) + logits_processors=logit_processors) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) diff --git a/vllm/logits_processors.py b/vllm/logits_processors.py index 262d4c01814c2..cfae2ad6ccb62 100644 --- a/vllm/logits_processors.py +++ b/vllm/logits_processors.py @@ -2,37 +2,45 @@ import torch from typing import Dict + class LogitsProcessor(ABC): - @abstractmethod - def __call__(self, logits: torch.tensor) -> torch.tensor: - pass + + @abstractmethod + def __call__(self, logits: torch.tensor) -> torch.tensor: + pass + class BiasLogitsProcessor(LogitsProcessor): - """This is to enable logit_bias in the OpenAI server. + """This is to enable logit_bias in the OpenAI server. + + biases is a dict where each value is -100 to 100 + according to the OpenAI API docs. - biases is a dict where each value is -100 to 100 according to the OpenAI API docs. + Args: + biases: Dict ov values from -100 to 100 to scale the + probability of a token being generated. + Each key of the dict coresponds to the the token id. + """ - Args: - biases: Dict ov values from -100 to 100 to scale the probability of a token being generated. - Each key of the dict coresponds to the the token id. - """ - def __init__(self, biases: Dict[int, float]): - self.biases = biases + def __init__(self, biases: Dict[int, float]): + self.biases = biases - if not biases: - return + if not biases: + return - self.keys = torch.tensor(list(self.biases.keys()), dtype=torch.long) - self.values = torch.tensor(list(self.biases.values()), dtype=torch.long) + self.keys = torch.tensor(list(self.biases.keys()), dtype=torch.long) + self.values = torch.tensor(list(self.biases.values()), + dtype=torch.long) - def __call__(self, logits): - if not self.biases: - return logits + def __call__(self, logits): + if not self.biases: + return logits - values = self.values.to(logits.device) - keys = self.keys.to(logits.device) + values = self.values.to(logits.device) + keys = self.keys.to(logits.device) - update_factors = torch.where(values >= 0, 1 + (values / 100), 1 / (1 - (values / 100))) - logits[0, keys] *= update_factors + update_factors = torch.where(values >= 0, 1 + (values / 100), + 1 / (1 - (values / 100))) + logits[0, keys] *= update_factors - return logits \ No newline at end of file + return logits diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index b0009e48e17a1..12f90022b00f0 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -145,6 +145,7 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: output_tokens.append(seq_data.output_token_ids) return output_tokens + def _apply_logits_processors( input_metadata: InputMetadata, logits: torch.Tensor, @@ -153,11 +154,13 @@ def _apply_logits_processors( _, sampling_params = seq_group logits_processors = sampling_params.logits_processors - for logits_processor in logits_processors: - logits = logits_processor(logits) - + if logits_processors is not None: + for logits_processor in logits_processors: + logits = logits_processor(logits) + return logits + def _apply_penalties( logits: torch.Tensor, output_tokens: List[List[int]], diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index cdf12c6066a41..b6516907d31b6 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -41,26 +41,24 @@ class SamplingParams: tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. logprobs: Number of log probabilities to return per output token. - logits_processors: List of LogitsProcessors to change the probability + logits_processors: List of LogitsProcessors to change the probability of token prediction at runtime. """ - def __init__( - self, - n: int = 1, - best_of: Optional[int] = None, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = -1, - use_beam_search: bool = False, - stop: Union[None, str, List[str]] = None, - ignore_eos: bool = False, - max_tokens: int = 16, - logprobs: Optional[int] = None, - logits_processors: List[LogitsProcessor] = [] - ) -> None: + def __init__(self, + n: int = 1, + best_of: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + use_beam_search: bool = False, + stop: Union[None, str, List[str]] = None, + ignore_eos: bool = False, + max_tokens: int = 16, + logprobs: Optional[int] = None, + logits_processors: List[LogitsProcessor] = None) -> None: self.n = n self.best_of = best_of if best_of is not None else n self.presence_penalty = presence_penalty From 10a18cd433858a5bb91ecd79d66ba819e4993dbb Mon Sep 17 00:00:00 2001 From: Zach Blank Date: Tue, 25 Jul 2023 16:02:58 +0000 Subject: [PATCH 3/3] add inoput ids to the logit processors --- vllm/model_executor/layers/sampler.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 12f90022b00f0..3cea6f02c0daa 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -46,9 +46,6 @@ def forward( # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) - # Apply and user defined logits processors. - logits = _apply_logits_processors(input_metadata, logits) - if embedding_bias is not None: logits += embedding_bias logits = gather_from_tensor_model_parallel_region(logits) @@ -65,6 +62,9 @@ def forward( logits = _apply_penalties(logits, output_tokens, presence_penalties, frequency_penalties, self.vocab_size) + # Apply and user defined logits processors. + logits = _apply_logits_processors(input_metadata, logits, output_tokens) + # Apply temperature scaling. temperatures = _get_temperatures(input_metadata) assert len(temperatures) == logits.shape[0] @@ -149,6 +149,7 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: def _apply_logits_processors( input_metadata: InputMetadata, logits: torch.Tensor, + output_tokens: List[List[int]] ) -> torch.Tensor: for _, seq_group in enumerate(input_metadata.seq_groups): _, sampling_params = seq_group @@ -156,7 +157,7 @@ def _apply_logits_processors( if logits_processors is not None: for logits_processor in logits_processors: - logits = logits_processor(logits) + logits = logits_processor(logits, output_tokens) return logits