Skip to content

Commit

Permalink
Server working.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Nov 22, 2023
1 parent e642ef4 commit 0994bd8
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 23 deletions.
22 changes: 17 additions & 5 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
import json
from http import HTTPStatus
from typing import Annotated, AsyncIterator
from typing import Annotated, AsyncIterator, List

from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse
Expand Down Expand Up @@ -39,7 +39,8 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse


router = APIRouter()

import logging
logger = logging.getLogger(__name__)

def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
sampling_params = SamplingParams(
Expand All @@ -58,6 +59,8 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
sampling_params.temperature = request.temperature
if request.top_p is not None:
sampling_params.top_p = request.top_p
if request.logprobs is not None:
sampling_params.logprobs = request.logprobs
return sampling_params


Expand Down Expand Up @@ -128,7 +131,7 @@ async def generate_completion_stream(
created_time = int(time.time())

def create_stream_response(
choices: list[ChatCompletionResponseStreamChoice],
choices: List[ChatCompletionResponseStreamChoice],
) -> ChatCompletionStreamResponse:
return ChatCompletionStreamResponse(
id=request_id,
Expand All @@ -148,7 +151,6 @@ def create_stream_response(
],
)
yield f"data: {json.dumps(first_chunk.dict(exclude_unset=True), ensure_ascii=False)}\n\n"

async for res in result_generator:
if res.error:
raise RuntimeError(f"Error when generating: {res.error}")
Expand All @@ -164,6 +166,7 @@ def create_stream_response(
finish_reason=seq.finish_reason.value
if seq.finish_reason is not None
else None,
logprob_info=seq.logprob_info[0] if seq.logprob_info else None
)
for seq in res.sequences
]
Expand All @@ -184,26 +187,35 @@ async def collect_result_stream(
finish_reasons = [None] * num_sequences
num_prompt_tokens = 0
num_generated_tokens = [0 for _ in range(num_sequences)]
logprob_infos = [[] for _ in range(num_sequences)]
async for res in result_generator:
# TODO: verify that the request cancellation happens after this returns
if res.error:
raise RuntimeError(f"Error when generating: {res.error}")
if res.num_prompt_tokens is not None:
num_prompt_tokens = res.num_prompt_tokens
for seq in res.sequences:
if seq.logprob_info:
logprob_infos[seq.index].append(seq.logprob_info)
if seq.index >= len(sequences):
raise RuntimeError(f"Unexpected sequence index: {seq.index}.")
num_generated_tokens[seq.index] = seq.num_generated_tokens
if seq.is_finished:
finish_reasons[seq.index] = seq.finish_reason.value
else:
sequences[seq.index].append(seq.delta)

breakpoint()
choices = [
ChatCompletionResponseChoice(
index=index,
message=ChatMessage(role="assistant", content="".join(chunks)),
finish_reason=finish_reason,
logprobs={
"token_logprobs": [float(logprob_info[0]) for logprob_info in logprob_infos[index]],
# "tokens": [],
# "offset": [],
"top_logprobs": [logprob_info[1] for logprob_info in logprob_infos[index]]
},
)
for index, (chunks, finish_reason) in enumerate(zip(sequences, finish_reasons))
]
Expand Down
5 changes: 4 additions & 1 deletion serve/mlc_serve/api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
# https://github.com/vllm-project/vllm/blob/acbed3ef40f015fcf64460e629813922fab90380/vllm/entrypoints/openai/protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
from typing import Dict, List, Literal, Optional, Union, Any

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -70,11 +70,13 @@ class ChatCompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
ignore_eos: Optional[bool] = False
logprobs: Optional[int] = None


class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[Dict[str, Any]]
finish_reason: Optional[Literal["stop", "length", "cancelled"]] = None


Expand All @@ -95,6 +97,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprob_info: Optional[Any]
finish_reason: Optional[Literal["stop", "length"]] = None


Expand Down
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/async_connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import AsyncIterator, Any
from typing import AsyncIterator, Any, Dict

from .base import (
InferenceEngine,
Expand All @@ -27,7 +27,7 @@ def __init__(self, engine: InferenceEngine, engine_wait_timeout=1):
self.engine_loop_task = None
self.engine_loop_exception = None
self.shutdown_event = asyncio.Event()
self.result_queues = dict[RequestId, ResultQueue]()
self.result_queues: Dict[RequestId, ResultQueue] = {}

async def start(self):
"""
Expand Down
11 changes: 10 additions & 1 deletion serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@
logger = logging.getLogger(__name__)


def logprob_detok(tokenizer, logprob_info):
if logprob_info is None:
return None
return (
logprob_info[0], {
tokenizer.decode(top_token): float(logprob) for top_token, logprob in logprob_info[1]
}
)

class StagingInferenceEngine(ScopedInferenceEngine):
"""
An implementation of InferenceEngine that offloads the text generation loop to another worker process,
Expand Down Expand Up @@ -200,7 +209,7 @@ def step(self) -> InferenceStepResult:
len(state.token_ids) - state.prompt_len
),
finish_reason=seq_output.finish_reason,
logprob_info=seq_output.logprob_info,
logprob_info=logprob_detok(self.tokenizer, seq_output.logprob_info),
),
],
num_prompt_tokens=state.prompt_len,
Expand Down
15 changes: 7 additions & 8 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import deque
from dataclasses import dataclass
from threading import Condition, Lock, Thread
from typing import Callable, Optional, Union, Any, Tuple, List
from typing import Callable, Optional, Union, Any, Tuple, List, Deque, Dict
import numpy as np

