Skip to content

Commit

Permalink
Merge pull request #3 from Deelvin/vc/update
Browse files Browse the repository at this point in the history
Updates after review
  • Loading branch information
zxybazh authored Jan 9, 2024
2 parents f11b7f8 + b2850ba commit 4c56eac
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 122 deletions.
26 changes: 7 additions & 19 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def create_stream_response(
finish_reason=seq.finish_reason.value
if seq.finish_reason is not None
else None,
logprob_info=seq.logprob_info[0] if seq.logprob_info else None
logprob_info=Logprobs(content=[seq.logprob_info]) if seq.logprob_info else None
)
for seq in res.sequences
]
Expand Down Expand Up @@ -241,28 +241,16 @@ async def collect_result_stream(
finish_reasons[seq.index] = seq.finish_reason.value # type: ignore

choices = []
for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)):
content = []
if logprob_infos[index] != []:
for logprob_info in logprob_infos[index]:
top_logprobs = [TopLogprobs(
token=str(token),
logprob=float(logprob),
# TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object
bytes=None,
) for token, logprob in logprob_info[1]]
content.append(LogprobsContent(
token=str(logprob_info[0][0]),
logprob=float(logprob_info[0][1]),
# TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object
bytes=None,
top_logprobs=top_logprobs,
))
for index, (logprob_info_seq, chunks, finish_reason) in enumerate(zip(logprob_infos, sequences, finish_reasons)):
logprobs = None
if logprob_info_seq != []:
logprobs = Logprobs(content=logprob_info_seq)

choice = ChatCompletionResponseChoice(
index=index,
message=ChatMessage(role="assistant", content="".join(chunks)),
finish_reason=finish_reason,
logprobs=Logprobs(content=content),
logprobs=logprobs,
)
choices.append(choice)

Expand Down
4 changes: 2 additions & 2 deletions serve/mlc_serve/api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class LogprobsContent(BaseModel):
token: str
logprob: float
bytes: Optional[List] = None
top_logprobs: List[TopLogprobs]
top_logprobs: List[TopLogprobs] # It can be empty


