Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Logprobs in MLC Batch Serving #82

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
ab47b41
Squashed commit for logprobs implementation.
zxybazh Jan 22, 2024
86f6fa1
fix None check
Jan 23, 2024
9a29650
Change detokenization to using token ids.
zxybazh Jan 25, 2024
012388d
Fix wrong usage of token ids. Remove logging.
zxybazh Jan 29, 2024
db31164
extend benchmarks for logprobs
Jan 26, 2024
be81755
fix test without logprobs
Jan 26, 2024
e8ec3fc
clean code
Jan 26, 2024
49187f5
black format engine_common.py
Jan 26, 2024
013ed5a
logprobs is strictly bool, top_logprobs is int
Jan 26, 2024
79ec413
refactor logprob info collection to not reduce performance
Jan 28, 2024
fca1a6f
quick fix for check
Jan 29, 2024
675b631
review fix
Jan 29, 2024
18f80fa
fix list index out of range
Jan 29, 2024
29ea525
rollback after rebase
Jan 29, 2024
aa99322
test
Jan 29, 2024
8fa785e
Merge pull request #7 from Deelvin/vc/benchmark
Jan 29, 2024
d57b197
Squashed commit for logprobs implementation.
zxybazh Jan 22, 2024
7995c84
fix None check
Jan 23, 2024
ae3fc5b
Change detokenization to using token ids.
zxybazh Jan 25, 2024
0cb036f
Fix wrong usage of token ids. Remove logging.
zxybazh Jan 29, 2024
ed51e7d
extend benchmarks for logprobs
Jan 26, 2024
ff17ae2
fix test without logprobs
Jan 26, 2024
f5e4339
clean code
Jan 26, 2024
a3f6e8b
black format engine_common.py
Jan 26, 2024
c54a410
logprobs is strictly bool, top_logprobs is int
Jan 26, 2024
379d991
refactor logprob info collection to not reduce performance
Jan 28, 2024
58bac8f
quick fix for check
Jan 29, 2024
7de8d88
review fix
Jan 29, 2024
661fa18
fix list index out of range
Jan 29, 2024
6662a65
rollback after rebase
Jan 29, 2024
970d7f8
test
Jan 29, 2024
c58d69c
small fix
Jan 30, 2024
ebae200
rename for the sake of clarity
Jan 30, 2024
b2863d5
some fixes with cpu-gpu tensor copying
Jan 30, 2024
57b3a35
refactor logprob pass to calculate
Jan 30, 2024
4e29403
remove excess deps for token detokenization
Jan 30, 2024
a9157b9
small clean
Jan 30, 2024
39efb61
small clean
Jan 31, 2024
601e68d
return None instead of list of Nones
Jan 31, 2024
4f9241b
resolve conflicts
Jan 31, 2024
7ec21a7
fix mypy
Jan 31, 2024
7aa60ed
Merge pull request #8 from Deelvin/vc/perf
Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions serve/benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def create_request(request_id):
frequency_penalty=args.sampling_setting["frequency_penalty"],
presence_penalty=args.sampling_setting["presence_penalty"],
logit_bias=args.sampling_setting["logit_bias"],
logprobs = args.sampling_setting["logprobs"],
top_logprobs = args.sampling_setting["top_logprobs"],
),
stopping_criteria=StoppingCriteria(
max_tokens=args.num_output_tokens, stop_sequences=None
Expand Down
2 changes: 2 additions & 0 deletions serve/benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def run_mlc(engine, requests, args) -> float:
frequency_penalty=args.sampling_setting["frequency_penalty"],
presence_penalty=args.sampling_setting["presence_penalty"],
logit_bias=args.sampling_setting["logit_bias"],
logprobs = args.sampling_setting["logprobs"],
top_logprobs = args.sampling_setting["top_logprobs"],
),
stopping_criteria=StoppingCriteria(
max_tokens=args.num_output_tokens, stop_sequences=None
Expand Down
18 changes: 18 additions & 0 deletions serve/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ def add_sampling_flags(parser):
action="store_true",
help="Apply all penalties, logit bias, top-p and top-k.",
)
parser.add_argument(
"--logprobs",
action="store_true",
default=False,
help="Switch on logprobs output"
)
parser.add_argument(
"--top-logprobs",
type=int,
default=5,
help="Number of top logprobs to output, limited by 5. Works only with logprobs true."
)


def postproc_sampling_args(args):
Expand All @@ -33,6 +45,8 @@ def postproc_sampling_args(args):
"repetition_penalty": 1.0,
"top_p": 1.0,
"top_k": -1,
"logprobs": False,
"top_logprobs": 5,
}

if args.apply_all_sampling_params:
Expand All @@ -51,3 +65,7 @@ def postproc_sampling_args(args):
if args.apply_top_p_top_k:
args.sampling_setting["top_k"] = 2
args.sampling_setting["top_p"] = 0.7

