diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index dfa1872107..5ec9625577 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -2,7 +2,7 @@ import uuid import json from http import HTTPStatus -from typing import Annotated, AsyncIterator +from typing import Annotated, AsyncIterator, List from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse, StreamingResponse @@ -39,7 +39,8 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse router = APIRouter() - +import logging +logger = logging.getLogger(__name__) def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params = SamplingParams( @@ -58,6 +59,8 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params.temperature = request.temperature if request.top_p is not None: sampling_params.top_p = request.top_p + if request.logprobs is not None: + sampling_params.logprobs = request.logprobs return sampling_params @@ -128,7 +131,7 @@ async def generate_completion_stream( created_time = int(time.time()) def create_stream_response( - choices: list[ChatCompletionResponseStreamChoice], + choices: List[ChatCompletionResponseStreamChoice], ) -> ChatCompletionStreamResponse: return ChatCompletionStreamResponse( id=request_id, @@ -148,7 +151,6 @@ def create_stream_response( ], ) yield f"data: {json.dumps(first_chunk.dict(exclude_unset=True), ensure_ascii=False)}\n\n" - async for res in result_generator: if res.error: raise RuntimeError(f"Error when generating: {res.error}") @@ -164,6 +166,7 @@ def create_stream_response( finish_reason=seq.finish_reason.value if seq.finish_reason is not None else None, + logprob_info=seq.logprob_info[0] if seq.logprob_info else None ) for seq in res.sequences ] @@ -184,6 +187,7 @@ async def collect_result_stream( finish_reasons = [None] * num_sequences num_prompt_tokens = 0 num_generated_tokens = [0 for _ in range(num_sequences)] + logprob_infos = [[] for _ in range(num_sequences)] async for res in result_generator: # TODO: verify that the request cancellation happens after this returns if res.error: @@ -191,6 +195,8 @@ async def collect_result_stream( if res.num_prompt_tokens is not None: num_prompt_tokens = res.num_prompt_tokens for seq in res.sequences: + if seq.logprob_info: + logprob_infos[seq.index].append(seq.logprob_info) if seq.index >= len(sequences): raise RuntimeError(f"Unexpected sequence index: {seq.index}.") num_generated_tokens[seq.index] = seq.num_generated_tokens @@ -198,12 +204,18 @@ async def collect_result_stream( finish_reasons[seq.index] = seq.finish_reason.value else: sequences[seq.index].append(seq.delta) - + breakpoint() choices = [ ChatCompletionResponseChoice( index=index, message=ChatMessage(role="assistant", content="".join(chunks)), finish_reason=finish_reason, + logprobs={ + "token_logprobs": [float(logprob_info[0]) for logprob_info in logprob_infos[index]], + # "tokens": [], + # "offset": [], + "top_logprobs": [logprob_info[1] for logprob_info in logprob_infos[index]] + }, ) for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)) ] diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index abc977a59b..622eb6a093 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -2,7 +2,7 @@ # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # https://github.com/vllm-project/vllm/blob/acbed3ef40f015fcf64460e629813922fab90380/vllm/entrypoints/openai/protocol.py import time -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union, Any from pydantic import BaseModel, Field @@ -70,11 +70,13 @@ class ChatCompletionRequest(BaseModel): logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None ignore_eos: Optional[bool] = False + logprobs: Optional[int] = None class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage + logprobs: Optional[Dict[str, Any]] finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None @@ -95,6 +97,7 @@ class DeltaMessage(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage + logprob_info: Optional[Any] finish_reason: Optional[Literal["stop", "length"]] = None diff --git a/serve/mlc_serve/engine/async_connector.py b/serve/mlc_serve/engine/async_connector.py index afc8068b37..1ec40b60cc 100644 --- a/serve/mlc_serve/engine/async_connector.py +++ b/serve/mlc_serve/engine/async_connector.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import AsyncIterator, Any +from typing import AsyncIterator, Any, Dict from .base import ( InferenceEngine, @@ -27,7 +27,7 @@ def __init__(self, engine: InferenceEngine, engine_wait_timeout=1): self.engine_loop_task = None self.engine_loop_exception = None self.shutdown_event = asyncio.Event() - self.result_queues = dict[RequestId, ResultQueue]() + self.result_queues: Dict[RequestId, ResultQueue] = {} async def start(self): """ diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index b6e5533886..761821eea9 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -29,6 +29,15 @@ logger = logging.getLogger(__name__) +def logprob_detok(tokenizer, logprob_info): + if logprob_info is None: + return None + return ( + logprob_info[0], { + tokenizer.decode(top_token): float(logprob) for top_token, logprob in logprob_info[1] + } + ) + class StagingInferenceEngine(ScopedInferenceEngine): """ An implementation of InferenceEngine that offloads the text generation loop to another worker process, @@ -200,7 +209,7 @@ def step(self) -> InferenceStepResult: len(state.token_ids) - state.prompt_len ), finish_reason=seq_output.finish_reason, - logprob_info=seq_output.logprob_info, + logprob_info=logprob_detok(self.tokenizer, seq_output.logprob_info), ), ], num_prompt_tokens=state.prompt_len, diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index 4c318bf3bc..828e3c74c3 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -7,7 +7,7 @@ from collections import deque from dataclasses import dataclass from threading import Condition, Lock, Thread -from typing import Callable, Optional, Union, Any, Tuple, List +from typing import Callable, Optional, Union, Any, Tuple, List, Deque, Dict import numpy as np from .base import FinishReason, RequestId, RequestState @@ -79,15 +79,15 @@ def __init__( assert self.prompt_allocate_ratio >= 1.0 self.queue_lock = Lock() - self.queue = deque[RequestState]() + self.queue: Deque[RequestState] = deque() self.has_new_requests = Condition(lock=self.queue_lock) - self.cancelled_requests = list[RequestState]() - self.stopped_requests = list[RequestState]() + self.cancelled_requests: List[RequestState] = [] + self.stopped_requests: List[RequestState] = [] - self.current_batch = dict[RequestId, RequestState]() + self.current_batch: Dict[RequestId, RequestState] = {} - def add(self, request_states: list[RequestState]): + def add(self, request_states: List[RequestState]): with self.queue_lock: # States which have been invalidated should never be added, directly # cancel them instead. @@ -140,7 +140,7 @@ def has_pending_requests(self) -> bool: def step(self) -> GenerationLoopWorkerOutput: logger.debug("Starting new inference step.") - outputs = list[SequenceGenerationOutput]() + outputs: List[SequenceGenerationOutput] = [] result = GenerationLoopWorkerOutput(sequences=outputs) # TODO: consolidate into a single function @@ -215,7 +215,6 @@ def step(self) -> GenerationLoopWorkerOutput: id=res.sequence_id, new_tokens=[], error=res.error, - logprob_info=res.logprob_info, ) ) continue diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index b652278123..d52a1e8fd8 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -650,12 +650,12 @@ def generate( next_tokens, logprob_info = sample(logits, sampling_params, self.vocab_size) return [ TextGenerationResult( - sequence_id=zipped[0], - generated_tokens=[zipped[1]], + sequence_id=sequence_id, + generated_tokens=[next_token], error=None, logprob_info=fetch_logprobs(logprob_info, idx, sampling_params[idx]), ) - for idx, zipped in enumerate(zip(sequence_ids, next_tokens)) + for idx, (sequence_id, next_token) in enumerate(zip(sequence_ids, next_tokens)) ] except RuntimeError: # Fallback to per-token sampling in case some logits values are corrupted. @@ -690,7 +690,6 @@ def generate( logprob_info=fetch_logprobs(logprob_info, idx, sampling_param) ) ) - return outputs @@ -721,8 +720,8 @@ def __init__(self, model: Model): self.model = model def generate( - self, requests: list[Union[PrefillRequest, DecodeRequest]], kv_cache - ) -> list[TextGenerationResult]: + self, requests: List[Union[PrefillRequest, DecodeRequest]], kv_cache + ) -> List[TextGenerationResult]: prefill_requests = [r for r in requests if isinstance(r, PrefillRequest)] decode_requests = [r for r in requests if isinstance(r, DecodeRequest)]