Skip to content

Commit

Permalink
Init with tests.
Browse files Browse the repository at this point in the history
Server working.
Major fix, serve working great.
Minor fix and tests.
Remove extra line.
  • Loading branch information
zxybazh committed Nov 27, 2023
1 parent 8e14a7c commit e232862
Show file tree
Hide file tree
Showing 11 changed files with 187 additions and 46 deletions.
30 changes: 21 additions & 9 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import uuid
import json
from http import HTTPStatus
from typing import Annotated, AsyncIterator
from typing import Annotated, AsyncIterator, List
from itertools import accumulate

from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse
Expand Down Expand Up @@ -40,7 +41,6 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse

router = APIRouter()


def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
sampling_params = SamplingParams(
# These params came from vllm
Expand All @@ -58,6 +58,8 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
sampling_params.temperature = request.temperature
if request.top_p is not None:
sampling_params.top_p = request.top_p
if request.logprobs is not None:
sampling_params.logprobs = request.logprobs
return sampling_params


Expand Down Expand Up @@ -128,7 +130,7 @@ async def generate_completion_stream(
created_time = int(time.time())

def create_stream_response(
choices: list[ChatCompletionResponseStreamChoice],
choices: List[ChatCompletionResponseStreamChoice],
) -> ChatCompletionStreamResponse:
return ChatCompletionStreamResponse(
id=request_id,
Expand All @@ -148,7 +150,6 @@ def create_stream_response(
],
)
yield f"data: {json.dumps(first_chunk.dict(exclude_unset=True), ensure_ascii=False)}\n\n"

async for res in result_generator:
if res.error:
raise RuntimeError(f"Error when generating: {res.error}")
Expand All @@ -164,6 +165,7 @@ def create_stream_response(
finish_reason=seq.finish_reason.value
if seq.finish_reason is not None
else None,
logprob_info=seq.logprob_info[0] if seq.logprob_info else None
)
for seq in res.sequences
]
Expand All @@ -184,29 +186,39 @@ async def collect_result_stream(
finish_reasons = [None] * num_sequences
num_prompt_tokens = 0
num_generated_tokens = [0 for _ in range(num_sequences)]
logprob_infos = [[] for _ in range(num_sequences)]
async for res in result_generator:
# TODO: verify that the request cancellation happens after this returns
if res.error:
raise RuntimeError(f"Error when generating: {res.error}")
if res.num_prompt_tokens is not None:
num_prompt_tokens = res.num_prompt_tokens
for seq in res.sequences:
if seq.logprob_info:
logprob_infos[seq.index].append(seq.logprob_info)
if seq.index >= len(sequences):
raise RuntimeError(f"Unexpected sequence index: {seq.index}.")
num_generated_tokens[seq.index] = seq.num_generated_tokens
if seq.is_finished:
finish_reasons[seq.index] = seq.finish_reason.value
else:
sequences[seq.index].append(seq.delta)

choices = [
ChatCompletionResponseChoice(

choices = []
for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons)):
choice = ChatCompletionResponseChoice(
index=index,
message=ChatMessage(role="assistant", content="".join(chunks)),
finish_reason=finish_reason,
)
for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons))
]
if logprob_infos[index] != []:
choice.logprobs={
"token_logprobs": [float(logprob_info[0][1]) for logprob_info in logprob_infos[index]],
"tokens": [str(logprob_info[0][0]) for logprob_info in logprob_infos[index]],
"offset": list(accumulate([len(str(logprob_info[0][0])) for logprob_info in logprob_infos[index]])),
"top_logprobs": [logprob_info[1] for logprob_info in logprob_infos[index]]
}
choices.append(choice)

usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
Expand Down
5 changes: 4 additions & 1 deletion serve/mlc_serve/api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
# https://github.com/vllm-project/vllm/blob/acbed3ef40f015fcf64460e629813922fab90380/vllm/entrypoints/openai/protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
from typing import Dict, List, Literal, Optional, Union, Tuple

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -70,11 +70,13 @@ class ChatCompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
ignore_eos: Optional[bool] = False
logprobs: Optional[int] = None


class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Dict[str, Union[List, Dict]]]
finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None


