From 4cf50cede888efe97e88edb04ae56357312ff2e9 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 5 Dec 2023 14:40:50 +0400 Subject: [PATCH 1/4] fix log_softmax --- serve/mlc_serve/model/paged_cache_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 192c75bd62..5d06993ddc 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -269,7 +269,7 @@ def _is_safe_to_sample(prob_like): if logits_greedy.shape[0] > 0: # Greedy sampling - logprobs = torch.log(torch.softmax(logits_greedy, dim=-1)) + logprobs = torch.log_softmax(logits_greedy, dim=-1) res_greedy_logprob, res_greedy = torch.max(logprobs, dim=-1) top_greedy_logprob, top_greedy = torch.topk(logprobs, k=5, dim=-1, largest=True, sorted=True) @@ -311,7 +311,7 @@ def _is_safe_to_sample(prob_like): logits = _apply_top_p_top_k(logits_random, top_ps, top_ks) probs = torch.softmax(logits_random, dim=-1) - logprobs = torch.log(torch.softmax(logits_greedy, dim=-1)) + logprobs = torch.log_softmax(logits_greedy, dim=-1) top_random_logprob, top_random = torch.topk(logprobs, k=5, dim=-1, largest=True, sorted=True) top_random_logprob = top_random_logprob.cpu().numpy() top_random = top_random.cpu().numpy() From 7680c05eb4ee729f33469164a8fbb468cce83766 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 5 Dec 2023 14:51:51 +0400 Subject: [PATCH 2/4] use constant for number of top logprobs --- serve/mlc_serve/engine/__init__.py | 2 +- serve/mlc_serve/engine/sampling_params.py | 5 +++-- serve/mlc_serve/model/paged_cache_model.py | 20 +++++++++++++++----- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/serve/mlc_serve/engine/__init__.py b/serve/mlc_serve/engine/__init__.py index 1b873a4692..08666ae4c8 100644 --- a/serve/mlc_serve/engine/__init__.py +++ b/serve/mlc_serve/engine/__init__.py @@ -13,4 +13,4 @@ MLCServeEngineConfig, get_engine_config ) -from .sampling_params import SamplingParams, SamplingType +from .sampling_params import SamplingParams, SamplingType, TOP_LOGPROBS_NUMBER diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index d1f85b4d08..9639158165 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -10,6 +10,7 @@ from typing import Optional _SAMPLING_EPS = 1e-5 +TOP_LOGPROBS_NUMBER = 5 class SamplingType(IntEnum): @@ -76,9 +77,9 @@ def _verify_args(self) -> None: raise ValueError( f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." ) - if self.logprobs is not None and (self.logprobs < 0 or self.logprobs > 5): + if self.logprobs is not None and (self.logprobs < 0 or self.logprobs > TOP_LOGPROBS_NUMBER): raise ValueError( - f"logprobs must be between 0 and 5, got {self.logprobs}." + f"logprobs must be between 0 and {TOP_LOGPROBS_NUMBER}, got {self.logprobs}." ) def _verify_greedy_sampling(self) -> None: diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 5d06993ddc..dc99478008 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -16,7 +16,13 @@ from .base import get_model_artifact_config from .tokenizer import HfTokenizerModule, ConversationTemplate -from ..engine import RequestId, SamplingType, MLCServeEngineConfig, SamplingParams +from ..engine import ( + RequestId, + SamplingType, + MLCServeEngineConfig, + SamplingParams, + TOP_LOGPROBS_NUMBER +) from ..engine.model_module import ( DecodeRequest, PrefillRequest, @@ -272,7 +278,9 @@ def _is_safe_to_sample(prob_like): logprobs = torch.log_softmax(logits_greedy, dim=-1) res_greedy_logprob, res_greedy = torch.max(logprobs, dim=-1) - top_greedy_logprob, top_greedy = torch.topk(logprobs, k=5, dim=-1, largest=True, sorted=True) + top_greedy_logprob, top_greedy = torch.topk( + logprobs, k=TOP_LOGPROBS_NUMBER, dim=-1, largest=True, sorted=True + ) # Convert to numpy res_greedy_logprob = res_greedy_logprob.cpu().numpy() res_greedy = res_greedy.cpu().numpy() @@ -312,7 +320,9 @@ def _is_safe_to_sample(prob_like): probs = torch.softmax(logits_random, dim=-1) logprobs = torch.log_softmax(logits_greedy, dim=-1) - top_random_logprob, top_random = torch.topk(logprobs, k=5, dim=-1, largest=True, sorted=True) + top_random_logprob, top_random = torch.topk( + logprobs, k=TOP_LOGPROBS_NUMBER, dim=-1, largest=True, sorted=True + ) top_random_logprob = top_random_logprob.cpu().numpy() top_random = top_random.cpu().numpy() @@ -327,8 +337,8 @@ def _is_safe_to_sample(prob_like): res = np.empty((num_seq,), dtype=np.int32) res_logprobs = np.empty((num_seq,), dtype=np.float32) - top = np.empty((num_seq, 5), dtype=np.int32) - top_logprobs = np.empty((num_seq, 5), dtype=np.float32) + top = np.empty((num_seq, TOP_LOGPROBS_NUMBER), dtype=np.int32) + top_logprobs = np.empty((num_seq, TOP_LOGPROBS_NUMBER), dtype=np.float32) res[mask_random] = res_random res_logprobs[mask_random] = res_random_logprobs From 72f30449a9c7c3014384e5070f256b3904ec9797 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 7 Dec 2023 11:49:35 +0400 Subject: [PATCH 3/4] small clean --- serve/mlc_serve/model/paged_cache_model.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index dc99478008..8547561ead 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -632,13 +632,6 @@ def generate( out = self.mod["prefill"]( input_ids, positions, seq_lens, kv_cache, slot_mapping, self.params ) - - if self.disco_session: - logits, _ = out.debug_get_from_remote(0) - else: - logits = out[ - 0 - ] # Ignore returned KV cache since it is updated in-place anyway. else: torch.cuda.nvtx.range_push(f"forward decode {input_shape}") @@ -655,10 +648,12 @@ def generate( self.params, ) - if self.disco_session: - logits, _ = out.debug_get_from_remote(0) - else: - logits = out[0] + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[ + 0 + ] # Ignore returned KV cache since it is updated in-place anyway. torch.cuda.synchronize() torch.cuda.nvtx.range_pop() From ca91b7abc33b529b4e796d0a8bb05d6f92018ab9 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 18 Dec 2023 22:41:58 +0400 Subject: [PATCH 4/4] upstream to new OpenAI API --- serve/mlc_serve/api/handler.py | 21 +++++++++++++----- serve/mlc_serve/api/protocol.py | 9 ++++++-- serve/mlc_serve/engine/sampling_params.py | 22 ++++++++++++------- serve/mlc_serve/model/paged_cache_model.py | 10 ++++++--- .../unittest/test_engine_with_samplers.py | 11 +++++----- 5 files changed, 49 insertions(+), 24 deletions(-) diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index 63deae4a94..05ba738e61 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -59,6 +59,7 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: if request.top_p is not None: sampling_params.top_p = request.top_p if request.logprobs is not None: + sampling_params.top_logprobs = request.top_logprobs sampling_params.logprobs = request.logprobs return sampling_params @@ -211,13 +212,21 @@ async def collect_result_stream( message=ChatMessage(role="assistant", content="".join(chunks)), finish_reason=finish_reason, ) + content = [] if logprob_infos[index] != []: - choice.logprobs={ - "token_logprobs": [float(logprob_info[0][1]) for logprob_info in logprob_infos[index]], - "tokens": [str(logprob_info[0][0]) for logprob_info in logprob_infos[index]], - "offset": list(accumulate([len(str(logprob_info[0][0])) for logprob_info in logprob_infos[index]])), - "top_logprobs": [logprob_info[1] for logprob_info in logprob_infos[index]] - } + for logprob_info in logprob_infos[index]: + content.append({ + "token": str(logprob_info[0][0]), + "logprob": float(logprob_info[0][1]), + # TODO(vvchernov): implement bytes bases on https://platform.openai.com/docs/api-reference/chat/object + "bytes": None, + "top_logprobs": [{ + "token": top_logprob[0], + "logprob": top_logprob[1], + "bytes": None, + } for top_logprob in logprob_info[1]], + }) + choice.logprobs.content = content choices.append(choice) usage = UsageInfo( diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index a56ff4dc16..518e3291b6 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -70,13 +70,18 @@ class ChatCompletionRequest(BaseModel): logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None ignore_eos: Optional[bool] = False - logprobs: Optional[int] = None + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = None + + +class Logprobs(BaseModel): + content: Optional[List[Dict]] class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage - logprobs: Optional[Dict[str, Union[List, Dict]]] + logprobs: Optional[Logprobs] finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index 9639158165..b721881197 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -39,9 +39,13 @@ class SamplingParams: to consider. Must be in (0, 1]. Set to 1 to consider all tokens. top_k: Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens. - logprobs: Optional[Integer] that determines number of log probabilities - to return per sampled tokens, default to None meaning disabled, - otherwise minimum 0, maximum 5. + logprobs: Optional[bool] Whether to return log probabilities of the output + tokens or not. If true, returns the log probabilities of each output + token returned in the content of message. + top_logprobs: Optional[Integer] An integer between 0 and 5 specifying + the number of most likely tokens to return at each token position, + each with an associated log probability. logprobs must be set to + true if this parameter is used. """ presence_penalty: float = 0.0 @@ -49,7 +53,8 @@ class SamplingParams: temperature: float = 1.0 top_p: float = 1.0 top_k: int = -1 - logprobs: Optional[int] = None + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = None def __post_init__(self): self._verify_args() @@ -77,10 +82,11 @@ def _verify_args(self) -> None: raise ValueError( f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." ) - if self.logprobs is not None and (self.logprobs < 0 or self.logprobs > TOP_LOGPROBS_NUMBER): - raise ValueError( - f"logprobs must be between 0 and {TOP_LOGPROBS_NUMBER}, got {self.logprobs}." - ) + if self.logprobs is not None and self.logprobs: + if (self.top_logprobs < 0 or self.top_logprobs > TOP_LOGPROBS_NUMBER): + raise ValueError( + f"top_logprobs must be between 0 and {TOP_LOGPROBS_NUMBER}, got {self.top_logprobs}." + ) def _verify_greedy_sampling(self) -> None: if self.top_p < 1.0 - _SAMPLING_EPS: diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 8547561ead..659a4ffcfc 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -397,13 +397,17 @@ def fetch_logprobs( sampling_param: SamplingParams, ) -> Optional[Tuple[np.ndarray, List[Tuple[np.ndarray, np.ndarray]]]]: """Fetch the logprob information with index""" - if sampling_param.logprobs is None or logprob_info is None: + if ( + sampling_param.logprobs is None or + not sampling_param.logprobs or + logprob_info is None + ): return None (res, res_logprobs), (top, top_logprobs) = logprob_info return (res[index],res_logprobs[index]), \ zip( - top[index][:sampling_param.logprobs], - top_logprobs[index][:sampling_param.logprobs] + top[index][:sampling_param.top_logprobs], + top_logprobs[index][:sampling_param.top_logprobs] ) diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index f230301bc0..53edebcb32 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -52,13 +52,14 @@ def create_engine( )) return engine -def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos, logprobs): +def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos, top_logprobs): return Request( request_id = str(idx), messages = [ChatMessage(role="user", content=prompt)], sampling_params = SamplingParams( temperature=0.0, - logprobs=logprobs + logprobs=True, + top_logprobs=top_logprobs ), stopping_criteria = StoppingCriteria( max_tokens=max_tokens, @@ -226,7 +227,7 @@ def test_logprobs( max_num_sequences=4, max_input_len=512, num_requests=5, - logprobs=3, + top_logprobs=3, ): prompt = "hi" engine = create_engine( @@ -236,7 +237,7 @@ def test_logprobs( max_input_len, ) s = 113 - requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True, logprobs=logprobs) for n in range(s, s+num_requests)] + requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True, top_logprobs=top_logprobs) for n in range(s, s+num_requests)] engine.add(requests) generated = ["" for _ in range(num_requests)] @@ -247,7 +248,7 @@ def test_logprobs( assert len(res.sequences) == 1 seq = res.sequences[0] - assert seq.finish_reason is not None or len(list(seq.logprob_info[1])) == logprobs + assert seq.finish_reason is not None or len(list(seq.logprobs.content[0]["top_logprobs"])) == top_logprobs if seq.is_finished: assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens