From ec7b61d308620fa2767795d7cbc2fedc015e4270 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 2 Feb 2024 12:28:32 +0400 Subject: [PATCH 1/3] refactor sampling from logits, other small fixes --- serve/mlc_serve/engine/model_module.py | 4 + serve/mlc_serve/model/dummy_model.py | 3 +- serve/mlc_serve/model/model_common.py | 140 ++++++++++++------------- serve/mlc_serve/model/tvm_model.py | 9 +- 4 files changed, 79 insertions(+), 77 deletions(-) diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 00893efa44..a5b86d69b9 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -66,6 +66,10 @@ class EvalMultiQueryRequest: sampling_params: SamplingParams +RequestType = Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest] +RequestsType = Sequence[RequestType] + + @dataclass class TextGenerationResult: """ diff --git a/serve/mlc_serve/model/dummy_model.py b/serve/mlc_serve/model/dummy_model.py index 4e41508c1e..630ed7c267 100644 --- a/serve/mlc_serve/model/dummy_model.py +++ b/serve/mlc_serve/model/dummy_model.py @@ -102,12 +102,13 @@ def generate( ) -> list[TextGenerationResult]: result = [] for req in requests: + # TODO(vvchernov): support other types of Request if isinstance(req, DecodeRequest): seq_id = req.sequence_id request_id = req.sequence_id.request_id if req.sequence_id.sequence_index > 0: raise RuntimeError("Multiple generated sequences not supported") - else: + else: # PrefillRequest seq_id = SequenceId(req.request_id, 0) request_id = req.request_id diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index fee952ad4d..9f712433d5 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union, Sequence +from typing import List, Optional, Tuple, Union import structlog import numpy as np @@ -11,16 +11,16 @@ SamplingParams, get_prompt_sequence_id, LOGPROB_TOP_K_MAX, - RawLogprobsInfo, - RawLogprobsInfos, PROMPT_SEQEUNCE_INDEX, + RawLogprobsInfo, RawLogprobsInfos, SequenceId, ) from ..engine.model_module import ( - DecodeRequest, PrefillRequest, EvalMultiQueryRequest, + RequestType, + RequestsType, TextGenerationResult, ) @@ -302,10 +302,50 @@ def _is_safe_to_sample(prob_like): return res, check_logprob_infos(logprob_infos) +def update_tokens_frequency( + request: RequestType, + new_token: int +): + if not new_token in request.sampling_params.appeared_tokens_freq: + request.sampling_params.appeared_tokens_freq[new_token] = 0 + request.sampling_params.appeared_tokens_freq[new_token] += 1 + + +def append_text_gen_res( + outputs: List[TextGenerationResult], + request: RequestType, + new_token: List[int], + sequence_id: SequenceId, + logprob_info: Optional[RawLogprobsInfos], + err_msg: Optional[str]=None, +) -> List[TextGenerationResult]: + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + assert isinstance(request, PrefillRequest) + for seq_id in range(request.num_sequence): # type: ignore + outputs.append( + TextGenerationResult( + sequence_id=SequenceId(sequence_id.request_id, seq_id), + generated_tokens=new_token, + error=err_msg, + logprob_info=logprob_info, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=new_token, + error=err_msg, + logprob_info=logprob_info, + ) + ) + return outputs + + def sample_from_logits( logits: Union[tvm.nd.NDArray, torch.Tensor], sequence_ids: List[SequenceId], - requests: Sequence[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], + requests: RequestsType, vocab_size, ) -> List[TextGenerationResult]: assert logits.shape[0] == len(requests) @@ -317,29 +357,14 @@ def sample_from_logits( assert next_tokens is not None outputs = [] for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): - if not new_token in sampling_params[i].appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequence): # type: ignore - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[new_token], - error=None, - logprob_info=get_logprob_infos(i, logprob_infos), - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], - error=None, - logprob_info=get_logprob_infos(i, logprob_infos), - ) - ) + update_tokens_frequency(requests[i], new_token) + outputs = append_text_gen_res( + outputs, + requests[i], + [new_token], + sequence_id, + get_logprob_infos(i, logprob_infos), + ) return outputs except RuntimeError: @@ -362,50 +387,23 @@ def sample_from_logits( if maybe_new_token is not None: new_token = maybe_new_token[0] - if not new_token in requests[i].sampling_params.appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequence): # type: ignore - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[new_token], # type: ignore - error=None, - logprob_info=get_logprob_infos(0, logprob_infos), - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], # type: ignore - error=None, - logprob_info=get_logprob_infos(0, logprob_infos), - ) - ) + update_tokens_frequency(requests[i], new_token) + outputs = append_text_gen_res( + outputs, + requests[i], + [new_token], + sequence_id, + get_logprob_infos(0, logprob_infos), + ) else: - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequence): # type: ignore - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[], - error=err_msg, - logprob_info=get_logprob_infos(0, logprob_infos), - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[], - error=err_msg, - logprob_info=get_logprob_infos(0, logprob_infos), - ) - ) + outputs = append_text_gen_res( + outputs, + requests[i], + [], # new_token + sequence_id, + get_logprob_infos(0, logprob_infos), + err_msg, + ) return outputs diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 202a04e30d..0c28ff7003 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -1,6 +1,6 @@ import math import os -from typing import List, Union, Tuple, Sequence +from typing import List, Tuple import structlog import numpy as np @@ -24,9 +24,10 @@ ) from ..engine.model_module import ( DecodeRequest, - PrefillRequest, DraftTokens, EvalMultiQueryRequest, + PrefillRequest, + RequestsType, TextGenerationResult, TextGenerator, ) @@ -276,9 +277,7 @@ def generate_multi_query( def generate( self, - requests: Sequence[ - Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest] - ], + requests: RequestsType, cache: KVCacheInfo, ) -> List[TextGenerationResult]: if len(requests) == 0: From e58b7d3421608f6a19d7ad29d785fea1e1022586 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 2 Feb 2024 12:39:06 +0400 Subject: [PATCH 2/3] use RequestType --- serve/mlc_serve/engine/engine_common.py | 7 +++---- serve/mlc_serve/model/dummy_model.py | 3 ++- serve/mlc_serve/model/paged_cache_model.py | 7 ++++--- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 8cb503de06..4a02ce2f60 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -27,6 +27,7 @@ ConversationTemplate, KVCacheManager, ModelModule, + RequestType, TextGenerator, Tokenizer as TokenizerP, ) @@ -228,10 +229,8 @@ def update_sequence( def get_requests_to_process( current_states: list[RequestState], cache_manager: KVCacheManager -) -> Tuple[ - list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], bool, int -]: - requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]] = [] +) -> Tuple[list[RequestType], bool, int]: + requests: list[RequestType] = [] # TODO: consider having hybrid batch if the underlying attention kernel supports # mixing prefill and decode. is_prompt_batch = any(not state.is_prefilled for state in current_states) diff --git a/serve/mlc_serve/model/dummy_model.py b/serve/mlc_serve/model/dummy_model.py index 630ed7c267..b8900273a6 100644 --- a/serve/mlc_serve/model/dummy_model.py +++ b/serve/mlc_serve/model/dummy_model.py @@ -11,6 +11,7 @@ DecodeRequest, KVCache, PrefillRequest, + RequestType, SequenceId, TextGenerationResult, ) @@ -97,7 +98,7 @@ def get_max_new_tokens(self) -> int: class DummyTextGenerator: def generate( self, - requests: list[Union[PrefillRequest, DecodeRequest]], + requests: list[RequestType], kv_cache: DummyCache, ) -> list[TextGenerationResult]: result = [] diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 0b16ab0b3c..433ca2baa3 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -1,6 +1,6 @@ from pathlib import Path import structlog -from typing import List, Union +from typing import List from .base import get_model_artifact_config from .paged_cache_manager import CacheManager @@ -13,6 +13,7 @@ ModelModule, PrefillRequest, EvalMultiQueryRequest, + RequestType, TextGenerationResult, TextGenerator, ) @@ -26,9 +27,9 @@ def __init__(self, model: TextGenerator): def generate( self, - requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], + requests: List[RequestType], kv_cache, - ) -> list[TextGenerationResult]: + ) -> List[TextGenerationResult]: prefill_requests = [] decode_requests = [] multi_query_decode_requests = [] From fc1356866b83b99da8e19dca2621d5bdf842d626 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 2 Feb 2024 13:43:38 +0400 Subject: [PATCH 3/3] fix lint --- serve/mlc_serve/model/dummy_model.py | 6 ++---- serve/mlc_serve/model/model_common.py | 3 +-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/serve/mlc_serve/model/dummy_model.py b/serve/mlc_serve/model/dummy_model.py index b8900273a6..4e41508c1e 100644 --- a/serve/mlc_serve/model/dummy_model.py +++ b/serve/mlc_serve/model/dummy_model.py @@ -11,7 +11,6 @@ DecodeRequest, KVCache, PrefillRequest, - RequestType, SequenceId, TextGenerationResult, ) @@ -98,18 +97,17 @@ def get_max_new_tokens(self) -> int: class DummyTextGenerator: def generate( self, - requests: list[RequestType], + requests: list[Union[PrefillRequest, DecodeRequest]], kv_cache: DummyCache, ) -> list[TextGenerationResult]: result = [] for req in requests: - # TODO(vvchernov): support other types of Request if isinstance(req, DecodeRequest): seq_id = req.sequence_id request_id = req.sequence_id.request_id if req.sequence_id.sequence_index > 0: raise RuntimeError("Multiple generated sequences not supported") - else: # PrefillRequest + else: seq_id = SequenceId(req.request_id, 0) request_id = req.request_id diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 9f712433d5..22ebec7ebb 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -351,11 +351,11 @@ def sample_from_logits( assert logits.shape[0] == len(requests) sampling_params = [req.sampling_params for req in requests] + outputs: List[TextGenerationResult] = [] try: next_tokens, logprob_infos = sample(logits, sampling_params, vocab_size) assert next_tokens is not None - outputs = [] for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): update_tokens_frequency(requests[i], new_token) outputs = append_text_gen_res( @@ -369,7 +369,6 @@ def sample_from_logits( return outputs except RuntimeError: # Fallback to per-token sampling in case some logits values are corrupted. - outputs = [] err_msg = ( "Error from sampling: probability tensor contains either `inf`, `nan`" " or element < 0"