Skip to content

Commit

Permalink
Enable Logprobs in MLC Batch Serving (#82)
Browse files Browse the repository at this point in the history
* Squashed commit for logprobs implementation.

Co-authored-by: Valery Chernov <[email protected]>
Co-authored-by: Ilya Kozulin <[email protected]>

* fix None check

* Change detokenization to using token ids.

* Fix wrong usage of token ids. Remove logging.

* extend benchmarks for logprobs

* fix test without logprobs

* clean code

* black format engine_common.py

* logprobs is strictly bool, top_logprobs is int

* refactor logprob info collection to not reduce performance

* quick fix for check

* review fix

* fix list index out of range

* rollback after rebase

* test

* Squashed commit for logprobs implementation.

Co-authored-by: Valery Chernov <[email protected]>
Co-authored-by: Ilya Kozulin <[email protected]>

* fix None check

* Change detokenization to using token ids.

* Fix wrong usage of token ids. Remove logging.

* extend benchmarks for logprobs

* fix test without logprobs

* clean code

* black format engine_common.py

* logprobs is strictly bool, top_logprobs is int

* refactor logprob info collection to not reduce performance

* quick fix for check

* review fix

* fix list index out of range

* rollback after rebase

* test

* small fix

* rename for the sake of clarity

* some fixes with cpu-gpu tensor copying

* refactor logprob pass to calculate

* remove excess deps for token detokenization

* small clean

* small clean

* return None instead of list of Nones

* fix mypy

---------

Co-authored-by: Valery Chernov <[email protected]>
Co-authored-by: Ilya Kozulin <[email protected]>
Co-authored-by: Valery Chernov <[email protected]>
  • Loading branch information
4 people authored Jan 31, 2024
1 parent 4535ff5 commit 2b3fcf0
Show file tree
Hide file tree
Showing 22 changed files with 376 additions and 53 deletions.
2 changes: 2 additions & 0 deletions serve/benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def create_request(request_id):
frequency_penalty=args.sampling_setting["frequency_penalty"],
presence_penalty=args.sampling_setting["presence_penalty"],
logit_bias=args.sampling_setting["logit_bias"],
logprobs = args.sampling_setting["logprobs"],
top_logprobs = args.sampling_setting["top_logprobs"],
),
stopping_criteria=StoppingCriteria(
max_tokens=args.num_output_tokens, stop_sequences=None
Expand Down
2 changes: 2 additions & 0 deletions serve/benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def run_mlc(engine, requests, args) -> float:
frequency_penalty=args.sampling_setting["frequency_penalty"],
presence_penalty=args.sampling_setting["presence_penalty"],
logit_bias=args.sampling_setting["logit_bias"],
logprobs = args.sampling_setting["logprobs"],
top_logprobs = args.sampling_setting["top_logprobs"],
),
stopping_criteria=StoppingCriteria(
max_tokens=args.num_output_tokens, stop_sequences=None
Expand Down
18 changes: 18 additions & 0 deletions serve/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ def add_sampling_flags(parser):
action="store_true",
help="Apply all penalties, logit bias, top-p and top-k.",
)
parser.add_argument(
"--logprobs",
action="store_true",
default=False,
help="Switch on logprobs output"
)
parser.add_argument(
"--top-logprobs",
type=int,
default=5,
help="Number of top logprobs to output, limited by 5. Works only with logprobs true."
)


def postproc_sampling_args(args):
Expand All @@ -33,6 +45,8 @@ def postproc_sampling_args(args):
"repetition_penalty": 1.0,
"top_p": 1.0,
"top_k": -1,
"logprobs": False,
"top_logprobs": 5,
}

if args.apply_all_sampling_params:
Expand All @@ -51,3 +65,7 @@ def postproc_sampling_args(args):
if args.apply_top_p_top_k:
args.sampling_setting["top_k"] = 2
args.sampling_setting["top_p"] = 0.7

if args.logprobs:
args.sampling_setting["logprobs"] = True
args.sampling_setting["top_logprobs"] = args.top_logprobs
27 changes: 21 additions & 6 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse

# TODO(amalyshe): hadnle random_seed
# TODO(amalyshe): handle random_seed
# from .base import set_global_random_seed
from ..api.protocol import (
ChatCompletionRequest,
Expand All @@ -20,6 +20,7 @@
ChatMessage,
DeltaMessage,
ErrorResponse,
Logprobs,
UsageInfo,
)
from ..engine import (
Expand Down Expand Up @@ -64,6 +65,9 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
sampling_params.top_p = request.top_p
if request.logit_bias is not None:
sampling_params.logit_bias = request.logit_bias
if request.logprobs:
sampling_params.top_logprobs = request.top_logprobs
sampling_params.logprobs = request.logprobs
return sampling_params


Expand Down Expand Up @@ -156,7 +160,7 @@ async def generate_completion_stream(
created_time = int(time.time())

def create_stream_response(
choices: list[ChatCompletionResponseStreamChoice],
choices: List[ChatCompletionResponseStreamChoice],
) -> ChatCompletionStreamResponse:
return ChatCompletionStreamResponse(
id=request_id,
Expand Down Expand Up @@ -192,6 +196,7 @@ def create_stream_response(
finish_reason=seq.finish_reason.value
if seq.finish_reason is not None
else None,
logprob_info=Logprobs(content=seq.logprob_info) if seq.logprob_info != [] else None
)
for seq in res.sequences
]
Expand All @@ -212,6 +217,7 @@ async def collect_result_stream(
finish_reasons = [None] * num_sequences
num_prompt_tokens = 0
num_generated_tokens = [0 for _ in range(num_sequences)]
logprob_infos = [[] for _ in range(num_sequences)] # type: ignore
async for res in result_generator:
# TODO: verify that the request cancellation happens after this returns
if res.error:
Expand All @@ -226,18 +232,27 @@ async def collect_result_stream(
if seq.delta:
sequences[seq.index].append(seq.delta)

if seq.logprob_info:
assert seq.delta
logprob_infos[seq.index].extend(seq.logprob_info)

if seq.is_finished:
assert seq.finish_reason is not None
finish_reasons[seq.index] = seq.finish_reason.value # type: ignore

choices = [
ChatCompletionResponseChoice(
choices = []
for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)):
logprobs = None
if logprob_info_seq != []:
logprobs = Logprobs(content=logprob_info_seq)

choice = ChatCompletionResponseChoice(
index=index,
message=ChatMessage(role="assistant", content="".join(chunks)),
finish_reason=finish_reason,
logprobs=logprobs,
)
for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons))
]
choices.append(choice)

usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
Expand Down
6 changes: 6 additions & 0 deletions serve/mlc_serve/api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from pydantic import BaseModel, Field

from ..openai_logprob_protocol import Logprobs


class ErrorResponse(BaseModel):
object: str = "error"
Expand Down Expand Up @@ -71,11 +73,14 @@ class ChatCompletionRequest(BaseModel):
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
ignore_eos: Optional[bool] = False
logprobs: bool = False
top_logprobs: int = 0


class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Logprobs] = None
finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None


Expand All @@ -96,6 +101,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[Logprobs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


Expand Down
4 changes: 3 additions & 1 deletion serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,7 @@
RequestState,
PROMPT_SEQEUNCE_INDEX,
get_prompt_sequence_id,
RawLogprobsInfo,
RawLogprobsInfos,
)
from .sampling_params import SamplingParams, SamplingType
from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX
5 changes: 2 additions & 3 deletions serve/mlc_serve/engine/async_connector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import structlog
from typing import AsyncIterator, Any
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncIterator, Dict
from collections import deque

