From fc1356866b83b99da8e19dca2621d5bdf842d626 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 2 Feb 2024 13:43:38 +0400 Subject: [PATCH] fix lint --- serve/mlc_serve/model/dummy_model.py | 6 ++---- serve/mlc_serve/model/model_common.py | 3 +-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/serve/mlc_serve/model/dummy_model.py b/serve/mlc_serve/model/dummy_model.py index b8900273a6..4e41508c1e 100644 --- a/serve/mlc_serve/model/dummy_model.py +++ b/serve/mlc_serve/model/dummy_model.py @@ -11,7 +11,6 @@ DecodeRequest, KVCache, PrefillRequest, - RequestType, SequenceId, TextGenerationResult, ) @@ -98,18 +97,17 @@ def get_max_new_tokens(self) -> int: class DummyTextGenerator: def generate( self, - requests: list[RequestType], + requests: list[Union[PrefillRequest, DecodeRequest]], kv_cache: DummyCache, ) -> list[TextGenerationResult]: result = [] for req in requests: - # TODO(vvchernov): support other types of Request if isinstance(req, DecodeRequest): seq_id = req.sequence_id request_id = req.sequence_id.request_id if req.sequence_id.sequence_index > 0: raise RuntimeError("Multiple generated sequences not supported") - else: # PrefillRequest + else: seq_id = SequenceId(req.request_id, 0) request_id = req.request_id diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 9f712433d5..22ebec7ebb 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -351,11 +351,11 @@ def sample_from_logits( assert logits.shape[0] == len(requests) sampling_params = [req.sampling_params for req in requests] + outputs: List[TextGenerationResult] = [] try: next_tokens, logprob_infos = sample(logits, sampling_params, vocab_size) assert next_tokens is not None - outputs = [] for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): update_tokens_frequency(requests[i], new_token) outputs = append_text_gen_res( @@ -369,7 +369,6 @@ def sample_from_logits( return outputs except RuntimeError: # Fallback to per-token sampling in case some logits values are corrupted. - outputs = [] err_msg = ( "Error from sampling: probability tensor contains either `inf`, `nan`" " or element < 0"