Skip to content

Commit

Permalink
Major fix, serve working great.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Nov 27, 2023
1 parent b051604 commit a509ded
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 38 deletions.
24 changes: 13 additions & 11 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from http import HTTPStatus
from typing import Annotated, AsyncIterator, List
from itertools import accumulate

from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse
Expand Down Expand Up @@ -204,21 +205,22 @@ async def collect_result_stream(
finish_reasons[seq.index] = seq.finish_reason.value
else:
sequences[seq.index].append(seq.delta)
breakpoint()
choices = [
ChatCompletionResponseChoice(

choices = []
for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)):
choice = ChatCompletionResponseChoice(
index=index,
message=ChatMessage(role="assistant", content="".join(chunks)),
finish_reason=finish_reason,
logprobs={
"token_logprobs": [float(logprob_info[0]) for logprob_info in logprob_infos[index]],
# "tokens": [],
# "offset": [],
"top_logprobs": [logprob_info[1] for logprob_info in logprob_infos[index]]
},
)
for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons))
]
if logprob_infos[index] != []:
choice.logprobs={
"token_logprobs": [float(logprob_info[0][1]) for logprob_info in logprob_infos[index]],
"tokens": [str(logprob_info[0][0]) for logprob_info in logprob_infos[index]],
"offset": list(accumulate([len(str(logprob_info[0][0])) for logprob_info in logprob_infos[index]])),
"top_logprobs": [logprob_info[1] for logprob_info in logprob_infos[index]]
}
choices.append(choice)

usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
Expand Down
6 changes: 3 additions & 3 deletions serve/mlc_serve/api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
# https://github.com/vllm-project/vllm/blob/acbed3ef40f015fcf64460e629813922fab90380/vllm/entrypoints/openai/protocol.py
import time
from typing import Dict, List, Literal, Optional, Union, Any
from typing import Dict, List, Literal, Optional, Union, Tuple

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -76,7 +76,7 @@ class ChatCompletionRequest(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Dict[str, Any]]
logprobs: Optional[Dict[str, Union[List, Dict]]]
finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None


Expand All @@ -97,7 +97,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprob_info: Optional[Any]
logprob_info: Optional[Tuple[Tuple, List[Tuple]]]
finish_reason: Optional[Literal["stop", "length"]] = None


Expand Down
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from abc import ABC, abstractmethod

from typing import List, Callable, Any, Optional, Dict
from typing import List, Callable, Any, Optional, Dict, Tuple
import inspect

from .sampling_params import SamplingParams, SamplingType
Expand Down Expand Up @@ -150,7 +150,7 @@ class SequenceOutput:
finish_reason: Optional[FinishReason] = None
# Number of generated tokens so far
num_generated_tokens: int = 0
logprob_info: Optional[Any] = None
logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None

@property
def is_finished(self) -> bool:
Expand Down
5 changes: 3 additions & 2 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from ..model.base import ModelArtifactConfig
from .sampling_params import SamplingParams

LOGPROBS_TYPE = Tuple[np.ndarray, List[Tuple[np.ndarray, np.ndarray]]]

LOGPROBS_TYPE = Tuple[Tuple, List[Tuple]]
# ((token, logprob), [(top1_token, top1_logprob), ...])

@dataclass
class SequenceId:
Expand Down Expand Up @@ -60,7 +61,7 @@ class TextGenerationResult:
# making this a list of token ids to leave room for speculative decoding
generated_tokens: list[int]
error: Optional[str]
logprob_info: Optional[Tuple[np.ndarray, List[Tuple[np.ndarray, np.ndarray]]]] = None
logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None


class KVCache(Protocol):
Expand Down
33 changes: 24 additions & 9 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import multiprocessing
import queue
from threading import Lock
from typing import Callable, Optional
from typing import Callable, Tuple, List

import os