class Logprobs(BaseModel):
Expand All @@ -98,7 +98,7 @@ class Logprobs(BaseModel):
See details in https://platform.openai.com/docs/api-reference/chat/object#chat-create-logprobs
"""

content: Optional[List[LogprobsContent]]
content: List[LogprobsContent]


class ChatCompletionResponseChoice(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
RequestState,
PROMPT_SEQEUNCE_INDEX,
get_prompt_sequence_id,
LOGPROBS_TYPE,
RawLogprobsInfo,
)
from .sampling_params import SamplingParams, SamplingType, TOP_LOGPROBS_NUMBER
from .sampling_params import SamplingParams, SamplingType, LOGPROB_TOP_K_MAX
14 changes: 11 additions & 3 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@

from typing import List, Callable, Any, Optional, Dict, Tuple
import inspect
import numpy as np

from .sampling_params import SamplingParams, SamplingType
from ..api.protocol import LogprobsContent

RequestId = str
LOGPROBS_TYPE = Tuple[Tuple, List[Tuple]]
# ((token, logprob), [(top1_token, top1_logprob), ...])


@dataclass
class RawLogprobsInfo:
current_token: int
current_logprob: float
top_tokens: Optional[np.array]
top_logprobs: Optional[np.array]


# TODO(@sunggg): consider transition to something like Pydantic
Expand Down Expand Up @@ -163,7 +171,7 @@ class SequenceOutput:
finish_reason: Optional[FinishReason] = None
# Number of generated tokens so far
num_generated_tokens: int = 0
logprob_info: Optional[LOGPROBS_TYPE] = None
logprob_info: Optional[LogprobsContent] = None

@property
def is_finished(self) -> bool:
Expand Down
67 changes: 44 additions & 23 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import structlog

from .base import (
GenerationSequence,
RawLogprobsInfo,
Request,
RequestId,
RequestState,
GenerationSequence,
SequenceId,
StoppingCriteria,
LOGPROBS_TYPE,
)
from .model_module import (
DecodeRequest,
Expand All @@ -28,6 +28,7 @@
Tokenizer as TokenizerP,
)
from ..model.base import ModelArtifactConfig
from ..api.protocol import LogprobsContent, TopLogprobs

LOG = structlog.stdlib.get_logger(__name__)

Expand Down Expand Up @@ -133,29 +134,49 @@ def detokenize_incrementally(
return delta


def logprob_detokenize(tokenizer: TokenizerP, logprob_info: Optional[LOGPROBS_TYPE]) -> Optional[LOGPROBS_TYPE]:
"""Detokenize top tokens in logprob information"""
def logprob_detokenize(
tokenizer: TokenizerP,
logprob_info: Optional[RawLogprobsInfo],
) -> Optional[LogprobsContent]:
"""Detokenize tokens from RawLogprobInfo and convert the latter to LogprobContent"""
if logprob_info is None:
return None
(res, res_logprob), top_tokens = logprob_info
top_tokens = list(top_tokens)
count: Dict[str, int] = {}
top_logprobs: List[Tuple] = []
# dedup duplicates
# Todo: Make sure decode can generate different tokens
for top_token, _ in top_tokens:
detokenized = tokenizer.decode(top_token)
if detokenized in count:
count[detokenized] += 1
else:
count[detokenized] = 1
for top_token, top_logprob in top_tokens:
detokenized = tokenizer.decode(top_token)
if count[detokenized] == 1:
top_logprobs.append((detokenized, float(top_logprob)))
else:
top_logprobs.append((f"{detokenized}_{top_token}", float(top_logprob)))
return (str(tokenizer.decode(res)), res_logprob), top_logprobs

top_logprobs: List[TopLogprobs] = []
if (
logprob_info.top_tokens is not None and
logprob_info.top_logprobs is not None
):
top_tokens = zip(logprob_info.top_tokens[:], logprob_info.top_logprobs[:])
count: Dict[str, int] = {}
# dedup duplicates
# Todo: Make sure decode can generate different tokens
for top_token, _ in top_tokens:
detokenized = tokenizer.decode(top_token)
if detokenized in count:
count[detokenized] += 1
else:
count[detokenized] = 1
for top_token, top_logprob in top_tokens:
detokenized = tokenizer.decode(top_token)
if count[detokenized] != 1:
detokenized = f"{detokenized}_{top_token}"
top_logprobs.append(TopLogprobs(
token=detokenized,
logprob=float(top_logprob),
# TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object
bytes=None,
))

logprobs_content = LogprobsContent(
token=tokenizer.decode([logprob_info.current_token]),
logprob=logprob_info,
# TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object
bytes=None,
top_logprobs=top_logprobs,
)

return logprobs_content


def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
Expand Down
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from .base import (
ChatMessage,
LOGPROBS_TYPE,
MLCServeEngineConfig,
RawLogprobsInfo,
RequestId,
RequestState,
SequenceId,
Expand Down Expand Up @@ -51,7 +51,7 @@ class TextGenerationResult:
# making this a list of token ids to leave room for speculative decoding
generated_tokens: List[int]
error: Optional[str]
logprob_info: Optional[LOGPROBS_TYPE] = None
logprob_info: Optional[RawLogprobsInfo] = None


class KVCache(Protocol):
Expand Down
6 changes: 3 additions & 3 deletions serve/mlc_serve/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Dict, Optional

_SAMPLING_EPS = 1e-5
TOP_LOGPROBS_NUMBER = 5
LOGPROB_TOP_K_MAX = 5


class SamplingType(IntEnum):
Expand Down Expand Up @@ -105,9 +105,9 @@ def _verify_args(self) -> None:
f"logit bias must be in [-100, 100], got {bias} for token {token}."
)
if self.logprobs is not None and self.logprobs:
if (self.top_logprobs < 1 or self.top_logprobs > TOP_LOGPROBS_NUMBER):
if (self.top_logprobs < 1 or self.top_logprobs > LOGPROB_TOP_K_MAX):
raise ValueError(
f"top_logprobs must be between 1 and {TOP_LOGPROBS_NUMBER}, got {self.top_logprobs}."
f"top_logprobs must be between 1 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}."
)

def _verify_greedy_sampling(self) -> None:
Expand Down
7 changes: 3 additions & 4 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
import multiprocessing.synchronize
from dataclasses import dataclass
from threading import Thread, Lock
from typing import Callable, Optional, Union, Tuple, Any, Dict, List
from typing import Callable, Optional, Union, Any, Dict, List

import structlog
import numpy as np

from .base import (
FinishReason,
LOGPROBS_TYPE,
RawLogprobsInfo,
RequestId,
RequestState,
ValidationError,
Expand Down Expand Up @@ -67,7 +66,7 @@ class SequenceGenerationOutput:
new_tokens: List[int]
finish_reason: Optional[FinishReason] = None
error: Optional[Union[str, ValidationError]] = None
logprob_info: Optional[LOGPROBS_TYPE] = None
logprob_info: Optional[RawLogprobsInfo] = None


@dataclass
Expand Down
Loading

0 comments on commit 4c56eac

Please sign in to comment.