Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decompose sample_from_logits for clarity and further development #190

Merged
merged 3 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ConversationTemplate,
KVCacheManager,
ModelModule,
RequestType,
TextGenerator,
Tokenizer as TokenizerP,
)
Expand Down Expand Up @@ -228,10 +229,8 @@ def update_sequence(

def get_requests_to_process(
current_states: list[RequestState], cache_manager: KVCacheManager
) -> Tuple[
list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], bool, int
]:
requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]] = []
) -> Tuple[list[RequestType], bool, int]:
requests: list[RequestType] = []
# TODO: consider having hybrid batch if the underlying attention kernel supports
# mixing prefill and decode.
is_prompt_batch = any(not state.is_prefilled for state in current_states)
Expand Down
4 changes: 4 additions & 0 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class EvalMultiQueryRequest:
sampling_params: SamplingParams


RequestType = Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]
RequestsType = Sequence[RequestType]


@dataclass
class TextGenerationResult:
"""
Expand Down
143 changes: 70 additions & 73 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Union, Sequence
from typing import List, Optional, Tuple, Union

import structlog
import numpy as np
Expand All @@ -11,16 +11,16 @@
SamplingParams,
get_prompt_sequence_id,
LOGPROB_TOP_K_MAX,
RawLogprobsInfo,
RawLogprobsInfos,
PROMPT_SEQEUNCE_INDEX,
RawLogprobsInfo,
RawLogprobsInfos,
SequenceId,
)
from ..engine.model_module import (
DecodeRequest,
PrefillRequest,
EvalMultiQueryRequest,
RequestType,
RequestsType,
TextGenerationResult,
)

Expand Down Expand Up @@ -302,49 +302,73 @@ def _is_safe_to_sample(prob_like):
return res, check_logprob_infos(logprob_infos)


def update_tokens_frequency(
request: RequestType,
new_token: int
):
if not new_token in request.sampling_params.appeared_tokens_freq:
request.sampling_params.appeared_tokens_freq[new_token] = 0
request.sampling_params.appeared_tokens_freq[new_token] += 1


def append_text_gen_res(
outputs: List[TextGenerationResult],
request: RequestType,
new_token: List[int],
sequence_id: SequenceId,
logprob_info: Optional[RawLogprobsInfos],
err_msg: Optional[str]=None,
) -> List[TextGenerationResult]:
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
assert isinstance(request, PrefillRequest)
for seq_id in range(request.num_sequence): # type: ignore
outputs.append(
TextGenerationResult(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=new_token,
error=err_msg,
logprob_info=logprob_info,
)
)
else:
outputs.append(
TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=new_token,
error=err_msg,
logprob_info=logprob_info,
)
)
return outputs


def sample_from_logits(
logits: Union[tvm.nd.NDArray, torch.Tensor],
sequence_ids: List[SequenceId],
requests: Sequence[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]],
requests: RequestsType,
vocab_size,
) -> List[TextGenerationResult]:
assert logits.shape[0] == len(requests)

sampling_params = [req.sampling_params for req in requests]
outputs: List[TextGenerationResult] = []

try:
next_tokens, logprob_infos = sample(logits, sampling_params, vocab_size)
assert next_tokens is not None
outputs = []
for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)):
if not new_token in sampling_params[i].appeared_tokens_freq:
requests[i].sampling_params.appeared_tokens_freq[new_token] = 0
requests[i].sampling_params.appeared_tokens_freq[new_token] += 1
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
assert isinstance(requests[i], PrefillRequest)
for seq_id in range(requests[i].num_sequence): # type: ignore
outputs.append(
TextGenerationResult(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=[new_token],
error=None,
logprob_info=get_logprob_infos(i, logprob_infos),
)
)
else:
outputs.append(
TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=[new_token],
error=None,
logprob_info=get_logprob_infos(i, logprob_infos),
)
)
update_tokens_frequency(requests[i], new_token)
outputs = append_text_gen_res(
outputs,
requests[i],
[new_token],
sequence_id,
get_logprob_infos(i, logprob_infos),
)

return outputs
except RuntimeError:
# Fallback to per-token sampling in case some logits values are corrupted.
outputs = []
err_msg = (
"Error from sampling: probability tensor contains either `inf`, `nan`"
" or element < 0"
Expand All @@ -362,50 +386,23 @@ def sample_from_logits(

if maybe_new_token is not None:
new_token = maybe_new_token[0]
if not new_token in requests[i].sampling_params.appeared_tokens_freq:
requests[i].sampling_params.appeared_tokens_freq[new_token] = 0
requests[i].sampling_params.appeared_tokens_freq[new_token] += 1
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
assert isinstance(requests[i], PrefillRequest)
for seq_id in range(requests[i].num_sequence): # type: ignore
outputs.append(
TextGenerationResult(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=[new_token], # type: ignore
error=None,
logprob_info=get_logprob_infos(0, logprob_infos),
)
)
else:
outputs.append(
TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=[new_token], # type: ignore
error=None,
logprob_info=get_logprob_infos(0, logprob_infos),
)
)
update_tokens_frequency(requests[i], new_token)
outputs = append_text_gen_res(
outputs,
requests[i],
[new_token],
sequence_id,
get_logprob_infos(0, logprob_infos),
)
else:
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
assert isinstance(requests[i], PrefillRequest)
for seq_id in range(requests[i].num_sequence): # type: ignore
outputs.append(
TextGenerationResult(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=[],
error=err_msg,
logprob_info=get_logprob_infos(0, logprob_infos),
)
)
else:
outputs.append(
TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=[],
error=err_msg,
logprob_info=get_logprob_infos(0, logprob_infos),
)
)
outputs = append_text_gen_res(
outputs,
requests[i],
[], # new_token
sequence_id,
get_logprob_infos(0, logprob_infos),
err_msg,
)

return outputs

Expand Down
7 changes: 4 additions & 3 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path
import structlog
from typing import List, Union
from typing import List

from .base import get_model_artifact_config
from .paged_cache_manager import CacheManager
Expand All @@ -13,6 +13,7 @@
ModelModule,
PrefillRequest,
EvalMultiQueryRequest,
RequestType,
TextGenerationResult,
TextGenerator,
)
Expand All @@ -26,9 +27,9 @@ def __init__(self, model: TextGenerator):

def generate(
self,
requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]],
requests: List[RequestType],
kv_cache,
) -> list[TextGenerationResult]:
) -> List[TextGenerationResult]:
prefill_requests = []
decode_requests = []
multi_query_decode_requests = []
Expand Down
9 changes: 4 additions & 5 deletions serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import os
from typing import List, Union, Tuple, Sequence
from typing import List, Tuple

import structlog
import numpy as np
Expand All @@ -24,9 +24,10 @@
)
from ..engine.model_module import (
DecodeRequest,
PrefillRequest,
DraftTokens,
EvalMultiQueryRequest,
PrefillRequest,
RequestsType,
TextGenerationResult,
TextGenerator,
)
Expand Down Expand Up @@ -276,9 +277,7 @@ def generate_multi_query(

def generate(
self,
requests: Sequence[
Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]
],
requests: RequestsType,
cache: KVCacheInfo,
) -> List[TextGenerationResult]:
if len(requests) == 0:
Expand Down
Loading