diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index 71d51332b2..63deae4a94 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -40,8 +40,6 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse router = APIRouter() -import logging -logger = logging.getLogger(__name__) def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params = SamplingParams( diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 5e658cc636..192c75bd62 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -658,11 +658,11 @@ def generate( return [ TextGenerationResult( sequence_id=sequence_id, - generated_tokens=[next_token], + generated_tokens=[new_token], error=None, logprob_info=fetch_logprobs(logprob_info, index, sampling_params[index]), ) - for index, (sequence_id, next_token) in enumerate(zip(sequence_ids, next_tokens)) + for index, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)) ] except RuntimeError: # Fallback to per-token sampling in case some logits values are corrupted. diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index 006ebfbbbf..23c52683a8 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -223,16 +223,16 @@ def test_stop( def test_logprobs( model_artifact_path, use_staging_engine, - max_num_batched_tokens=2560, - max_input_len=2560, + max_num_sequences=4, + max_input_len=512, num_requests=5, logprobs=3, ): prompt = "hi" engine = create_engine( - model_artifact_path, - use_staging_engine, - max_num_batched_tokens, + model_artifact_path, + use_staging_engine, + max_num_sequences, max_input_len, ) s = 113