diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index 0a373ec272..18f253bb8b 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -196,7 +196,7 @@ def create_stream_response( finish_reason=seq.finish_reason.value if seq.finish_reason is not None else None, - logprob_info=seq.logprob_info[0] if seq.logprob_info else None + logprob_info=Logprobs(content=[seq.logprob_info]) if seq.logprob_info else None ) for seq in res.sequences ] @@ -241,28 +241,16 @@ async def collect_result_stream( finish_reasons[seq.index] = seq.finish_reason.value # type: ignore choices = [] - for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)): - content = [] - if logprob_infos[index] != []: - for logprob_info in logprob_infos[index]: - top_logprobs = [TopLogprobs( - token=str(token), - logprob=float(logprob), - # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object - bytes=None, - ) for token, logprob in logprob_info[1]] - content.append(LogprobsContent( - token=str(logprob_info[0][0]), - logprob=float(logprob_info[0][1]), - # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object - bytes=None, - top_logprobs=top_logprobs, - )) + for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)): + logprobs = None + if logprob_info_seq != []: + logprobs = Logprobs(content=logprob_info_seq) + choice = ChatCompletionResponseChoice( index=index, message=ChatMessage(role="assistant", content="".join(chunks)), finish_reason=finish_reason, - logprobs=Logprobs(content=content), + logprobs=logprobs, ) choices.append(choice) diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index b179e3a164..605ca98849 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -89,7 +89,7 @@ class LogprobsContent(BaseModel): token: str logprob: float bytes: Optional[List] = None - top_logprobs: List[TopLogprobs] + top_logprobs: List[TopLogprobs] # It can be empty class Logprobs(BaseModel): @@ -98,7 +98,7 @@ class Logprobs(BaseModel): See details in https://platform.openai.com/docs/api-reference/chat/object#chat-create-logprobs """ - content: Optional[List[LogprobsContent]] + content: List[LogprobsContent] class ChatCompletionResponseChoice(BaseModel): diff --git a/serve/mlc_serve/engine/__init__.py b/serve/mlc_serve/engine/__init__.py index d17ed20d05..92d101ea95 100644 --- a/serve/mlc_serve/engine/__init__.py +++ b/serve/mlc_serve/engine/__init__.py @@ -16,6 +16,6 @@ RequestState, PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, - LOGPROBS_TYPE, + RawLogprobsInfo, ) -from .sampling_params import SamplingParams, SamplingType, TOP_LOGPROBS_NUMBER +from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 3a324ec1d0..5af0939b64 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -6,12 +6,20 @@ from typing import List, Callable, Any, Optional, Dict, Tuple import inspect +import numpy as np from .sampling_params import SamplingParams, SamplingType +from ..api.protocol import LogprobsContent RequestId = str -LOGPROBS_TYPE = Tuple[Tuple, List[Tuple]] -# ((token, logprob), [(top1_token, top1_logprob), ...]) + + +@dataclass +class RawLogprobsInfo: + current_token: int + current_logprob: float + top_tokens: Optional[np.array] + top_logprobs: Optional[np.array] # TODO(@sunggg): consider transition to something like Pydantic @@ -163,7 +171,7 @@ class SequenceOutput: finish_reason: Optional[FinishReason] = None # Number of generated tokens so far num_generated_tokens: int = 0 - logprob_info: Optional[LOGPROBS_TYPE] = None + logprob_info: Optional[LogprobsContent] = None @property def is_finished(self) -> bool: diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index bf6a242b0c..9ab238d3bb 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -10,13 +10,13 @@ import structlog from .base import ( + GenerationSequence, + RawLogprobsInfo, Request, RequestId, RequestState, - GenerationSequence, SequenceId, StoppingCriteria, - LOGPROBS_TYPE, ) from .model_module import ( DecodeRequest, @@ -28,6 +28,7 @@ Tokenizer as TokenizerP, ) from ..model.base import ModelArtifactConfig +from ..api.protocol import LogprobsContent, TopLogprobs LOG = structlog.stdlib.get_logger(__name__) @@ -133,29 +134,49 @@ def detokenize_incrementally( return delta -def logprob_detokenize(tokenizer: TokenizerP, logprob_info: Optional[LOGPROBS_TYPE]) -> Optional[LOGPROBS_TYPE]: - """Detokenize top tokens in logprob information""" +def logprob_detokenize( + tokenizer: TokenizerP, + logprob_info: Optional[RawLogprobsInfo], +) -> Optional[LogprobsContent]: + """Detokenize tokens from RawLogprobInfo and convert the latter to LogprobContent""" if logprob_info is None: return None - (res, res_logprob), top_tokens = logprob_info - top_tokens = list(top_tokens) - count: Dict[str, int] = {} - top_logprobs: List[Tuple] = [] - # dedup duplicates - # Todo: Make sure decode can generate different tokens - for top_token, _ in top_tokens: - detokenized = tokenizer.decode(top_token) - if detokenized in count: - count[detokenized] += 1 - else: - count[detokenized] = 1 - for top_token, top_logprob in top_tokens: - detokenized = tokenizer.decode(top_token) - if count[detokenized] == 1: - top_logprobs.append((detokenized, float(top_logprob))) - else: - top_logprobs.append((f"{detokenized}_{top_token}", float(top_logprob))) - return (str(tokenizer.decode(res)), res_logprob), top_logprobs + + top_logprobs: List[TopLogprobs] = [] + if ( + logprob_info.top_tokens is not None and + logprob_info.top_logprobs is not None + ): + top_tokens = zip(logprob_info.top_tokens[:], logprob_info.top_logprobs[:]) + count: Dict[str, int] = {} + # dedup duplicates + # Todo: Make sure decode can generate different tokens + for top_token, _ in top_tokens: + detokenized = tokenizer.decode(top_token) + if detokenized in count: + count[detokenized] += 1 + else: + count[detokenized] = 1 + for top_token, top_logprob in top_tokens: + detokenized = tokenizer.decode(top_token) + if count[detokenized] != 1: + detokenized = f"{detokenized}_{top_token}" + top_logprobs.append(TopLogprobs( + token=detokenized, + logprob=float(top_logprob), + # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object + bytes=None, + )) + + logprobs_content = LogprobsContent( + token=tokenizer.decode([logprob_info.current_token]), + logprob=logprob_info, + # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object + bytes=None, + top_logprobs=top_logprobs, + ) + + return logprobs_content def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended): diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 439f2b20ee..d23086c592 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -6,8 +6,8 @@ from .base import ( ChatMessage, - LOGPROBS_TYPE, MLCServeEngineConfig, + RawLogprobsInfo, RequestId, RequestState, SequenceId, @@ -51,7 +51,7 @@ class TextGenerationResult: # making this a list of token ids to leave room for speculative decoding generated_tokens: List[int] error: Optional[str] - logprob_info: Optional[LOGPROBS_TYPE] = None + logprob_info: Optional[RawLogprobsInfo] = None class KVCache(Protocol): diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index 9fcbd64fab..d5ba5d109c 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -10,7 +10,7 @@ from typing import Dict, Optional _SAMPLING_EPS = 1e-5 -TOP_LOGPROBS_NUMBER = 5 +LOGPROB_TOP_K_MAX = 5 class SamplingType(IntEnum): @@ -105,9 +105,9 @@ def _verify_args(self) -> None: f"logit bias must be in [-100, 100], got {bias} for token {token}." ) if self.logprobs is not None and self.logprobs: - if (self.top_logprobs < 1 or self.top_logprobs > TOP_LOGPROBS_NUMBER): + if (self.top_logprobs < 1 or self.top_logprobs > LOGPROB_TOP_K_MAX): raise ValueError( - f"top_logprobs must be between 1 and {TOP_LOGPROBS_NUMBER}, got {self.top_logprobs}." + f"top_logprobs must be between 1 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}." ) def _verify_greedy_sampling(self) -> None: diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index 0a0b3041ea..f7df7d10c9 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -6,14 +6,13 @@ import multiprocessing.synchronize from dataclasses import dataclass from threading import Thread, Lock -from typing import Callable, Optional, Union, Tuple, Any, Dict, List +from typing import Callable, Optional, Union, Any, Dict, List import structlog -import numpy as np from .base import ( FinishReason, - LOGPROBS_TYPE, + RawLogprobsInfo, RequestId, RequestState, ValidationError, @@ -67,7 +66,7 @@ class SequenceGenerationOutput: new_tokens: List[int] finish_reason: Optional[FinishReason] = None error: Optional[Union[str, ValidationError]] = None - logprob_info: Optional[LOGPROBS_TYPE] = None + logprob_info: Optional[RawLogprobsInfo] = None @dataclass diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 673d0ea766..55e073387b 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -3,7 +3,6 @@ from pathlib import Path from collections import defaultdict from typing import List, Union, Optional, Tuple -from dataclasses import dataclass import structlog import numpy as np @@ -19,10 +18,10 @@ SamplingType, MLCServeEngineConfig, SamplingParams, - TOP_LOGPROBS_NUMBER, - LOGPROBS_TYPE, + LOGPROB_TOP_K_MAX, SequenceId, PROMPT_SEQEUNCE_INDEX, + RawLogprobsInfo, get_prompt_sequence_id, ) from ..engine.model_module import ( @@ -35,6 +34,49 @@ LOG = structlog.stdlib.get_logger(__name__) +def fetch_raw_logprob_infos( + logits, + res_tokens, + sampling_params, +) -> List[Optional[RawLogprobsInfo]]: + logprob_infos: List[Optional[RawLogprobsInfo]] = [] + num_seq = logits.shape[0] + for index in range(num_seq): + if ( + sampling_params[index].logprobs is not None and + sampling_params[index].logprobs + ): + # Logprob sampling + logprobs = torch.log_softmax(logits[index], dim=-1) + res_logprob = logprobs[res_tokens[index]].cpu().numpy() + + top_logprobs_num = sampling_params[index].top_logprobs + if ( + top_logprobs_num is None or + top_logprobs_num == 0 + ): + top_logprobs = None + top_tokens = None + else: + assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs" + top_logprobs, top_tokens = torch.topk( + logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True + ) + top_tokens=top_tokens.cpu().numpy(), + top_logprobs=top_logprobs.cpu().numpy() + + # Set to raw logprob info + logprob_infos.append(RawLogprobsInfo( + current_token=res_tokens[index].cpu().numpy(), + current_logprob=res_logprob, + top_tokens=top_tokens, + top_logprobs=top_logprobs, + )) + else: + logprob_infos.append(None) + return logprob_infos + + def _apply_top_p_top_k(logits, top_ps, top_ks): p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) @@ -63,7 +105,7 @@ def sample( sampling_params: List[SamplingParams], vocab_size: int, check_safety=False, -) -> Optional[Tuple[np.ndarray, Optional[Tuple[Tuple, Tuple]]]]: +) -> Optional[Tuple[np.ndarray, List[Optional[RawLogprobsInfo]]]]: def _is_safe_to_sample(prob_like): return ( torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) @@ -82,21 +124,18 @@ def _is_safe_to_sample(prob_like): logits_greedy = logits[mask_greedy] if logits_greedy.shape[0] > 0: - # Greedy sampling - logprobs = torch.log_softmax(logits_greedy, dim=-1) - res_greedy_logprob, res_greedy = torch.max(logprobs, dim=-1) - - top_greedy_logprob, top_greedy = torch.topk( - logprobs, k=TOP_LOGPROBS_NUMBER, dim=-1, largest=True, sorted=True + res_greedy = torch.argmax(logits_greedy, -1) + + logprob_infos_greedy = fetch_raw_logprob_infos( + logits_greedy, + res_greedy, + sampling_params[mask_greedy], ) - # Convert to numpy - res_greedy_logprob = res_greedy_logprob.cpu().numpy() + res_greedy = res_greedy.cpu().numpy() - top_greedy_logprob = top_greedy_logprob.cpu().numpy() - top_greedy = top_greedy.cpu().numpy() # Case when there's only greedy sampling if logits_greedy.shape[0] == num_seq: - return res_greedy, ((res_greedy, res_greedy_logprob), (top_greedy, top_greedy_logprob)) + return res_greedy, logprob_infos_greedy temperatures = [] top_ps = [] @@ -131,7 +170,6 @@ def _is_safe_to_sample(prob_like): if param.logit_bias: logits[i][param.logit_bias_index] += torch.Tensor(param.logit_bias_value).type_as(logits).to(device=logits.device) - logits_random = logits[mask_random] @@ -140,43 +178,40 @@ def _is_safe_to_sample(prob_like): logits_random.div_(t.unsqueeze(dim=1)) if do_top_p or do_top_k: + # TODO(vvchernov): looks like there is misprinting. Should logits_random be returned? + # If no, where are logits used below? logits = _apply_top_p_top_k(logits_random, top_ps, top_ks) probs = torch.softmax(logits_random, dim=-1) - logprobs = torch.log_softmax(logits_greedy, dim=-1) - top_random_logprob, top_random = torch.topk( - logprobs, k=TOP_LOGPROBS_NUMBER, dim=-1, largest=True, sorted=True - ) - top_random_logprob = top_random_logprob.cpu().numpy() - top_random = top_random.cpu().numpy() if check_safety and not _is_safe_to_sample(probs): return None - res_random = torch.multinomial(probs, 1, True).cpu().numpy()[:, 0] - res_random_logprobs = torch.gather(logprobs, dim=-1, index=torch.tensor(res_random, dtype=torch.int64, device=logits.device)).cpu().numpy() + res_random = torch.multinomial(probs, 1, True)[:, 0] + logprob_infos_random = fetch_raw_logprob_infos( + logits_random, + res_random, + sampling_params[mask_random], + ) + + res_random = res_random.cpu().numpy() + # Case when there's only random sampling if logits_random.shape[0] == num_seq: - return res_random, ((res_random, res_random_logprobs), (top_random, top_random_logprob)) + return res_random, logprob_infos_random res = np.empty((num_seq,), dtype=np.int32) - res_logprobs = np.empty((num_seq,), dtype=np.float32) - top = np.empty((num_seq, TOP_LOGPROBS_NUMBER), dtype=np.int32) - top_logprobs = np.empty((num_seq, TOP_LOGPROBS_NUMBER), dtype=np.float32) - res[mask_random] = res_random - res_logprobs[mask_random] = res_random_logprobs - top[mask_random] = top_random - top_logprobs[mask_random] = top_random_logprob + logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq + logprob_infos[mask_random] = logprob_infos_random if logits_greedy.shape[0] > 0: res[mask_greedy] = res_greedy - res_logprobs[mask_greedy] = res_greedy_logprob - top[mask_greedy] = top_greedy - top_logprobs[mask_greedy] = top_greedy_logprob - return res, ((res, res_logprobs), (top, top_logprobs)) + logprob_infos[mask_greedy] = logprob_infos_greedy + + return res, logprob_infos def load_disco_module(artifact_path, lib_path, num_shards): @@ -228,26 +263,6 @@ def get_tvm_model(config, dev): return load_disco_module(config.model_artifact_path, lib_path, config.num_shards) -def fetch_logprobs( - logprob_info: Optional[Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]], - index: int, - sampling_param: SamplingParams, - ) -> Optional[LOGPROBS_TYPE]: # np.ndarray inside - """Fetch the logprob information with index""" - if ( - sampling_param.logprobs is None or - not sampling_param.logprobs or - logprob_info is None - ): - return None - (res, res_logprobs), (top, top_logprobs) = logprob_info - return (res[index],res_logprobs[index]), \ - list(zip( - top[index][:sampling_param.top_logprobs], - top_logprobs[index][:sampling_param.top_logprobs] - )) - - def _prepare_inputs( sequence_ids, all_token_ids, @@ -528,7 +543,7 @@ def generate( cache.pending_copy_from_to = [] try: - next_tokens, logprob_info = sample(logits, sampling_params, self.vocab_size) + next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size) assert next_tokens is not None outputs = [] for i, (sequence_id, new_token) in enumerate( @@ -544,7 +559,7 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, - logprob_info=fetch_logprobs(logprob_info, seq_id, sampling_params[seq_id]), + logprob_info=logprob_infos[i], ) ) else: @@ -553,7 +568,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], error=None, - logprob_info=fetch_logprobs(logprob_info, i, sampling_params[i]), + logprob_info=logprob_infos[i], ) ) @@ -569,7 +584,7 @@ def generate( for i, (sequence_id, logits_per_token, sampling_param) in enumerate( zip(sequence_ids, torch.from_dlpack(logits), sampling_params) ): - maybe_new_token, logprob_info = sample( + maybe_new_token, logprob_infos = sample( torch.unsqueeze(logits_per_token, 0), [sampling_param], self.vocab_size, @@ -590,7 +605,7 @@ def generate( ), generated_tokens=[new_token], # type: ignore error=None, - logprob_info=fetch_logprobs(logprob_info, i, sampling_param) + logprob_info=logprob_infos[i] ) ) else: @@ -599,7 +614,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], # type: ignore error=None, - logprob_info=fetch_logprobs(logprob_info, i, sampling_param) + logprob_info=logprob_infos[i] ) ) else: @@ -612,7 +627,7 @@ def generate( ), generated_tokens=[], error=err_msg, - logprob_info=fetch_logprobs(logprob_info, i, sampling_param) + logprob_info=logprob_infos[i] ) ) else: @@ -621,7 +636,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[], error=err_msg, - logprob_info=fetch_logprobs(logprob_info, i, sampling_param) + logprob_info=logprob_infos[i] ) )