From 294815ba1431dde2cf02531b7e599cc007095dfd Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sat, 6 Jan 2024 14:03:13 +0400 Subject: [PATCH] hide logprobs calculation by condition, add dataclass for raw logprobs info, fix logprobs_random --- serve/mlc_serve/api/protocol.py | 8 +- serve/mlc_serve/engine/__init__.py | 1 + serve/mlc_serve/engine/base.py | 9 ++ serve/mlc_serve/engine/model_module.py | 4 +- serve/mlc_serve/model/paged_cache_model.py | 141 ++++++++++++--------- 5 files changed, 95 insertions(+), 68 deletions(-) diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index b179e3a164..94e917e611 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -78,7 +78,8 @@ class ChatCompletionRequest(BaseModel): class TopLogprobs(BaseModel): """An OpenAI API compatible schema for logprobs output.""" - token: str + # token is string in OpenAI, but for unification int type is added + token: Union[str, int] logprob: float bytes: Optional[List] = None @@ -86,10 +87,11 @@ class TopLogprobs(BaseModel): class LogprobsContent(BaseModel): """An OpenAI API compatible schema for logprobs output.""" - token: str + # token is string in OpenAI, but for unification int type is added + token: Union[str, int] logprob: float bytes: Optional[List] = None - top_logprobs: List[TopLogprobs] + top_logprobs: List[TopLogprobs] # It can be empty class Logprobs(BaseModel): diff --git a/serve/mlc_serve/engine/__init__.py b/serve/mlc_serve/engine/__init__.py index 61700bbb00..83e58ddccc 100644 --- a/serve/mlc_serve/engine/__init__.py +++ b/serve/mlc_serve/engine/__init__.py @@ -17,5 +17,6 @@ PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, LOGPROBS_TYPE, + RawLogprobsInfo, ) from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 3a324ec1d0..ad572ba368 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -6,6 +6,7 @@ from typing import List, Callable, Any, Optional, Dict, Tuple import inspect +import numpy as np from .sampling_params import SamplingParams, SamplingType @@ -14,6 +15,14 @@ # ((token, logprob), [(top1_token, top1_logprob), ...]) +@dataclass +class RawLogprobsInfo: + current_token: np.array + current_logprob: np.array + top_tokens: Optional[np.array] + top_logprobs: Optional[np.array] + + # TODO(@sunggg): consider transition to something like Pydantic @dataclass class MLCServeEngineConfig: diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 439f2b20ee..d23086c592 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -6,8 +6,8 @@ from .base import ( ChatMessage, - LOGPROBS_TYPE, MLCServeEngineConfig, + RawLogprobsInfo, RequestId, RequestState, SequenceId, @@ -51,7 +51,7 @@ class TextGenerationResult: # making this a list of token ids to leave room for speculative decoding generated_tokens: List[int] error: Optional[str] - logprob_info: Optional[LOGPROBS_TYPE] = None + logprob_info: Optional[RawLogprobsInfo] = None class KVCache(Protocol): diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 8feaab0c98..414b4c254e 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -3,7 +3,6 @@ from pathlib import Path from collections import defaultdict from typing import List, Union, Optional, Tuple -from dataclasses import dataclass import structlog import numpy as np @@ -20,9 +19,9 @@ MLCServeEngineConfig, SamplingParams, LOGPROB_TOP_K_MAX, - LOGPROBS_TYPE, SequenceId, PROMPT_SEQEUNCE_INDEX, + RawLogprobsInfo, get_prompt_sequence_id, ) from ..engine.model_module import ( @@ -35,6 +34,49 @@ LOG = structlog.stdlib.get_logger(__name__) +def fetch_raw_logprob_infos( + logits, + res_tokens, + sampling_params, +) -> List[Optional[RawLogprobsInfo]]: + logprob_infos: List[Optional[RawLogprobsInfo]] = [] + num_seq = logits.shape[0] + for index in range(num_seq): + if ( + sampling_params[index].logprobs is None or + not sampling_params[index].logprobs + ): + # Logprob sampling + logprobs = torch.log_softmax(logits[index], dim=-1) + res_logprob = logprobs[res_tokens[index]].cpu().numpy() + + top_logprobs_num = sampling_params[index].top_logprobs + if ( + top_logprobs_num is None or + top_logprobs_num == 0 + ): + top_logprobs = None + top_tokens = None + else: + assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs" + top_logprobs, top_tokens = torch.topk( + logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True + ) + top_tokens=top_tokens.cpu().numpy(), + top_logprobs=top_logprobs.cpu().numpy() + + # Set to raw logprob info + logprob_infos.append(RawLogprobsInfo( + current_token=res_tokens[index].cpu().numpy(), + current_logprob=res_logprob, + top_tokens=top_tokens, + top_logprobs=top_logprobs, + )) + else: + logprob_infos.append(None) + return logprob_infos + + def _apply_top_p_top_k(logits, top_ps, top_ks): p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) @@ -63,7 +105,7 @@ def sample( sampling_params: List[SamplingParams], vocab_size: int, check_safety=False, -) -> Optional[Tuple[np.ndarray, Optional[Tuple[Tuple, Tuple]]]]: +) -> Optional[Tuple[np.ndarray, List[Optional[RawLogprobsInfo]]]]: def _is_safe_to_sample(prob_like): return ( torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) @@ -82,21 +124,18 @@ def _is_safe_to_sample(prob_like): logits_greedy = logits[mask_greedy] if logits_greedy.shape[0] > 0: - # Greedy sampling - logprobs = torch.log_softmax(logits_greedy, dim=-1) - res_greedy_logprob, res_greedy = torch.max(logprobs, dim=-1) - - top_greedy_logprob, top_greedy = torch.topk( - logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True + res_greedy = torch.argmax(logits_greedy, -1) + + logprob_infos_greedy = fetch_raw_logprob_infos( + logits_greedy, + res_greedy, + sampling_params[mask_greedy], ) - # Convert to numpy - res_greedy_logprob = res_greedy_logprob.cpu().numpy() + res_greedy = res_greedy.cpu().numpy() - top_greedy_logprob = top_greedy_logprob.cpu().numpy() - top_greedy = top_greedy.cpu().numpy() # Case when there's only greedy sampling if logits_greedy.shape[0] == num_seq: - return res_greedy, ((res_greedy, res_greedy_logprob), (top_greedy, top_greedy_logprob)) + return res_greedy, logprob_infos_greedy temperatures = [] top_ps = [] @@ -131,7 +170,6 @@ def _is_safe_to_sample(prob_like): if param.logit_bias: logits[i][param.logit_bias_index] += torch.Tensor(param.logit_bias_value).type_as(logits).to(device=logits.device) - logits_random = logits[mask_random] @@ -140,43 +178,40 @@ def _is_safe_to_sample(prob_like): logits_random.div_(t.unsqueeze(dim=1)) if do_top_p or do_top_k: + # TODO(vvchernov): looks like there is misprinting. Should logits_random be returned? + # If no, where are logits used below? logits = _apply_top_p_top_k(logits_random, top_ps, top_ks) probs = torch.softmax(logits_random, dim=-1) - logprobs = torch.log_softmax(logits_greedy, dim=-1) - top_random_logprob, top_random = torch.topk( - logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True - ) - top_random_logprob = top_random_logprob.cpu().numpy() - top_random = top_random.cpu().numpy() if check_safety and not _is_safe_to_sample(probs): return None - res_random = torch.multinomial(probs, 1, True).cpu().numpy()[:, 0] - res_random_logprobs = torch.gather(logprobs, dim=-1, index=torch.tensor(res_random, dtype=torch.int64, device=logits.device)).cpu().numpy() + res_random = torch.multinomial(probs, 1, True)[:, 0] + logprob_infos_random = fetch_raw_logprob_infos( + logits_random, + res_random, + sampling_params[mask_random], + ) + + res_random = res_random.cpu().numpy() + # Case when there's only random sampling if logits_random.shape[0] == num_seq: - return res_random, ((res_random, res_random_logprobs), (top_random, top_random_logprob)) + return res_random, logprob_infos_random res = np.empty((num_seq,), dtype=np.int32) - res_logprobs = np.empty((num_seq,), dtype=np.float32) - top = np.empty((num_seq, LOGPROB_TOP_K_MAX), dtype=np.int32) - top_logprobs = np.empty((num_seq, LOGPROB_TOP_K_MAX), dtype=np.float32) - res[mask_random] = res_random - res_logprobs[mask_random] = res_random_logprobs - top[mask_random] = top_random - top_logprobs[mask_random] = top_random_logprob + logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq + logprob_infos[mask_random] = logprob_infos_random if logits_greedy.shape[0] > 0: res[mask_greedy] = res_greedy - res_logprobs[mask_greedy] = res_greedy_logprob - top[mask_greedy] = top_greedy - top_logprobs[mask_greedy] = top_greedy_logprob - return res, ((res, res_logprobs), (top, top_logprobs)) + logprob_infos[mask_greedy] = logprob_infos_greedy + + return res, logprob_infos def load_disco_module(artifact_path, lib_path, num_shards): @@ -228,26 +263,6 @@ def get_tvm_model(config, dev): return load_disco_module(config.model_artifact_path, lib_path, config.num_shards) -def fetch_logprobs( - logprob_info: Optional[Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]], - index: int, - sampling_param: SamplingParams, - ) -> Optional[LOGPROBS_TYPE]: # np.ndarray inside - """Fetch the logprob information with index""" - if ( - sampling_param.logprobs is None or - not sampling_param.logprobs or - logprob_info is None - ): - return None - (res, res_logprobs), (top, top_logprobs) = logprob_info - return (res[index],res_logprobs[index]), \ - list(zip( - top[index][:sampling_param.top_logprobs], - top_logprobs[index][:sampling_param.top_logprobs] - )) - - def _prepare_inputs( sequence_ids, all_token_ids, @@ -528,7 +543,7 @@ def generate( cache.pending_copy_from_to = [] try: - next_tokens, logprob_info = sample(logits, sampling_params, self.vocab_size) + next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size) assert next_tokens is not None outputs = [] for i, (sequence_id, new_token) in enumerate( @@ -544,7 +559,7 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, - logprob_info=fetch_logprobs(logprob_info, seq_id, sampling_params[seq_id]), + logprob_info=logprob_infos[i], ) ) else: @@ -553,7 +568,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], error=None, - logprob_info=fetch_logprobs(logprob_info, i, sampling_params[i]), + logprob_info=logprob_infos[i], ) ) @@ -569,7 +584,7 @@ def generate( for i, (sequence_id, logits_per_token, sampling_param) in enumerate( zip(sequence_ids, torch.from_dlpack(logits), sampling_params) ): - maybe_new_token, logprob_info = sample( + maybe_new_token, logprob_infos = sample( torch.unsqueeze(logits_per_token, 0), [sampling_param], self.vocab_size, @@ -590,7 +605,7 @@ def generate( ), generated_tokens=[new_token], # type: ignore error=None, - logprob_info=fetch_logprobs(logprob_info, i, sampling_param) + logprob_info=logprob_infos[i] ) ) else: @@ -599,7 +614,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], # type: ignore error=None, - logprob_info=fetch_logprobs(logprob_info, i, sampling_param) + logprob_info=logprob_infos[i] ) ) else: @@ -612,7 +627,7 @@ def generate( ), generated_tokens=[], error=err_msg, - logprob_info=fetch_logprobs(logprob_info, i, sampling_param) + logprob_info=logprob_infos[i] ) ) else: @@ -621,7 +636,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[], error=err_msg, - logprob_info=fetch_logprobs(logprob_info, i, sampling_param) + logprob_info=logprob_infos[i] ) )