if args.logprobs:
args.sampling_setting["logprobs"] = True
args.sampling_setting["top_logprobs"] = args.top_logprobs
27 changes: 21 additions & 6 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse

# TODO(amalyshe): hadnle random_seed
# TODO(amalyshe): handle random_seed
# from .base import set_global_random_seed
from ..api.protocol import (
ChatCompletionRequest,
Expand All @@ -20,6 +20,7 @@
ChatMessage,
DeltaMessage,
ErrorResponse,
Logprobs,
UsageInfo,
)
from ..engine import (
Expand Down Expand Up @@ -64,6 +65,9 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
sampling_params.top_p = request.top_p
if request.logit_bias is not None:
sampling_params.logit_bias = request.logit_bias
if request.logprobs:
sampling_params.top_logprobs = request.top_logprobs
sampling_params.logprobs = request.logprobs
return sampling_params


Expand Down Expand Up @@ -156,7 +160,7 @@ async def generate_completion_stream(
created_time = int(time.time())

def create_stream_response(
choices: list[ChatCompletionResponseStreamChoice],
choices: List[ChatCompletionResponseStreamChoice],
) -> ChatCompletionStreamResponse:
return ChatCompletionStreamResponse(
id=request_id,
Expand Down Expand Up @@ -192,6 +196,7 @@ def create_stream_response(
finish_reason=seq.finish_reason.value
if seq.finish_reason is not None
else None,
logprob_info=Logprobs(content=seq.logprob_info) if seq.logprob_info != [] else None
)
for seq in res.sequences
]
Expand All @@ -212,6 +217,7 @@ async def collect_result_stream(
finish_reasons = [None] * num_sequences
num_prompt_tokens = 0
num_generated_tokens = [0 for _ in range(num_sequences)]
logprob_infos = [[] for _ in range(num_sequences)] # type: ignore
async for res in result_generator:
# TODO: verify that the request cancellation happens after this returns
if res.error:
Expand All @@ -226,18 +232,27 @@ async def collect_result_stream(
if seq.delta:
sequences[seq.index].append(seq.delta)

if seq.logprob_info:
assert seq.delta
logprob_infos[seq.index].extend(seq.logprob_info)

if seq.is_finished:
assert seq.finish_reason is not None
finish_reasons[seq.index] = seq.finish_reason.value # type: ignore

choices = [
ChatCompletionResponseChoice(
choices = []
for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)):
logprobs = None
if logprob_info_seq != []:
logprobs = Logprobs(content=logprob_info_seq)

choice = ChatCompletionResponseChoice(
index=index,
message=ChatMessage(role="assistant", content="".join(chunks)),
finish_reason=finish_reason,
logprobs=logprobs,
)
for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons))
]
choices.append(choice)

usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
Expand Down
6 changes: 6 additions & 0 deletions serve/mlc_serve/api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from pydantic import BaseModel, Field

from ..openai_logprob_protocol import Logprobs


class ErrorResponse(BaseModel):
object: str = "error"
Expand Down Expand Up @@ -71,11 +73,14 @@ class ChatCompletionRequest(BaseModel):
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
ignore_eos: Optional[bool] = False
logprobs: bool = False
top_logprobs: int = 0


class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Logprobs] = None
finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None


Expand All @@ -96,6 +101,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[Logprobs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


Expand Down
4 changes: 3 additions & 1 deletion serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,7 @@
RequestState,
PROMPT_SEQEUNCE_INDEX,
get_prompt_sequence_id,
RawLogprobsInfo,
RawLogprobsInfos,
)
from .sampling_params import SamplingParams, SamplingType
from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX
5 changes: 2 additions & 3 deletions serve/mlc_serve/engine/async_connector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import structlog
from typing import AsyncIterator, Any
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncIterator, Dict
from collections import deque

