diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index dfa1872107..63deae4a94 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -2,7 +2,8 @@ import uuid import json from http import HTTPStatus -from typing import Annotated, AsyncIterator +from typing import Annotated, AsyncIterator, List +from itertools import accumulate from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse, StreamingResponse @@ -40,7 +41,6 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse router = APIRouter() - def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params = SamplingParams( # These params came from vllm @@ -58,6 +58,8 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params.temperature = request.temperature if request.top_p is not None: sampling_params.top_p = request.top_p + if request.logprobs is not None: + sampling_params.logprobs = request.logprobs return sampling_params @@ -128,7 +130,7 @@ async def generate_completion_stream( created_time = int(time.time()) def create_stream_response( - choices: list[ChatCompletionResponseStreamChoice], + choices: List[ChatCompletionResponseStreamChoice], ) -> ChatCompletionStreamResponse: return ChatCompletionStreamResponse( id=request_id, @@ -148,7 +150,6 @@ def create_stream_response( ], ) yield f"data: {json.dumps(first_chunk.dict(exclude_unset=True), ensure_ascii=False)}\n\n" - async for res in result_generator: if res.error: raise RuntimeError(f"Error when generating: {res.error}") @@ -164,6 +165,7 @@ def create_stream_response( finish_reason=seq.finish_reason.value if seq.finish_reason is not None else None, + logprob_info=seq.logprob_info[0] if seq.logprob_info else None ) for seq in res.sequences ] @@ -184,6 +186,7 @@ async def collect_result_stream( finish_reasons = [None] * num_sequences num_prompt_tokens = 0 num_generated_tokens = [0 for _ in range(num_sequences)] + logprob_infos = [[] for _ in range(num_sequences)] async for res in result_generator: # TODO: verify that the request cancellation happens after this returns if res.error: @@ -191,6 +194,8 @@ async def collect_result_stream( if res.num_prompt_tokens is not None: num_prompt_tokens = res.num_prompt_tokens for seq in res.sequences: + if seq.logprob_info: + logprob_infos[seq.index].append(seq.logprob_info) if seq.index >= len(sequences): raise RuntimeError(f"Unexpected sequence index: {seq.index}.") num_generated_tokens[seq.index] = seq.num_generated_tokens @@ -198,15 +203,22 @@ async def collect_result_stream( finish_reasons[seq.index] = seq.finish_reason.value else: sequences[seq.index].append(seq.delta) - - choices = [ - ChatCompletionResponseChoice( + + choices = [] + for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)): + choice = ChatCompletionResponseChoice( index=index, message=ChatMessage(role="assistant", content="".join(chunks)), finish_reason=finish_reason, ) - for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)) - ] + if logprob_infos[index] != []: + choice.logprobs={ + "token_logprobs": [float(logprob_info[0][1]) for logprob_info in logprob_infos[index]], + "tokens": [str(logprob_info[0][0]) for logprob_info in logprob_infos[index]], + "offset": list(accumulate([len(str(logprob_info[0][0])) for logprob_info in logprob_infos[index]])), + "top_logprobs": [logprob_info[1] for logprob_info in logprob_infos[index]] + } + choices.append(choice) usage = UsageInfo( prompt_tokens=num_prompt_tokens, diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index abc977a59b..a56ff4dc16 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -2,7 +2,7 @@ # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # https://github.com/vllm-project/vllm/blob/acbed3ef40f015fcf64460e629813922fab90380/vllm/entrypoints/openai/protocol.py import time -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union, Tuple from pydantic import BaseModel, Field @@ -70,11 +70,13 @@ class ChatCompletionRequest(BaseModel): logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None ignore_eos: Optional[bool] = False + logprobs: Optional[int] = None class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage + logprobs: Optional[Dict[str, Union[List, Dict]]] finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None @@ -95,6 +97,7 @@ class DeltaMessage(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage + logprob_info: Optional[Tuple[Tuple, List[Tuple]]] finish_reason: Optional[Literal["stop", "length"]] = None diff --git a/serve/mlc_serve/engine/async_connector.py b/serve/mlc_serve/engine/async_connector.py index c7d5d3d7b0..1bf261be10 100644 --- a/serve/mlc_serve/engine/async_connector.py +++ b/serve/mlc_serve/engine/async_connector.py @@ -1,6 +1,7 @@ import asyncio import structlog -from typing import AsyncIterator, Any +from typing import AsyncIterator, Any, Dict +import logging from .base import ( InferenceEngine, @@ -29,7 +30,7 @@ def __init__(self, engine: InferenceEngine, engine_wait_timeout=1): self.engine_loop_task = None self.engine_loop_exception = None self.shutdown_event = asyncio.Event() - self.result_queues = dict[RequestId, ResultQueue]() + self.result_queues: Dict[RequestId, ResultQueue] = {} async def start(self): """ diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 43a9f5244e..4694071ecb 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -4,7 +4,7 @@ from enum import Enum from abc import ABC, abstractmethod -from typing import List, Callable, Any, Optional, Dict +from typing import List, Callable, Any, Optional, Dict, Tuple import inspect from .sampling_params import SamplingParams, SamplingType @@ -150,6 +150,7 @@ class SequenceOutput: finish_reason: Optional[FinishReason] = None # Number of generated tokens so far num_generated_tokens: int = 0 + logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None @property def is_finished(self) -> bool: diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 5666023409..3f39cb7c12 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -2,13 +2,18 @@ Required interfaces for the actual inference capability in InferenceEngine. """ from dataclasses import dataclass -from typing import Optional, Protocol, Union +from typing import Optional, Protocol, Union, Tuple, List + +import numpy as np from .base import ChatMessage, RequestId, MLCServeEngineConfig from ..model.base import ModelArtifactConfig from .sampling_params import SamplingParams +LOGPROBS_TYPE = Tuple[Tuple, List[Tuple]] +# ((token, logprob), [(top1_token, top1_logprob), ...]) + @dataclass class SequenceId: """ @@ -56,6 +61,7 @@ class TextGenerationResult: # making this a list of token ids to leave room for speculative decoding generated_tokens: list[int] error: Optional[str] + logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None class KVCache(Protocol): diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index fbe153283d..d1f85b4d08 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -7,6 +7,7 @@ from enum import IntEnum from functools import cached_property +from typing import Optional _SAMPLING_EPS = 1e-5 @@ -37,6 +38,9 @@ class SamplingParams: to consider. Must be in (0, 1]. Set to 1 to consider all tokens. top_k: Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens. + logprobs: Optional[Integer] that determines number of log probabilities + to return per sampled tokens, default to None meaning disabled, + otherwise minimum 0, maximum 5. """ presence_penalty: float = 0.0 @@ -44,6 +48,7 @@ class SamplingParams: temperature: float = 1.0 top_p: float = 1.0 top_k: int = -1 + logprobs: Optional[int] = None def __post_init__(self): self._verify_args() @@ -71,6 +76,10 @@ def _verify_args(self) -> None: raise ValueError( f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." ) + if self.logprobs is not None and (self.logprobs < 0 or self.logprobs > 5): + raise ValueError( + f"logprobs must be between 0 and 5, got {self.logprobs}." + ) def _verify_greedy_sampling(self) -> None: if self.top_p < 1.0 - _SAMPLING_EPS: diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index 789be80ad8..7ad8c12112 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -5,7 +5,7 @@ import multiprocessing import queue from threading import Lock -from typing import Callable, Optional +from typing import Callable, Tuple, List import os @@ -21,7 +21,7 @@ SequenceOutput, check_stopping_sequences, ) -from .model_module import ModelModule, TokenizerModule +from .model_module import ModelModule, TokenizerModule, Tokenizer from .staging_engine_worker import ( AddRequestsCommand, CancelRequestCommand, @@ -35,6 +35,30 @@ LOG = structlog.stdlib.get_logger(__name__) +def logprob_detokenize(tokenizer: Tokenizer, logprob_info: Tuple[Tuple, List[Tuple]]) -> Tuple[Tuple, List[Tuple]]: + """Detokenize logprob information""" + if logprob_info is None: + return None + (res, res_logprob), top_tokens = logprob_info + top_tokens = list(top_tokens) + count = {} + logprob_dict = {} + # dedup duplicates + # Todo: Make sure decode can generate different tokens + for top_token, _ in top_tokens: + detokenized = tokenizer.decode(top_token) + if detokenized in count: + count[detokenized] += 1 + else: + count[detokenized] = 1 + for top_token, top_logprob in top_tokens: + detokenized = tokenizer.decode(top_token) + if count[detokenized] == 1: + logprob_dict[detokenized] = float(top_logprob) + else: + logprob_dict[f"{detokenized}_{top_token}"] = float(top_logprob) + return (str(tokenizer.decode(res)), res_logprob), logprob_dict + class StagingInferenceEngine(ScopedInferenceEngine): """ An implementation of InferenceEngine that offloads the text generation loop to another worker process, @@ -223,6 +247,7 @@ def step(self) -> InferenceStepResult: len(state.token_ids) - state.prompt_len ), finish_reason=seq_output.finish_reason, + logprob_info=logprob_detokenize(self.tokenizer, seq_output.logprob_info), ), ], num_prompt_tokens=state.prompt_len, diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index e9d8c36966..c8276e7d63 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -6,9 +6,9 @@ from collections import deque from dataclasses import dataclass from threading import Condition, Lock, Thread -from typing import Callable, Optional, Union, Any, Dict, Deque, List - +from typing import Callable, Optional, Union, Tuple, Any, Dict, Deque, List import structlog +import numpy as np from .base import FinishReason, RequestId, RequestState from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId, TextGenerator, Tokenizer as TokenizerP @@ -24,7 +24,7 @@ class ShutdownCommand: @dataclass class AddRequestsCommand: - request_states: list[RequestState] + request_states: List[RequestState] @dataclass @@ -45,14 +45,15 @@ class StopRequestCommand: @dataclass class SequenceGenerationOutput: id: SequenceId - new_tokens: list[int] + new_tokens: List[int] finish_reason: Optional[FinishReason] = None error: Optional[str] = None + logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None @dataclass class GenerationLoopWorkerOutput: - sequences: list[SequenceGenerationOutput] + sequences: List[SequenceGenerationOutput] error: Optional[str] = None @@ -96,13 +97,13 @@ def __init__( assert self.prompt_allocate_ratio >= 1.0 self.queue_lock = Lock() - self.queue = deque[RequestState]() + self.queue: Deque[RequestState] = deque() self.has_new_requests = Condition(lock=self.queue_lock) - self.cancelled_requests = list[RequestState]() - self.stopped_requests = list[RequestState]() + self.cancelled_requests: List[RequestState] = [] + self.stopped_requests: List[RequestState] = [] - self.current_batch = dict[RequestId, RequestState]() + self.current_batch: Dict[RequestId, RequestState] = {} def add(self, request_states: list[RequestState]): LOG.debug("GenerationLoopWorker", requests_states=request_states) @@ -158,7 +159,7 @@ def has_pending_requests(self) -> bool: def step(self) -> GenerationLoopWorkerOutput: LOG.debug("Starting new inference step.") - outputs = list[SequenceGenerationOutput]() + outputs: List[SequenceGenerationOutput] = [] result = GenerationLoopWorkerOutput(sequences=outputs) # TODO: consolidate into a single function @@ -253,7 +254,7 @@ def step(self) -> GenerationLoopWorkerOutput: state.token_ids.extend(new_tokens) outputs.append( - SequenceGenerationOutput(id=res.sequence_id, new_tokens=new_tokens) + SequenceGenerationOutput(id=res.sequence_id, new_tokens=new_tokens, logprob_info=res.logprob_info) ) LOG.debug("Finished state update and stopping criteria check.") diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index 15a184f113..c472df1797 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -225,6 +225,7 @@ def step(self) -> InferenceStepResult: num_generated_tokens=( len(state.token_ids) - state.prompt_len ), + logprob_info=res.logprob_info ), ], num_prompt_tokens=state.prompt_len, diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index a3a2a6ac7f..192c75bd62 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -2,9 +2,8 @@ import math import os from collections import defaultdict -from typing import List, Union, Optional +from typing import List, Union, Optional, Tuple from dataclasses import dataclass -import inspect import structlog import numpy as np @@ -23,6 +22,7 @@ PrefillRequest, SequenceId, TextGenerationResult, + LOGPROBS_TYPE ) from ..engine.model_module import ModelModule @@ -249,7 +249,7 @@ def sample( sampling_params: List[SamplingParams], vocab_size: int, check_safety=False, -) -> Optional[np.ndarray]: +) -> Optional[Tuple[np.ndarray, Optional[LOGPROBS_TYPE]]]: def _is_safe_to_sample(prob_like): return ( torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) @@ -268,10 +268,19 @@ def _is_safe_to_sample(prob_like): logits_greedy = logits[mask_greedy] if logits_greedy.shape[0] > 0: - res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy() - + # Greedy sampling + logprobs = torch.log(torch.softmax(logits_greedy, dim=-1)) + res_greedy_logprob, res_greedy = torch.max(logprobs, dim=-1) + + top_greedy_logprob, top_greedy = torch.topk(logprobs, k=5, dim=-1, largest=True, sorted=True) + # Convert to numpy + res_greedy_logprob = res_greedy_logprob.cpu().numpy() + res_greedy = res_greedy.cpu().numpy() + top_greedy_logprob = top_greedy_logprob.cpu().numpy() + top_greedy = top_greedy.cpu().numpy() + # Case when there's only greedy sampling if logits_greedy.shape[0] == num_seq: - return res_greedy + return res_greedy, ((res_greedy, res_greedy_logprob), (top_greedy, top_greedy_logprob)) temperatures = [] top_ps = [] @@ -302,22 +311,38 @@ def _is_safe_to_sample(prob_like): logits = _apply_top_p_top_k(logits_random, top_ps, top_ks) probs = torch.softmax(logits_random, dim=-1) + logprobs = torch.log(torch.softmax(logits_greedy, dim=-1)) + top_random_logprob, top_random = torch.topk(logprobs, k=5, dim=-1, largest=True, sorted=True) + top_random_logprob = top_random_logprob.cpu().numpy() + top_random = top_random.cpu().numpy() if check_safety and not _is_safe_to_sample(probs): return None res_random = torch.multinomial(probs, 1, True).cpu().numpy()[:, 0] + res_random_logprobs = torch.gather(logprobs, dim=-1, index=torch.tensor(res_random, dtype=torch.int64, device=logits.device)).cpu().numpy() if logits_random.shape[0] == num_seq: - return res_random + return res_random, (res_random_logprobs, (top_random, top_random_logprob)) res = np.empty((num_seq,), dtype=np.int32) + res_logprobs = np.empty((num_seq,), dtype=np.float32) + top = np.empty((num_seq, 5), dtype=np.int32) + top_logprobs = np.empty((num_seq, 5), dtype=np.float32) + res[mask_random] = res_random + res_logprobs[mask_random] = res_random_logprobs + top[mask_random] = top_random + top_logprobs[mask_random] = top_random_logprob + if logits_greedy.shape[0] > 0: res[mask_greedy] = res_greedy + res_logprobs[mask_greedy] = res_greedy_logprob + top[mask_greedy] = top_greedy + top_logprobs[mask_greedy] = top_greedy_logprob - return res + return res, ((res, res_logprobs), (top, top_logprobs)) def load_disco_module(artifact_path, lib_path, num_shards): @@ -356,6 +381,22 @@ def get_tvm_model(config, dev): return load_disco_module(config.model_artifact_path, lib_path, config.num_shards) +def fetch_logprobs( + logprob_info: LOGPROBS_TYPE, + index: int, + sampling_param: SamplingParams, + ) -> Optional[Tuple[np.ndarray, List[Tuple[np.ndarray, np.ndarray]]]]: + """Fetch the logprob information with index""" + if sampling_param.logprobs is None or logprob_info is None: + return None + (res, res_logprobs), (top, top_logprobs) = logprob_info + return (res[index],res_logprobs[index]), \ + zip( + top[index][:sampling_param.logprobs], + top_logprobs[index][:sampling_param.logprobs] + ) + + def _prepare_inputs( sequence_ids, all_token_ids, @@ -613,15 +654,15 @@ def generate( torch.cuda.nvtx.range_pop() try: - next_tokens = sample(logits, sampling_params, self.vocab_size) - + next_tokens, logprob_info = sample(logits, sampling_params, self.vocab_size) return [ TextGenerationResult( sequence_id=sequence_id, generated_tokens=[new_token], error=None, + logprob_info=fetch_logprobs(logprob_info, index, sampling_params[index]), ) - for sequence_id, new_token in zip(sequence_ids, next_tokens) + for index, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)) ] except RuntimeError: # Fallback to per-token sampling in case some logits values are corrupted. @@ -631,10 +672,10 @@ def generate( " or element < 0" ) - for sequence_id, logits_per_token, sampling_param in zip( - sequence_ids, torch.from_dlpack(logits), sampling_params + for index, sequence_id, logits_per_token, sampling_param in enumerate( + zip(sequence_ids, torch.from_dlpack(logits), sampling_params) ): - maybe_new_token = sample( + maybe_new_token, logprob_info = sample( torch.unsqueeze(logits_per_token, 0), [sampling_param], self.vocab_size, @@ -647,6 +688,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[maybe_new_token[0]], error=None, + logprob_info=fetch_logprobs(logprob_info, index, sampling_param) ) ) else: @@ -655,9 +697,9 @@ def generate( sequence_id=sequence_id, generated_tokens=[], error=err_msg, + logprob_info=fetch_logprobs(logprob_info, index, sampling_param) ) ) - return outputs @@ -688,8 +730,8 @@ def __init__(self, model: Model): self.model = model def generate( - self, requests: list[Union[PrefillRequest, DecodeRequest]], kv_cache - ) -> list[TextGenerationResult]: + self, requests: List[Union[PrefillRequest, DecodeRequest]], kv_cache + ) -> List[TextGenerationResult]: prefill_requests = [r for r in requests if isinstance(r, PrefillRequest)] decode_requests = [r for r in requests if isinstance(r, DecodeRequest)] diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index af10398fc2..f230301bc0 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -52,13 +52,14 @@ def create_engine( )) return engine -def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos): +def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos, logprobs): return Request( request_id = str(idx), messages = [ChatMessage(role="user", content=prompt)], sampling_params = SamplingParams( temperature=0.0, - ), + logprobs=logprobs + ), stopping_criteria = StoppingCriteria( max_tokens=max_tokens, stop_sequences=stop @@ -219,6 +220,43 @@ def test_stop( if use_staging_engine: engine.stop() +def test_logprobs( + model_artifact_path, + use_staging_engine, + max_num_sequences=4, + max_input_len=512, + num_requests=5, + logprobs=3, +): + prompt = "hi" + engine = create_engine( + model_artifact_path, + use_staging_engine, + max_num_sequences, + max_input_len, + ) + s = 113 + requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True, logprobs=logprobs) for n in range(s, s+num_requests)] + engine.add(requests) + + generated = ["" for _ in range(num_requests)] + + while engine.has_pending_requests(): + results = engine.step() + for res in results.outputs: + assert len(res.sequences) == 1 + seq = res.sequences[0] + + assert seq.finish_reason is not None or len(list(seq.logprob_info[1])) == logprobs + + if seq.is_finished: + assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens + assert seq.finish_reason == FinishReason.Length + else: + generated[int(res.request_id)] += seq.delta + + if use_staging_engine: + engine.stop() if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -235,3 +273,5 @@ def test_stop( test_stop(model_artifact_path, use_staging_engine=True) test_max_context_length(model_artifact_path, use_staging_engine=True) test_max_context_length(model_artifact_path, use_staging_engine=False) + test_logprobs(model_artifact_path, use_staging_engine=True) + test_logprobs(model_artifact_path, use_staging_engine=False)