from .base import (
Expand All @@ -26,7 +25,7 @@ def __init__(self, engine: InferenceEngine, engine_wait_timeout=1):
self.engine_loop_task = None
self.engine_loop_exception = None
self.shutdown_event = asyncio.Event()
self.result_queues = dict[RequestId, ResultQueue]()
self.result_queues: Dict[RequestId, ResultQueue] = {}
self.recent_cancelled_requests = deque[RequestId](maxlen=64)

async def start(self):
Expand Down
15 changes: 14 additions & 1 deletion serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,25 @@

from typing import List, Callable, Any, Optional, Dict
import inspect
import numpy as np

from .sampling_params import SamplingParams, SamplingType
from ..openai_logprob_protocol import LogprobsContent

LOG = structlog.stdlib.get_logger(__name__)
RequestId = str


@dataclass
class RawLogprobsInfo:
current_token_id: int
current_logprob: float
top_token_ids: Optional[np.array]
top_logprobs: Optional[np.array]

RawLogprobsInfos = List[Optional[RawLogprobsInfo]]


# TODO(@sunggg): consider transition to something like Pydantic
@dataclass
class MLCServeEngineConfig:
Expand Down Expand Up @@ -155,6 +167,7 @@ class SequenceOutput:
finish_reason: Optional[FinishReason] = None
# Number of generated tokens so far
num_generated_tokens: int = 0
logprob_info: List[Optional[LogprobsContent]] = field(default_factory=list)

@property
def is_finished(self) -> bool:
Expand All @@ -164,7 +177,7 @@ def is_finished(self) -> bool:
@dataclass
class RequestOutput:
request_id: RequestId
sequences: list[SequenceOutput]
sequences: List[SequenceOutput]
# TODO: reconsider the place to put this number
# Only set for outputs with valid sequence outputs
num_prompt_tokens: Optional[int] = None
Expand Down
6 changes: 3 additions & 3 deletions serve/mlc_serve/engine/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@


class DummyInferenceEngine:
def __init__(self):
self.queue_lock = Lock()
self.has_new_requests = Condition(self.queue_lock)
def __init__(self) -> None:
self.queue_lock: Lock = Lock()
self.has_new_requests: Condition = Condition(self.queue_lock)
self.request_queue: Dict[RequestId, int] = {}

def add(self, requests: list[Request]):
Expand Down
53 changes: 51 additions & 2 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
"""

import time
from typing import Tuple, Deque, Dict, Optional, Union, Callable
from typing import Tuple, Deque, Dict, Optional, Union, Callable, List
from collections import deque
from threading import Condition, Lock

import structlog

from .base import (
GenerationSequence,
RawLogprobsInfo,
RawLogprobsInfos,
Request,
RequestId,
RequestState,
GenerationSequence,
SequenceId,
StoppingCriteria,
)
Expand All @@ -27,6 +29,7 @@
Tokenizer as TokenizerP,
)
from ..model.base import ModelArtifactConfig
from ..openai_logprob_protocol import LogprobsContent, TopLogprobs

LOG = structlog.stdlib.get_logger(__name__)

Expand Down Expand Up @@ -135,6 +138,52 @@ def detokenize_incrementally(
return delta


def logprob_detokenize(
tokenizer: TokenizerP,
logprob_info: Optional[RawLogprobsInfo],
) -> Optional[LogprobsContent]:
"""Detokenize tokens from RawLogprobInfo and convert the latter to LogprobContent"""
if logprob_info is None:
return None

top_logprobs: List[TopLogprobs] = []
if logprob_info.top_token_ids is not None and logprob_info.top_logprobs is not None:
top_tokens = list(zip(logprob_info.top_token_ids, logprob_info.top_logprobs))
for top_token_id, top_logprob in top_tokens:
top_logprobs.append(
TopLogprobs(
token=tokenizer.decode(top_token_id),
logprob=float(top_logprob),
# TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object
bytes=None,
)
)

logprobs_content = LogprobsContent(
token=tokenizer.decode([logprob_info.current_token_id]),
logprob=logprob_info.current_logprob,
# TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object
bytes=None,
top_logprobs=top_logprobs,
)

return logprobs_content


def logprobs_detokenize(
tokenizer: TokenizerP,
logprob_info: Optional[RawLogprobsInfos],
) -> List[Optional[LogprobsContent]]:
if logprob_info is None:
return []

res: List[Optional[LogprobsContent]] = []
for info in logprob_info:
res.append(logprob_detokenize(tokenizer, info))

return res


def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
if stopping_criteria.stop_sequences:
for t in stopping_criteria.stop_sequences:
Expand Down
10 changes: 9 additions & 1 deletion serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
from dataclasses import dataclass
from typing import Optional, Protocol, Union, List, Sequence

from .base import ChatMessage, RequestId, MLCServeEngineConfig, RequestState, SequenceId
from .base import (
ChatMessage,
MLCServeEngineConfig,
RawLogprobsInfos,
RequestId,
RequestState,
SequenceId,
)
from ..model.base import ModelArtifactConfig
from .sampling_params import SamplingParams

Expand Down Expand Up @@ -44,6 +51,7 @@ class TextGenerationResult:
# making this a list of token ids to leave room for speculative decoding
generated_tokens: List[int]
error: Optional[str]
logprob_info: Optional[RawLogprobsInfos]


class KVCache(Protocol):
Expand Down
Loading

0 comments on commit 2b3fcf0

Please sign in to comment.