From 4c67afb1d971c075325d82bc5b1cc0627beeae92 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 4 Jan 2024 22:19:01 +0400 Subject: [PATCH 1/8] updates after review --- serve/mlc_serve/api/handler.py | 20 +++++++++++--------- serve/mlc_serve/engine/__init__.py | 2 +- serve/mlc_serve/engine/sampling_params.py | 6 +++--- serve/mlc_serve/model/paged_cache_model.py | 10 +++++----- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index 0a373ec272..0ffd0483f3 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -241,19 +241,21 @@ async def collect_result_stream( finish_reasons[seq.index] = seq.finish_reason.value # type: ignore choices = [] - for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)): - content = [] - if logprob_infos[index] != []: - for logprob_info in logprob_infos[index]: + for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)): + logprobs_content = [] + if logprob_info_seq != []: + for logprob_info in logprob_info_seq: + cur_token_logprob_info = logprob_info[0] + top_logprobs_info = logprob_info[1] top_logprobs = [TopLogprobs( token=str(token), logprob=float(logprob), # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object bytes=None, - ) for token, logprob in logprob_info[1]] - content.append(LogprobsContent( - token=str(logprob_info[0][0]), - logprob=float(logprob_info[0][1]), + ) for token, logprob in top_logprobs_info] + logprobs_content.append(LogprobsContent( + token=str(cur_token_logprob_info[0]), + logprob=float(cur_token_logprob_info[1]), # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object bytes=None, top_logprobs=top_logprobs, @@ -262,7 +264,7 @@ async def collect_result_stream( index=index, message=ChatMessage(role="assistant", content="".join(chunks)), finish_reason=finish_reason, - logprobs=Logprobs(content=content), + logprobs=Logprobs(content=logprobs_content), ) choices.append(choice) diff --git a/serve/mlc_serve/engine/__init__.py b/serve/mlc_serve/engine/__init__.py index d17ed20d05..61700bbb00 100644 --- a/serve/mlc_serve/engine/__init__.py +++ b/serve/mlc_serve/engine/__init__.py @@ -18,4 +18,4 @@ get_prompt_sequence_id, LOGPROBS_TYPE, ) -from .sampling_params import SamplingParams, SamplingType, TOP_LOGPROBS_NUMBER +from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index 9fcbd64fab..d5ba5d109c 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -10,7 +10,7 @@ from typing import Dict, Optional _SAMPLING_EPS = 1e-5 -TOP_LOGPROBS_NUMBER = 5 +LOGPROB_TOP_K_MAX = 5 class SamplingType(IntEnum): @@ -105,9 +105,9 @@ def _verify_args(self) -> None: f"logit bias must be in [-100, 100], got {bias} for token {token}." ) if self.logprobs is not None and self.logprobs: - if (self.top_logprobs < 1 or self.top_logprobs > TOP_LOGPROBS_NUMBER): + if (self.top_logprobs < 1 or self.top_logprobs > LOGPROB_TOP_K_MAX): raise ValueError( - f"top_logprobs must be between 1 and {TOP_LOGPROBS_NUMBER}, got {self.top_logprobs}." + f"top_logprobs must be between 1 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}." ) def _verify_greedy_sampling(self) -> None: diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 673d0ea766..8feaab0c98 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -19,7 +19,7 @@ SamplingType, MLCServeEngineConfig, SamplingParams, - TOP_LOGPROBS_NUMBER, + LOGPROB_TOP_K_MAX, LOGPROBS_TYPE, SequenceId, PROMPT_SEQEUNCE_INDEX, @@ -87,7 +87,7 @@ def _is_safe_to_sample(prob_like): res_greedy_logprob, res_greedy = torch.max(logprobs, dim=-1) top_greedy_logprob, top_greedy = torch.topk( - logprobs, k=TOP_LOGPROBS_NUMBER, dim=-1, largest=True, sorted=True + logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True ) # Convert to numpy res_greedy_logprob = res_greedy_logprob.cpu().numpy() @@ -145,7 +145,7 @@ def _is_safe_to_sample(prob_like): probs = torch.softmax(logits_random, dim=-1) logprobs = torch.log_softmax(logits_greedy, dim=-1) top_random_logprob, top_random = torch.topk( - logprobs, k=TOP_LOGPROBS_NUMBER, dim=-1, largest=True, sorted=True + logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True ) top_random_logprob = top_random_logprob.cpu().numpy() top_random = top_random.cpu().numpy() @@ -161,8 +161,8 @@ def _is_safe_to_sample(prob_like): res = np.empty((num_seq,), dtype=np.int32) res_logprobs = np.empty((num_seq,), dtype=np.float32) - top = np.empty((num_seq, TOP_LOGPROBS_NUMBER), dtype=np.int32) - top_logprobs = np.empty((num_seq, TOP_LOGPROBS_NUMBER), dtype=np.float32) + top = np.empty((num_seq, LOGPROB_TOP_K_MAX), dtype=np.int32) + top_logprobs = np.empty((num_seq, LOGPROB_TOP_K_MAX), dtype=np.float32) res[mask_random] = res_random res_logprobs[mask_random] = res_random_logprobs From 294815ba1431dde2cf02531b7e599cc007095dfd Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sat, 6 Jan 2024 14:03:13 +0400 Subject: [PATCH 2/8] hide logprobs calculation by condition, add dataclass for raw logprobs info, fix logprobs_random --- serve/mlc_serve/api/protocol.py | 8 +- serve/mlc_serve/engine/__init__.py | 1 + serve/mlc_serve/engine/base.py | 9 ++ serve/mlc_serve/engine/model_module.py | 4 +- serve/mlc_serve/model/paged_cache_model.py | 141 ++++++++++++--------- 5 files changed, 95 insertions(+), 68 deletions(-) diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index b179e3a164..94e917e611 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -78,7 +78,8 @@ class ChatCompletionRequest(BaseModel): class TopLogprobs(BaseModel): """An OpenAI API compatible schema for logprobs output.""" - token: str + # token is string in OpenAI, but for unification int type is added + token: Union[str, int] logprob: float bytes: Optional[List] = None @@ -86,10 +87,11 @@ class TopLogprobs(BaseModel): class LogprobsContent(BaseModel): """An OpenAI API compatible schema for logprobs output.""" - token: str + # token is string in OpenAI, but for unification int type is added + token: Union[str, int] logprob: float bytes: Optional[List] = None - top_logprobs: List[TopLogprobs] + top_logprobs: List[TopLogprobs] # It can be empty class Logprobs(BaseModel): diff --git a/serve/mlc_serve/engine/__init__.py b/serve/mlc_serve/engine/__init__.py index 61700bbb00..83e58ddccc 100644 --- a/serve/mlc_serve/engine/__init__.py +++ b/serve/mlc_serve/engine/__init__.py @@ -17,5 +17,6 @@ PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, LOGPROBS_TYPE, + RawLogprobsInfo, ) from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 3a324ec1d0..ad572ba368 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -6,6 +6,7 @@ from typing import List, Callable, Any, Optional, Dict, Tuple import inspect +import numpy as np from .sampling_params import SamplingParams, SamplingType @@ -14,6 +15,14 @@ # ((token, logprob), [(top1_token, top1_logprob), ...]) +@dataclass +class RawLogprobsInfo: + current_token: np.array + current_logprob: np.array + top_tokens: Optional[np.array] + top_logprobs: Optional[np.array] + + # TODO(@sunggg): consider transition to something like Pydantic @dataclass class MLCServeEngineConfig: diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 439f2b20ee..d23086c592 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -6,8 +6,8 @@ from .base import ( ChatMessage, - LOGPROBS_TYPE, MLCServeEngineConfig, + RawLogprobsInfo, RequestId, RequestState, SequenceId, @@ -51,7 +51,7 @@ class TextGenerationResult: # making this a list of token ids to leave room for speculative decoding generated_tokens: List[int] error: Optional[str] - logprob_info: Optional[LOGPROBS_TYPE] = None + logprob_info: Optional[RawLogprobsInfo] = None class KVCache(Protocol): diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 8feaab0c98..414b4c254e 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -3,7 +3,6 @@ from pathlib import Path from collections import defaultdict from typing import List, Union, Optional, Tuple -from dataclasses import dataclass import structlog import numpy as np @@ -20,9 +19,9 @@ MLCServeEngineConfig, SamplingParams, LOGPROB_TOP_K_MAX, - LOGPROBS_TYPE, SequenceId, PROMPT_SEQEUNCE_INDEX, + RawLogprobsInfo, get_prompt_sequence_id, ) from ..engine.model_module import ( @@ -35,6 +34,49 @@ LOG = structlog.stdlib.get_logger(__name__) +def fetch_raw_logprob_infos( + logits, + res_tokens, + sampling_params, +) -> List[Optional[RawLogprobsInfo]]: + logprob_infos: List[Optional[RawLogprobsInfo]] = [] + num_seq = logits.shape[0] + for index in range(num_seq): + if ( + sampling_params[index].logprobs is None or + not sampling_params[index].logprobs + ): + # Logprob sampling + logprobs = torch.log_softmax(logits[index], dim=-1) + res_logprob = logprobs[res_tokens[index]].cpu().numpy() + + top_logprobs_num = sampling_params[index].top_logprobs + if ( + top_logprobs_num is None or + top_logprobs_num == 0 + ): + top_logprobs = None + top_tokens = None + else: + assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs" + top_logprobs, top_tokens = torch.topk( + logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True + ) + top_tokens=top_tokens.cpu().numpy(), + top_logprobs=top_logprobs.cpu().numpy() + + # Set to raw logprob info + logprob_infos.append(RawLogprobsInfo( + current_token=res_tokens[index].cpu().numpy(), + current_logprob=res_logprob, + top_tokens=top_tokens, + top_logprobs=top_logprobs, + )) + else: + logprob_infos.append(None) + return logprob_infos + + def _apply_top_p_top_k(logits, top_ps, top_ks): p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) @@ -63,7 +105,7 @@ def sample( sampling_params: List[SamplingParams], vocab_size: int, check_safety=False, -) -> Optional[Tuple[np.ndarray, Optional[Tuple[Tuple, Tuple]]]]: +) -> Optional[Tuple[np.ndarray, List[Optional[RawLogprobsInfo]]]]: def _is_safe_to_sample(prob_like): return ( torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) @@ -82,21 +124,18 @@ def _is_safe_to_sample(prob_like): logits_greedy = logits[mask_greedy] if logits_greedy.shape[0] > 0: - # Greedy sampling - logprobs = torch.log_softmax(logits_greedy, dim=-1) - res_greedy_logprob, res_greedy = torch.max(logprobs, dim=-1) - - top_greedy_logprob, top_greedy = torch.topk( - logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True + res_greedy = torch.argmax(logits_greedy, -1) + + logprob_infos_greedy = fetch_raw_logprob_infos( + logits_greedy, + res_greedy, + sampling_params[mask_greedy], ) - # Convert to numpy - res_greedy_logprob = res_greedy_logprob.cpu().numpy() + res_greedy = res_greedy.cpu().numpy() - top_greedy_logprob = top_greedy_logprob.cpu().numpy() - top_greedy = top_greedy.cpu().numpy() # Case when there's only greedy sampling if logits_greedy.shape[0] == num_seq: - return res_greedy, ((res_greedy, res_greedy_logprob), (top_greedy, top_greedy_logprob)) + return res_greedy, logprob_infos_greedy temperatures = [] top_ps = [] @@ -131,7 +170,6 @@ def _is_safe_to_sample(prob_like): if param.logit_bias: logits[i][param.logit_bias_index] += torch.Tensor(param.logit_bias_value).type_as(logits).to(device=logits.device) - logits_random = logits[mask_random] @@ -140,43 +178,40 @@ def _is_safe_to_sample(prob_like): logits_random.div_(t.unsqueeze(dim=1)) if do_top_p or do_top_k: + # TODO(vvchernov): looks like there is misprinting. Should logits_random be returned? + # If no, where are logits used below? logits = _apply_top_p_top_k(logits_random, top_ps, top_ks) probs = torch.softmax(logits_random, dim=-1) - logprobs = torch.log_softmax(logits_greedy, dim=-1) - top_random_logprob, top_random = torch.topk( - logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True - ) - top_random_logprob = top_random_logprob.cpu().numpy() - top_random = top_random.cpu().numpy() if check_safety and not _is_safe_to_sample(probs): return None - res_random = torch.multinomial(probs, 1, True).cpu().numpy()[:, 0] - res_random_logprobs = torch.gather(logprobs, dim=-1, index=torch.tensor(res_random, dtype=torch.int64, device=logits.device)).cpu().numpy() + res_random = torch.multinomial(probs, 1, True)[:, 0] + logprob_infos_random = fetch_raw_logprob_infos( + logits_random, + res_random, + sampling_params[mask_random], + ) + + res_random = res_random.cpu().numpy() + # Case when there's only random sampling if logits_random.shape[0] == num_seq: - return res_random, ((res_random, res_random_logprobs), (top_random, top_random_logprob)) + return res_random, logprob_infos_random res = np.empty((num_seq,), dtype=np.int32) - res_logprobs = np.empty((num_seq,), dtype=np.float32) - top = np.empty((num_seq, LOGPROB_TOP_K_MAX), dtype=np.int32) - top_logprobs = np.empty((num_seq, LOGPROB_TOP_K_MAX), dtype=np.float32) - res[mask_random] = res_random - res_logprobs[mask_random] = res_random_logprobs - top[mask_random] = top_random - top_logprobs[mask_random] = top_random_logprob + logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq + logprob_infos[mask_random] = logprob_infos_random if logits_greedy.shape[0] > 0: res[mask_greedy] = res_greedy - res_logprobs[mask_greedy] = res_greedy_logprob - top[mask_greedy] = top_greedy - top_logprobs[mask_greedy] = top_greedy_logprob - return res, ((res, res_logprobs), (top, top_logprobs)) + logprob_infos[mask_greedy] = logprob_infos_greedy + + return res, logprob_infos def load_disco_module(artifact_path, lib_path, num_shards): @@ -228,26 +263,6 @@ def get_tvm_model(config, dev): return load_disco_module(config.model_artifact_path, lib_path, config.num_shards) -def fetch_logprobs( - logprob_info: Optional[Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]], - index: int, - sampling_param: SamplingParams, - ) -> Optional[LOGPROBS_TYPE]: # np.ndarray inside - """Fetch the logprob information with index""" - if ( - sampling_param.logprobs is None or - not sampling_param.logprobs or - logprob_info is None - ): - return None - (res, res_logprobs), (top, top_logprobs) = logprob_info - return (res[index],res_logprobs[index]), \ - list(zip( - top[index][:sampling_param.top_logprobs], - top_logprobs[index][:sampling_param.top_logprobs] - )) - - def _prepare_inputs( sequence_ids, all_token_ids, @@ -528,7 +543,7 @@ def generate( cache.pending_copy_from_to = [] try: - next_tokens, logprob_info = sample(logits, sampling_params, self.vocab_size) + next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size) assert next_tokens is not None outputs = [] for i, (sequence_id, new_token) in enumerate( @@ -544,7 +559,7 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, - logprob_info=fetch_logprobs(logprob_info, seq_id, sampling_params[seq_id]), + logprob_info=logprob_infos[i], ) ) else: @@ -553,7 +568,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], error=None, - logprob_info=fetch_logprobs(logprob_info, i, sampling_params[i]), + logprob_info=logprob_infos[i], ) ) @@ -569,7 +584,7 @@ def generate( for i, (sequence_id, logits_per_token, sampling_param) in enumerate( zip(sequence_ids, torch.from_dlpack(logits), sampling_params) ): - maybe_new_token, logprob_info = sample( + maybe_new_token, logprob_infos = sample( torch.unsqueeze(logits_per_token, 0), [sampling_param], self.vocab_size, @@ -590,7 +605,7 @@ def generate( ), generated_tokens=[new_token], # type: ignore error=None, - logprob_info=fetch_logprobs(logprob_info, i, sampling_param) + logprob_info=logprob_infos[i] ) ) else: @@ -599,7 +614,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], # type: ignore error=None, - logprob_info=fetch_logprobs(logprob_info, i, sampling_param) + logprob_info=logprob_infos[i] ) ) else: @@ -612,7 +627,7 @@ def generate( ), generated_tokens=[], error=err_msg, - logprob_info=fetch_logprobs(logprob_info, i, sampling_param) + logprob_info=logprob_infos[i] ) ) else: @@ -621,7 +636,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[], error=err_msg, - logprob_info=fetch_logprobs(logprob_info, i, sampling_param) + logprob_info=logprob_infos[i] ) ) From 8922a427d4d08877bf0c36e090d997204e8d58eb Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 8 Jan 2024 13:03:50 +0400 Subject: [PATCH 3/8] convert RawLogprobsInfo to LogprobsContent --- serve/mlc_serve/api/protocol.py | 6 +-- serve/mlc_serve/engine/base.py | 4 +- serve/mlc_serve/engine/engine_common.py | 67 ++++++++++++++++--------- 3 files changed, 48 insertions(+), 29 deletions(-) diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index 94e917e611..341eb0c939 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -78,8 +78,7 @@ class ChatCompletionRequest(BaseModel): class TopLogprobs(BaseModel): """An OpenAI API compatible schema for logprobs output.""" - # token is string in OpenAI, but for unification int type is added - token: Union[str, int] + token: str logprob: float bytes: Optional[List] = None @@ -87,8 +86,7 @@ class TopLogprobs(BaseModel): class LogprobsContent(BaseModel): """An OpenAI API compatible schema for logprobs output.""" - # token is string in OpenAI, but for unification int type is added - token: Union[str, int] + token: str logprob: float bytes: Optional[List] = None top_logprobs: List[TopLogprobs] # It can be empty diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index ad572ba368..147f3fd633 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -17,8 +17,8 @@ @dataclass class RawLogprobsInfo: - current_token: np.array - current_logprob: np.array + current_token: int + current_logprob: float top_tokens: Optional[np.array] top_logprobs: Optional[np.array] diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index bf6a242b0c..ccb88a0799 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -10,13 +10,13 @@ import structlog from .base import ( + GenerationSequence, + RawLogprobsInfo, Request, RequestId, RequestState, - GenerationSequence, SequenceId, StoppingCriteria, - LOGPROBS_TYPE, ) from .model_module import ( DecodeRequest, @@ -28,6 +28,7 @@ Tokenizer as TokenizerP, ) from ..model.base import ModelArtifactConfig +from ..api.protocol import LogprobsContent, TopLogprobs LOG = structlog.stdlib.get_logger(__name__) @@ -133,29 +134,49 @@ def detokenize_incrementally( return delta -def logprob_detokenize(tokenizer: TokenizerP, logprob_info: Optional[LOGPROBS_TYPE]) -> Optional[LOGPROBS_TYPE]: - """Detokenize top tokens in logprob information""" +def logprob_detokenize( + tokenizer: TokenizerP, + logprob_info: Optional[RawLogprobsInfo], +) -> Optional[LogprobsContent]: + """Detokenize tokens from RawLogprobInfo and convert the latter to LogprobContent""" if logprob_info is None: return None - (res, res_logprob), top_tokens = logprob_info - top_tokens = list(top_tokens) - count: Dict[str, int] = {} - top_logprobs: List[Tuple] = [] - # dedup duplicates - # Todo: Make sure decode can generate different tokens - for top_token, _ in top_tokens: - detokenized = tokenizer.decode(top_token) - if detokenized in count: - count[detokenized] += 1 - else: - count[detokenized] = 1 - for top_token, top_logprob in top_tokens: - detokenized = tokenizer.decode(top_token) - if count[detokenized] == 1: - top_logprobs.append((detokenized, float(top_logprob))) - else: - top_logprobs.append((f"{detokenized}_{top_token}", float(top_logprob))) - return (str(tokenizer.decode(res)), res_logprob), top_logprobs + + top_logprobs: List[TopLogprobs] = [] + if ( + logprob_info.top_tokens is not None and + logprob_info.top_logprobs is not None + ): + top_tokens = zip(logprob_info.top_tokens[:], logprob_info.top_logprobs[:]) + count: Dict[str, int] = {} + # dedup duplicates + # Todo: Make sure decode can generate different tokens + for top_token, _ in top_tokens: + detokenized = tokenizer.decode(top_token) + if detokenized in count: + count[detokenized] += 1 + else: + count[detokenized] = 1 + for top_token, top_logprob in top_tokens: + detokenized = tokenizer.decode(top_token) + if count[detokenized] != 1: + detokenized = f"{detokenized}_{top_token}" + top_logprobs.append(TopLogprobs( + token=detokenized, + logprob=float(top_logprob), + # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object + bytes=None, + )) + + logprobs_content = LogprobsContent( + token=tokenizer.decode(logprob_info.current_token), + logprob=logprob_info, + # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object + bytes=None, + top_logprobs=top_logprobs, + ) + + return logprobs_content def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended): From 7965267a2ab16f8240e6924f06458760c4b8384f Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 8 Jan 2024 13:12:22 +0400 Subject: [PATCH 4/8] remove LOGPROBS_TYPE --- serve/mlc_serve/engine/__init__.py | 1 - serve/mlc_serve/engine/base.py | 5 ++--- serve/mlc_serve/engine/engine_common.py | 2 +- serve/mlc_serve/engine/staging_engine_worker.py | 7 +++---- 4 files changed, 6 insertions(+), 9 deletions(-) diff --git a/serve/mlc_serve/engine/__init__.py b/serve/mlc_serve/engine/__init__.py index 83e58ddccc..92d101ea95 100644 --- a/serve/mlc_serve/engine/__init__.py +++ b/serve/mlc_serve/engine/__init__.py @@ -16,7 +16,6 @@ RequestState, PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, - LOGPROBS_TYPE, RawLogprobsInfo, ) from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 147f3fd633..5af0939b64 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -9,10 +9,9 @@ import numpy as np from .sampling_params import SamplingParams, SamplingType +from ..api.protocol import LogprobsContent RequestId = str -LOGPROBS_TYPE = Tuple[Tuple, List[Tuple]] -# ((token, logprob), [(top1_token, top1_logprob), ...]) @dataclass @@ -172,7 +171,7 @@ class SequenceOutput: finish_reason: Optional[FinishReason] = None # Number of generated tokens so far num_generated_tokens: int = 0 - logprob_info: Optional[LOGPROBS_TYPE] = None + logprob_info: Optional[LogprobsContent] = None @property def is_finished(self) -> bool: diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index ccb88a0799..9ab238d3bb 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -169,7 +169,7 @@ def logprob_detokenize( )) logprobs_content = LogprobsContent( - token=tokenizer.decode(logprob_info.current_token), + token=tokenizer.decode([logprob_info.current_token]), logprob=logprob_info, # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object bytes=None, diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index 0a0b3041ea..f7df7d10c9 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -6,14 +6,13 @@ import multiprocessing.synchronize from dataclasses import dataclass from threading import Thread, Lock -from typing import Callable, Optional, Union, Tuple, Any, Dict, List +from typing import Callable, Optional, Union, Any, Dict, List import structlog -import numpy as np from .base import ( FinishReason, - LOGPROBS_TYPE, + RawLogprobsInfo, RequestId, RequestState, ValidationError, @@ -67,7 +66,7 @@ class SequenceGenerationOutput: new_tokens: List[int] finish_reason: Optional[FinishReason] = None error: Optional[Union[str, ValidationError]] = None - logprob_info: Optional[LOGPROBS_TYPE] = None + logprob_info: Optional[RawLogprobsInfo] = None @dataclass From b8748fc4421236443e8bab14c05521f95bc6ab31 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 8 Jan 2024 14:43:37 +0400 Subject: [PATCH 5/8] fix result collection --- serve/mlc_serve/api/handler.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index 0ffd0483f3..d644c27db7 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -242,29 +242,18 @@ async def collect_result_stream( choices = [] for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)): - logprobs_content = [] + logprobs = None if logprob_info_seq != []: + logprobs_content = [] for logprob_info in logprob_info_seq: - cur_token_logprob_info = logprob_info[0] - top_logprobs_info = logprob_info[1] - top_logprobs = [TopLogprobs( - token=str(token), - logprob=float(logprob), - # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object - bytes=None, - ) for token, logprob in top_logprobs_info] - logprobs_content.append(LogprobsContent( - token=str(cur_token_logprob_info[0]), - logprob=float(cur_token_logprob_info[1]), - # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object - bytes=None, - top_logprobs=top_logprobs, - )) + logprobs_content.append(logprob_info) + logprobs = Logprobs(content=logprobs_content) + choice = ChatCompletionResponseChoice( index=index, message=ChatMessage(role="assistant", content="".join(chunks)), finish_reason=finish_reason, - logprobs=Logprobs(content=logprobs_content), + logprobs=logprobs, ) choices.append(choice) From 72b949acd552b81f043c938881c02fb03b334820 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 8 Jan 2024 15:03:17 +0400 Subject: [PATCH 6/8] clean code --- serve/mlc_serve/api/handler.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index d644c27db7..18f253bb8b 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -196,7 +196,7 @@ def create_stream_response( finish_reason=seq.finish_reason.value if seq.finish_reason is not None else None, - logprob_info=seq.logprob_info[0] if seq.logprob_info else None + logprob_info=Logprobs(content=[seq.logprob_info]) if seq.logprob_info else None ) for seq in res.sequences ] @@ -244,10 +244,7 @@ async def collect_result_stream( for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)): logprobs = None if logprob_info_seq != []: - logprobs_content = [] - for logprob_info in logprob_info_seq: - logprobs_content.append(logprob_info) - logprobs = Logprobs(content=logprobs_content) + logprobs = Logprobs(content=logprob_info_seq) choice = ChatCompletionResponseChoice( index=index, From ae81ccb380cda71d4787b6d6509e62ea8f42e3be Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 8 Jan 2024 15:06:16 +0400 Subject: [PATCH 7/8] more clean --- serve/mlc_serve/api/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index 341eb0c939..605ca98849 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -98,7 +98,7 @@ class Logprobs(BaseModel): See details in https://platform.openai.com/docs/api-reference/chat/object#chat-create-logprobs """ - content: Optional[List[LogprobsContent]] + content: List[LogprobsContent] class ChatCompletionResponseChoice(BaseModel): From b2850baa220044c72cc2e3f7231774bc2446ce6d Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 9 Jan 2024 13:09:17 +0400 Subject: [PATCH 8/8] fix condition --- serve/mlc_serve/model/paged_cache_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 414b4c254e..55e073387b 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -43,8 +43,8 @@ def fetch_raw_logprob_infos( num_seq = logits.shape[0] for index in range(num_seq): if ( - sampling_params[index].logprobs is None or - not sampling_params[index].logprobs + sampling_params[index].logprobs is not None and + sampling_params[index].logprobs ): # Logprob sampling logprobs = torch.log_softmax(logits[index], dim=-1)