Expand All @@ -21,7 +21,7 @@
SequenceOutput,
check_stopping_sequences,
)
from .model_module import ModelModule, TokenizerModule
from .model_module import ModelModule, TokenizerModule, Tokenizer
from .staging_engine_worker import (
AddRequestsCommand,
CancelRequestCommand,
Expand All @@ -35,14 +35,29 @@
LOG = structlog.stdlib.get_logger(__name__)


def logprob_detok(tokenizer, logprob_info):
def logprob_detokenize(tokenizer: Tokenizer, logprob_info: Tuple[Tuple, List[Tuple]]) -> Tuple[Tuple, List[Tuple]]:
"""Detokenize logprob information"""
if logprob_info is None:
return None
return (
logprob_info[0], {
tokenizer.decode(top_token): float(logprob) for top_token, logprob in logprob_info[1]
}
)
(res, res_logprob), top_tokens = logprob_info
top_tokens = list(top_tokens)
count = {}
logprob_dict = {}
# dedup duplicates
# Todo: Make sure decode can generate different tokens
for top_token, _ in top_tokens:
detokenized = tokenizer.decode(top_token)
if detokenized in count:
count[detokenized] += 1
else:
count[detokenized] = 1
for top_token, top_logprob in top_tokens:
detokenized = tokenizer.decode(top_token)
if count[detokenized] == 1:
logprob_dict[detokenized] = float(top_logprob)
else:
logprob_dict[f"{detokenized}_{top_token}"] = float(top_logprob)
return (str(tokenizer.decode(res)), res_logprob), logprob_dict

class StagingInferenceEngine(ScopedInferenceEngine):
"""
Expand Down Expand Up @@ -232,7 +247,7 @@ def step(self) -> InferenceStepResult:
len(state.token_ids) - state.prompt_len
),
finish_reason=seq_output.finish_reason,
logprob_info=logprob_detok(self.tokenizer, seq_output.logprob_info),
logprob_info=logprob_detokenize(self.tokenizer, seq_output.logprob_info),
),
],
num_prompt_tokens=state.prompt_len,
Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class SequenceGenerationOutput:
new_tokens: List[int]
finish_reason: Optional[FinishReason] = None
error: Optional[str] = None
logprob_info: Optional[Tuple[np.ndarray, List[Tuple[np.ndarray, np.ndarray]]]] = None
logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None


@dataclass
Expand Down
26 changes: 16 additions & 10 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def _is_safe_to_sample(prob_like):
# Greedy sampling
logprobs = torch.log(torch.softmax(logits_greedy, dim=-1))
res_greedy_logprob, res_greedy = torch.max(logprobs, dim=-1)

top_greedy_logprob, top_greedy = torch.topk(logprobs, k=5, dim=-1, largest=True, sorted=True)
# Convert to numpy
res_greedy_logprob = res_greedy_logprob.cpu().numpy()
Expand All @@ -279,7 +280,7 @@ def _is_safe_to_sample(prob_like):
top_greedy = top_greedy.cpu().numpy()
# Case when there's only greedy sampling
if logits_greedy.shape[0] == num_seq:
return res_greedy, (res_greedy_logprob, (top_greedy, top_greedy_logprob))
return res_greedy, ((res_greedy, res_greedy_logprob), (top_greedy, top_greedy_logprob))

temperatures = []
top_ps = []
Expand Down Expand Up @@ -341,7 +342,7 @@ def _is_safe_to_sample(prob_like):
top[mask_greedy] = top_greedy
top_logprobs[mask_greedy] = top_greedy_logprob

return res, (res_logprobs, (top, top_logprobs))
return res, ((res, res_logprobs), (top, top_logprobs))


def load_disco_module(artifact_path, lib_path, num_shards):
Expand Down Expand Up @@ -382,13 +383,18 @@ def get_tvm_model(config, dev):

def fetch_logprobs(
logprob_info: LOGPROBS_TYPE,
idx: int,
index: int,
sampling_param: SamplingParams,
) -> Optional[Tuple[np.ndarray, List[Tuple[np.ndarray, np.ndarray]]]]:
"""Fetch the logprob information with index"""
if sampling_param.logprobs is None or logprob_info is None:
return None
res_logprobs, (top, top_logprobs) = logprob_info
return res_logprobs[idx], zip(top[idx][:sampling_param.logprobs], top_logprobs[idx][:sampling_param.logprobs])
(res, res_logprobs), (top, top_logprobs) = logprob_info
return (res[index],res_logprobs[index]), \
zip(
top[index][:sampling_param.logprobs],
top_logprobs[index][:sampling_param.logprobs]
)


def _prepare_inputs(
Expand Down Expand Up @@ -654,9 +660,9 @@ def generate(
sequence_id=sequence_id,
generated_tokens=[next_token],
error=None,
logprob_info=fetch_logprobs(logprob_info, idx, sampling_params[idx]),
logprob_info=fetch_logprobs(logprob_info, index, sampling_params[index]),
)
for idx, (sequence_id, next_token) in enumerate(zip(sequence_ids, next_tokens))
for index, (sequence_id, next_token) in enumerate(zip(sequence_ids, next_tokens))
]
except RuntimeError:
# Fallback to per-token sampling in case some logits values are corrupted.
Expand All @@ -666,7 +672,7 @@ def generate(
" or element < 0"
)

for idx, sequence_id, logits_per_token, sampling_param in enumerate(
for index, sequence_id, logits_per_token, sampling_param in enumerate(
zip(sequence_ids, torch.from_dlpack(logits), sampling_params)
):
maybe_new_token, logprob_info = sample(
Expand All @@ -682,7 +688,7 @@ def generate(
sequence_id=sequence_id,
generated_tokens=[maybe_new_token[0]],
error=None,
logprob_info=fetch_logprobs(logprob_info, idx, sampling_param)
logprob_info=fetch_logprobs(logprob_info, index, sampling_param)
)
)
else:
Expand All @@ -691,7 +697,7 @@ def generate(
sequence_id=sequence_id,
generated_tokens=[],
error=err_msg,
logprob_info=fetch_logprobs(logprob_info, idx, sampling_param)
logprob_info=fetch_logprobs(logprob_info, index, sampling_param)
)
)
return outputs
Expand Down

0 comments on commit a509ded

Please sign in to comment.