from .base import FinishReason, RequestId, RequestState
Expand Down Expand Up @@ -79,15 +79,15 @@ def __init__(
assert self.prompt_allocate_ratio >= 1.0

self.queue_lock = Lock()
self.queue = deque[RequestState]()
self.queue: Deque[RequestState] = deque()
self.has_new_requests = Condition(lock=self.queue_lock)

self.cancelled_requests = list[RequestState]()
self.stopped_requests = list[RequestState]()
self.cancelled_requests: List[RequestState] = []
self.stopped_requests: List[RequestState] = []

self.current_batch = dict[RequestId, RequestState]()
self.current_batch: Dict[RequestId, RequestState] = {}

def add(self, request_states: list[RequestState]):
def add(self, request_states: List[RequestState]):
with self.queue_lock:
# States which have been invalidated should never be added, directly
# cancel them instead.
Expand Down Expand Up @@ -140,7 +140,7 @@ def has_pending_requests(self) -> bool:
def step(self) -> GenerationLoopWorkerOutput:
logger.debug("Starting new inference step.")

outputs = list[SequenceGenerationOutput]()
outputs: List[SequenceGenerationOutput] = []
result = GenerationLoopWorkerOutput(sequences=outputs)

# TODO: consolidate into a single function
Expand Down Expand Up @@ -215,7 +215,6 @@ def step(self) -> GenerationLoopWorkerOutput:
id=res.sequence_id,
new_tokens=[],
error=res.error,
logprob_info=res.logprob_info,
)
)
continue
Expand Down
11 changes: 5 additions & 6 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,12 +650,12 @@ def generate(
next_tokens, logprob_info = sample(logits, sampling_params, self.vocab_size)
return [
TextGenerationResult(
sequence_id=zipped[0],
generated_tokens=[zipped[1]],
sequence_id=sequence_id,
generated_tokens=[next_token],
error=None,
logprob_info=fetch_logprobs(logprob_info, idx, sampling_params[idx]),
)
for idx, zipped in enumerate(zip(sequence_ids, next_tokens))
for idx, (sequence_id, next_token) in enumerate(zip(sequence_ids, next_tokens))
]
except RuntimeError:
# Fallback to per-token sampling in case some logits values are corrupted.
Expand Down Expand Up @@ -690,7 +690,6 @@ def generate(
logprob_info=fetch_logprobs(logprob_info, idx, sampling_param)
)
)

return outputs


Expand Down Expand Up @@ -721,8 +720,8 @@ def __init__(self, model: Model):
self.model = model

def generate(
self, requests: list[Union[PrefillRequest, DecodeRequest]], kv_cache
) -> list[TextGenerationResult]:
self, requests: List[Union[PrefillRequest, DecodeRequest]], kv_cache
) -> List[TextGenerationResult]:
prefill_requests = [r for r in requests if isinstance(r, PrefillRequest)]
decode_requests = [r for r in requests if isinstance(r, DecodeRequest)]

Expand Down

0 comments on commit 0994bd8

Please sign in to comment.