From ab47b414f025388c7dc66494103242c324e20d1d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 22 Jan 2024 07:43:49 +0000 Subject: [PATCH 01/39] Squashed commit for logprobs implementation. Co-authored-by: Valery Chernov Co-authored-by: Ilya Kozulin --- serve/mlc_serve/api/handler.py | 31 ++++-- serve/mlc_serve/api/protocol.py | 6 ++ serve/mlc_serve/engine/__init__.py | 3 +- serve/mlc_serve/engine/async_connector.py | 5 +- serve/mlc_serve/engine/base.py | 11 ++ serve/mlc_serve/engine/dummy.py | 6 +- serve/mlc_serve/engine/engine_common.py | 65 +++++++++++- serve/mlc_serve/engine/model_module.py | 10 +- serve/mlc_serve/engine/sampling_params.py | 16 ++- serve/mlc_serve/engine/staging_engine.py | 4 +- .../mlc_serve/engine/staging_engine_worker.py | 10 +- serve/mlc_serve/engine/sync_engine.py | 2 + serve/mlc_serve/model/model_common.py | 100 ++++++++++++++++-- serve/mlc_serve/model/paged_cache_manager.py | 10 +- serve/mlc_serve/model/paged_cache_model.py | 8 +- serve/mlc_serve/model/tvm_model.py | 27 ++--- serve/mlc_serve/openai_logprob_protocol.py | 28 +++++ .../unittest/test_engine_with_samplers.py | 43 +++++++- 18 files changed, 331 insertions(+), 54 deletions(-) create mode 100644 serve/mlc_serve/openai_logprob_protocol.py diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index df7a74e738..730cba61bb 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -9,7 +9,7 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse, StreamingResponse -# TODO(amalyshe): hadnle random_seed +# TODO(amalyshe): handle random_seed # from .base import set_global_random_seed from ..api.protocol import ( ChatCompletionRequest, @@ -20,6 +20,7 @@ ChatMessage, DeltaMessage, ErrorResponse, + Logprobs, UsageInfo, ) from ..engine import ( @@ -42,7 +43,6 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse router = APIRouter() - def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params = SamplingParams( # These params came from vllm @@ -64,6 +64,9 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params.top_p = request.top_p if request.logit_bias is not None: sampling_params.logit_bias = request.logit_bias + if request.logprobs is not None: + sampling_params.top_logprobs = request.top_logprobs + sampling_params.logprobs = request.logprobs return sampling_params @@ -156,7 +159,7 @@ async def generate_completion_stream( created_time = int(time.time()) def create_stream_response( - choices: list[ChatCompletionResponseStreamChoice], + choices: List[ChatCompletionResponseStreamChoice], ) -> ChatCompletionStreamResponse: return ChatCompletionStreamResponse( id=request_id, @@ -176,7 +179,6 @@ def create_stream_response( ], ) yield f"data: {json.dumps(first_chunk.dict(exclude_unset=True), ensure_ascii=False)}\n\n" - async for res in result_generator: if res.error: raise RuntimeError(f"Error when generating: {res.error}") @@ -192,6 +194,7 @@ def create_stream_response( finish_reason=seq.finish_reason.value if seq.finish_reason is not None else None, + logprob_info=Logprobs(content=seq.logprob_info) if seq.logprob_info else None ) for seq in res.sequences ] @@ -212,6 +215,7 @@ async def collect_result_stream( finish_reasons = [None] * num_sequences num_prompt_tokens = 0 num_generated_tokens = [0 for _ in range(num_sequences)] + logprob_infos = [[] for _ in range(num_sequences)] # type: ignore async for res in result_generator: # TODO: verify that the request cancellation happens after this returns if res.error: @@ -226,18 +230,27 @@ async def collect_result_stream( if seq.delta: sequences[seq.index].append(seq.delta) + if seq.logprob_info: + assert seq.delta + logprob_infos[seq.index].extend(seq.logprob_info) + if seq.is_finished: assert seq.finish_reason is not None finish_reasons[seq.index] = seq.finish_reason.value # type: ignore - - choices = [ - ChatCompletionResponseChoice( + + choices = [] + for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)): + logprobs = None + if logprob_info_seq != []: + logprobs = Logprobs(content=logprob_info_seq) + + choice = ChatCompletionResponseChoice( index=index, message=ChatMessage(role="assistant", content="".join(chunks)), finish_reason=finish_reason, + logprobs=logprobs, ) - for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)) - ] + choices.append(choice) usage = UsageInfo( prompt_tokens=num_prompt_tokens, diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index 286c13edd2..b22a0ca54c 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -6,6 +6,8 @@ from pydantic import BaseModel, Field +from ..openai_logprob_protocol import Logprobs + class ErrorResponse(BaseModel): object: str = "error" @@ -71,11 +73,14 @@ class ChatCompletionRequest(BaseModel): logit_bias: Optional[Dict[int, float]] = None user: Optional[str] = None ignore_eos: Optional[bool] = False + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = None class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage + logprobs: Optional[Logprobs] = None finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None @@ -96,6 +101,7 @@ class DeltaMessage(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage + logprobs: Optional[Logprobs] = None finish_reason: Optional[Literal["stop", "length"]] = None diff --git a/serve/mlc_serve/engine/__init__.py b/serve/mlc_serve/engine/__init__.py index 129b7c05ed..92d101ea95 100644 --- a/serve/mlc_serve/engine/__init__.py +++ b/serve/mlc_serve/engine/__init__.py @@ -16,5 +16,6 @@ RequestState, PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, + RawLogprobsInfo, ) -from .sampling_params import SamplingParams, SamplingType +from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX diff --git a/serve/mlc_serve/engine/async_connector.py b/serve/mlc_serve/engine/async_connector.py index 04a9233296..23d1afc426 100644 --- a/serve/mlc_serve/engine/async_connector.py +++ b/serve/mlc_serve/engine/async_connector.py @@ -1,7 +1,6 @@ import asyncio import structlog -from typing import AsyncIterator, Any -from concurrent.futures import ThreadPoolExecutor +from typing import AsyncIterator, Dict from collections import deque from .base import ( @@ -26,7 +25,7 @@ def __init__(self, engine: InferenceEngine, engine_wait_timeout=1): self.engine_loop_task = None self.engine_loop_exception = None self.shutdown_event = asyncio.Event() - self.result_queues = dict[RequestId, ResultQueue]() + self.result_queues: Dict[RequestId, ResultQueue] = {} self.recent_cancelled_requests = deque[RequestId](maxlen=64) async def start(self): diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 098209da55..d256cadbfe 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -6,13 +6,23 @@ from typing import List, Callable, Any, Optional, Dict import inspect +import numpy as np from .sampling_params import SamplingParams, SamplingType +from ..openai_logprob_protocol import LogprobsContent LOG = structlog.stdlib.get_logger(__name__) RequestId = str +@dataclass +class RawLogprobsInfo: + current_token: int + current_logprob: float + top_tokens: Optional[np.array] + top_logprobs: Optional[np.array] + + # TODO(@sunggg): consider transition to something like Pydantic @dataclass class MLCServeEngineConfig: @@ -155,6 +165,7 @@ class SequenceOutput: finish_reason: Optional[FinishReason] = None # Number of generated tokens so far num_generated_tokens: int = 0 + logprob_info: Optional[List[Optional[LogprobsContent]]] = None @property def is_finished(self) -> bool: diff --git a/serve/mlc_serve/engine/dummy.py b/serve/mlc_serve/engine/dummy.py index 3d7c10651c..03227dd3ca 100644 --- a/serve/mlc_serve/engine/dummy.py +++ b/serve/mlc_serve/engine/dummy.py @@ -12,9 +12,9 @@ class DummyInferenceEngine: - def __init__(self): - self.queue_lock = Lock() - self.has_new_requests = Condition(self.queue_lock) + def __init__(self) -> None: + self.queue_lock: Lock = Lock() + self.has_new_requests: Condition = Condition(self.queue_lock) self.request_queue: Dict[RequestId, int] = {} def add(self, requests: list[Request]): diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 1bc252b48c..7a9c585659 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -3,17 +3,18 @@ """ import time -from typing import Tuple, Deque, Dict, Optional, Union, Callable +from typing import Tuple, Deque, Dict, Optional, Union, Callable, List from collections import deque from threading import Condition, Lock import structlog from .base import ( + GenerationSequence, + RawLogprobsInfo, Request, RequestId, RequestState, - GenerationSequence, SequenceId, StoppingCriteria, ) @@ -27,6 +28,7 @@ Tokenizer as TokenizerP, ) from ..model.base import ModelArtifactConfig +from ..openai_logprob_protocol import LogprobsContent, TopLogprobs LOG = structlog.stdlib.get_logger(__name__) @@ -135,6 +137,65 @@ def detokenize_incrementally( return delta +def logprob_detokenize( + tokenizer: TokenizerP, + logprob_info: Optional[RawLogprobsInfo], +) -> Optional[LogprobsContent]: + """Detokenize tokens from RawLogprobInfo and convert the latter to LogprobContent""" + if logprob_info is None: + return None + + top_logprobs: List[TopLogprobs] = [] + if ( + logprob_info.top_tokens is not None and + logprob_info.top_logprobs is not None + ): + top_tokens = list(zip(logprob_info.top_tokens, logprob_info.top_logprobs)) + count: Dict[str, int] = {} + # dedup duplicates + # Todo: Make sure decode can generate different tokens + for top_token, _ in top_tokens: + detokenized = tokenizer.decode(top_token) + if detokenized in count: + count[detokenized] += 1 + else: + count[detokenized] = 1 + for top_token, top_logprob in top_tokens: + detokenized = tokenizer.decode(top_token) + if count[detokenized] != 1: + detokenized = f"{detokenized}_{top_token}" + top_logprobs.append(TopLogprobs( + token=detokenized, + logprob=float(top_logprob), + # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object + bytes=None, + )) + + logprobs_content = LogprobsContent( + token=tokenizer.decode([logprob_info.current_token]), + logprob=logprob_info.current_logprob, + # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object + bytes=None, + top_logprobs=top_logprobs, + ) + + return logprobs_content + + +def logprobs_detokenize( + tokenizer: TokenizerP, + logprobs_info: List[Optional[RawLogprobsInfo]], +) -> Optional[List[Optional[LogprobsContent]]]: + res: List[Optional[LogprobsContent]] = [] + for logprob_info in logprobs_info: + res.append(logprob_detokenize(tokenizer, logprob_info)) + + check_all = all([x is None for x in res]) + if check_all: + return None + return res + + def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended): if stopping_criteria.stop_sequences: for t in stopping_criteria.stop_sequences: diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 79b77e93a3..d305c514b8 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -4,7 +4,14 @@ from dataclasses import dataclass from typing import Optional, Protocol, Union, List, Sequence -from .base import ChatMessage, RequestId, MLCServeEngineConfig, RequestState, SequenceId +from .base import ( + ChatMessage, + MLCServeEngineConfig, + RawLogprobsInfo, + RequestId, + RequestState, + SequenceId, +) from ..model.base import ModelArtifactConfig from .sampling_params import SamplingParams @@ -44,6 +51,7 @@ class TextGenerationResult: # making this a list of token ids to leave room for speculative decoding generated_tokens: List[int] error: Optional[str] + logprob_info: Optional[List[Optional[RawLogprobsInfo]]] = None class KVCache(Protocol): diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index bbcdc7fd2c..d5ba5d109c 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -9,8 +9,8 @@ from functools import cached_property from typing import Dict, Optional - _SAMPLING_EPS = 1e-5 +LOGPROB_TOP_K_MAX = 5 class SamplingType(IntEnum): @@ -46,6 +46,13 @@ class SamplingParams: to -1 to consider all tokens. logit_bias: The bias applied on the logit before sampling. Must be in [-100, 100]. + logprobs: Optional[bool] Whether to return log probabilities of the output + tokens or not. If true, returns the log probabilities of each output + token returned in the content of message. + top_logprobs: Optional[Integer] An integer between 1 and 5 specifying + the number of most likely tokens to return at each token position, + each with an associated log probability. logprobs must be set to + true if this parameter is used. """ presence_penalty: float = 0.0 @@ -58,6 +65,8 @@ class SamplingParams: appeared_tokens_freq: Dict[int, int] = None logit_bias_index: list[int] = None logit_bias_value: list[float] = None + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = None def __post_init__(self): self.appeared_tokens_freq = {} @@ -95,6 +104,11 @@ def _verify_args(self) -> None: raise ValueError( f"logit bias must be in [-100, 100], got {bias} for token {token}." ) + if self.logprobs is not None and self.logprobs: + if (self.top_logprobs < 1 or self.top_logprobs > LOGPROB_TOP_K_MAX): + raise ValueError( + f"top_logprobs must be between 1 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}." + ) def _verify_greedy_sampling(self) -> None: if self.top_p < 1.0 - _SAMPLING_EPS: diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index 3d1fe70b99..c8354e4c5b 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -5,8 +5,8 @@ import multiprocessing import queue from threading import Lock -from typing import Callable from collections import defaultdict +from typing import Callable import structlog @@ -24,6 +24,7 @@ from .engine_common import ( get_new_request_state, update_sequence, + logprobs_detokenize ) from .model_module import ModelModule, TokenizerModule from .staging_engine_worker import ( @@ -251,6 +252,7 @@ def step(self) -> InferenceStepResult: delta, finish_reason, num_generated_tokens=len(gen_seq.generated_token_ids), + logprob_info=logprobs_detokenize(self.tokenizer, seq_output.logprob_info), ) seq_outputs[request_id].append(output) diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index e74a6181c8..ceb84857ae 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -12,12 +12,14 @@ from .base import ( FinishReason, + RawLogprobsInfo, RequestId, RequestState, ValidationError, SequenceId, GenerationSequence, ) + from .metrics import PrometheusMetrics from .metrics_labels import * from .model_module import ( @@ -40,7 +42,7 @@ class ShutdownCommand: @dataclass class AddRequestsCommand: - request_states: list[RequestState] + request_states: List[RequestState] @dataclass @@ -61,14 +63,15 @@ class StopSequenceCommand: @dataclass class SequenceGenerationOutput: id: SequenceId - new_tokens: list[int] + new_tokens: List[int] finish_reason: Optional[FinishReason] = None error: Optional[Union[str, ValidationError]] = None + logprob_info: Optional[List[RawLogprobsInfo]] = None @dataclass class GenerationLoopWorkerOutput: - sequences: list[SequenceGenerationOutput] + sequences: List[SequenceGenerationOutput] error: Optional[BaseException] = None @@ -288,6 +291,7 @@ def step(self) -> GenerationLoopWorkerOutput: id=res.sequence_id, new_tokens=new_tokens, finish_reason=finish_reason, + logprob_info=res.logprob_info, ) ) diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index 5dcf80eda7..c400ec6b4a 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -22,6 +22,7 @@ get_requests_to_process, update_sequence, EngineBase, + logprobs_detokenize ) from .model_module import ( ModelModule, @@ -222,6 +223,7 @@ def step(self) -> InferenceStepResult: delta, num_generated_tokens=len(gen_seq.generated_token_ids), finish_reason=finish_reason, + logprob_info=logprobs_detokenize(self.tokenizer, res.logprob_info), ) ) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 2ecbab9202..13e92425e0 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -1,4 +1,4 @@ -from typing import List, Union, Optional +from typing import List, Optional, Tuple, Union import structlog import numpy as np @@ -9,6 +9,8 @@ from ..engine import ( SamplingType, SamplingParams, + LOGPROB_TOP_K_MAX, + RawLogprobsInfo, ) LOG = structlog.stdlib.get_logger(__name__) @@ -36,6 +38,50 @@ def get_num_cache_blocks( ) +def fetch_raw_logprob_infos( + logits, + res_tokens, + sampling_params, +) -> List[Optional[RawLogprobsInfo]]: + logprob_infos: List[Optional[RawLogprobsInfo]] = [] + num_seq = logits.shape[0] + for index in range(num_seq): + if ( + sampling_params[index].logprobs is not None and + sampling_params[index].logprobs + ): + # Logprob sampling + logprobs = torch.log_softmax(logits[index], dim=-1) + res_logprob = logprobs[res_tokens[index]].cpu().numpy() + + top_logprobs_num = sampling_params[index].top_logprobs + if ( + top_logprobs_num is None or + top_logprobs_num == 0 + ): + top_logprobs = None + top_tokens = None + else: + assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs" + top_logprobs, top_tokens = torch.topk( + logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True + ) + top_tokens=top_tokens.cpu().numpy() + top_logprobs=top_logprobs.cpu().numpy() + + # Set to raw logprob info + logprob_infos.append(RawLogprobsInfo( + current_token=res_tokens[index].cpu().numpy(), + current_logprob=res_logprob, + top_tokens=top_tokens, + top_logprobs=top_logprobs, + )) + else: + logprob_infos.append(None) + + return logprob_infos + + def _apply_top_p_top_k(logits, top_ps, top_ks): p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) @@ -59,12 +105,31 @@ def _apply_top_p_top_k(logits, top_ps, top_ks): return logits +def update_masked_list(input_list, mask, update): + j = 0 + for i in range(len(mask)): + if mask[i]: + input_list[i] = update[j] + j = j + 1 + + return input_list + + +def filter_list_by_mask(i_list, mask): + o_list = [] + for i in range(len(mask)): + if mask[i]: + o_list.append(i_list[i]) + + return o_list + + def sample( logits: Union[tvm.nd.NDArray, torch.Tensor], sampling_params: List[SamplingParams], vocab_size: int, check_safety=False, -) -> Optional[np.ndarray]: +) -> Optional[Tuple[np.ndarray, List[Optional[RawLogprobsInfo]]]]: def _is_safe_to_sample(prob_like): return ( torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) @@ -84,11 +149,19 @@ def _is_safe_to_sample(prob_like): logits_greedy = logits[mask_greedy] if logits_greedy.shape[0] > 0: - res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy() + res_greedy = torch.argmax(logits_greedy, -1) + + logprob_infos_greedy = fetch_raw_logprob_infos( + logits_greedy, + res_greedy, + filter_list_by_mask(sampling_params, mask_greedy) + ) + res_greedy = res_greedy.cpu().numpy() + # Case when there's only greedy sampling if logits_greedy.shape[0] == num_seq: torch.cuda.nvtx.range_pop() - return res_greedy + return res_greedy, logprob_infos_greedy temperatures = [] top_ps = [] @@ -155,20 +228,33 @@ def _is_safe_to_sample(prob_like): torch.cuda.nvtx.range_pop() return None - res_random = torch.multinomial(probs, 1, True).cpu().numpy()[:, 0] + res_random = torch.multinomial(probs, 1, True)[:, 0] + logprob_infos_random = fetch_raw_logprob_infos( + logits_random, + res_random, + filter_list_by_mask(sampling_params, mask_random), + ) + + res_random = res_random.cpu().numpy() + # Case when there's only random sampling if logits_random.shape[0] == num_seq: torch.cuda.nvtx.range_pop() - return res_random + return res_random, logprob_infos_random res = np.empty((num_seq,), dtype=np.int32) res[mask_random] = res_random + logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq + logprob_infos = update_masked_list(logprob_infos, mask_random, logprob_infos_random) + if logits_greedy.shape[0] > 0: res[mask_greedy] = res_greedy + logprob_infos = update_masked_list(logprob_infos, mask_greedy, logprob_infos_greedy) + torch.cuda.nvtx.range_pop() - return res + return res, logprob_infos def prepare_inputs( diff --git a/serve/mlc_serve/model/paged_cache_manager.py b/serve/mlc_serve/model/paged_cache_manager.py index 7193c0eadb..1614b442d4 100644 --- a/serve/mlc_serve/model/paged_cache_manager.py +++ b/serve/mlc_serve/model/paged_cache_manager.py @@ -1,6 +1,6 @@ import math from collections import defaultdict -from typing import List, Optional +from typing import Any, List, Optional from ..engine import ( RequestId, @@ -104,16 +104,16 @@ def replace_head_prompt_block_with(self, new_block): class KVCacheInfo: def __init__( self, - block_size, + block_size: int ): self.block_size = block_size # SequenceId -> list[int] - self.prompt_block_tables = defaultdict(list) - self.slot_mappings = defaultdict(list) + self.prompt_block_tables = defaultdict(list) # type: ignore + self.slot_mappings = defaultdict(list) # type: ignore # The core data structure - self.decode_block_tables = dict[SequenceId, DecodeBlockTable]() + self.decode_block_tables: dict = dict[SequenceId, DecodeBlockTable]() # Record indices of blocks to copy after prefill in the format [src1, dst1, src2, dst2, ...] self.pending_copy_from_to: list[int] = [] diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index a7f11750eb..7daf3336f4 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -1,6 +1,6 @@ -from typing import Union from pathlib import Path import structlog +from typing import List, Union from .base import get_model_artifact_config from .paged_cache_manager import CacheManager @@ -10,11 +10,11 @@ from ..engine import MLCServeEngineConfig from ..engine.model_module import ( DecodeRequest, + ModelModule, PrefillRequest, TextGenerationResult, TextGenerator, ) -from ..engine.model_module import ModelModule LOG = structlog.stdlib.get_logger(__name__) @@ -24,8 +24,8 @@ def __init__(self, model: TextGenerator): self.model = model def generate( - self, requests: list[Union[PrefillRequest, DecodeRequest]], kv_cache - ) -> list[TextGenerationResult]: + self, requests: List[Union[PrefillRequest, DecodeRequest]], kv_cache + ) -> List[TextGenerationResult]: prefill_requests = [r for r in requests if isinstance(r, PrefillRequest)] decode_requests = [r for r in requests if isinstance(r, DecodeRequest)] diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index cfb528ac43..4144c862fb 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -282,13 +282,6 @@ def generate( slot_mapping, self.params, ) - - if self.disco_session: - logits, _ = out.debug_get_from_remote(0) - else: - logits = out[ - 0 - ] # Ignore returned KV cache since it is updated in-place anyway. else: torch.cuda.nvtx.range_push(f"forward decode {input_shape}") @@ -305,10 +298,12 @@ def generate( self.params, ) - if self.disco_session: - logits, _ = out.debug_get_from_remote(0) - else: - logits = out[0] + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[ + 0 + ] # Ignore returned KV cache since it is updated in-place anyway. torch.cuda.synchronize() torch.cuda.nvtx.range_pop() @@ -330,7 +325,7 @@ def generate( cache.pending_copy_from_to = [] try: - next_tokens = sample(logits, sampling_params, self.vocab_size) + next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size) assert next_tokens is not None outputs = [] for i, (sequence_id, new_token) in enumerate( @@ -346,6 +341,7 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, + logprob_info=[logprob_infos[i]], ) ) else: @@ -354,6 +350,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], error=None, + logprob_info=[logprob_infos[i]], ) ) @@ -369,7 +366,7 @@ def generate( for i, (sequence_id, logits_per_token, sampling_param) in enumerate( zip(sequence_ids, torch.from_dlpack(logits), sampling_params) ): - maybe_new_token = sample( + maybe_new_token, logprob_infos = sample( torch.unsqueeze(logits_per_token, 0), [sampling_param], self.vocab_size, @@ -393,6 +390,7 @@ def generate( ), generated_tokens=[new_token], # type: ignore error=None, + logprob_info=[logprob_infos[0]] ) ) else: @@ -401,6 +399,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], # type: ignore error=None, + logprob_info=[logprob_infos[0]] ) ) else: @@ -413,6 +412,7 @@ def generate( ), generated_tokens=[], error=err_msg, + logprob_info=[logprob_infos[0]] ) ) else: @@ -421,6 +421,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[], error=err_msg, + logprob_info=[logprob_infos[0]] ) ) diff --git a/serve/mlc_serve/openai_logprob_protocol.py b/serve/mlc_serve/openai_logprob_protocol.py new file mode 100644 index 0000000000..9c2a4db502 --- /dev/null +++ b/serve/mlc_serve/openai_logprob_protocol.py @@ -0,0 +1,28 @@ +from typing import List, Optional + +from pydantic import BaseModel + +class TopLogprobs(BaseModel): + """An OpenAI API compatible schema for logprobs output.""" + + token: str + logprob: float + bytes: Optional[List] = None + + +class LogprobsContent(BaseModel): + """An OpenAI API compatible schema for logprobs output.""" + + token: str + logprob: float + bytes: Optional[List] = None + top_logprobs: List[TopLogprobs] # It can be empty + + +class Logprobs(BaseModel): + """ + An OpenAI API compatible schema for logprobs output. + See details in https://platform.openai.com/docs/api-reference/chat/object#chat-create-logprobs + """ + + content: List[LogprobsContent] \ No newline at end of file diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index 1642364232..2a7bf1efd4 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -48,7 +48,7 @@ def create_engine( def create_request( - idx, prompt, temp, freq_pen, pre_pen, max_tokens, stop, ignore_eos, logit_bias=None + idx, prompt, temp, freq_pen, pre_pen, max_tokens, stop, ignore_eos, top_logprobs=0, logprobs=False, logit_bias=None ): return Request( request_id=str(idx), @@ -58,6 +58,8 @@ def create_request( frequency_penalty=freq_pen, presence_penalty=pre_pen, logit_bias=logit_bias, + logprobs=logprobs, + top_logprobs=top_logprobs, ), stopping_criteria=StoppingCriteria(max_tokens=max_tokens, stop_sequences=stop), debug_options=DebugOptions(ignore_eos=ignore_eos), @@ -337,6 +339,43 @@ def _test_penalty( if use_staging_engine: engine.stop() +def _test_logprobs( + model_artifact_path, + use_staging_engine, + max_num_sequences=4, + max_input_len=512, + num_requests=5, + top_logprobs=3, +): + prompt = "hi" + engine = create_engine( + model_artifact_path, + use_staging_engine, + max_num_sequences, + max_input_len, + ) + s = 113 + requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True, top_logprobs=top_logprobs, logprobs=True) for n in range(s, s+num_requests)] + engine.add(requests) + + generated = ["" for _ in range(num_requests)] + + while engine.has_pending_requests(): + results = engine.step() + for res in results.outputs: + assert len(res.sequences) == 1 + seq = res.sequences[0] + + assert seq.finish_reason is not None or len(list(seq.logprobs.content[0]["top_logprobs"])) == top_logprobs + + if seq.is_finished: + assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens + assert seq.finish_reason == FinishReason.Length + else: + generated[int(res.request_id)] += seq.delta + + if use_staging_engine: + engine.stop() if __name__ == "__main__": parser = get_default_mlc_serve_argparser("test engine with samplers") @@ -349,6 +388,8 @@ def _test_penalty( _test_ignore_eos(args.model_artifact_path, use_staging_engine=False) _test_stop(args.model_artifact_path, use_staging_engine=False) _test_stop(args.model_artifact_path, use_staging_engine=True) + _test_logprobs(args.model_artifact_path, use_staging_engine=True) + _test_logprobs(args.model_artifact_path, use_staging_engine=False) # These tests are broken since we are now imposing no length limit # if max_tokens = None. The tests do not finish in a reasonable time. # _test_max_context_length(model_artifact_path, use_staging_engine=True) From 86f6fa18382d0b124e43064fb9699405442ba8ef Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 23 Jan 2024 10:24:27 +0400 Subject: [PATCH 02/39] fix None check --- serve/mlc_serve/engine/engine_common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 7a9c585659..1f08d892e0 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -186,6 +186,9 @@ def logprobs_detokenize( tokenizer: TokenizerP, logprobs_info: List[Optional[RawLogprobsInfo]], ) -> Optional[List[Optional[LogprobsContent]]]: + if logprobs_info is None: + return None + res: List[Optional[LogprobsContent]] = [] for logprob_info in logprobs_info: res.append(logprob_detokenize(tokenizer, logprob_info)) From 9a296500746fcb03f79267dbda773c0c7ae347d1 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 25 Jan 2024 15:03:41 +0000 Subject: [PATCH 03/39] Change detokenization to using token ids. --- serve/mlc_serve/engine/base.py | 1 + serve/mlc_serve/engine/engine_common.py | 22 ++++++++-------------- serve/mlc_serve/model/model_common.py | 1 + serve/mlc_serve/model/tvm_model.py | 11 ++++++++++- 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index d256cadbfe..33e76ae743 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -21,6 +21,7 @@ class RawLogprobsInfo: current_logprob: float top_tokens: Optional[np.array] top_logprobs: Optional[np.array] + previous_tokens: Optional[List[int]] # TODO(@sunggg): consider transition to something like Pydantic diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 1f08d892e0..4d0554294d 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -151,19 +151,13 @@ def logprob_detokenize( logprob_info.top_logprobs is not None ): top_tokens = list(zip(logprob_info.top_tokens, logprob_info.top_logprobs)) - count: Dict[str, int] = {} # dedup duplicates # Todo: Make sure decode can generate different tokens - for top_token, _ in top_tokens: - detokenized = tokenizer.decode(top_token) - if detokenized in count: - count[detokenized] += 1 - else: - count[detokenized] = 1 + if logprob_info.previous_tokens is None: + logprob_info.previous_tokens = [] for top_token, top_logprob in top_tokens: - detokenized = tokenizer.decode(top_token) - if count[detokenized] != 1: - detokenized = f"{detokenized}_{top_token}" + detokenized = tokenizer.convert_ids_to_tokens(logprob_info.previous_tokens + [top_token])[-1] + LOG.info(f"detokenized: {detokenized}") top_logprobs.append(TopLogprobs( token=detokenized, logprob=float(top_logprob), @@ -184,14 +178,14 @@ def logprob_detokenize( def logprobs_detokenize( tokenizer: TokenizerP, - logprobs_info: List[Optional[RawLogprobsInfo]], + logprob_info: List[Optional[RawLogprobsInfo]], ) -> Optional[List[Optional[LogprobsContent]]]: - if logprobs_info is None: + if logprob_info is None: return None res: List[Optional[LogprobsContent]] = [] - for logprob_info in logprobs_info: - res.append(logprob_detokenize(tokenizer, logprob_info)) + for info in logprob_info: + res.append(logprob_detokenize(tokenizer, info)) check_all = all([x is None for x in res]) if check_all: diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 13e92425e0..127e551692 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -75,6 +75,7 @@ def fetch_raw_logprob_infos( current_logprob=res_logprob, top_tokens=top_tokens, top_logprobs=top_logprobs, + previous_tokens=None )) else: logprob_infos.append(None) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 4144c862fb..e011ddf48b 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -22,6 +22,7 @@ PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, MLCServeEngineConfig, + RawLogprobsInfo, ) from ..engine.model_module import ( DecodeRequest, @@ -84,6 +85,11 @@ def get_tvm_model(config, dev): return load_disco_module(config.model_artifact_path, lib_path, config.num_shards) +def attach_detokenization_info(logprob_info:RawLogprobsInfo, token_ids: List[int]): + if logprob_info is None: + return None + logprob_info.previous_tokens = token_ids + return logprob_info def _prepare_inputs( sequence_ids, @@ -326,6 +332,7 @@ def generate( try: next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size) + current_ids = list(input_ids.numpy()) assert next_tokens is not None outputs = [] for i, (sequence_id, new_token) in enumerate( @@ -341,9 +348,10 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, - logprob_info=[logprob_infos[i]], + logprob_info=[attach_detokenization_info(logprob_infos[i], current_ids)], ) ) + current_ids.append(new_token) else: outputs.append( TextGenerationResult( @@ -353,6 +361,7 @@ def generate( logprob_info=[logprob_infos[i]], ) ) + current_ids.append(new_token) return outputs except RuntimeError: From 012388d8975344a569462632e5752c5e1f8d3344 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 29 Jan 2024 08:46:54 +0000 Subject: [PATCH 04/39] Fix wrong usage of token ids. Remove logging. --- serve/mlc_serve/engine/engine_common.py | 1 - serve/mlc_serve/model/tvm_model.py | 17 ++++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 4d0554294d..73be36f9dd 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -157,7 +157,6 @@ def logprob_detokenize( logprob_info.previous_tokens = [] for top_token, top_logprob in top_tokens: detokenized = tokenizer.convert_ids_to_tokens(logprob_info.previous_tokens + [top_token])[-1] - LOG.info(f"detokenized: {detokenized}") top_logprobs.append(TopLogprobs( token=detokenized, logprob=float(top_logprob), diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index e011ddf48b..70147d721f 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -86,8 +86,6 @@ def get_tvm_model(config, dev): return load_disco_module(config.model_artifact_path, lib_path, config.num_shards) def attach_detokenization_info(logprob_info:RawLogprobsInfo, token_ids: List[int]): - if logprob_info is None: - return None logprob_info.previous_tokens = token_ids return logprob_info @@ -332,11 +330,10 @@ def generate( try: next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size) - current_ids = list(input_ids.numpy()) assert next_tokens is not None outputs = [] - for i, (sequence_id, new_token) in enumerate( - zip(sequence_ids, next_tokens) + for i, (sequence_id, new_token, token_ids) in enumerate( + zip(sequence_ids, next_tokens, all_token_ids) ): if not new_token in requests[i].sampling_params.appeared_tokens_freq: requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 @@ -348,20 +345,22 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, - logprob_info=[attach_detokenization_info(logprob_infos[i], current_ids)], + logprob_info=[attach_detokenization_info(logprob_infos[i], token_ids) if logprob_infos[i] else None], ) ) - current_ids.append(new_token) + if logprob_infos[i]: + token_ids.append(new_token) else: outputs.append( TextGenerationResult( sequence_id=sequence_id, generated_tokens=[new_token], error=None, - logprob_info=[logprob_infos[i]], + logprob_info=[attach_detokenization_info(logprob_infos[i], token_ids) if logprob_infos[i] else None], ) ) - current_ids.append(new_token) + if logprob_infos[i]: + token_ids.append(new_token) return outputs except RuntimeError: From db311645aa85fb6bb3418277c5e37e4155df6901 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 26 Jan 2024 11:25:19 +0400 Subject: [PATCH 05/39] extend benchmarks for logprobs --- serve/benchmarks/benchmark_latency.py | 2 ++ serve/benchmarks/benchmark_throughput.py | 2 ++ serve/benchmarks/utils.py | 16 ++++++++++++++++ 3 files changed, 20 insertions(+) diff --git a/serve/benchmarks/benchmark_latency.py b/serve/benchmarks/benchmark_latency.py index c64f968640..84c377b710 100644 --- a/serve/benchmarks/benchmark_latency.py +++ b/serve/benchmarks/benchmark_latency.py @@ -34,6 +34,8 @@ def create_request(request_id): frequency_penalty=args.sampling_setting["frequency_penalty"], presence_penalty=args.sampling_setting["presence_penalty"], logit_bias=args.sampling_setting["logit_bias"], + logprobs = args.sampling_setting["logprobs"], + top_logprobs = args.sampling_setting["top_logprobs"], ), stopping_criteria=StoppingCriteria( max_tokens=args.num_output_tokens, stop_sequences=None diff --git a/serve/benchmarks/benchmark_throughput.py b/serve/benchmarks/benchmark_throughput.py index 1a7fb46451..4f82d80dd7 100644 --- a/serve/benchmarks/benchmark_throughput.py +++ b/serve/benchmarks/benchmark_throughput.py @@ -139,6 +139,8 @@ def run_mlc(engine, requests, args) -> float: frequency_penalty=args.sampling_setting["frequency_penalty"], presence_penalty=args.sampling_setting["presence_penalty"], logit_bias=args.sampling_setting["logit_bias"], + logprobs = args.sampling_setting["logprobs"], + top_logprobs = args.sampling_setting["top_logprobs"], ), stopping_criteria=StoppingCriteria( max_tokens=args.num_output_tokens, stop_sequences=None diff --git a/serve/benchmarks/utils.py b/serve/benchmarks/utils.py index defea7b099..c62c134de8 100644 --- a/serve/benchmarks/utils.py +++ b/serve/benchmarks/utils.py @@ -22,6 +22,18 @@ def add_sampling_flags(parser): action="store_true", help="Apply all penalties, logit bias, top-p and top-k.", ) + parser.add_argument( + "--logprobs", + action="store_true", + default=False, + help="Switch on logprobs output" + ) + parser.add_argument( + "--top-logprobs", + type=int, + default=5, + help="Number of top logprobs to output, limited by 5. Works only with logprobs true." + ) def postproc_sampling_args(args): @@ -51,3 +63,7 @@ def postproc_sampling_args(args): if args.apply_top_p_top_k: args.sampling_setting["top_k"] = 2 args.sampling_setting["top_p"] = 0.7 + + if args.logprobs: + args.sampling_setting["logprobs"] = True + args.sampling_setting["top_logprobs"] = args.top_logprobs From be817558c682efc6c03391fdeadd04a23e3ce8a6 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 26 Jan 2024 13:19:21 +0400 Subject: [PATCH 06/39] fix test without logprobs --- serve/benchmarks/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/serve/benchmarks/utils.py b/serve/benchmarks/utils.py index c62c134de8..4507dc0dd2 100644 --- a/serve/benchmarks/utils.py +++ b/serve/benchmarks/utils.py @@ -45,6 +45,8 @@ def postproc_sampling_args(args): "repetition_penalty": 1.0, "top_p": 1.0, "top_k": -1, + "logprobs": False, + "top_logprobs": 5, } if args.apply_all_sampling_params: From e8ec3fc747a047a10249b51e8c42d3052f1825ca Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 26 Jan 2024 14:27:22 +0400 Subject: [PATCH 07/39] clean code --- serve/mlc_serve/engine/engine_common.py | 8 ++++---- serve/mlc_serve/openai_logprob_protocol.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 73be36f9dd..a3519ced50 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -138,8 +138,8 @@ def detokenize_incrementally( def logprob_detokenize( - tokenizer: TokenizerP, - logprob_info: Optional[RawLogprobsInfo], + tokenizer: TokenizerP, + logprob_info: Optional[RawLogprobsInfo], ) -> Optional[LogprobsContent]: """Detokenize tokens from RawLogprobInfo and convert the latter to LogprobContent""" if logprob_info is None: @@ -176,8 +176,8 @@ def logprob_detokenize( def logprobs_detokenize( - tokenizer: TokenizerP, - logprob_info: List[Optional[RawLogprobsInfo]], + tokenizer: TokenizerP, + logprob_info: List[Optional[RawLogprobsInfo]], ) -> Optional[List[Optional[LogprobsContent]]]: if logprob_info is None: return None diff --git a/serve/mlc_serve/openai_logprob_protocol.py b/serve/mlc_serve/openai_logprob_protocol.py index 9c2a4db502..36f2b693f6 100644 --- a/serve/mlc_serve/openai_logprob_protocol.py +++ b/serve/mlc_serve/openai_logprob_protocol.py @@ -25,4 +25,4 @@ class Logprobs(BaseModel): See details in https://platform.openai.com/docs/api-reference/chat/object#chat-create-logprobs """ - content: List[LogprobsContent] \ No newline at end of file + content: List[LogprobsContent] From 49187f58c25ac6add87e7c0a09027ac45075ba6e Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 26 Jan 2024 15:09:35 +0400 Subject: [PATCH 08/39] black format engine_common.py --- serve/mlc_serve/engine/engine_common.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index a3519ced50..01489b4ea3 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -146,23 +146,24 @@ def logprob_detokenize( return None top_logprobs: List[TopLogprobs] = [] - if ( - logprob_info.top_tokens is not None and - logprob_info.top_logprobs is not None - ): + if logprob_info.top_tokens is not None and logprob_info.top_logprobs is not None: top_tokens = list(zip(logprob_info.top_tokens, logprob_info.top_logprobs)) # dedup duplicates # Todo: Make sure decode can generate different tokens if logprob_info.previous_tokens is None: logprob_info.previous_tokens = [] for top_token, top_logprob in top_tokens: - detokenized = tokenizer.convert_ids_to_tokens(logprob_info.previous_tokens + [top_token])[-1] - top_logprobs.append(TopLogprobs( - token=detokenized, - logprob=float(top_logprob), - # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object - bytes=None, - )) + detokenized = tokenizer.convert_ids_to_tokens( + logprob_info.previous_tokens + [top_token] + )[-1] + top_logprobs.append( + TopLogprobs( + token=detokenized, + logprob=float(top_logprob), + # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object + bytes=None, + ) + ) logprobs_content = LogprobsContent( token=tokenizer.decode([logprob_info.current_token]), From 013ed5a7d338d4ea424d01baaf772f3d2473a273 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 26 Jan 2024 17:45:57 +0400 Subject: [PATCH 09/39] logprobs is strictly bool, top_logprobs is int --- serve/mlc_serve/api/handler.py | 2 +- serve/mlc_serve/api/protocol.py | 4 ++-- serve/mlc_serve/engine/sampling_params.py | 12 ++++++------ serve/mlc_serve/model/model_common.py | 10 ++-------- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index 730cba61bb..73e873bc73 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -64,7 +64,7 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params.top_p = request.top_p if request.logit_bias is not None: sampling_params.logit_bias = request.logit_bias - if request.logprobs is not None: + if request.logprobs: sampling_params.top_logprobs = request.top_logprobs sampling_params.logprobs = request.logprobs return sampling_params diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index b22a0ca54c..4f42f7233e 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -73,8 +73,8 @@ class ChatCompletionRequest(BaseModel): logit_bias: Optional[Dict[int, float]] = None user: Optional[str] = None ignore_eos: Optional[bool] = False - logprobs: Optional[bool] = False - top_logprobs: Optional[int] = None + logprobs: bool = False + top_logprobs: int = 0 class ChatCompletionResponseChoice(BaseModel): diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index d5ba5d109c..961b2b744a 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -49,7 +49,7 @@ class SamplingParams: logprobs: Optional[bool] Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. - top_logprobs: Optional[Integer] An integer between 1 and 5 specifying + top_logprobs: Optional[Integer] An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. @@ -65,8 +65,8 @@ class SamplingParams: appeared_tokens_freq: Dict[int, int] = None logit_bias_index: list[int] = None logit_bias_value: list[float] = None - logprobs: Optional[bool] = False - top_logprobs: Optional[int] = None + logprobs: bool = False + top_logprobs: int = 0 def __post_init__(self): self.appeared_tokens_freq = {} @@ -104,10 +104,10 @@ def _verify_args(self) -> None: raise ValueError( f"logit bias must be in [-100, 100], got {bias} for token {token}." ) - if self.logprobs is not None and self.logprobs: - if (self.top_logprobs < 1 or self.top_logprobs > LOGPROB_TOP_K_MAX): + if self.logprobs: + if (self.top_logprobs < 0 or self.top_logprobs > LOGPROB_TOP_K_MAX): raise ValueError( - f"top_logprobs must be between 1 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}." + f"top_logprobs must be between 0 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}." ) def _verify_greedy_sampling(self) -> None: diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 127e551692..9ccb98cc62 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -46,19 +46,13 @@ def fetch_raw_logprob_infos( logprob_infos: List[Optional[RawLogprobsInfo]] = [] num_seq = logits.shape[0] for index in range(num_seq): - if ( - sampling_params[index].logprobs is not None and - sampling_params[index].logprobs - ): + if sampling_params[index].logprobs: # Logprob sampling logprobs = torch.log_softmax(logits[index], dim=-1) res_logprob = logprobs[res_tokens[index]].cpu().numpy() top_logprobs_num = sampling_params[index].top_logprobs - if ( - top_logprobs_num is None or - top_logprobs_num == 0 - ): + if top_logprobs_num == 0: top_logprobs = None top_tokens = None else: From 79ec4135b5d93a0be3684fb1dd84443be9cd533b Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sun, 28 Jan 2024 15:38:37 +0400 Subject: [PATCH 10/39] refactor logprob info collection to not reduce performance --- serve/mlc_serve/model/model_common.py | 120 +++++++++++++------------- 1 file changed, 58 insertions(+), 62 deletions(-) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 9ccb98cc62..1ed6381f94 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -38,41 +38,55 @@ def get_num_cache_blocks( ) -def fetch_raw_logprob_infos( +def get_raw_logprob_info( logits, - res_tokens, - sampling_params, + token, + top_logprobs_num, +) -> RawLogprobsInfo: + logprobs = torch.log_softmax(logits, dim=-1) + res_logprob = logprobs[token].cpu().numpy() + + if top_logprobs_num == 0: + top_logprobs = None + top_tokens = None + else: + assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs" + top_logprobs, top_tokens = torch.topk( + logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True + ) + top_tokens=top_tokens.cpu().numpy() + top_logprobs=top_logprobs.cpu().numpy() + + # Set to raw logprob info + return RawLogprobsInfo( + # TODO(vvchernov): it is number, cpu().numpy()? + current_token=token.cpu().numpy(), + current_logprob=res_logprob, + top_tokens=top_tokens, + top_logprobs=top_logprobs, + previous_tokens=None + ) + + +def get_masked_logprobs( + logprob_infos: List[Optional[RawLogprobsInfo]], + mask: torch.Tensor, + sampling_params: List[SamplingParams], + logits: torch.Tensor, + tokens: torch.Tensor, ) -> List[Optional[RawLogprobsInfo]]: - logprob_infos: List[Optional[RawLogprobsInfo]] = [] - num_seq = logits.shape[0] - for index in range(num_seq): - if sampling_params[index].logprobs: - # Logprob sampling - logprobs = torch.log_softmax(logits[index], dim=-1) - res_logprob = logprobs[res_tokens[index]].cpu().numpy() - - top_logprobs_num = sampling_params[index].top_logprobs - if top_logprobs_num == 0: - top_logprobs = None - top_tokens = None - else: - assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs" - top_logprobs, top_tokens = torch.topk( - logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True + num_seq = len(logprob_infos) + + mask_counter = 0 + for i in range(num_seq): + if mask[i]: + if sampling_params[i].logprobs: + logprob_infos[i] = get_raw_logprob_info( + logits[mask_counter], + tokens[mask_counter], + sampling_params[i].top_logprobs, ) - top_tokens=top_tokens.cpu().numpy() - top_logprobs=top_logprobs.cpu().numpy() - - # Set to raw logprob info - logprob_infos.append(RawLogprobsInfo( - current_token=res_tokens[index].cpu().numpy(), - current_logprob=res_logprob, - top_tokens=top_tokens, - top_logprobs=top_logprobs, - previous_tokens=None - )) - else: - logprob_infos.append(None) + mask_counter = mask_counter + 1 return logprob_infos @@ -100,25 +114,6 @@ def _apply_top_p_top_k(logits, top_ps, top_ks): return logits -def update_masked_list(input_list, mask, update): - j = 0 - for i in range(len(mask)): - if mask[i]: - input_list[i] = update[j] - j = j + 1 - - return input_list - - -def filter_list_by_mask(i_list, mask): - o_list = [] - for i in range(len(mask)): - if mask[i]: - o_list.append(i_list[i]) - - return o_list - - def sample( logits: Union[tvm.nd.NDArray, torch.Tensor], sampling_params: List[SamplingParams], @@ -135,6 +130,8 @@ def _is_safe_to_sample(prob_like): logits = torch.from_dlpack(logits) num_seq = len(sampling_params) + logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq + mask_random = torch.tensor( [p.sampling_type == SamplingType.RANDOM for p in sampling_params], dtype=torch.bool, @@ -146,17 +143,19 @@ def _is_safe_to_sample(prob_like): if logits_greedy.shape[0] > 0: res_greedy = torch.argmax(logits_greedy, -1) - logprob_infos_greedy = fetch_raw_logprob_infos( + logprob_infos = get_masked_logprobs( + logprob_infos, + mask_greedy, + sampling_params, logits_greedy, res_greedy, - filter_list_by_mask(sampling_params, mask_greedy) ) res_greedy = res_greedy.cpu().numpy() # Case when there's only greedy sampling if logits_greedy.shape[0] == num_seq: torch.cuda.nvtx.range_pop() - return res_greedy, logprob_infos_greedy + return res_greedy, logprob_infos temperatures = [] top_ps = [] @@ -225,29 +224,26 @@ def _is_safe_to_sample(prob_like): res_random = torch.multinomial(probs, 1, True)[:, 0] - logprob_infos_random = fetch_raw_logprob_infos( + logprob_infos = get_masked_logprobs( + logprob_infos, + mask_random, + sampling_params, logits_random, res_random, - filter_list_by_mask(sampling_params, mask_random), ) res_random = res_random.cpu().numpy() # Case when there's only random sampling if logits_random.shape[0] == num_seq: torch.cuda.nvtx.range_pop() - return res_random, logprob_infos_random + return res_random, logprob_infos res = np.empty((num_seq,), dtype=np.int32) res[mask_random] = res_random - logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq - logprob_infos = update_masked_list(logprob_infos, mask_random, logprob_infos_random) - if logits_greedy.shape[0] > 0: res[mask_greedy] = res_greedy - logprob_infos = update_masked_list(logprob_infos, mask_greedy, logprob_infos_greedy) - torch.cuda.nvtx.range_pop() return res, logprob_infos From fca1a6f68b25832b3ec2a6bab4c3b3eef57c9dcc Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 29 Jan 2024 10:55:22 +0400 Subject: [PATCH 11/39] quick fix for check --- serve/mlc_serve/model/tvm_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 70147d721f..ad14fc11fa 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -252,6 +252,8 @@ def generate( ) input_shape = input_ids.shape + # TODO(vvchernov): quick fix, but need to refactor logic + current_ids = list(input_ids.numpy()) if self.disco_session: input_ids = copy_to_worker_0(self.disco_session, input_ids) From 675b631aa986749b742f7afe93405a8ab9d38cbf Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 29 Jan 2024 10:57:45 +0400 Subject: [PATCH 12/39] review fix --- serve/mlc_serve/engine/engine_common.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 01489b4ea3..9f151ba042 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -180,9 +180,6 @@ def logprobs_detokenize( tokenizer: TokenizerP, logprob_info: List[Optional[RawLogprobsInfo]], ) -> Optional[List[Optional[LogprobsContent]]]: - if logprob_info is None: - return None - res: List[Optional[LogprobsContent]] = [] for info in logprob_info: res.append(logprob_detokenize(tokenizer, info)) From 18f80fa452d23889fb62cb0a72ada057ff49b710 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 29 Jan 2024 11:31:13 +0400 Subject: [PATCH 13/39] fix list index out of range --- serve/mlc_serve/engine/engine_common.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 9f151ba042..5579d6d27a 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -153,9 +153,11 @@ def logprob_detokenize( if logprob_info.previous_tokens is None: logprob_info.previous_tokens = [] for top_token, top_logprob in top_tokens: - detokenized = tokenizer.convert_ids_to_tokens( - logprob_info.previous_tokens + [top_token] - )[-1] + # TODO(vvchernov): not clear what do we want + # detokenized = tokenizer.convert_ids_to_tokens( + # logprob_info.previous_tokens + [top_token] + # )[-1] + detokenized = tokenizer.decode(top_token) top_logprobs.append( TopLogprobs( token=detokenized, From 29ea525fd85dd7316505b0da9a0c5fc15d2f735e Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 29 Jan 2024 16:57:56 +0400 Subject: [PATCH 14/39] rollback after rebase --- serve/mlc_serve/engine/engine_common.py | 10 ++++------ serve/mlc_serve/model/tvm_model.py | 2 -- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 5579d6d27a..f2509a5f1c 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -148,16 +148,14 @@ def logprob_detokenize( top_logprobs: List[TopLogprobs] = [] if logprob_info.top_tokens is not None and logprob_info.top_logprobs is not None: top_tokens = list(zip(logprob_info.top_tokens, logprob_info.top_logprobs)) - # dedup duplicates - # Todo: Make sure decode can generate different tokens if logprob_info.previous_tokens is None: logprob_info.previous_tokens = [] for top_token, top_logprob in top_tokens: # TODO(vvchernov): not clear what do we want - # detokenized = tokenizer.convert_ids_to_tokens( - # logprob_info.previous_tokens + [top_token] - # )[-1] - detokenized = tokenizer.decode(top_token) + detokenized = tokenizer.convert_ids_to_tokens( + logprob_info.previous_tokens + [top_token] + )[-1] + # detokenized = tokenizer.decode(top_token) top_logprobs.append( TopLogprobs( token=detokenized, diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index ad14fc11fa..70147d721f 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -252,8 +252,6 @@ def generate( ) input_shape = input_ids.shape - # TODO(vvchernov): quick fix, but need to refactor logic - current_ids = list(input_ids.numpy()) if self.disco_session: input_ids = copy_to_worker_0(self.disco_session, input_ids) From aa993224eea942809c1d3f362012a7965f604e0d Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 29 Jan 2024 18:16:54 +0400 Subject: [PATCH 15/39] test --- serve/mlc_serve/engine/engine_common.py | 8 ++++---- serve/mlc_serve/model/tvm_model.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index f2509a5f1c..51293f51b4 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -152,10 +152,10 @@ def logprob_detokenize( logprob_info.previous_tokens = [] for top_token, top_logprob in top_tokens: # TODO(vvchernov): not clear what do we want - detokenized = tokenizer.convert_ids_to_tokens( - logprob_info.previous_tokens + [top_token] - )[-1] - # detokenized = tokenizer.decode(top_token) + # detokenized = tokenizer.convert_ids_to_tokens( + # logprob_info.previous_tokens + [top_token] + # )[-1] + detokenized = tokenizer.decode(top_token) top_logprobs.append( TopLogprobs( token=detokenized, diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 70147d721f..e9d849f9b0 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -348,8 +348,8 @@ def generate( logprob_info=[attach_detokenization_info(logprob_infos[i], token_ids) if logprob_infos[i] else None], ) ) - if logprob_infos[i]: - token_ids.append(new_token) + # if logprob_infos[i]: + # token_ids.append(new_token) else: outputs.append( TextGenerationResult( @@ -359,8 +359,8 @@ def generate( logprob_info=[attach_detokenization_info(logprob_infos[i], token_ids) if logprob_infos[i] else None], ) ) - if logprob_infos[i]: - token_ids.append(new_token) + # if logprob_infos[i]: + # token_ids.append(new_token) return outputs except RuntimeError: From d57b197e33149824df5099a75ad33277405ba5ee Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 22 Jan 2024 07:43:49 +0000 Subject: [PATCH 16/39] Squashed commit for logprobs implementation. Co-authored-by: Valery Chernov Co-authored-by: Ilya Kozulin --- serve/mlc_serve/api/handler.py | 31 ++++-- serve/mlc_serve/api/protocol.py | 6 ++ serve/mlc_serve/engine/__init__.py | 3 +- serve/mlc_serve/engine/async_connector.py | 5 +- serve/mlc_serve/engine/base.py | 11 ++ serve/mlc_serve/engine/dummy.py | 6 +- serve/mlc_serve/engine/engine_common.py | 65 +++++++++++- serve/mlc_serve/engine/model_module.py | 10 +- serve/mlc_serve/engine/sampling_params.py | 16 ++- serve/mlc_serve/engine/staging_engine.py | 4 +- .../mlc_serve/engine/staging_engine_worker.py | 10 +- serve/mlc_serve/engine/sync_engine.py | 2 + serve/mlc_serve/model/model_common.py | 100 ++++++++++++++++-- serve/mlc_serve/model/paged_cache_manager.py | 10 +- serve/mlc_serve/model/paged_cache_model.py | 8 +- serve/mlc_serve/model/tvm_model.py | 27 ++--- serve/mlc_serve/openai_logprob_protocol.py | 28 +++++ .../unittest/test_engine_with_samplers.py | 43 +++++++- 18 files changed, 331 insertions(+), 54 deletions(-) create mode 100644 serve/mlc_serve/openai_logprob_protocol.py diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index df7a74e738..730cba61bb 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -9,7 +9,7 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse, StreamingResponse -# TODO(amalyshe): hadnle random_seed +# TODO(amalyshe): handle random_seed # from .base import set_global_random_seed from ..api.protocol import ( ChatCompletionRequest, @@ -20,6 +20,7 @@ ChatMessage, DeltaMessage, ErrorResponse, + Logprobs, UsageInfo, ) from ..engine import ( @@ -42,7 +43,6 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse router = APIRouter() - def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params = SamplingParams( # These params came from vllm @@ -64,6 +64,9 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params.top_p = request.top_p if request.logit_bias is not None: sampling_params.logit_bias = request.logit_bias + if request.logprobs is not None: + sampling_params.top_logprobs = request.top_logprobs + sampling_params.logprobs = request.logprobs return sampling_params @@ -156,7 +159,7 @@ async def generate_completion_stream( created_time = int(time.time()) def create_stream_response( - choices: list[ChatCompletionResponseStreamChoice], + choices: List[ChatCompletionResponseStreamChoice], ) -> ChatCompletionStreamResponse: return ChatCompletionStreamResponse( id=request_id, @@ -176,7 +179,6 @@ def create_stream_response( ], ) yield f"data: {json.dumps(first_chunk.dict(exclude_unset=True), ensure_ascii=False)}\n\n" - async for res in result_generator: if res.error: raise RuntimeError(f"Error when generating: {res.error}") @@ -192,6 +194,7 @@ def create_stream_response( finish_reason=seq.finish_reason.value if seq.finish_reason is not None else None, + logprob_info=Logprobs(content=seq.logprob_info) if seq.logprob_info else None ) for seq in res.sequences ] @@ -212,6 +215,7 @@ async def collect_result_stream( finish_reasons = [None] * num_sequences num_prompt_tokens = 0 num_generated_tokens = [0 for _ in range(num_sequences)] + logprob_infos = [[] for _ in range(num_sequences)] # type: ignore async for res in result_generator: # TODO: verify that the request cancellation happens after this returns if res.error: @@ -226,18 +230,27 @@ async def collect_result_stream( if seq.delta: sequences[seq.index].append(seq.delta) + if seq.logprob_info: + assert seq.delta + logprob_infos[seq.index].extend(seq.logprob_info) + if seq.is_finished: assert seq.finish_reason is not None finish_reasons[seq.index] = seq.finish_reason.value # type: ignore - - choices = [ - ChatCompletionResponseChoice( + + choices = [] + for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)): + logprobs = None + if logprob_info_seq != []: + logprobs = Logprobs(content=logprob_info_seq) + + choice = ChatCompletionResponseChoice( index=index, message=ChatMessage(role="assistant", content="".join(chunks)), finish_reason=finish_reason, + logprobs=logprobs, ) - for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)) - ] + choices.append(choice) usage = UsageInfo( prompt_tokens=num_prompt_tokens, diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index 286c13edd2..b22a0ca54c 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -6,6 +6,8 @@ from pydantic import BaseModel, Field +from ..openai_logprob_protocol import Logprobs + class ErrorResponse(BaseModel): object: str = "error" @@ -71,11 +73,14 @@ class ChatCompletionRequest(BaseModel): logit_bias: Optional[Dict[int, float]] = None user: Optional[str] = None ignore_eos: Optional[bool] = False + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = None class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage + logprobs: Optional[Logprobs] = None finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None @@ -96,6 +101,7 @@ class DeltaMessage(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage + logprobs: Optional[Logprobs] = None finish_reason: Optional[Literal["stop", "length"]] = None diff --git a/serve/mlc_serve/engine/__init__.py b/serve/mlc_serve/engine/__init__.py index 129b7c05ed..92d101ea95 100644 --- a/serve/mlc_serve/engine/__init__.py +++ b/serve/mlc_serve/engine/__init__.py @@ -16,5 +16,6 @@ RequestState, PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, + RawLogprobsInfo, ) -from .sampling_params import SamplingParams, SamplingType +from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX diff --git a/serve/mlc_serve/engine/async_connector.py b/serve/mlc_serve/engine/async_connector.py index 04a9233296..23d1afc426 100644 --- a/serve/mlc_serve/engine/async_connector.py +++ b/serve/mlc_serve/engine/async_connector.py @@ -1,7 +1,6 @@ import asyncio import structlog -from typing import AsyncIterator, Any -from concurrent.futures import ThreadPoolExecutor +from typing import AsyncIterator, Dict from collections import deque from .base import ( @@ -26,7 +25,7 @@ def __init__(self, engine: InferenceEngine, engine_wait_timeout=1): self.engine_loop_task = None self.engine_loop_exception = None self.shutdown_event = asyncio.Event() - self.result_queues = dict[RequestId, ResultQueue]() + self.result_queues: Dict[RequestId, ResultQueue] = {} self.recent_cancelled_requests = deque[RequestId](maxlen=64) async def start(self): diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 098209da55..d256cadbfe 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -6,13 +6,23 @@ from typing import List, Callable, Any, Optional, Dict import inspect +import numpy as np from .sampling_params import SamplingParams, SamplingType +from ..openai_logprob_protocol import LogprobsContent LOG = structlog.stdlib.get_logger(__name__) RequestId = str +@dataclass +class RawLogprobsInfo: + current_token: int + current_logprob: float + top_tokens: Optional[np.array] + top_logprobs: Optional[np.array] + + # TODO(@sunggg): consider transition to something like Pydantic @dataclass class MLCServeEngineConfig: @@ -155,6 +165,7 @@ class SequenceOutput: finish_reason: Optional[FinishReason] = None # Number of generated tokens so far num_generated_tokens: int = 0 + logprob_info: Optional[List[Optional[LogprobsContent]]] = None @property def is_finished(self) -> bool: diff --git a/serve/mlc_serve/engine/dummy.py b/serve/mlc_serve/engine/dummy.py index 3d7c10651c..03227dd3ca 100644 --- a/serve/mlc_serve/engine/dummy.py +++ b/serve/mlc_serve/engine/dummy.py @@ -12,9 +12,9 @@ class DummyInferenceEngine: - def __init__(self): - self.queue_lock = Lock() - self.has_new_requests = Condition(self.queue_lock) + def __init__(self) -> None: + self.queue_lock: Lock = Lock() + self.has_new_requests: Condition = Condition(self.queue_lock) self.request_queue: Dict[RequestId, int] = {} def add(self, requests: list[Request]): diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 1bc252b48c..7a9c585659 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -3,17 +3,18 @@ """ import time -from typing import Tuple, Deque, Dict, Optional, Union, Callable +from typing import Tuple, Deque, Dict, Optional, Union, Callable, List from collections import deque from threading import Condition, Lock import structlog from .base import ( + GenerationSequence, + RawLogprobsInfo, Request, RequestId, RequestState, - GenerationSequence, SequenceId, StoppingCriteria, ) @@ -27,6 +28,7 @@ Tokenizer as TokenizerP, ) from ..model.base import ModelArtifactConfig +from ..openai_logprob_protocol import LogprobsContent, TopLogprobs LOG = structlog.stdlib.get_logger(__name__) @@ -135,6 +137,65 @@ def detokenize_incrementally( return delta +def logprob_detokenize( + tokenizer: TokenizerP, + logprob_info: Optional[RawLogprobsInfo], +) -> Optional[LogprobsContent]: + """Detokenize tokens from RawLogprobInfo and convert the latter to LogprobContent""" + if logprob_info is None: + return None + + top_logprobs: List[TopLogprobs] = [] + if ( + logprob_info.top_tokens is not None and + logprob_info.top_logprobs is not None + ): + top_tokens = list(zip(logprob_info.top_tokens, logprob_info.top_logprobs)) + count: Dict[str, int] = {} + # dedup duplicates + # Todo: Make sure decode can generate different tokens + for top_token, _ in top_tokens: + detokenized = tokenizer.decode(top_token) + if detokenized in count: + count[detokenized] += 1 + else: + count[detokenized] = 1 + for top_token, top_logprob in top_tokens: + detokenized = tokenizer.decode(top_token) + if count[detokenized] != 1: + detokenized = f"{detokenized}_{top_token}" + top_logprobs.append(TopLogprobs( + token=detokenized, + logprob=float(top_logprob), + # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object + bytes=None, + )) + + logprobs_content = LogprobsContent( + token=tokenizer.decode([logprob_info.current_token]), + logprob=logprob_info.current_logprob, + # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object + bytes=None, + top_logprobs=top_logprobs, + ) + + return logprobs_content + + +def logprobs_detokenize( + tokenizer: TokenizerP, + logprobs_info: List[Optional[RawLogprobsInfo]], +) -> Optional[List[Optional[LogprobsContent]]]: + res: List[Optional[LogprobsContent]] = [] + for logprob_info in logprobs_info: + res.append(logprob_detokenize(tokenizer, logprob_info)) + + check_all = all([x is None for x in res]) + if check_all: + return None + return res + + def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended): if stopping_criteria.stop_sequences: for t in stopping_criteria.stop_sequences: diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 79b77e93a3..d305c514b8 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -4,7 +4,14 @@ from dataclasses import dataclass from typing import Optional, Protocol, Union, List, Sequence -from .base import ChatMessage, RequestId, MLCServeEngineConfig, RequestState, SequenceId +from .base import ( + ChatMessage, + MLCServeEngineConfig, + RawLogprobsInfo, + RequestId, + RequestState, + SequenceId, +) from ..model.base import ModelArtifactConfig from .sampling_params import SamplingParams @@ -44,6 +51,7 @@ class TextGenerationResult: # making this a list of token ids to leave room for speculative decoding generated_tokens: List[int] error: Optional[str] + logprob_info: Optional[List[Optional[RawLogprobsInfo]]] = None class KVCache(Protocol): diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index bbcdc7fd2c..d5ba5d109c 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -9,8 +9,8 @@ from functools import cached_property from typing import Dict, Optional - _SAMPLING_EPS = 1e-5 +LOGPROB_TOP_K_MAX = 5 class SamplingType(IntEnum): @@ -46,6 +46,13 @@ class SamplingParams: to -1 to consider all tokens. logit_bias: The bias applied on the logit before sampling. Must be in [-100, 100]. + logprobs: Optional[bool] Whether to return log probabilities of the output + tokens or not. If true, returns the log probabilities of each output + token returned in the content of message. + top_logprobs: Optional[Integer] An integer between 1 and 5 specifying + the number of most likely tokens to return at each token position, + each with an associated log probability. logprobs must be set to + true if this parameter is used. """ presence_penalty: float = 0.0 @@ -58,6 +65,8 @@ class SamplingParams: appeared_tokens_freq: Dict[int, int] = None logit_bias_index: list[int] = None logit_bias_value: list[float] = None + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = None def __post_init__(self): self.appeared_tokens_freq = {} @@ -95,6 +104,11 @@ def _verify_args(self) -> None: raise ValueError( f"logit bias must be in [-100, 100], got {bias} for token {token}." ) + if self.logprobs is not None and self.logprobs: + if (self.top_logprobs < 1 or self.top_logprobs > LOGPROB_TOP_K_MAX): + raise ValueError( + f"top_logprobs must be between 1 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}." + ) def _verify_greedy_sampling(self) -> None: if self.top_p < 1.0 - _SAMPLING_EPS: diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index 3d1fe70b99..c8354e4c5b 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -5,8 +5,8 @@ import multiprocessing import queue from threading import Lock -from typing import Callable from collections import defaultdict +from typing import Callable import structlog @@ -24,6 +24,7 @@ from .engine_common import ( get_new_request_state, update_sequence, + logprobs_detokenize ) from .model_module import ModelModule, TokenizerModule from .staging_engine_worker import ( @@ -251,6 +252,7 @@ def step(self) -> InferenceStepResult: delta, finish_reason, num_generated_tokens=len(gen_seq.generated_token_ids), + logprob_info=logprobs_detokenize(self.tokenizer, seq_output.logprob_info), ) seq_outputs[request_id].append(output) diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index e74a6181c8..ceb84857ae 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -12,12 +12,14 @@ from .base import ( FinishReason, + RawLogprobsInfo, RequestId, RequestState, ValidationError, SequenceId, GenerationSequence, ) + from .metrics import PrometheusMetrics from .metrics_labels import * from .model_module import ( @@ -40,7 +42,7 @@ class ShutdownCommand: @dataclass class AddRequestsCommand: - request_states: list[RequestState] + request_states: List[RequestState] @dataclass @@ -61,14 +63,15 @@ class StopSequenceCommand: @dataclass class SequenceGenerationOutput: id: SequenceId - new_tokens: list[int] + new_tokens: List[int] finish_reason: Optional[FinishReason] = None error: Optional[Union[str, ValidationError]] = None + logprob_info: Optional[List[RawLogprobsInfo]] = None @dataclass class GenerationLoopWorkerOutput: - sequences: list[SequenceGenerationOutput] + sequences: List[SequenceGenerationOutput] error: Optional[BaseException] = None @@ -288,6 +291,7 @@ def step(self) -> GenerationLoopWorkerOutput: id=res.sequence_id, new_tokens=new_tokens, finish_reason=finish_reason, + logprob_info=res.logprob_info, ) ) diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index 5dcf80eda7..c400ec6b4a 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -22,6 +22,7 @@ get_requests_to_process, update_sequence, EngineBase, + logprobs_detokenize ) from .model_module import ( ModelModule, @@ -222,6 +223,7 @@ def step(self) -> InferenceStepResult: delta, num_generated_tokens=len(gen_seq.generated_token_ids), finish_reason=finish_reason, + logprob_info=logprobs_detokenize(self.tokenizer, res.logprob_info), ) ) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index d99dd0bbeb..0a337c1707 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -1,4 +1,4 @@ -from typing import List, Union, Optional +from typing import List, Optional, Tuple, Union import structlog import numpy as np @@ -9,6 +9,8 @@ from ..engine import ( SamplingType, SamplingParams, + LOGPROB_TOP_K_MAX, + RawLogprobsInfo, ) LOG = structlog.stdlib.get_logger(__name__) @@ -36,6 +38,50 @@ def get_num_cache_blocks( ) +def fetch_raw_logprob_infos( + logits, + res_tokens, + sampling_params, +) -> List[Optional[RawLogprobsInfo]]: + logprob_infos: List[Optional[RawLogprobsInfo]] = [] + num_seq = logits.shape[0] + for index in range(num_seq): + if ( + sampling_params[index].logprobs is not None and + sampling_params[index].logprobs + ): + # Logprob sampling + logprobs = torch.log_softmax(logits[index], dim=-1) + res_logprob = logprobs[res_tokens[index]].cpu().numpy() + + top_logprobs_num = sampling_params[index].top_logprobs + if ( + top_logprobs_num is None or + top_logprobs_num == 0 + ): + top_logprobs = None + top_tokens = None + else: + assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs" + top_logprobs, top_tokens = torch.topk( + logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True + ) + top_tokens=top_tokens.cpu().numpy() + top_logprobs=top_logprobs.cpu().numpy() + + # Set to raw logprob info + logprob_infos.append(RawLogprobsInfo( + current_token=res_tokens[index].cpu().numpy(), + current_logprob=res_logprob, + top_tokens=top_tokens, + top_logprobs=top_logprobs, + )) + else: + logprob_infos.append(None) + + return logprob_infos + + def _apply_top_p_top_k(logits, top_ps, top_ks): p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) @@ -59,12 +105,31 @@ def _apply_top_p_top_k(logits, top_ps, top_ks): return logits +def update_masked_list(input_list, mask, update): + j = 0 + for i in range(len(mask)): + if mask[i]: + input_list[i] = update[j] + j = j + 1 + + return input_list + + +def filter_list_by_mask(i_list, mask): + o_list = [] + for i in range(len(mask)): + if mask[i]: + o_list.append(i_list[i]) + + return o_list + + def sample( logits: Union[tvm.nd.NDArray, torch.Tensor], sampling_params: List[SamplingParams], vocab_size: int, check_safety=False, -) -> Optional[np.ndarray]: +) -> Optional[Tuple[np.ndarray, List[Optional[RawLogprobsInfo]]]]: def _is_safe_to_sample(prob_like): return ( torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) @@ -90,11 +155,19 @@ def _is_safe_to_sample(prob_like): logits_greedy = logits[mask_greedy_dvc] if logits_greedy.shape[0] > 0: - res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy() + res_greedy = torch.argmax(logits_greedy, -1) + + logprob_infos_greedy = fetch_raw_logprob_infos( + logits_greedy, + res_greedy, + filter_list_by_mask(sampling_params, mask_greedy_dvc) + ) + res_greedy = res_greedy.cpu().numpy() + # Case when there's only greedy sampling if logits_greedy.shape[0] == num_seq: torch.cuda.nvtx.range_pop() - return res_greedy + return res_greedy, logprob_infos_greedy temperatures = [] top_ps = [] @@ -161,20 +234,33 @@ def _is_safe_to_sample(prob_like): torch.cuda.nvtx.range_pop() return None - res_random = torch.multinomial(probs, 1, True)[:, 0].cpu().numpy() + res_random = torch.multinomial(probs, 1, True)[:, 0] + logprob_infos_random = fetch_raw_logprob_infos( + logits_random, + res_random, + filter_list_by_mask(sampling_params, mask_random_dvc), + ) + + res_random = res_random.cpu().numpy() + # Case when there's only random sampling if logits_random.shape[0] == num_seq: torch.cuda.nvtx.range_pop() - return res_random + return res_random, logprob_infos_random res = np.empty((num_seq,), dtype=np.int32) res[mask_random_cpu] = res_random + logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq + logprob_infos = update_masked_list(logprob_infos, mask_random_cpu, logprob_infos_random) + if logits_greedy.shape[0] > 0: res[mask_greedy_cpu] = res_greedy + logprob_infos = update_masked_list(logprob_infos, mask_greedy_cpu, logprob_infos_greedy) + torch.cuda.nvtx.range_pop() - return res + return res, logprob_infos def prepare_inputs( diff --git a/serve/mlc_serve/model/paged_cache_manager.py b/serve/mlc_serve/model/paged_cache_manager.py index 7193c0eadb..1614b442d4 100644 --- a/serve/mlc_serve/model/paged_cache_manager.py +++ b/serve/mlc_serve/model/paged_cache_manager.py @@ -1,6 +1,6 @@ import math from collections import defaultdict -from typing import List, Optional +from typing import Any, List, Optional from ..engine import ( RequestId, @@ -104,16 +104,16 @@ def replace_head_prompt_block_with(self, new_block): class KVCacheInfo: def __init__( self, - block_size, + block_size: int ): self.block_size = block_size # SequenceId -> list[int] - self.prompt_block_tables = defaultdict(list) - self.slot_mappings = defaultdict(list) + self.prompt_block_tables = defaultdict(list) # type: ignore + self.slot_mappings = defaultdict(list) # type: ignore # The core data structure - self.decode_block_tables = dict[SequenceId, DecodeBlockTable]() + self.decode_block_tables: dict = dict[SequenceId, DecodeBlockTable]() # Record indices of blocks to copy after prefill in the format [src1, dst1, src2, dst2, ...] self.pending_copy_from_to: list[int] = [] diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index a7f11750eb..7daf3336f4 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -1,6 +1,6 @@ -from typing import Union from pathlib import Path import structlog +from typing import List, Union from .base import get_model_artifact_config from .paged_cache_manager import CacheManager @@ -10,11 +10,11 @@ from ..engine import MLCServeEngineConfig from ..engine.model_module import ( DecodeRequest, + ModelModule, PrefillRequest, TextGenerationResult, TextGenerator, ) -from ..engine.model_module import ModelModule LOG = structlog.stdlib.get_logger(__name__) @@ -24,8 +24,8 @@ def __init__(self, model: TextGenerator): self.model = model def generate( - self, requests: list[Union[PrefillRequest, DecodeRequest]], kv_cache - ) -> list[TextGenerationResult]: + self, requests: List[Union[PrefillRequest, DecodeRequest]], kv_cache + ) -> List[TextGenerationResult]: prefill_requests = [r for r in requests if isinstance(r, PrefillRequest)] decode_requests = [r for r in requests if isinstance(r, DecodeRequest)] diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index cfb528ac43..4144c862fb 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -282,13 +282,6 @@ def generate( slot_mapping, self.params, ) - - if self.disco_session: - logits, _ = out.debug_get_from_remote(0) - else: - logits = out[ - 0 - ] # Ignore returned KV cache since it is updated in-place anyway. else: torch.cuda.nvtx.range_push(f"forward decode {input_shape}") @@ -305,10 +298,12 @@ def generate( self.params, ) - if self.disco_session: - logits, _ = out.debug_get_from_remote(0) - else: - logits = out[0] + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[ + 0 + ] # Ignore returned KV cache since it is updated in-place anyway. torch.cuda.synchronize() torch.cuda.nvtx.range_pop() @@ -330,7 +325,7 @@ def generate( cache.pending_copy_from_to = [] try: - next_tokens = sample(logits, sampling_params, self.vocab_size) + next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size) assert next_tokens is not None outputs = [] for i, (sequence_id, new_token) in enumerate( @@ -346,6 +341,7 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, + logprob_info=[logprob_infos[i]], ) ) else: @@ -354,6 +350,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], error=None, + logprob_info=[logprob_infos[i]], ) ) @@ -369,7 +366,7 @@ def generate( for i, (sequence_id, logits_per_token, sampling_param) in enumerate( zip(sequence_ids, torch.from_dlpack(logits), sampling_params) ): - maybe_new_token = sample( + maybe_new_token, logprob_infos = sample( torch.unsqueeze(logits_per_token, 0), [sampling_param], self.vocab_size, @@ -393,6 +390,7 @@ def generate( ), generated_tokens=[new_token], # type: ignore error=None, + logprob_info=[logprob_infos[0]] ) ) else: @@ -401,6 +399,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], # type: ignore error=None, + logprob_info=[logprob_infos[0]] ) ) else: @@ -413,6 +412,7 @@ def generate( ), generated_tokens=[], error=err_msg, + logprob_info=[logprob_infos[0]] ) ) else: @@ -421,6 +421,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[], error=err_msg, + logprob_info=[logprob_infos[0]] ) ) diff --git a/serve/mlc_serve/openai_logprob_protocol.py b/serve/mlc_serve/openai_logprob_protocol.py new file mode 100644 index 0000000000..9c2a4db502 --- /dev/null +++ b/serve/mlc_serve/openai_logprob_protocol.py @@ -0,0 +1,28 @@ +from typing import List, Optional + +from pydantic import BaseModel + +class TopLogprobs(BaseModel): + """An OpenAI API compatible schema for logprobs output.""" + + token: str + logprob: float + bytes: Optional[List] = None + + +class LogprobsContent(BaseModel): + """An OpenAI API compatible schema for logprobs output.""" + + token: str + logprob: float + bytes: Optional[List] = None + top_logprobs: List[TopLogprobs] # It can be empty + + +class Logprobs(BaseModel): + """ + An OpenAI API compatible schema for logprobs output. + See details in https://platform.openai.com/docs/api-reference/chat/object#chat-create-logprobs + """ + + content: List[LogprobsContent] \ No newline at end of file diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index 1642364232..2a7bf1efd4 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -48,7 +48,7 @@ def create_engine( def create_request( - idx, prompt, temp, freq_pen, pre_pen, max_tokens, stop, ignore_eos, logit_bias=None + idx, prompt, temp, freq_pen, pre_pen, max_tokens, stop, ignore_eos, top_logprobs=0, logprobs=False, logit_bias=None ): return Request( request_id=str(idx), @@ -58,6 +58,8 @@ def create_request( frequency_penalty=freq_pen, presence_penalty=pre_pen, logit_bias=logit_bias, + logprobs=logprobs, + top_logprobs=top_logprobs, ), stopping_criteria=StoppingCriteria(max_tokens=max_tokens, stop_sequences=stop), debug_options=DebugOptions(ignore_eos=ignore_eos), @@ -337,6 +339,43 @@ def _test_penalty( if use_staging_engine: engine.stop() +def _test_logprobs( + model_artifact_path, + use_staging_engine, + max_num_sequences=4, + max_input_len=512, + num_requests=5, + top_logprobs=3, +): + prompt = "hi" + engine = create_engine( + model_artifact_path, + use_staging_engine, + max_num_sequences, + max_input_len, + ) + s = 113 + requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True, top_logprobs=top_logprobs, logprobs=True) for n in range(s, s+num_requests)] + engine.add(requests) + + generated = ["" for _ in range(num_requests)] + + while engine.has_pending_requests(): + results = engine.step() + for res in results.outputs: + assert len(res.sequences) == 1 + seq = res.sequences[0] + + assert seq.finish_reason is not None or len(list(seq.logprobs.content[0]["top_logprobs"])) == top_logprobs + + if seq.is_finished: + assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens + assert seq.finish_reason == FinishReason.Length + else: + generated[int(res.request_id)] += seq.delta + + if use_staging_engine: + engine.stop() if __name__ == "__main__": parser = get_default_mlc_serve_argparser("test engine with samplers") @@ -349,6 +388,8 @@ def _test_penalty( _test_ignore_eos(args.model_artifact_path, use_staging_engine=False) _test_stop(args.model_artifact_path, use_staging_engine=False) _test_stop(args.model_artifact_path, use_staging_engine=True) + _test_logprobs(args.model_artifact_path, use_staging_engine=True) + _test_logprobs(args.model_artifact_path, use_staging_engine=False) # These tests are broken since we are now imposing no length limit # if max_tokens = None. The tests do not finish in a reasonable time. # _test_max_context_length(model_artifact_path, use_staging_engine=True) From 7995c849e1dd272f1e0a17fa8469fd853d112dc6 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 23 Jan 2024 10:24:27 +0400 Subject: [PATCH 17/39] fix None check --- serve/mlc_serve/engine/engine_common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 7a9c585659..1f08d892e0 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -186,6 +186,9 @@ def logprobs_detokenize( tokenizer: TokenizerP, logprobs_info: List[Optional[RawLogprobsInfo]], ) -> Optional[List[Optional[LogprobsContent]]]: + if logprobs_info is None: + return None + res: List[Optional[LogprobsContent]] = [] for logprob_info in logprobs_info: res.append(logprob_detokenize(tokenizer, logprob_info)) From ae3fc5b31c05deeb2ae930581b884e4ad9f1f461 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 25 Jan 2024 15:03:41 +0000 Subject: [PATCH 18/39] Change detokenization to using token ids. --- serve/mlc_serve/engine/base.py | 1 + serve/mlc_serve/engine/engine_common.py | 22 ++++++++-------------- serve/mlc_serve/model/model_common.py | 1 + serve/mlc_serve/model/tvm_model.py | 11 ++++++++++- 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index d256cadbfe..33e76ae743 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -21,6 +21,7 @@ class RawLogprobsInfo: current_logprob: float top_tokens: Optional[np.array] top_logprobs: Optional[np.array] + previous_tokens: Optional[List[int]] # TODO(@sunggg): consider transition to something like Pydantic diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 1f08d892e0..4d0554294d 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -151,19 +151,13 @@ def logprob_detokenize( logprob_info.top_logprobs is not None ): top_tokens = list(zip(logprob_info.top_tokens, logprob_info.top_logprobs)) - count: Dict[str, int] = {} # dedup duplicates # Todo: Make sure decode can generate different tokens - for top_token, _ in top_tokens: - detokenized = tokenizer.decode(top_token) - if detokenized in count: - count[detokenized] += 1 - else: - count[detokenized] = 1 + if logprob_info.previous_tokens is None: + logprob_info.previous_tokens = [] for top_token, top_logprob in top_tokens: - detokenized = tokenizer.decode(top_token) - if count[detokenized] != 1: - detokenized = f"{detokenized}_{top_token}" + detokenized = tokenizer.convert_ids_to_tokens(logprob_info.previous_tokens + [top_token])[-1] + LOG.info(f"detokenized: {detokenized}") top_logprobs.append(TopLogprobs( token=detokenized, logprob=float(top_logprob), @@ -184,14 +178,14 @@ def logprob_detokenize( def logprobs_detokenize( tokenizer: TokenizerP, - logprobs_info: List[Optional[RawLogprobsInfo]], + logprob_info: List[Optional[RawLogprobsInfo]], ) -> Optional[List[Optional[LogprobsContent]]]: - if logprobs_info is None: + if logprob_info is None: return None res: List[Optional[LogprobsContent]] = [] - for logprob_info in logprobs_info: - res.append(logprob_detokenize(tokenizer, logprob_info)) + for info in logprob_info: + res.append(logprob_detokenize(tokenizer, info)) check_all = all([x is None for x in res]) if check_all: diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 0a337c1707..32234a5884 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -75,6 +75,7 @@ def fetch_raw_logprob_infos( current_logprob=res_logprob, top_tokens=top_tokens, top_logprobs=top_logprobs, + previous_tokens=None )) else: logprob_infos.append(None) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 4144c862fb..e011ddf48b 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -22,6 +22,7 @@ PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, MLCServeEngineConfig, + RawLogprobsInfo, ) from ..engine.model_module import ( DecodeRequest, @@ -84,6 +85,11 @@ def get_tvm_model(config, dev): return load_disco_module(config.model_artifact_path, lib_path, config.num_shards) +def attach_detokenization_info(logprob_info:RawLogprobsInfo, token_ids: List[int]): + if logprob_info is None: + return None + logprob_info.previous_tokens = token_ids + return logprob_info def _prepare_inputs( sequence_ids, @@ -326,6 +332,7 @@ def generate( try: next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size) + current_ids = list(input_ids.numpy()) assert next_tokens is not None outputs = [] for i, (sequence_id, new_token) in enumerate( @@ -341,9 +348,10 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, - logprob_info=[logprob_infos[i]], + logprob_info=[attach_detokenization_info(logprob_infos[i], current_ids)], ) ) + current_ids.append(new_token) else: outputs.append( TextGenerationResult( @@ -353,6 +361,7 @@ def generate( logprob_info=[logprob_infos[i]], ) ) + current_ids.append(new_token) return outputs except RuntimeError: From 0cb036fe31e9149019eb6440c6c7c28c0aee919d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 29 Jan 2024 08:46:54 +0000 Subject: [PATCH 19/39] Fix wrong usage of token ids. Remove logging. --- serve/mlc_serve/engine/engine_common.py | 1 - serve/mlc_serve/model/tvm_model.py | 17 ++++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 4d0554294d..73be36f9dd 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -157,7 +157,6 @@ def logprob_detokenize( logprob_info.previous_tokens = [] for top_token, top_logprob in top_tokens: detokenized = tokenizer.convert_ids_to_tokens(logprob_info.previous_tokens + [top_token])[-1] - LOG.info(f"detokenized: {detokenized}") top_logprobs.append(TopLogprobs( token=detokenized, logprob=float(top_logprob), diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index e011ddf48b..70147d721f 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -86,8 +86,6 @@ def get_tvm_model(config, dev): return load_disco_module(config.model_artifact_path, lib_path, config.num_shards) def attach_detokenization_info(logprob_info:RawLogprobsInfo, token_ids: List[int]): - if logprob_info is None: - return None logprob_info.previous_tokens = token_ids return logprob_info @@ -332,11 +330,10 @@ def generate( try: next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size) - current_ids = list(input_ids.numpy()) assert next_tokens is not None outputs = [] - for i, (sequence_id, new_token) in enumerate( - zip(sequence_ids, next_tokens) + for i, (sequence_id, new_token, token_ids) in enumerate( + zip(sequence_ids, next_tokens, all_token_ids) ): if not new_token in requests[i].sampling_params.appeared_tokens_freq: requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 @@ -348,20 +345,22 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, - logprob_info=[attach_detokenization_info(logprob_infos[i], current_ids)], + logprob_info=[attach_detokenization_info(logprob_infos[i], token_ids) if logprob_infos[i] else None], ) ) - current_ids.append(new_token) + if logprob_infos[i]: + token_ids.append(new_token) else: outputs.append( TextGenerationResult( sequence_id=sequence_id, generated_tokens=[new_token], error=None, - logprob_info=[logprob_infos[i]], + logprob_info=[attach_detokenization_info(logprob_infos[i], token_ids) if logprob_infos[i] else None], ) ) - current_ids.append(new_token) + if logprob_infos[i]: + token_ids.append(new_token) return outputs except RuntimeError: From ed51e7d67053dd8c36098b2d919433f3a2d730dd Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 26 Jan 2024 11:25:19 +0400 Subject: [PATCH 20/39] extend benchmarks for logprobs --- serve/benchmarks/benchmark_latency.py | 2 ++ serve/benchmarks/benchmark_throughput.py | 2 ++ serve/benchmarks/utils.py | 16 ++++++++++++++++ 3 files changed, 20 insertions(+) diff --git a/serve/benchmarks/benchmark_latency.py b/serve/benchmarks/benchmark_latency.py index c64f968640..84c377b710 100644 --- a/serve/benchmarks/benchmark_latency.py +++ b/serve/benchmarks/benchmark_latency.py @@ -34,6 +34,8 @@ def create_request(request_id): frequency_penalty=args.sampling_setting["frequency_penalty"], presence_penalty=args.sampling_setting["presence_penalty"], logit_bias=args.sampling_setting["logit_bias"], + logprobs = args.sampling_setting["logprobs"], + top_logprobs = args.sampling_setting["top_logprobs"], ), stopping_criteria=StoppingCriteria( max_tokens=args.num_output_tokens, stop_sequences=None diff --git a/serve/benchmarks/benchmark_throughput.py b/serve/benchmarks/benchmark_throughput.py index 1a7fb46451..4f82d80dd7 100644 --- a/serve/benchmarks/benchmark_throughput.py +++ b/serve/benchmarks/benchmark_throughput.py @@ -139,6 +139,8 @@ def run_mlc(engine, requests, args) -> float: frequency_penalty=args.sampling_setting["frequency_penalty"], presence_penalty=args.sampling_setting["presence_penalty"], logit_bias=args.sampling_setting["logit_bias"], + logprobs = args.sampling_setting["logprobs"], + top_logprobs = args.sampling_setting["top_logprobs"], ), stopping_criteria=StoppingCriteria( max_tokens=args.num_output_tokens, stop_sequences=None diff --git a/serve/benchmarks/utils.py b/serve/benchmarks/utils.py index defea7b099..c62c134de8 100644 --- a/serve/benchmarks/utils.py +++ b/serve/benchmarks/utils.py @@ -22,6 +22,18 @@ def add_sampling_flags(parser): action="store_true", help="Apply all penalties, logit bias, top-p and top-k.", ) + parser.add_argument( + "--logprobs", + action="store_true", + default=False, + help="Switch on logprobs output" + ) + parser.add_argument( + "--top-logprobs", + type=int, + default=5, + help="Number of top logprobs to output, limited by 5. Works only with logprobs true." + ) def postproc_sampling_args(args): @@ -51,3 +63,7 @@ def postproc_sampling_args(args): if args.apply_top_p_top_k: args.sampling_setting["top_k"] = 2 args.sampling_setting["top_p"] = 0.7 + + if args.logprobs: + args.sampling_setting["logprobs"] = True + args.sampling_setting["top_logprobs"] = args.top_logprobs From ff17ae2a802dd0d128251edd8fe21a0cd692e58b Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 26 Jan 2024 13:19:21 +0400 Subject: [PATCH 21/39] fix test without logprobs --- serve/benchmarks/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/serve/benchmarks/utils.py b/serve/benchmarks/utils.py index c62c134de8..4507dc0dd2 100644 --- a/serve/benchmarks/utils.py +++ b/serve/benchmarks/utils.py @@ -45,6 +45,8 @@ def postproc_sampling_args(args): "repetition_penalty": 1.0, "top_p": 1.0, "top_k": -1, + "logprobs": False, + "top_logprobs": 5, } if args.apply_all_sampling_params: From f5e433907da6c83f8aa2f9a2e9fcef42a96a4f64 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 26 Jan 2024 14:27:22 +0400 Subject: [PATCH 22/39] clean code --- serve/mlc_serve/engine/engine_common.py | 8 ++++---- serve/mlc_serve/openai_logprob_protocol.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 73be36f9dd..a3519ced50 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -138,8 +138,8 @@ def detokenize_incrementally( def logprob_detokenize( - tokenizer: TokenizerP, - logprob_info: Optional[RawLogprobsInfo], + tokenizer: TokenizerP, + logprob_info: Optional[RawLogprobsInfo], ) -> Optional[LogprobsContent]: """Detokenize tokens from RawLogprobInfo and convert the latter to LogprobContent""" if logprob_info is None: @@ -176,8 +176,8 @@ def logprob_detokenize( def logprobs_detokenize( - tokenizer: TokenizerP, - logprob_info: List[Optional[RawLogprobsInfo]], + tokenizer: TokenizerP, + logprob_info: List[Optional[RawLogprobsInfo]], ) -> Optional[List[Optional[LogprobsContent]]]: if logprob_info is None: return None diff --git a/serve/mlc_serve/openai_logprob_protocol.py b/serve/mlc_serve/openai_logprob_protocol.py index 9c2a4db502..36f2b693f6 100644 --- a/serve/mlc_serve/openai_logprob_protocol.py +++ b/serve/mlc_serve/openai_logprob_protocol.py @@ -25,4 +25,4 @@ class Logprobs(BaseModel): See details in https://platform.openai.com/docs/api-reference/chat/object#chat-create-logprobs """ - content: List[LogprobsContent] \ No newline at end of file + content: List[LogprobsContent] From a3f6e8b68630f833adb648a847cd20da88f147f2 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 26 Jan 2024 15:09:35 +0400 Subject: [PATCH 23/39] black format engine_common.py --- serve/mlc_serve/engine/engine_common.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index a3519ced50..01489b4ea3 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -146,23 +146,24 @@ def logprob_detokenize( return None top_logprobs: List[TopLogprobs] = [] - if ( - logprob_info.top_tokens is not None and - logprob_info.top_logprobs is not None - ): + if logprob_info.top_tokens is not None and logprob_info.top_logprobs is not None: top_tokens = list(zip(logprob_info.top_tokens, logprob_info.top_logprobs)) # dedup duplicates # Todo: Make sure decode can generate different tokens if logprob_info.previous_tokens is None: logprob_info.previous_tokens = [] for top_token, top_logprob in top_tokens: - detokenized = tokenizer.convert_ids_to_tokens(logprob_info.previous_tokens + [top_token])[-1] - top_logprobs.append(TopLogprobs( - token=detokenized, - logprob=float(top_logprob), - # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object - bytes=None, - )) + detokenized = tokenizer.convert_ids_to_tokens( + logprob_info.previous_tokens + [top_token] + )[-1] + top_logprobs.append( + TopLogprobs( + token=detokenized, + logprob=float(top_logprob), + # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object + bytes=None, + ) + ) logprobs_content = LogprobsContent( token=tokenizer.decode([logprob_info.current_token]), From c54a4103984994eb16253e122e769045ea82a3b9 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 26 Jan 2024 17:45:57 +0400 Subject: [PATCH 24/39] logprobs is strictly bool, top_logprobs is int --- serve/mlc_serve/api/handler.py | 2 +- serve/mlc_serve/api/protocol.py | 4 ++-- serve/mlc_serve/engine/sampling_params.py | 12 ++++++------ serve/mlc_serve/model/model_common.py | 10 ++-------- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index 730cba61bb..73e873bc73 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -64,7 +64,7 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params.top_p = request.top_p if request.logit_bias is not None: sampling_params.logit_bias = request.logit_bias - if request.logprobs is not None: + if request.logprobs: sampling_params.top_logprobs = request.top_logprobs sampling_params.logprobs = request.logprobs return sampling_params diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index b22a0ca54c..4f42f7233e 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -73,8 +73,8 @@ class ChatCompletionRequest(BaseModel): logit_bias: Optional[Dict[int, float]] = None user: Optional[str] = None ignore_eos: Optional[bool] = False - logprobs: Optional[bool] = False - top_logprobs: Optional[int] = None + logprobs: bool = False + top_logprobs: int = 0 class ChatCompletionResponseChoice(BaseModel): diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index d5ba5d109c..961b2b744a 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -49,7 +49,7 @@ class SamplingParams: logprobs: Optional[bool] Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. - top_logprobs: Optional[Integer] An integer between 1 and 5 specifying + top_logprobs: Optional[Integer] An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. @@ -65,8 +65,8 @@ class SamplingParams: appeared_tokens_freq: Dict[int, int] = None logit_bias_index: list[int] = None logit_bias_value: list[float] = None - logprobs: Optional[bool] = False - top_logprobs: Optional[int] = None + logprobs: bool = False + top_logprobs: int = 0 def __post_init__(self): self.appeared_tokens_freq = {} @@ -104,10 +104,10 @@ def _verify_args(self) -> None: raise ValueError( f"logit bias must be in [-100, 100], got {bias} for token {token}." ) - if self.logprobs is not None and self.logprobs: - if (self.top_logprobs < 1 or self.top_logprobs > LOGPROB_TOP_K_MAX): + if self.logprobs: + if (self.top_logprobs < 0 or self.top_logprobs > LOGPROB_TOP_K_MAX): raise ValueError( - f"top_logprobs must be between 1 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}." + f"top_logprobs must be between 0 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}." ) def _verify_greedy_sampling(self) -> None: diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 32234a5884..9d5e662b57 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -46,19 +46,13 @@ def fetch_raw_logprob_infos( logprob_infos: List[Optional[RawLogprobsInfo]] = [] num_seq = logits.shape[0] for index in range(num_seq): - if ( - sampling_params[index].logprobs is not None and - sampling_params[index].logprobs - ): + if sampling_params[index].logprobs: # Logprob sampling logprobs = torch.log_softmax(logits[index], dim=-1) res_logprob = logprobs[res_tokens[index]].cpu().numpy() top_logprobs_num = sampling_params[index].top_logprobs - if ( - top_logprobs_num is None or - top_logprobs_num == 0 - ): + if top_logprobs_num == 0: top_logprobs = None top_tokens = None else: From 379d99108db62c0fd435f2a55bf55fd0cbbf92c7 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sun, 28 Jan 2024 15:38:37 +0400 Subject: [PATCH 25/39] refactor logprob info collection to not reduce performance --- serve/mlc_serve/model/model_common.py | 120 +++++++++++++------------- 1 file changed, 58 insertions(+), 62 deletions(-) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 9d5e662b57..252e958a30 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -38,41 +38,55 @@ def get_num_cache_blocks( ) -def fetch_raw_logprob_infos( +def get_raw_logprob_info( logits, - res_tokens, - sampling_params, + token, + top_logprobs_num, +) -> RawLogprobsInfo: + logprobs = torch.log_softmax(logits, dim=-1) + res_logprob = logprobs[token].cpu().numpy() + + if top_logprobs_num == 0: + top_logprobs = None + top_tokens = None + else: + assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs" + top_logprobs, top_tokens = torch.topk( + logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True + ) + top_tokens=top_tokens.cpu().numpy() + top_logprobs=top_logprobs.cpu().numpy() + + # Set to raw logprob info + return RawLogprobsInfo( + # TODO(vvchernov): it is number, cpu().numpy()? + current_token=token.cpu().numpy(), + current_logprob=res_logprob, + top_tokens=top_tokens, + top_logprobs=top_logprobs, + previous_tokens=None + ) + + +def get_masked_logprobs( + logprob_infos: List[Optional[RawLogprobsInfo]], + mask: torch.Tensor, + sampling_params: List[SamplingParams], + logits: torch.Tensor, + tokens: torch.Tensor, ) -> List[Optional[RawLogprobsInfo]]: - logprob_infos: List[Optional[RawLogprobsInfo]] = [] - num_seq = logits.shape[0] - for index in range(num_seq): - if sampling_params[index].logprobs: - # Logprob sampling - logprobs = torch.log_softmax(logits[index], dim=-1) - res_logprob = logprobs[res_tokens[index]].cpu().numpy() - - top_logprobs_num = sampling_params[index].top_logprobs - if top_logprobs_num == 0: - top_logprobs = None - top_tokens = None - else: - assert top_logprobs_num <= LOGPROB_TOP_K_MAX, "Invalid input top_logprobs" - top_logprobs, top_tokens = torch.topk( - logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True + num_seq = len(logprob_infos) + + mask_counter = 0 + for i in range(num_seq): + if mask[i]: + if sampling_params[i].logprobs: + logprob_infos[i] = get_raw_logprob_info( + logits[mask_counter], + tokens[mask_counter], + sampling_params[i].top_logprobs, ) - top_tokens=top_tokens.cpu().numpy() - top_logprobs=top_logprobs.cpu().numpy() - - # Set to raw logprob info - logprob_infos.append(RawLogprobsInfo( - current_token=res_tokens[index].cpu().numpy(), - current_logprob=res_logprob, - top_tokens=top_tokens, - top_logprobs=top_logprobs, - previous_tokens=None - )) - else: - logprob_infos.append(None) + mask_counter = mask_counter + 1 return logprob_infos @@ -100,25 +114,6 @@ def _apply_top_p_top_k(logits, top_ps, top_ks): return logits -def update_masked_list(input_list, mask, update): - j = 0 - for i in range(len(mask)): - if mask[i]: - input_list[i] = update[j] - j = j + 1 - - return input_list - - -def filter_list_by_mask(i_list, mask): - o_list = [] - for i in range(len(mask)): - if mask[i]: - o_list.append(i_list[i]) - - return o_list - - def sample( logits: Union[tvm.nd.NDArray, torch.Tensor], sampling_params: List[SamplingParams], @@ -149,20 +144,24 @@ def _is_safe_to_sample(prob_like): logits_greedy = logits[mask_greedy_dvc] + logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq + if logits_greedy.shape[0] > 0: res_greedy = torch.argmax(logits_greedy, -1) - logprob_infos_greedy = fetch_raw_logprob_infos( + logprob_infos = get_masked_logprobs( + logprob_infos, + mask_greedy_dvc, + sampling_params, logits_greedy, res_greedy, - filter_list_by_mask(sampling_params, mask_greedy_dvc) ) res_greedy = res_greedy.cpu().numpy() # Case when there's only greedy sampling if logits_greedy.shape[0] == num_seq: torch.cuda.nvtx.range_pop() - return res_greedy, logprob_infos_greedy + return res_greedy, logprob_infos temperatures = [] top_ps = [] @@ -231,29 +230,26 @@ def _is_safe_to_sample(prob_like): res_random = torch.multinomial(probs, 1, True)[:, 0] - logprob_infos_random = fetch_raw_logprob_infos( + logprob_infos = get_masked_logprobs( + logprob_infos, + mask_random_dvc, + sampling_params, logits_random, res_random, - filter_list_by_mask(sampling_params, mask_random_dvc), ) res_random = res_random.cpu().numpy() # Case when there's only random sampling if logits_random.shape[0] == num_seq: torch.cuda.nvtx.range_pop() - return res_random, logprob_infos_random + return res_random, logprob_infos res = np.empty((num_seq,), dtype=np.int32) res[mask_random_cpu] = res_random - logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq - logprob_infos = update_masked_list(logprob_infos, mask_random_cpu, logprob_infos_random) - if logits_greedy.shape[0] > 0: res[mask_greedy_cpu] = res_greedy - logprob_infos = update_masked_list(logprob_infos, mask_greedy_cpu, logprob_infos_greedy) - torch.cuda.nvtx.range_pop() return res, logprob_infos From 58bac8f313ee2c53bd3927cc6f1710c849f1ea3e Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 29 Jan 2024 10:55:22 +0400 Subject: [PATCH 26/39] quick fix for check --- serve/mlc_serve/model/tvm_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 70147d721f..ad14fc11fa 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -252,6 +252,8 @@ def generate( ) input_shape = input_ids.shape + # TODO(vvchernov): quick fix, but need to refactor logic + current_ids = list(input_ids.numpy()) if self.disco_session: input_ids = copy_to_worker_0(self.disco_session, input_ids) From 7de8d88bd2da43b1b39275f063c780398e944e20 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 29 Jan 2024 10:57:45 +0400 Subject: [PATCH 27/39] review fix --- serve/mlc_serve/engine/engine_common.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 01489b4ea3..9f151ba042 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -180,9 +180,6 @@ def logprobs_detokenize( tokenizer: TokenizerP, logprob_info: List[Optional[RawLogprobsInfo]], ) -> Optional[List[Optional[LogprobsContent]]]: - if logprob_info is None: - return None - res: List[Optional[LogprobsContent]] = [] for info in logprob_info: res.append(logprob_detokenize(tokenizer, info)) From 661fa1827611675769b09bc78fbda657b9903042 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 29 Jan 2024 11:31:13 +0400 Subject: [PATCH 28/39] fix list index out of range --- serve/mlc_serve/engine/engine_common.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 9f151ba042..5579d6d27a 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -153,9 +153,11 @@ def logprob_detokenize( if logprob_info.previous_tokens is None: logprob_info.previous_tokens = [] for top_token, top_logprob in top_tokens: - detokenized = tokenizer.convert_ids_to_tokens( - logprob_info.previous_tokens + [top_token] - )[-1] + # TODO(vvchernov): not clear what do we want + # detokenized = tokenizer.convert_ids_to_tokens( + # logprob_info.previous_tokens + [top_token] + # )[-1] + detokenized = tokenizer.decode(top_token) top_logprobs.append( TopLogprobs( token=detokenized, From 6662a656d77cb53e1b5d9c9276f58dc682224968 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 29 Jan 2024 16:57:56 +0400 Subject: [PATCH 29/39] rollback after rebase --- serve/mlc_serve/engine/engine_common.py | 10 ++++------ serve/mlc_serve/model/tvm_model.py | 2 -- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 5579d6d27a..f2509a5f1c 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -148,16 +148,14 @@ def logprob_detokenize( top_logprobs: List[TopLogprobs] = [] if logprob_info.top_tokens is not None and logprob_info.top_logprobs is not None: top_tokens = list(zip(logprob_info.top_tokens, logprob_info.top_logprobs)) - # dedup duplicates - # Todo: Make sure decode can generate different tokens if logprob_info.previous_tokens is None: logprob_info.previous_tokens = [] for top_token, top_logprob in top_tokens: # TODO(vvchernov): not clear what do we want - # detokenized = tokenizer.convert_ids_to_tokens( - # logprob_info.previous_tokens + [top_token] - # )[-1] - detokenized = tokenizer.decode(top_token) + detokenized = tokenizer.convert_ids_to_tokens( + logprob_info.previous_tokens + [top_token] + )[-1] + # detokenized = tokenizer.decode(top_token) top_logprobs.append( TopLogprobs( token=detokenized, diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index ad14fc11fa..70147d721f 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -252,8 +252,6 @@ def generate( ) input_shape = input_ids.shape - # TODO(vvchernov): quick fix, but need to refactor logic - current_ids = list(input_ids.numpy()) if self.disco_session: input_ids = copy_to_worker_0(self.disco_session, input_ids) From 970d7f85eee2aaa14e5f77a13f5c53bf51207eab Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 29 Jan 2024 18:16:54 +0400 Subject: [PATCH 30/39] test --- serve/mlc_serve/engine/engine_common.py | 8 ++++---- serve/mlc_serve/model/tvm_model.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index f2509a5f1c..51293f51b4 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -152,10 +152,10 @@ def logprob_detokenize( logprob_info.previous_tokens = [] for top_token, top_logprob in top_tokens: # TODO(vvchernov): not clear what do we want - detokenized = tokenizer.convert_ids_to_tokens( - logprob_info.previous_tokens + [top_token] - )[-1] - # detokenized = tokenizer.decode(top_token) + # detokenized = tokenizer.convert_ids_to_tokens( + # logprob_info.previous_tokens + [top_token] + # )[-1] + detokenized = tokenizer.decode(top_token) top_logprobs.append( TopLogprobs( token=detokenized, diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 70147d721f..e9d849f9b0 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -348,8 +348,8 @@ def generate( logprob_info=[attach_detokenization_info(logprob_infos[i], token_ids) if logprob_infos[i] else None], ) ) - if logprob_infos[i]: - token_ids.append(new_token) + # if logprob_infos[i]: + # token_ids.append(new_token) else: outputs.append( TextGenerationResult( @@ -359,8 +359,8 @@ def generate( logprob_info=[attach_detokenization_info(logprob_infos[i], token_ids) if logprob_infos[i] else None], ) ) - if logprob_infos[i]: - token_ids.append(new_token) + # if logprob_infos[i]: + # token_ids.append(new_token) return outputs except RuntimeError: From c58d69c8fb8ee451b6c05cc4cfb4ff20e167e64c Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 30 Jan 2024 10:36:32 +0400 Subject: [PATCH 31/39] small fix --- serve/mlc_serve/model/model_common.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 252e958a30..50dcb7d2b7 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -59,8 +59,7 @@ def get_raw_logprob_info( # Set to raw logprob info return RawLogprobsInfo( - # TODO(vvchernov): it is number, cpu().numpy()? - current_token=token.cpu().numpy(), + current_token=token, current_logprob=res_logprob, top_tokens=top_tokens, top_logprobs=top_logprobs, From ebae20023a7d4b77a2c3b3e1f21c8279682dcc2b Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 30 Jan 2024 10:43:15 +0400 Subject: [PATCH 32/39] rename for the sake of clarity --- serve/mlc_serve/engine/base.py | 4 ++-- serve/mlc_serve/engine/engine_common.py | 10 +++++----- serve/mlc_serve/model/model_common.py | 12 ++++++------ 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 33e76ae743..aaeb96ba9e 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -17,9 +17,9 @@ @dataclass class RawLogprobsInfo: - current_token: int + current_token_id: int current_logprob: float - top_tokens: Optional[np.array] + top_token_ids: Optional[np.array] top_logprobs: Optional[np.array] previous_tokens: Optional[List[int]] diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 51293f51b4..b3c437af96 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -146,16 +146,16 @@ def logprob_detokenize( return None top_logprobs: List[TopLogprobs] = [] - if logprob_info.top_tokens is not None and logprob_info.top_logprobs is not None: - top_tokens = list(zip(logprob_info.top_tokens, logprob_info.top_logprobs)) + if logprob_info.top_token_ids is not None and logprob_info.top_logprobs is not None: + top_tokens = list(zip(logprob_info.top_token_ids, logprob_info.top_logprobs)) if logprob_info.previous_tokens is None: logprob_info.previous_tokens = [] - for top_token, top_logprob in top_tokens: + for top_token_id, top_logprob in top_tokens: # TODO(vvchernov): not clear what do we want # detokenized = tokenizer.convert_ids_to_tokens( # logprob_info.previous_tokens + [top_token] # )[-1] - detokenized = tokenizer.decode(top_token) + detokenized = tokenizer.decode(top_token_id) top_logprobs.append( TopLogprobs( token=detokenized, @@ -166,7 +166,7 @@ def logprob_detokenize( ) logprobs_content = LogprobsContent( - token=tokenizer.decode([logprob_info.current_token]), + token=tokenizer.decode([logprob_info.current_token_id]), logprob=logprob_info.current_logprob, # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object bytes=None, diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 50dcb7d2b7..ebcc21e64a 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -40,11 +40,11 @@ def get_num_cache_blocks( def get_raw_logprob_info( logits, - token, + token_id, top_logprobs_num, ) -> RawLogprobsInfo: logprobs = torch.log_softmax(logits, dim=-1) - res_logprob = logprobs[token].cpu().numpy() + res_logprob = logprobs[token_id].cpu().numpy() if top_logprobs_num == 0: top_logprobs = None @@ -59,9 +59,9 @@ def get_raw_logprob_info( # Set to raw logprob info return RawLogprobsInfo( - current_token=token, + current_token_id=token_id, current_logprob=res_logprob, - top_tokens=top_tokens, + top_token_ids=top_tokens, top_logprobs=top_logprobs, previous_tokens=None ) @@ -72,7 +72,7 @@ def get_masked_logprobs( mask: torch.Tensor, sampling_params: List[SamplingParams], logits: torch.Tensor, - tokens: torch.Tensor, + token_ids: torch.Tensor, ) -> List[Optional[RawLogprobsInfo]]: num_seq = len(logprob_infos) @@ -82,7 +82,7 @@ def get_masked_logprobs( if sampling_params[i].logprobs: logprob_infos[i] = get_raw_logprob_info( logits[mask_counter], - tokens[mask_counter], + token_ids[mask_counter], sampling_params[i].top_logprobs, ) mask_counter = mask_counter + 1 From b2863d59d34de02d7d4e4dde3e4cf526e8c097ee Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 30 Jan 2024 15:42:03 +0400 Subject: [PATCH 33/39] some fixes with cpu-gpu tensor copying --- serve/mlc_serve/model/model_common.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index ebcc21e64a..639322ff6d 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -44,7 +44,7 @@ def get_raw_logprob_info( top_logprobs_num, ) -> RawLogprobsInfo: logprobs = torch.log_softmax(logits, dim=-1) - res_logprob = logprobs[token_id].cpu().numpy() + res_logprob = logprobs[token_id] if top_logprobs_num == 0: top_logprobs = None @@ -146,7 +146,7 @@ def _is_safe_to_sample(prob_like): logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq if logits_greedy.shape[0] > 0: - res_greedy = torch.argmax(logits_greedy, -1) + res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy() logprob_infos = get_masked_logprobs( logprob_infos, @@ -156,7 +156,6 @@ def _is_safe_to_sample(prob_like): res_greedy, ) - res_greedy = res_greedy.cpu().numpy() # Case when there's only greedy sampling if logits_greedy.shape[0] == num_seq: torch.cuda.nvtx.range_pop() @@ -227,7 +226,7 @@ def _is_safe_to_sample(prob_like): torch.cuda.nvtx.range_pop() return None - res_random = torch.multinomial(probs, 1, True)[:, 0] + res_random = torch.multinomial(probs, 1, True)[:, 0].cpu().numpy() logprob_infos = get_masked_logprobs( logprob_infos, @@ -237,7 +236,6 @@ def _is_safe_to_sample(prob_like): res_random, ) - res_random = res_random.cpu().numpy() # Case when there's only random sampling if logits_random.shape[0] == num_seq: torch.cuda.nvtx.range_pop() From 57b3a3554837b4a87d4f5099cda079de7b8bba00 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 30 Jan 2024 17:08:42 +0400 Subject: [PATCH 34/39] refactor logprob pass to calculate --- serve/mlc_serve/model/model_common.py | 60 +++++++++++++++++---------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 639322ff6d..bb4b65f9e1 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -67,25 +67,41 @@ def get_raw_logprob_info( ) -def get_masked_logprobs( - logprob_infos: List[Optional[RawLogprobsInfo]], - mask: torch.Tensor, +def get_logprob_indices( sampling_params: List[SamplingParams], + num_seq: int, +) -> Tuple[List[Tuple[int, int, int]], List[Tuple[int, int, int]]]: + lgp_inds_greedy: List[Tuple[int, int, int]] = [] + lgp_inds_random: List[Tuple[int, int, int]] = [] + + g_ind = 0 + r_ind = 0 + for i in range(num_seq): + sampling_param = sampling_params[i] + if sampling_param.sampling_type == SamplingType.RANDOM: + if sampling_param.logprobs: + lgp_inds_random.append((i, r_ind, sampling_param.top_logprobs)) + r_ind = r_ind + 1 + else: + if sampling_param.logprobs: + lgp_inds_greedy.append((i, g_ind, sampling_param.top_logprobs)) + g_ind = g_ind + 1 + + return lgp_inds_greedy, lgp_inds_random + + +def get_raw_logprob_infos( + logprob_infos: List[Optional[RawLogprobsInfo]], + indices: List[Tuple[int, int, int]], logits: torch.Tensor, token_ids: torch.Tensor, ) -> List[Optional[RawLogprobsInfo]]: - num_seq = len(logprob_infos) - - mask_counter = 0 - for i in range(num_seq): - if mask[i]: - if sampling_params[i].logprobs: - logprob_infos[i] = get_raw_logprob_info( - logits[mask_counter], - token_ids[mask_counter], - sampling_params[i].top_logprobs, - ) - mask_counter = mask_counter + 1 + for (i, ind, top_logprobs) in indices: + logprob_infos[i] = get_raw_logprob_info( + logits[ind], + token_ids[ind], + top_logprobs, + ) return logprob_infos @@ -144,14 +160,17 @@ def _is_safe_to_sample(prob_like): logits_greedy = logits[mask_greedy_dvc] logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq + lgp_inds_greedy, lgp_inds_random = get_logprob_indices( + sampling_params, + num_seq, + ) if logits_greedy.shape[0] > 0: res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy() - logprob_infos = get_masked_logprobs( + logprob_infos = get_raw_logprob_infos( logprob_infos, - mask_greedy_dvc, - sampling_params, + lgp_inds_greedy, logits_greedy, res_greedy, ) @@ -228,10 +247,9 @@ def _is_safe_to_sample(prob_like): res_random = torch.multinomial(probs, 1, True)[:, 0].cpu().numpy() - logprob_infos = get_masked_logprobs( + logprob_infos = get_raw_logprob_infos( logprob_infos, - mask_random_dvc, - sampling_params, + lgp_inds_random, logits_random, res_random, ) From 4e29403d5900ba80736f245d06b99c67853acf37 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 30 Jan 2024 18:06:18 +0400 Subject: [PATCH 35/39] remove excess deps for token detokenization --- serve/mlc_serve/engine/base.py | 1 - serve/mlc_serve/engine/engine_common.py | 14 ++------------ serve/mlc_serve/engine/model_module.py | 2 +- serve/mlc_serve/model/model_common.py | 1 - serve/mlc_serve/model/tvm_model.py | 24 ++++++++---------------- 5 files changed, 11 insertions(+), 31 deletions(-) diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index aaeb96ba9e..53d991a30c 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -21,7 +21,6 @@ class RawLogprobsInfo: current_logprob: float top_token_ids: Optional[np.array] top_logprobs: Optional[np.array] - previous_tokens: Optional[List[int]] # TODO(@sunggg): consider transition to something like Pydantic diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index b3c437af96..df1c7233f9 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -148,17 +148,10 @@ def logprob_detokenize( top_logprobs: List[TopLogprobs] = [] if logprob_info.top_token_ids is not None and logprob_info.top_logprobs is not None: top_tokens = list(zip(logprob_info.top_token_ids, logprob_info.top_logprobs)) - if logprob_info.previous_tokens is None: - logprob_info.previous_tokens = [] for top_token_id, top_logprob in top_tokens: - # TODO(vvchernov): not clear what do we want - # detokenized = tokenizer.convert_ids_to_tokens( - # logprob_info.previous_tokens + [top_token] - # )[-1] - detokenized = tokenizer.decode(top_token_id) top_logprobs.append( TopLogprobs( - token=detokenized, + token=tokenizer.decode(top_token_id), logprob=float(top_logprob), # TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object bytes=None, @@ -179,14 +172,11 @@ def logprob_detokenize( def logprobs_detokenize( tokenizer: TokenizerP, logprob_info: List[Optional[RawLogprobsInfo]], -) -> Optional[List[Optional[LogprobsContent]]]: +) -> List[Optional[LogprobsContent]]: res: List[Optional[LogprobsContent]] = [] for info in logprob_info: res.append(logprob_detokenize(tokenizer, info)) - check_all = all([x is None for x in res]) - if check_all: - return None return res diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index d305c514b8..5ff91816a0 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -51,7 +51,7 @@ class TextGenerationResult: # making this a list of token ids to leave room for speculative decoding generated_tokens: List[int] error: Optional[str] - logprob_info: Optional[List[Optional[RawLogprobsInfo]]] = None + logprob_info: List[Optional[RawLogprobsInfo]] class KVCache(Protocol): diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index bb4b65f9e1..ae4fa85e66 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -63,7 +63,6 @@ def get_raw_logprob_info( current_logprob=res_logprob, top_token_ids=top_tokens, top_logprobs=top_logprobs, - previous_tokens=None ) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index e9d849f9b0..258f8b61c0 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -22,7 +22,6 @@ PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, MLCServeEngineConfig, - RawLogprobsInfo, ) from ..engine.model_module import ( DecodeRequest, @@ -85,9 +84,6 @@ def get_tvm_model(config, dev): return load_disco_module(config.model_artifact_path, lib_path, config.num_shards) -def attach_detokenization_info(logprob_info:RawLogprobsInfo, token_ids: List[int]): - logprob_info.previous_tokens = token_ids - return logprob_info def _prepare_inputs( sequence_ids, @@ -332,8 +328,8 @@ def generate( next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size) assert next_tokens is not None outputs = [] - for i, (sequence_id, new_token, token_ids) in enumerate( - zip(sequence_ids, next_tokens, all_token_ids) + for i, (sequence_id, new_token) in enumerate( + zip(sequence_ids, next_tokens) ): if not new_token in requests[i].sampling_params.appeared_tokens_freq: requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 @@ -345,22 +341,18 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, - logprob_info=[attach_detokenization_info(logprob_infos[i], token_ids) if logprob_infos[i] else None], + logprob_info=[logprob_infos[i]], ) ) - # if logprob_infos[i]: - # token_ids.append(new_token) else: outputs.append( TextGenerationResult( sequence_id=sequence_id, generated_tokens=[new_token], error=None, - logprob_info=[attach_detokenization_info(logprob_infos[i], token_ids) if logprob_infos[i] else None], + logprob_info=[logprob_infos[i]], ) ) - # if logprob_infos[i]: - # token_ids.append(new_token) return outputs except RuntimeError: @@ -398,7 +390,7 @@ def generate( ), generated_tokens=[new_token], # type: ignore error=None, - logprob_info=[logprob_infos[0]] + logprob_info=logprob_infos ) ) else: @@ -407,7 +399,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], # type: ignore error=None, - logprob_info=[logprob_infos[0]] + logprob_info=logprob_infos ) ) else: @@ -420,7 +412,7 @@ def generate( ), generated_tokens=[], error=err_msg, - logprob_info=[logprob_infos[0]] + logprob_info=logprob_infos ) ) else: @@ -429,7 +421,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[], error=err_msg, - logprob_info=[logprob_infos[0]] + logprob_info=logprob_infos ) ) From a9157b918d2a44239eb79c6220d4db8aa1b46c46 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 30 Jan 2024 19:53:27 +0400 Subject: [PATCH 36/39] small clean --- serve/mlc_serve/api/handler.py | 2 +- serve/mlc_serve/engine/base.py | 4 ++-- serve/mlc_serve/engine/staging_engine_worker.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index 73e873bc73..ee61f2cb57 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -194,7 +194,7 @@ def create_stream_response( finish_reason=seq.finish_reason.value if seq.finish_reason is not None else None, - logprob_info=Logprobs(content=seq.logprob_info) if seq.logprob_info else None + logprob_info=Logprobs(content=seq.logprob_info) if seq.logprob_info != [] else None ) for seq in res.sequences ] diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 53d991a30c..c674270673 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -165,7 +165,7 @@ class SequenceOutput: finish_reason: Optional[FinishReason] = None # Number of generated tokens so far num_generated_tokens: int = 0 - logprob_info: Optional[List[Optional[LogprobsContent]]] = None + logprob_info: List[Optional[LogprobsContent]] = field(default_factory=list) @property def is_finished(self) -> bool: @@ -175,7 +175,7 @@ def is_finished(self) -> bool: @dataclass class RequestOutput: request_id: RequestId - sequences: list[SequenceOutput] + sequences: List[SequenceOutput] # TODO: reconsider the place to put this number # Only set for outputs with valid sequence outputs num_prompt_tokens: Optional[int] = None diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index ceb84857ae..e9d62ec5e0 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -4,7 +4,7 @@ import time import multiprocessing import multiprocessing.synchronize -from dataclasses import dataclass +from dataclasses import dataclass, field from threading import Thread, Lock from typing import Callable, Optional, Union, Any, Dict, List @@ -66,7 +66,7 @@ class SequenceGenerationOutput: new_tokens: List[int] finish_reason: Optional[FinishReason] = None error: Optional[Union[str, ValidationError]] = None - logprob_info: Optional[List[RawLogprobsInfo]] = None + logprob_info: List[RawLogprobsInfo] = field(default_factory=list) @dataclass From 39efb61d23c04c1b02b826f9673ead5bdc60372a Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 31 Jan 2024 15:25:29 +0400 Subject: [PATCH 37/39] small clean --- serve/mlc_serve/api/handler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index ee61f2cb57..1c558609c8 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -43,6 +43,7 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse router = APIRouter() + def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: sampling_params = SamplingParams( # These params came from vllm @@ -179,6 +180,7 @@ def create_stream_response( ], ) yield f"data: {json.dumps(first_chunk.dict(exclude_unset=True), ensure_ascii=False)}\n\n" + async for res in result_generator: if res.error: raise RuntimeError(f"Error when generating: {res.error}") @@ -237,7 +239,7 @@ async def collect_result_stream( if seq.is_finished: assert seq.finish_reason is not None finish_reasons[seq.index] = seq.finish_reason.value # type: ignore - + choices = [] for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)): logprobs = None From 601e68dfb2db8e0edc2ed50a75029c40f6d10127 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 31 Jan 2024 18:31:05 +0400 Subject: [PATCH 38/39] return None instead of list of Nones --- serve/mlc_serve/engine/__init__.py | 1 + serve/mlc_serve/engine/base.py | 2 ++ serve/mlc_serve/engine/engine_common.py | 6 +++- serve/mlc_serve/engine/model_module.py | 4 +-- .../mlc_serve/engine/staging_engine_worker.py | 4 +-- serve/mlc_serve/model/model_common.py | 28 ++++++++++++++----- serve/mlc_serve/model/tvm_model.py | 27 ++++++++++++------ 7 files changed, 52 insertions(+), 20 deletions(-) diff --git a/serve/mlc_serve/engine/__init__.py b/serve/mlc_serve/engine/__init__.py index 92d101ea95..b2fb08a079 100644 --- a/serve/mlc_serve/engine/__init__.py +++ b/serve/mlc_serve/engine/__init__.py @@ -17,5 +17,6 @@ PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, RawLogprobsInfo, + RawLogprobsInfos, ) from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index c674270673..b66dea3479 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -22,6 +22,8 @@ class RawLogprobsInfo: top_token_ids: Optional[np.array] top_logprobs: Optional[np.array] +RawLogprobsInfos = List[Optional[RawLogprobsInfo]] + # TODO(@sunggg): consider transition to something like Pydantic @dataclass diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index df1c7233f9..675e97e173 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -12,6 +12,7 @@ from .base import ( GenerationSequence, RawLogprobsInfo, + RawLogprobsInfos, Request, RequestId, RequestState, @@ -171,8 +172,11 @@ def logprob_detokenize( def logprobs_detokenize( tokenizer: TokenizerP, - logprob_info: List[Optional[RawLogprobsInfo]], + logprob_info: Optional[RawLogprobsInfos], ) -> List[Optional[LogprobsContent]]: + if logprob_info is None: + return [] + res: List[Optional[LogprobsContent]] = [] for info in logprob_info: res.append(logprob_detokenize(tokenizer, info)) diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 5ff91816a0..c5937dd18b 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -7,7 +7,7 @@ from .base import ( ChatMessage, MLCServeEngineConfig, - RawLogprobsInfo, + RawLogprobsInfos, RequestId, RequestState, SequenceId, @@ -51,7 +51,7 @@ class TextGenerationResult: # making this a list of token ids to leave room for speculative decoding generated_tokens: List[int] error: Optional[str] - logprob_info: List[Optional[RawLogprobsInfo]] + logprob_info: Optional[RawLogprobsInfos] class KVCache(Protocol): diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index e9d62ec5e0..6c02c0811c 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -12,7 +12,7 @@ from .base import ( FinishReason, - RawLogprobsInfo, + RawLogprobsInfos, RequestId, RequestState, ValidationError, @@ -66,7 +66,7 @@ class SequenceGenerationOutput: new_tokens: List[int] finish_reason: Optional[FinishReason] = None error: Optional[Union[str, ValidationError]] = None - logprob_info: List[RawLogprobsInfo] = field(default_factory=list) + logprob_info: Optional[RawLogprobsInfos] = None @dataclass diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index ae4fa85e66..b9e23ddad0 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -11,6 +11,7 @@ SamplingParams, LOGPROB_TOP_K_MAX, RawLogprobsInfo, + RawLogprobsInfos, ) LOG = structlog.stdlib.get_logger(__name__) @@ -90,11 +91,11 @@ def get_logprob_indices( def get_raw_logprob_infos( - logprob_infos: List[Optional[RawLogprobsInfo]], + logprob_infos: RawLogprobsInfos, indices: List[Tuple[int, int, int]], logits: torch.Tensor, token_ids: torch.Tensor, -) -> List[Optional[RawLogprobsInfo]]: +) -> RawLogprobsInfos: for (i, ind, top_logprobs) in indices: logprob_infos[i] = get_raw_logprob_info( logits[ind], @@ -105,6 +106,19 @@ def get_raw_logprob_infos( return logprob_infos +def check_logprob_infos( + logprob_infos: RawLogprobsInfos, +) -> Optional[RawLogprobsInfos]: + check = False + for info in logprob_infos: + if info is not None: + check = True + break + if check: + return logprob_infos + return None + + def _apply_top_p_top_k(logits, top_ps, top_ks): p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) @@ -133,7 +147,7 @@ def sample( sampling_params: List[SamplingParams], vocab_size: int, check_safety=False, -) -> Optional[Tuple[np.ndarray, List[Optional[RawLogprobsInfo]]]]: +) -> Optional[Tuple[np.ndarray, Optional[RawLogprobsInfos]]]: def _is_safe_to_sample(prob_like): return ( torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) @@ -158,7 +172,7 @@ def _is_safe_to_sample(prob_like): logits_greedy = logits[mask_greedy_dvc] - logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * num_seq + logprob_infos: RawLogprobsInfos = [None] * num_seq lgp_inds_greedy, lgp_inds_random = get_logprob_indices( sampling_params, num_seq, @@ -177,7 +191,7 @@ def _is_safe_to_sample(prob_like): # Case when there's only greedy sampling if logits_greedy.shape[0] == num_seq: torch.cuda.nvtx.range_pop() - return res_greedy, logprob_infos + return res_greedy, check_logprob_infos(logprob_infos) temperatures = [] top_ps = [] @@ -256,7 +270,7 @@ def _is_safe_to_sample(prob_like): # Case when there's only random sampling if logits_random.shape[0] == num_seq: torch.cuda.nvtx.range_pop() - return res_random, logprob_infos + return res_random, check_logprob_infos(logprob_infos) res = np.empty((num_seq,), dtype=np.int32) res[mask_random_cpu] = res_random @@ -265,7 +279,7 @@ def _is_safe_to_sample(prob_like): res[mask_greedy_cpu] = res_greedy torch.cuda.nvtx.range_pop() - return res, logprob_infos + return res, check_logprob_infos(logprob_infos) def prepare_inputs( diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 258f8b61c0..eb5cfc30d1 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -1,6 +1,6 @@ import math import os -from typing import List, Union, Tuple, Sequence +from typing import List, Optional, Union, Tuple, Sequence import structlog import numpy as np @@ -18,8 +18,9 @@ ) from ..engine import ( - SequenceId, PROMPT_SEQEUNCE_INDEX, + RawLogprobsInfos, + SequenceId, get_prompt_sequence_id, MLCServeEngineConfig, ) @@ -203,6 +204,16 @@ def profile_memory_usage(self, seq_lens): return self.get_used_memory() + def get_logprob_infos( + self, + i: int, + logprob_infos: Optional[RawLogprobsInfos], + ) -> Optional[RawLogprobsInfos]: + if logprob_infos is None or logprob_infos[i] is None: + return None + return [logprob_infos[i]] + + def generate( self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], @@ -341,7 +352,7 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, - logprob_info=[logprob_infos[i]], + logprob_info=self.get_logprob_infos(i, logprob_infos), ) ) else: @@ -350,7 +361,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], error=None, - logprob_info=[logprob_infos[i]], + logprob_info=self.get_logprob_infos(i, logprob_infos), ) ) @@ -390,7 +401,7 @@ def generate( ), generated_tokens=[new_token], # type: ignore error=None, - logprob_info=logprob_infos + logprob_info=self.get_logprob_infos(0, logprob_infos), ) ) else: @@ -399,7 +410,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], # type: ignore error=None, - logprob_info=logprob_infos + logprob_info=self.get_logprob_infos(0, logprob_infos), ) ) else: @@ -412,7 +423,7 @@ def generate( ), generated_tokens=[], error=err_msg, - logprob_info=logprob_infos + logprob_info=self.get_logprob_infos(0, logprob_infos), ) ) else: @@ -421,7 +432,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[], error=err_msg, - logprob_info=logprob_infos + logprob_info=self.get_logprob_infos(0, logprob_infos), ) ) From 7ec21a721c494bcafbc6a21d59e863ff2812ef14 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 31 Jan 2024 21:55:45 +0400 Subject: [PATCH 39/39] fix mypy --- serve/mlc_serve/model/dummy_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/serve/mlc_serve/model/dummy_model.py b/serve/mlc_serve/model/dummy_model.py index 5d1084951d..4e41508c1e 100644 --- a/serve/mlc_serve/model/dummy_model.py +++ b/serve/mlc_serve/model/dummy_model.py @@ -123,6 +123,7 @@ def generate( generated_tokens=[req.token_ids[-1] + 1], # generated_tokens=[1], error=None, + logprob_info=None, ) ) return result