Skip to content

Commit

Permalink
remove LOGPROBS_TYPE
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Jan 8, 2024
1 parent 8922a42 commit cb9bd88
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
3 changes: 2 additions & 1 deletion serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from .sampling_params import SamplingParams, SamplingType
from ..api.protocol import LogprobsContent

RequestId = str
LOGPROBS_TYPE = Tuple[Tuple, List[Tuple]]
Expand Down Expand Up @@ -172,7 +173,7 @@ class SequenceOutput:
finish_reason: Optional[FinishReason] = None
# Number of generated tokens so far
num_generated_tokens: int = 0
logprob_info: Optional[LOGPROBS_TYPE] = None
logprob_info: Optional[LogprobsContent] = None

@property
def is_finished(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def logprob_detokenize(
))

logprobs_content = LogprobsContent(
token=tokenizer.decode(logprob_info.current_token),
token=tokenizer.decode([logprob_info.current_token]),
logprob=logprob_info,
# TODO(vvchernov): implement bytes based on https://platform.openai.com/docs/api-reference/chat/object
bytes=None,
Expand Down
7 changes: 3 additions & 4 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
import multiprocessing.synchronize
from dataclasses import dataclass
from threading import Thread, Lock
from typing import Callable, Optional, Union, Tuple, Any, Dict, List
from typing import Callable, Optional, Union, Any, Dict, List

import structlog
import numpy as np

from .base import (
FinishReason,
LOGPROBS_TYPE,
RequestId,
RequestState,
ValidationError,
Expand All @@ -32,6 +30,7 @@
EngineBase,
)
from ..logging_utils import configure_logging
from ..api.protocol import LogprobsContent

LOG = structlog.stdlib.get_logger(__name__)

Expand Down Expand Up @@ -67,7 +66,7 @@ class SequenceGenerationOutput:
new_tokens: List[int]
finish_reason: Optional[FinishReason] = None
error: Optional[Union[str, ValidationError]] = None
logprob_info: Optional[LOGPROBS_TYPE] = None
logprob_info: Optional[LogprobsContent] = None


@dataclass
Expand Down

0 comments on commit cb9bd88

Please sign in to comment.