from .base import (
Expand All @@ -26,7 +25,7 @@ def __init__(self, engine: InferenceEngine, engine_wait_timeout=1):
self.engine_loop_task = None
self.engine_loop_exception = None
self.shutdown_event = asyncio.Event()
self.result_queues = dict[RequestId, ResultQueue]()
self.result_queues: Dict[RequestId, ResultQueue] = {}
self.recent_cancelled_requests = deque[RequestId](maxlen=64)

async def start(self):
Expand Down
15 changes: 14 additions & 1 deletion serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,25 @@

from typing import List, Callable, Any, Optional, Dict
import inspect
import numpy as np

from .sampling_params import SamplingParams, SamplingType
from ..openai_logprob_protocol import LogprobsContent

LOG = structlog.stdlib.get_logger(__name__)
RequestId = str


@dataclass
class RawLogprobsInfo:
current_token_id: int
current_logprob: float
top_token_ids: Optional[np.array]
top_logprobs: Optional[np.array]

RawLogprobsInfos = List[Optional[RawLogprobsInfo]]


# TODO(@sunggg): consider transition to something like Pydantic
@dataclass
class MLCServeEngineConfig:
Expand Down Expand Up @@ -155,6 +167,7 @@ class SequenceOutput:
finish_reason: Optional[FinishReason] = None
# Number of generated tokens so far
num_generated_tokens: int = 0
logprob_info: List[Optional[LogprobsContent]] = field(default_factory=list)

@property
def is_finished(self) -> bool:
Expand All @@ -164,7 +177,7 @@ def is_finished(self) -> bool:
@dataclass
class RequestOutput:
request_id: RequestId
sequences: list[SequenceOutput]
sequences: List[SequenceOutput]
# TODO: reconsider the place to put this number
# Only set for outputs with valid sequence outputs
num_prompt_tokens: Optional[int] = None
Expand Down
6 changes: 3 additions & 3 deletions serve/mlc_serve/engine/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@


class DummyInferenceEngine:
def __init__(self):
self.queue_lock = Lock()
self.has_new_requests = Condition(self.queue_lock)
def __init__(self) -> None:
self.queue_lock: Lock = Lock()
self.has_new_requests: Condition = Condition(self.queue_lock)
self.request_queue: Dict[RequestId, int] = {}

def add(self, requests: list[Request]):
Expand Down
53 changes: 51 additions & 2 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
"""

import time
from typing import Tuple, Deque, Dict, Optional, Union, Callable
from typing import Tuple, Deque, Dict, Optional, Union, Callable, List
from collections import deque
from threading import Condition, Lock

import structlog

from .base import (
GenerationSequence,
RawLogprobsInfo,
RawLogprobsInfos,
Request,
RequestId,
RequestState,
GenerationSequence,
SequenceId,
StoppingCriteria,
)
Expand All @@ -27,6 +29,7 @@
Tokenizer as TokenizerP,
)
from ..model.base import ModelArtifactConfig
from ..openai_logprob_protocol import LogprobsContent, TopLogprobs

LOG = structlog.stdlib.get_logger(__name__)

Expand Down Expand Up @@ -135,6 +138,52 @@ def detokenize_incrementally(
return delta


def logprob_detokenize(
tokenizer: TokenizerP,
logprob_info: Optional[RawLogprobsInfo],
) -> Optional[LogprobsContent]:
masahi marked this conversation as resolved.
Show resolved Hide resolved
"""Detokenize tokens from RawLogprobInfo and convert the latter to LogprobContent"""
if logprob_info is None:
return None

top_logprobs: List[TopLogprobs] = []
if logprob_info.top_token_ids is not None and logprob_info.top_logprobs is not None:
top_tokens = list(zip(logprob_info.top_token_ids, logprob_info.top_logprobs))
for top_token_id, top_logprob in top_tokens:
top_logprobs.append(
TopLogprobs(
token=tokenizer.decode(top_token_id),
logprob=float(top_logprob),
# TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object
bytes=None,
)
)

logprobs_content = LogprobsContent(
token=tokenizer.decode([logprob_info.current_token_id]),
logprob=logprob_info.current_logprob,
# TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object
bytes=None,
top_logprobs=top_logprobs,
)

return logprobs_content


def logprobs_detokenize(
tokenizer: TokenizerP,
logprob_info: Optional[RawLogprobsInfos],
) -> List[Optional[LogprobsContent]]:
if logprob_info is None:
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
return []

res: List[Optional[LogprobsContent]] = []
for info in logprob_info:
res.append(logprob_detokenize(tokenizer, info))

return res
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a personal taste but I want to see new lines here between L190 / L191 and L192 / L193. I think that would be more readable.



def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
if stopping_criteria.stop_sequences:
for t in stopping_criteria.stop_sequences:
Expand Down
10 changes: 9 additions & 1 deletion serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
from dataclasses import dataclass
from typing import Optional, Protocol, Union, List, Sequence

from .base import ChatMessage, RequestId, MLCServeEngineConfig, RequestState, SequenceId
from .base import (
ChatMessage,
MLCServeEngineConfig,
RawLogprobsInfos,
RequestId,
RequestState,
SequenceId,
)
from ..model.base import ModelArtifactConfig
from .sampling_params import SamplingParams

Expand Down Expand Up @@ -44,6 +51,7 @@ class TextGenerationResult:
# making this a list of token ids to leave room for speculative decoding
generated_tokens: List[int]
error: Optional[str]
logprob_info: Optional[RawLogprobsInfos]


class KVCache(Protocol):
Expand Down
Loading
Loading