diff --git a/serve/benchmarks/benchmark_latency.py b/serve/benchmarks/benchmark_latency.py index c64f968640..84c377b710 100644 --- a/serve/benchmarks/benchmark_latency.py +++ b/serve/benchmarks/benchmark_latency.py @@ -34,6 +34,8 @@ def create_request(request_id): frequency_penalty=args.sampling_setting["frequency_penalty"], presence_penalty=args.sampling_setting["presence_penalty"], logit_bias=args.sampling_setting["logit_bias"], + logprobs = args.sampling_setting["logprobs"], + top_logprobs = args.sampling_setting["top_logprobs"], ), stopping_criteria=StoppingCriteria( max_tokens=args.num_output_tokens, stop_sequences=None diff --git a/serve/benchmarks/benchmark_throughput.py b/serve/benchmarks/benchmark_throughput.py index 1a7fb46451..4f82d80dd7 100644 --- a/serve/benchmarks/benchmark_throughput.py +++ b/serve/benchmarks/benchmark_throughput.py @@ -139,6 +139,8 @@ def run_mlc(engine, requests, args) -> float: frequency_penalty=args.sampling_setting["frequency_penalty"], presence_penalty=args.sampling_setting["presence_penalty"], logit_bias=args.sampling_setting["logit_bias"], + logprobs = args.sampling_setting["logprobs"], + top_logprobs = args.sampling_setting["top_logprobs"], ), stopping_criteria=StoppingCriteria( max_tokens=args.num_output_tokens, stop_sequences=None diff --git a/serve/benchmarks/utils.py b/serve/benchmarks/utils.py index defea7b099..4507dc0dd2 100644 --- a/serve/benchmarks/utils.py +++ b/serve/benchmarks/utils.py @@ -22,6 +22,18 @@ def add_sampling_flags(parser): action="store_true", help="Apply all penalties, logit bias, top-p and top-k.", ) + parser.add_argument( + "--logprobs", + action="store_true", + default=False, + help="Switch on logprobs output" + ) + parser.add_argument( + "--top-logprobs", + type=int, + default=5, + help="Number of top logprobs to output, limited by 5. Works only with logprobs true." + ) def postproc_sampling_args(args): @@ -33,6 +45,8 @@ def postproc_sampling_args(args): "repetition_penalty": 1.0, "top_p": 1.0, "top_k": -1, + "logprobs": False, + "top_logprobs": 5, } if args.apply_all_sampling_params: @@ -51,3 +65,7 @@ def postproc_sampling_args(args): if args.apply_top_p_top_k: args.sampling_setting["top_k"] = 2 args.sampling_setting["top_p"] = 0.7 + + if args.logprobs: + args.sampling_setting["logprobs"] = True + args.sampling_setting["top_logprobs"] = args.top_logprobs diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index df7a74e738..1c558609c8 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -9,7 +9,7 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse, StreamingResponse -# TODO(amalyshe): hadnle random_seed +# TODO(amalyshe): handle random_seed # from .base import set_global_random_seed from ..api.protocol import ( ChatCompletionRequest, @@ -20,6 +20,7 @@ ChatMessage, DeltaMessage, ErrorResponse, + Logprobs, UsageInfo, ) from ..engine import ( @@ -64,6 +65,9 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params.top_p = request.top_p if request.logit_bias is not None: sampling_params.logit_bias = request.logit_bias + if request.logprobs: + sampling_params.top_logprobs = request.top_logprobs + sampling_params.logprobs = request.logprobs return sampling_params @@ -156,7 +160,7 @@ async def generate_completion_stream( created_time = int(time.time()) def create_stream_response( - choices: list[ChatCompletionResponseStreamChoice], + choices: List[ChatCompletionResponseStreamChoice], ) -> ChatCompletionStreamResponse: return ChatCompletionStreamResponse( id=request_id, @@ -192,6 +196,7 @@ def create_stream_response( finish_reason=seq.finish_reason.value if seq.finish_reason is not None else None, + logprob_info=Logprobs(content=seq.logprob_info) if seq.logprob_info != [] else None ) for seq in res.sequences ] @@ -212,6 +217,7 @@ async def collect_result_stream( finish_reasons = [None] * num_sequences num_prompt_tokens = 0 num_generated_tokens = [0 for _ in range(num_sequences)] + logprob_infos = [[] for _ in range(num_sequences)] # type: ignore async for res in result_generator: # TODO: verify that the request cancellation happens after this returns if res.error: @@ -226,18 +232,27 @@ async def collect_result_stream( if seq.delta: sequences[seq.index].append(seq.delta) + if seq.logprob_info: + assert seq.delta + logprob_infos[seq.index].extend(seq.logprob_info) + if seq.is_finished: assert seq.finish_reason is not None finish_reasons[seq.index] = seq.finish_reason.value # type: ignore - choices = [ - ChatCompletionResponseChoice( + choices = [] + for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)): + logprobs = None + if logprob_info_seq != []: + logprobs = Logprobs(content=logprob_info_seq) + + choice = ChatCompletionResponseChoice( index=index, message=ChatMessage(role="assistant", content="".join(chunks)), finish_reason=finish_reason, + logprobs=logprobs, ) - for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)) - ] + choices.append(choice) usage = UsageInfo( prompt_tokens=num_prompt_tokens, diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index 286c13edd2..4f42f7233e 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -6,6 +6,8 @@ from pydantic import BaseModel, Field +from ..openai_logprob_protocol import Logprobs + class ErrorResponse(BaseModel): object: str = "error" @@ -71,11 +73,14 @@ class ChatCompletionRequest(BaseModel): logit_bias: Optional[Dict[int, float]] = None user: Optional[str] = None ignore_eos: Optional[bool] = False + logprobs: bool = False + top_logprobs: int = 0 class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage + logprobs: Optional[Logprobs] = None finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None @@ -96,6 +101,7 @@ class DeltaMessage(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage + logprobs: Optional[Logprobs] = None finish_reason: Optional[Literal["stop", "length"]] = None diff --git a/serve/mlc_serve/engine/__init__.py b/serve/mlc_serve/engine/__init__.py index 129b7c05ed..b2fb08a079 100644 --- a/serve/mlc_serve/engine/__init__.py +++ b/serve/mlc_serve/engine/__init__.py @@ -16,5 +16,7 @@ RequestState, PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, + RawLogprobsInfo, + RawLogprobsInfos, ) -from .sampling_params import SamplingParams, SamplingType +from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX diff --git a/serve/mlc_serve/engine/async_connector.py b/serve/mlc_serve/engine/async_connector.py index 04a9233296..23d1afc426 100644 --- a/serve/mlc_serve/engine/async_connector.py +++ b/serve/mlc_serve/engine/async_connector.py @@ -1,7 +1,6 @@ import asyncio import structlog -from typing import AsyncIterator, Any -from concurrent.futures import ThreadPoolExecutor +from typing import AsyncIterator, Dict from collections import deque from .base import ( @@ -26,7 +25,7 @@ def __init__(self, engine: InferenceEngine, engine_wait_timeout=1): self.engine_loop_task = None self.engine_loop_exception = None self.shutdown_event = asyncio.Event() - self.result_queues = dict[RequestId, ResultQueue]() + self.result_queues: Dict[RequestId, ResultQueue] = {} self.recent_cancelled_requests = deque[RequestId](maxlen=64) async def start(self): diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 098209da55..b66dea3479 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -6,13 +6,25 @@ from typing import List, Callable, Any, Optional, Dict import inspect +import numpy as np from .sampling_params import SamplingParams, SamplingType +from ..openai_logprob_protocol import LogprobsContent LOG = structlog.stdlib.get_logger(__name__) RequestId = str +@dataclass +class RawLogprobsInfo: + current_token_id: int + current_logprob: float + top_token_ids: Optional[np.array] + top_logprobs: Optional[np.array] + +RawLogprobsInfos = List[Optional[RawLogprobsInfo]] + + # TODO(@sunggg): consider transition to something like Pydantic @dataclass class MLCServeEngineConfig: @@ -155,6 +167,7 @@ class SequenceOutput: finish_reason: Optional[FinishReason] = None # Number of generated tokens so far num_generated_tokens: int = 0 + logprob_info: List[Optional[LogprobsContent]] = field(default_factory=list) @property def is_finished(self) -> bool: @@ -164,7 +177,7 @@ def is_finished(self) -> bool: @dataclass class RequestOutput: request_id: RequestId - sequences: list[SequenceOutput] + sequences: List[SequenceOutput] # TODO: reconsider the place to put this number # Only set for outputs with valid sequence outputs num_prompt_tokens: Optional[int] = None diff --git a/serve/mlc_serve/engine/dummy.py b/serve/mlc_serve/engine/dummy.py index 3d7c10651c..03227dd3ca 100644 --- a/serve/mlc_serve/engine/dummy.py +++ b/serve/mlc_serve/engine/dummy.py @@ -12,9 +12,9 @@ class DummyInferenceEngine: - def __init__(self): - self.queue_lock = Lock() - self.has_new_requests = Condition(self.queue_lock) + def __init__(self) -> None: + self.queue_lock: Lock = Lock() + self.has_new_requests: Condition = Condition(self.queue_lock) self.request_queue: Dict[RequestId, int] = {} def add(self, requests: list[Request]): diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 1bc252b48c..675e97e173 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -3,17 +3,19 @@ """ import time -from typing import Tuple, Deque, Dict, Optional, Union, Callable +from typing import Tuple, Deque, Dict, Optional, Union, Callable, List from collections import deque from threading import Condition, Lock import structlog from .base import ( + GenerationSequence, + RawLogprobsInfo, + RawLogprobsInfos, Request, RequestId, RequestState, - GenerationSequence, SequenceId, StoppingCriteria, ) @@ -27,6 +29,7 @@ Tokenizer as TokenizerP, ) from ..model.base import ModelArtifactConfig +from ..openai_logprob_protocol import LogprobsContent, TopLogprobs LOG = structlog.stdlib.get_logger(__name__) @@ -135,6 +138,52 @@ def detokenize_incrementally( return delta +def logprob_detokenize( + tokenizer: TokenizerP, + logprob_info: Optional[RawLogprobsInfo], +) -> Optional[LogprobsContent]: + """Detokenize tokens from RawLogprobInfo and convert the latter to LogprobContent""" + if logprob_info is None: + return None + + top_logprobs: List[TopLogprobs] = [] + if logprob_info.top_token_ids is not None and logprob_info.top_logprobs is not None: + top_tokens = list(zip(logprob_info.top_token_ids, logprob_info.top_logprobs)) + for top_token_id, top_logprob in top_tokens: + top_logprobs.append( + TopLogprobs( + token=tokenizer.decode(top_token_id), + logprob=float(top_logprob), + # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object + bytes=None, + ) + ) + + logprobs_content = LogprobsContent( + token=tokenizer.decode([logprob_info.current_token_id]), + logprob=logprob_info.current_logprob, + # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object + bytes=None, + top_logprobs=top_logprobs, + ) + + return logprobs_content + + +def logprobs_detokenize( + tokenizer: TokenizerP, + logprob_info: Optional[RawLogprobsInfos], +) -> List[Optional[LogprobsContent]]: + if logprob_info is None: + return [] + + res: List[Optional[LogprobsContent]] = [] + for info in logprob_info: + res.append(logprob_detokenize(tokenizer, info)) + + return res + + def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended): if stopping_criteria.stop_sequences: for t in stopping_criteria.stop_sequences: diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 79b77e93a3..c5937dd18b 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -4,7 +4,14 @@ from dataclasses import dataclass from typing import Optional, Protocol, Union, List, Sequence -from .base import ChatMessage, RequestId, MLCServeEngineConfig, RequestState, SequenceId +from .base import ( + ChatMessage, + MLCServeEngineConfig, + RawLogprobsInfos, + RequestId, + RequestState, + SequenceId, +) from ..model.base import ModelArtifactConfig from .sampling_params import SamplingParams @@ -44,6 +51,7 @@ class TextGenerationResult: # making this a list of token ids to leave room for speculative decoding generated_tokens: List[int] error: Optional[str] + logprob_info: Optional[RawLogprobsInfos] class KVCache(Protocol): diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index bbcdc7fd2c..961b2b744a 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -9,8 +9,8 @@ from functools import cached_property from typing import Dict, Optional - _SAMPLING_EPS = 1e-5 +LOGPROB_TOP_K_MAX = 5 class SamplingType(IntEnum): @@ -46,6 +46,13 @@ class SamplingParams: to -1 to consider all tokens. logit_bias: The bias applied on the logit before sampling. Must be in [-100, 100]. + logprobs: Optional[bool] Whether to return log probabilities of the output + tokens or not. If true, returns the log probabilities of each output + token returned in the content of message. + top_logprobs: Optional[Integer] An integer between 0 and 5 specifying + the number of most likely tokens to return at each token position, + each with an associated log probability. logprobs must be set to + true if this parameter is used. """ presence_penalty: float = 0.0 @@ -58,6 +65,8 @@ class SamplingParams: appeared_tokens_freq: Dict[int, int] = None logit_bias_index: list[int] = None logit_bias_value: list[float] = None + logprobs: bool = False + top_logprobs: int = 0 def __post_init__(self): self.appeared_tokens_freq = {} @@ -95,6 +104,11 @@ def _verify_args(self) -> None: raise ValueError( f"logit bias must be in [-100, 100], got {bias} for token {token}." ) + if self.logprobs: + if (self.top_logprobs < 0 or self.top_logprobs > LOGPROB_TOP_K_MAX): + raise ValueError( + f"top_logprobs must be between 0 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}." + ) def _verify_greedy_sampling(self) -> None: if self.top_p < 1.0 - _SAMPLING_EPS: diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index 3d1fe70b99..c8354e4c5b 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -5,8 +5,8 @@ import multiprocessing import queue from threading import Lock -from typing import Callable from collections import defaultdict +from typing import Callable import structlog @@ -24,6 +24,7 @@ from .engine_common import ( get_new_request_state, update_sequence, + logprobs_detokenize ) from .model_module import ModelModule, TokenizerModule from .staging_engine_worker import ( @@ -251,6 +252,7 @@ def step(self) -> InferenceStepResult: delta, finish_reason, num_generated_tokens=len(gen_seq.generated_token_ids), + logprob_info=logprobs_detokenize(self.tokenizer, seq_output.logprob_info), ) seq_outputs[request_id].append(output) diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index e74a6181c8..6c02c0811c 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -4,7 +4,7 @@ import time import multiprocessing import multiprocessing.synchronize -from dataclasses import dataclass +from dataclasses import dataclass, field from threading import Thread, Lock from typing import Callable, Optional, Union, Any, Dict, List @@ -12,12 +12,14 @@ from .base import ( FinishReason, + RawLogprobsInfos, RequestId, RequestState, ValidationError, SequenceId, GenerationSequence, ) + from .metrics import PrometheusMetrics from .metrics_labels import * from .model_module import ( @@ -40,7 +42,7 @@ class ShutdownCommand: @dataclass class AddRequestsCommand: - request_states: list[RequestState] + request_states: List[RequestState] @dataclass @@ -61,14 +63,15 @@ class StopSequenceCommand: @dataclass class SequenceGenerationOutput: id: SequenceId - new_tokens: list[int] + new_tokens: List[int] finish_reason: Optional[FinishReason] = None error: Optional[Union[str, ValidationError]] = None + logprob_info: Optional[RawLogprobsInfos] = None @dataclass class GenerationLoopWorkerOutput: - sequences: list[SequenceGenerationOutput] + sequences: List[SequenceGenerationOutput] error: Optional[BaseException] = None @@ -288,6 +291,7 @@ def step(self) -> GenerationLoopWorkerOutput: id=res.sequence_id, new_tokens=new_tokens, finish_reason=finish_reason, + logprob_info=res.logprob_info, ) ) diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index 5dcf80eda7..c400ec6b4a 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -22,6 +22,7 @@ get_requests_to_process, update_sequence, EngineBase, + logprobs_detokenize ) from .model_module import ( ModelModule, @@ -222,6 +223,7 @@ def step(self) -> InferenceStepResult: delta, num_generated_tokens=len(gen_seq.generated_token_ids), finish_reason=finish_reason, + logprob_info=logprobs_detokenize(self.tokenizer, res.logprob_info), ) ) diff --git a/serve/mlc_serve/model/dummy_model.py b/serve/mlc_serve/model/dummy_model.py index 5d1084951d..4e41508c1e 100644 --- a/serve/mlc_serve/model/dummy_model.py +++ b/serve/mlc_serve/model/dummy_model.py @@ -123,6 +123,7 @@ def generate( generated_tokens=[req.token_ids[-1] + 1], # generated_tokens=[1], error=None, + logprob_info=None, ) ) return result diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index d99dd0bbeb..b9e23ddad0 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -1,4 +1,4 @@ -from typing import List, Union, Optional +from typing import List, Optional, Tuple, Union import structlog import numpy as np @@ -9,6 +9,9 @@ from ..engine import ( SamplingType, SamplingParams, + LOGPROB_TOP_K_MAX, + RawLogprobsInfo, + RawLogprobsInfos, ) LOG = structlog.stdlib.get_logger(__name__) @@ -36,6 +39,86 @@ def get_num_cache_blocks( ) +def get_raw_logprob_info( + logits, + token_id, + top_logprobs_num, +) -> RawLogprobsInfo: + logprobs = torch.log_softmax(logits, dim=-1) + res_logprob = logprobs[token_id] + + if top_logprobs_num == 0: + top_logprobs = None + top_tokens = None + else: + assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs" + top_logprobs, top_tokens = torch.topk( + logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True + ) + top_tokens=top_tokens.cpu().numpy() + top_logprobs=top_logprobs.cpu().numpy() + + # Set to raw logprob info + return RawLogprobsInfo( + current_token_id=token_id, + current_logprob=res_logprob, + top_token_ids=top_tokens, + top_logprobs=top_logprobs, + ) + + +def get_logprob_indices( + sampling_params: List[SamplingParams], + num_seq: int, +) -> Tuple[List[Tuple[int, int, int]], List[Tuple[int, int, int]]]: + lgp_inds_greedy: List[Tuple[int, int, int]] = [] + lgp_inds_random: List[Tuple[int, int, int]] = [] + + g_ind = 0 + r_ind = 0 + for i in range(num_seq): + sampling_param = sampling_params[i] + if sampling_param.sampling_type == SamplingType.RANDOM: + if sampling_param.logprobs: + lgp_inds_random.append((i, r_ind, sampling_param.top_logprobs)) + r_ind = r_ind + 1 + else: + if sampling_param.logprobs: + lgp_inds_greedy.append((i, g_ind, sampling_param.top_logprobs)) + g_ind = g_ind + 1 + + return lgp_inds_greedy, lgp_inds_random + + +def get_raw_logprob_infos( + logprob_infos: RawLogprobsInfos, + indices: List[Tuple[int, int, int]], + logits: torch.Tensor, + token_ids: torch.Tensor, +) -> RawLogprobsInfos: + for (i, ind, top_logprobs) in indices: + logprob_infos[i] = get_raw_logprob_info( + logits[ind], + token_ids[ind], + top_logprobs, + ) + + return logprob_infos + + +def check_logprob_infos( + logprob_infos: RawLogprobsInfos, +) -> Optional[RawLogprobsInfos]: + check = False + for info in logprob_infos: + if info is not None: + check = True + break + if check: + return logprob_infos + return None + + def _apply_top_p_top_k(logits, top_ps, top_ks): p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) @@ -64,7 +147,7 @@ def sample( sampling_params: List[SamplingParams], vocab_size: int, check_safety=False, -) -> Optional[np.ndarray]: +) -> Optional[Tuple[np.ndarray, Optional[RawLogprobsInfos]]]: def _is_safe_to_sample(prob_like): return ( torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) @@ -89,12 +172,26 @@ def _is_safe_to_sample(prob_like): logits_greedy = logits[mask_greedy_dvc] + logprob_infos: RawLogprobsInfos = [None] * num_seq + lgp_inds_greedy, lgp_inds_random = get_logprob_indices( + sampling_params, + num_seq, + ) + if logits_greedy.shape[0] > 0: res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy() + logprob_infos = get_raw_logprob_infos( + logprob_infos, + lgp_inds_greedy, + logits_greedy, + res_greedy, + ) + + # Case when there's only greedy sampling if logits_greedy.shape[0] == num_seq: torch.cuda.nvtx.range_pop() - return res_greedy + return res_greedy, check_logprob_infos(logprob_infos) temperatures = [] top_ps = [] @@ -163,9 +260,17 @@ def _is_safe_to_sample(prob_like): res_random = torch.multinomial(probs, 1, True)[:, 0].cpu().numpy() + logprob_infos = get_raw_logprob_infos( + logprob_infos, + lgp_inds_random, + logits_random, + res_random, + ) + + # Case when there's only random sampling if logits_random.shape[0] == num_seq: torch.cuda.nvtx.range_pop() - return res_random + return res_random, check_logprob_infos(logprob_infos) res = np.empty((num_seq,), dtype=np.int32) res[mask_random_cpu] = res_random @@ -174,7 +279,7 @@ def _is_safe_to_sample(prob_like): res[mask_greedy_cpu] = res_greedy torch.cuda.nvtx.range_pop() - return res + return res, check_logprob_infos(logprob_infos) def prepare_inputs( diff --git a/serve/mlc_serve/model/paged_cache_manager.py b/serve/mlc_serve/model/paged_cache_manager.py index 7193c0eadb..1614b442d4 100644 --- a/serve/mlc_serve/model/paged_cache_manager.py +++ b/serve/mlc_serve/model/paged_cache_manager.py @@ -1,6 +1,6 @@ import math from collections import defaultdict -from typing import List, Optional +from typing import Any, List, Optional from ..engine import ( RequestId, @@ -104,16 +104,16 @@ def replace_head_prompt_block_with(self, new_block): class KVCacheInfo: def __init__( self, - block_size, + block_size: int ): self.block_size = block_size # SequenceId -> list[int] - self.prompt_block_tables = defaultdict(list) - self.slot_mappings = defaultdict(list) + self.prompt_block_tables = defaultdict(list) # type: ignore + self.slot_mappings = defaultdict(list) # type: ignore # The core data structure - self.decode_block_tables = dict[SequenceId, DecodeBlockTable]() + self.decode_block_tables: dict = dict[SequenceId, DecodeBlockTable]() # Record indices of blocks to copy after prefill in the format [src1, dst1, src2, dst2, ...] self.pending_copy_from_to: list[int] = [] diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index a7f11750eb..7daf3336f4 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -1,6 +1,6 @@ -from typing import Union from pathlib import Path import structlog +from typing import List, Union from .base import get_model_artifact_config from .paged_cache_manager import CacheManager @@ -10,11 +10,11 @@ from ..engine import MLCServeEngineConfig from ..engine.model_module import ( DecodeRequest, + ModelModule, PrefillRequest, TextGenerationResult, TextGenerator, ) -from ..engine.model_module import ModelModule LOG = structlog.stdlib.get_logger(__name__) @@ -24,8 +24,8 @@ def __init__(self, model: TextGenerator): self.model = model def generate( - self, requests: list[Union[PrefillRequest, DecodeRequest]], kv_cache - ) -> list[TextGenerationResult]: + self, requests: List[Union[PrefillRequest, DecodeRequest]], kv_cache + ) -> List[TextGenerationResult]: prefill_requests = [r for r in requests if isinstance(r, PrefillRequest)] decode_requests = [r for r in requests if isinstance(r, DecodeRequest)] diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index cfb528ac43..eb5cfc30d1 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -1,6 +1,6 @@ import math import os -from typing import List, Union, Tuple, Sequence +from typing import List, Optional, Union, Tuple, Sequence import structlog import numpy as np @@ -18,8 +18,9 @@ ) from ..engine import ( - SequenceId, PROMPT_SEQEUNCE_INDEX, + RawLogprobsInfos, + SequenceId, get_prompt_sequence_id, MLCServeEngineConfig, ) @@ -203,6 +204,16 @@ def profile_memory_usage(self, seq_lens): return self.get_used_memory() + def get_logprob_infos( + self, + i: int, + logprob_infos: Optional[RawLogprobsInfos], + ) -> Optional[RawLogprobsInfos]: + if logprob_infos is None or logprob_infos[i] is None: + return None + return [logprob_infos[i]] + + def generate( self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], @@ -282,13 +293,6 @@ def generate( slot_mapping, self.params, ) - - if self.disco_session: - logits, _ = out.debug_get_from_remote(0) - else: - logits = out[ - 0 - ] # Ignore returned KV cache since it is updated in-place anyway. else: torch.cuda.nvtx.range_push(f"forward decode {input_shape}") @@ -305,10 +309,12 @@ def generate( self.params, ) - if self.disco_session: - logits, _ = out.debug_get_from_remote(0) - else: - logits = out[0] + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[ + 0 + ] # Ignore returned KV cache since it is updated in-place anyway. torch.cuda.synchronize() torch.cuda.nvtx.range_pop() @@ -330,7 +336,7 @@ def generate( cache.pending_copy_from_to = [] try: - next_tokens = sample(logits, sampling_params, self.vocab_size) + next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size) assert next_tokens is not None outputs = [] for i, (sequence_id, new_token) in enumerate( @@ -346,6 +352,7 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, + logprob_info=self.get_logprob_infos(i, logprob_infos), ) ) else: @@ -354,6 +361,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], error=None, + logprob_info=self.get_logprob_infos(i, logprob_infos), ) ) @@ -369,7 +377,7 @@ def generate( for i, (sequence_id, logits_per_token, sampling_param) in enumerate( zip(sequence_ids, torch.from_dlpack(logits), sampling_params) ): - maybe_new_token = sample( + maybe_new_token, logprob_infos = sample( torch.unsqueeze(logits_per_token, 0), [sampling_param], self.vocab_size, @@ -393,6 +401,7 @@ def generate( ), generated_tokens=[new_token], # type: ignore error=None, + logprob_info=self.get_logprob_infos(0, logprob_infos), ) ) else: @@ -401,6 +410,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], # type: ignore error=None, + logprob_info=self.get_logprob_infos(0, logprob_infos), ) ) else: @@ -413,6 +423,7 @@ def generate( ), generated_tokens=[], error=err_msg, + logprob_info=self.get_logprob_infos(0, logprob_infos), ) ) else: @@ -421,6 +432,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[], error=err_msg, + logprob_info=self.get_logprob_infos(0, logprob_infos), ) ) diff --git a/serve/mlc_serve/openai_logprob_protocol.py b/serve/mlc_serve/openai_logprob_protocol.py new file mode 100644 index 0000000000..36f2b693f6 --- /dev/null +++ b/serve/mlc_serve/openai_logprob_protocol.py @@ -0,0 +1,28 @@ +from typing import List, Optional + +from pydantic import BaseModel + +class TopLogprobs(BaseModel): + """An OpenAI API compatible schema for logprobs output.""" + + token: str + logprob: float + bytes: Optional[List] = None + + +class LogprobsContent(BaseModel): + """An OpenAI API compatible schema for logprobs output.""" + + token: str + logprob: float + bytes: Optional[List] = None + top_logprobs: List[TopLogprobs] # It can be empty + + +class Logprobs(BaseModel): + """ + An OpenAI API compatible schema for logprobs output. + See details in https://platform.openai.com/docs/api-reference/chat/object#chat-create-logprobs + """ + + content: List[LogprobsContent] diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index 1642364232..2a7bf1efd4 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -48,7 +48,7 @@ def create_engine( def create_request( - idx, prompt, temp, freq_pen, pre_pen, max_tokens, stop, ignore_eos, logit_bias=None + idx, prompt, temp, freq_pen, pre_pen, max_tokens, stop, ignore_eos, top_logprobs=0, logprobs=False, logit_bias=None ): return Request( request_id=str(idx), @@ -58,6 +58,8 @@ def create_request( frequency_penalty=freq_pen, presence_penalty=pre_pen, logit_bias=logit_bias, + logprobs=logprobs, + top_logprobs=top_logprobs, ), stopping_criteria=StoppingCriteria(max_tokens=max_tokens, stop_sequences=stop), debug_options=DebugOptions(ignore_eos=ignore_eos), @@ -337,6 +339,43 @@ def _test_penalty( if use_staging_engine: engine.stop() +def _test_logprobs( + model_artifact_path, + use_staging_engine, + max_num_sequences=4, + max_input_len=512, + num_requests=5, + top_logprobs=3, +): + prompt = "hi" + engine = create_engine( + model_artifact_path, + use_staging_engine, + max_num_sequences, + max_input_len, + ) + s = 113 + requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True, top_logprobs=top_logprobs, logprobs=True) for n in range(s, s+num_requests)] + engine.add(requests) + + generated = ["" for _ in range(num_requests)] + + while engine.has_pending_requests(): + results = engine.step() + for res in results.outputs: + assert len(res.sequences) == 1 + seq = res.sequences[0] + + assert seq.finish_reason is not None or len(list(seq.logprobs.content[0]["top_logprobs"])) == top_logprobs + + if seq.is_finished: + assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens + assert seq.finish_reason == FinishReason.Length + else: + generated[int(res.request_id)] += seq.delta + + if use_staging_engine: + engine.stop() if __name__ == "__main__": parser = get_default_mlc_serve_argparser("test engine with samplers") @@ -349,6 +388,8 @@ def _test_penalty( _test_ignore_eos(args.model_artifact_path, use_staging_engine=False) _test_stop(args.model_artifact_path, use_staging_engine=False) _test_stop(args.model_artifact_path, use_staging_engine=True) + _test_logprobs(args.model_artifact_path, use_staging_engine=True) + _test_logprobs(args.model_artifact_path, use_staging_engine=False) # These tests are broken since we are now imposing no length limit # if max_tokens = None. The tests do not finish in a reasonable time. # _test_max_context_length(model_artifact_path, use_staging_engine=True)