Expand All @@ -95,6 +97,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprob_info: Optional[Tuple[Tuple, List[Tuple]]]
finish_reason: Optional[Literal["stop", "length"]] = None


Expand Down
5 changes: 3 additions & 2 deletions serve/mlc_serve/engine/async_connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import structlog
from typing import AsyncIterator, Any
from typing import AsyncIterator, Any, Dict
import logging

from .base import (
InferenceEngine,
Expand Down Expand Up @@ -29,7 +30,7 @@ def __init__(self, engine: InferenceEngine, engine_wait_timeout=1):
self.engine_loop_task = None
self.engine_loop_exception = None
self.shutdown_event = asyncio.Event()
self.result_queues = dict[RequestId, ResultQueue]()
self.result_queues: Dict[RequestId, ResultQueue] = {}

async def start(self):
"""
Expand Down
3 changes: 2 additions & 1 deletion serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from abc import ABC, abstractmethod

from typing import List, Callable, Any, Optional, Dict
from typing import List, Callable, Any, Optional, Dict, Tuple
import inspect

from .sampling_params import SamplingParams, SamplingType
Expand Down Expand Up @@ -150,6 +150,7 @@ class SequenceOutput:
finish_reason: Optional[FinishReason] = None
# Number of generated tokens so far
num_generated_tokens: int = 0
logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None

@property
def is_finished(self) -> bool:
Expand Down
8 changes: 7 additions & 1 deletion serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
Required interfaces for the actual inference capability in InferenceEngine.
"""
from dataclasses import dataclass
from typing import Optional, Protocol, Union
from typing import Optional, Protocol, Union, Tuple, List

import numpy as np

from .base import ChatMessage, RequestId, MLCServeEngineConfig
from ..model.base import ModelArtifactConfig
from .sampling_params import SamplingParams


LOGPROBS_TYPE = Tuple[Tuple, List[Tuple]]
# ((token, logprob), [(top1_token, top1_logprob), ...])

@dataclass
class SequenceId:
"""
Expand Down Expand Up @@ -56,6 +61,7 @@ class TextGenerationResult:
# making this a list of token ids to leave room for speculative decoding
generated_tokens: list[int]
error: Optional[str]
logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None


class KVCache(Protocol):
Expand Down
9 changes: 9 additions & 0 deletions serve/mlc_serve/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import IntEnum
from functools import cached_property

from typing import Optional

_SAMPLING_EPS = 1e-5

Expand Down Expand Up @@ -37,13 +38,17 @@ class SamplingParams:
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
logprobs: Optional[Integer] that determines number of log probabilities
to return per sampled tokens, default to None meaning disabled,
otherwise minimum 0, maximum 5.
"""

presence_penalty: float = 0.0
frequency_penalty: float = 0.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
logprobs: Optional[int] = None

def __post_init__(self):
self._verify_args()
Expand Down Expand Up @@ -71,6 +76,10 @@ def _verify_args(self) -> None:
raise ValueError(
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
)
if self.logprobs is not None and (self.logprobs < 0 or self.logprobs > 5):
raise ValueError(
f"logprobs must be between 0 and 5, got {self.logprobs}."
)

def _verify_greedy_sampling(self) -> None:
if self.top_p < 1.0 - _SAMPLING_EPS:
Expand Down
29 changes: 27 additions & 2 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import multiprocessing
import queue
from threading import Lock
from typing import Callable, Optional
from typing import Callable, Tuple, List

import os

Expand All @@ -21,7 +21,7 @@
SequenceOutput,
check_stopping_sequences,
)
from .model_module import ModelModule, TokenizerModule
from .model_module import ModelModule, TokenizerModule, Tokenizer
from .staging_engine_worker import (
AddRequestsCommand,
CancelRequestCommand,
Expand All @@ -35,6 +35,30 @@
LOG = structlog.stdlib.get_logger(__name__)


def logprob_detokenize(tokenizer: Tokenizer, logprob_info: Tuple[Tuple, List[Tuple]]) -> Tuple[Tuple, List[Tuple]]:
"""Detokenize logprob information"""
if logprob_info is None:
return None
(res, res_logprob), top_tokens = logprob_info
top_tokens = list(top_tokens)
count = {}
logprob_dict = {}
# dedup duplicates
# Todo: Make sure decode can generate different tokens
for top_token, _ in top_tokens:
detokenized = tokenizer.decode(top_token)
if detokenized in count:
count[detokenized] += 1
else:
count[detokenized] = 1
for top_token, top_logprob in top_tokens:
detokenized = tokenizer.decode(top_token)
if count[detokenized] == 1:
logprob_dict[detokenized] = float(top_logprob)
else:
logprob_dict[f"{detokenized}_{top_token}"] = float(top_logprob)
return (str(tokenizer.decode(res)), res_logprob), logprob_dict

class StagingInferenceEngine(ScopedInferenceEngine):
"""
An implementation of InferenceEngine that offloads the text generation loop to another worker process,
Expand Down Expand Up @@ -223,6 +247,7 @@ def step(self) -> InferenceStepResult:
len(state.token_ids) - state.prompt_len
),
finish_reason=seq_output.finish_reason,
logprob_info=logprob_detokenize(self.tokenizer, seq_output.logprob_info),
),
],
num_prompt_tokens=state.prompt_len,
Expand Down
23 changes: 12 additions & 11 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from collections import deque
from dataclasses import dataclass
from threading import Condition, Lock, Thread
from typing import Callable, Optional, Union, Any, Dict, Deque, List

from typing import Callable, Optional, Union, Tuple, Any, Dict, Deque, List
import structlog
import numpy as np

from .base import FinishReason, RequestId, RequestState
from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId, TextGenerator, Tokenizer as TokenizerP
Expand All @@ -24,7 +24,7 @@ class ShutdownCommand:

@dataclass
class AddRequestsCommand:
request_states: list[RequestState]
request_states: List[RequestState]


@dataclass
Expand All @@ -45,14 +45,15 @@ class StopRequestCommand:
@dataclass
class SequenceGenerationOutput:
id: SequenceId
new_tokens: list[int]
new_tokens: List[int]
finish_reason: Optional[FinishReason] = None
error: Optional[str] = None
logprob_info: Optional[Tuple[Tuple, List[Tuple]]] = None


@dataclass
class GenerationLoopWorkerOutput:
sequences: list[SequenceGenerationOutput]
sequences: List[SequenceGenerationOutput]
error: Optional[str] = None


Expand Down Expand Up @@ -96,13 +97,13 @@ def __init__(
assert self.prompt_allocate_ratio >= 1.0

self.queue_lock = Lock()
self.queue = deque[RequestState]()
self.queue: Deque[RequestState] = deque()
self.has_new_requests = Condition(lock=self.queue_lock)

self.cancelled_requests = list[RequestState]()
self.stopped_requests = list[RequestState]()
self.cancelled_requests: List[RequestState] = []
self.stopped_requests: List[RequestState] = []

self.current_batch = dict[RequestId, RequestState]()
self.current_batch: Dict[RequestId, RequestState] = {}

def add(self, request_states: list[RequestState]):
LOG.debug("GenerationLoopWorker", requests_states=request_states)
Expand Down Expand Up @@ -158,7 +159,7 @@ def has_pending_requests(self) -> bool:
def step(self) -> GenerationLoopWorkerOutput:
LOG.debug("Starting new inference step.")

outputs = list[SequenceGenerationOutput]()
outputs: List[SequenceGenerationOutput] = []
result = GenerationLoopWorkerOutput(sequences=outputs)

# TODO: consolidate into a single function
Expand Down Expand Up @@ -253,7 +254,7 @@ def step(self) -> GenerationLoopWorkerOutput:

state.token_ids.extend(new_tokens)
outputs.append(
SequenceGenerationOutput(id=res.sequence_id, new_tokens=new_tokens)
SequenceGenerationOutput(id=res.sequence_id, new_tokens=new_tokens, logprob_info=res.logprob_info)
)

LOG.debug("Finished state update and stopping criteria check.")
Expand Down
1 change: 1 addition & 0 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def step(self) -> InferenceStepResult:
num_generated_tokens=(
len(state.token_ids) - state.prompt_len
),
logprob_info=res.logprob_info
),
],
num_prompt_tokens=state.prompt_len,
Expand Down
Loading

0 comments on commit e232862

Please sign in to comment.