Skip to content

Commit

Permalink
Decompose sample_from_logits for clarity and further development (#190)
Browse files Browse the repository at this point in the history
* refactor sampling from logits, other small fixes

* use RequestType

* fix lint
  • Loading branch information
vvchernov authored Feb 2, 2024
1 parent c16f3f0 commit c36d47c
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 85 deletions.
7 changes: 3 additions & 4 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ConversationTemplate,
KVCacheManager,
ModelModule,
RequestType,
TextGenerator,
Tokenizer as TokenizerP,
)
Expand Down Expand Up @@ -228,10 +229,8 @@ def update_sequence(

def get_requests_to_process(
current_states: list[RequestState], cache_manager: KVCacheManager
) -> Tuple[
list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], bool, int
]:
requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]] = []
) -> Tuple[list[RequestType], bool, int]:
requests: list[RequestType] = []
# TODO: consider having hybrid batch if the underlying attention kernel supports
# mixing prefill and decode.
is_prompt_batch = any(not state.is_prefilled for state in current_states)
Expand Down
4 changes: 4 additions & 0 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class EvalMultiQueryRequest:
sampling_params: SamplingParams


RequestType = Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]
RequestsType = Sequence[RequestType]


@dataclass
class TextGenerationResult:
"""
Expand Down
143 changes: 70 additions & 73 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Union, Sequence
from typing import List, Optional, Tuple, Union

import structlog
import numpy as np
Expand All @@ -11,16 +11,16 @@
SamplingParams,
get_prompt_sequence_id,
LOGPROB_TOP_K_MAX,
RawLogprobsInfo,
RawLogprobsInfos,
PROMPT_SEQEUNCE_INDEX,
RawLogprobsInfo,
RawLogprobsInfos,
SequenceId,
)
from ..engine.model_module import (
DecodeRequest,
PrefillRequest,
EvalMultiQueryRequest,
RequestType,
RequestsType,
TextGenerationResult,
)

Expand Down Expand Up @@ -302,49 +302,73 @@ def _is_safe_to_sample(prob_like):
return res, check_logprob_infos(logprob_infos)


def update_tokens_frequency(
request: RequestType,
new_token: int
):
if not new_token in request.sampling_params.appeared_tokens_freq:
request.sampling_params.appeared_tokens_freq[new_token] = 0
request.sampling_params.appeared_tokens_freq[new_token] += 1


def append_text_gen_res(
outputs: List[TextGenerationResult],
request: RequestType,
new_token: List[int],
sequence_id: SequenceId,
logprob_info: Optional[RawLogprobsInfos],
err_msg: Optional[str]=None,
) -> List[TextGenerationResult]:
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
assert isinstance(request, PrefillRequest)
for seq_id in range(request.num_sequence): # type: ignore
outputs.append(
TextGenerationResult(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=new_token,
error=err_msg,
logprob_info=logprob_info,
)
)
else:
outputs.append(
TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=new_token,
error=err_msg,
logprob_info=logprob_info,
)
)
return outputs


def sample_from_logits(
logits: Union[tvm.nd.NDArray, torch.Tensor],
sequence_ids: List[SequenceId],
requests: Sequence[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]],
requests: RequestsType,
vocab_size,
) -> List[TextGenerationResult]:
assert logits.shape[0] == len(requests)

sampling_params = [req.sampling_params for req in requests]
outputs: List[TextGenerationResult] = []

try:
next_tokens, logprob_infos = sample(logits, sampling_params, vocab_size)
assert next_tokens is not None
outputs = []
for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)):
if not new_token in sampling_params[i].appeared_tokens_freq:
requests[i].sampling_params.appeared_tokens_freq[new_token] = 0
requests[i].sampling_params.appeared_tokens_freq[new_token] += 1
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
assert isinstance(requests[i], PrefillRequest)
for seq_id in range(requests[i].num_sequence): # type: ignore
outputs.append(
TextGenerationResult(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=[new_token],
error=None,
logprob_info=get_logprob_infos(i, logprob_infos),
)
)
else:
outputs.append(
TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=[new_token],
error=None,
logprob_info=get_logprob_infos(i, logprob_infos),
)
)
update_tokens_frequency(requests[i], new_token)
outputs = append_text_gen_res(
outputs,
requests[i],
[new_token],
sequence_id,
get_logprob_infos(i, logprob_infos),
)

return outputs
except RuntimeError:
# Fallback to per-token sampling in case some logits values are corrupted.
outputs = []
err_msg = (
"Error from sampling: probability tensor contains either `inf`, `nan`"
" or element < 0"
Expand All @@ -362,50 +386,23 @@ def sample_from_logits(

if maybe_new_token is not None:
new_token = maybe_new_token[0]
if not new_token in requests[i].sampling_params.appeared_tokens_freq:
requests[i].sampling_params.appeared_tokens_freq[new_token] = 0
requests[i].sampling_params.appeared_tokens_freq[new_token] += 1
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
assert isinstance(requests[i], PrefillRequest)
for seq_id in range(requests[i].num_sequence): # type: ignore
outputs.append(
TextGenerationResult(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=[new_token], # type: ignore
error=None,
logprob_info=get_logprob_infos(0, logprob_infos),
)
)
else:
outputs.append(
TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=[new_token], # type: ignore
error=None,
logprob_info=get_logprob_infos(0, logprob_infos),
)
)
update_tokens_frequency(requests[i], new_token)
outputs = append_text_gen_res(
outputs,
requests[i],
[new_token],
sequence_id,
get_logprob_infos(0, logprob_infos),
)
else:
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
assert isinstance(requests[i], PrefillRequest)
for seq_id in range(requests[i].num_sequence): # type: ignore
outputs.append(
TextGenerationResult(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=[],
error=err_msg,
logprob_info=get_logprob_infos(0, logprob_infos),
)
)
else:
outputs.append(
TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=[],
error=err_msg,
logprob_info=get_logprob_infos(0, logprob_infos),
)
)
outputs = append_text_gen_res(
outputs,
requests[i],
[], # new_token
sequence_id,
get_logprob_infos(0, logprob_infos),
err_msg,
)

return outputs

Expand Down
7 changes: 4 additions & 3 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path
import structlog
from typing import List, Union
from typing import List

from .base import get_model_artifact_config
from .paged_cache_manager import CacheManager
Expand All @@ -13,6 +13,7 @@
ModelModule,
PrefillRequest,
EvalMultiQueryRequest,
RequestType,
TextGenerationResult,
TextGenerator,
)
Expand All @@ -26,9 +27,9 @@ def __init__(self, model: TextGenerator):

def generate(
self,
requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]],
requests: List[RequestType],
kv_cache,
) -> list[TextGenerationResult]:
) -> List[TextGenerationResult]:
prefill_requests = []
decode_requests = []
multi_query_decode_requests = []
Expand Down
9 changes: 4 additions & 5 deletions serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import os
from typing import List, Union, Tuple, Sequence
from typing import List, Tuple

import structlog
import numpy as np
Expand All @@ -24,9 +24,10 @@
)
from ..engine.model_module import (
DecodeRequest,
PrefillRequest,
DraftTokens,
EvalMultiQueryRequest,
PrefillRequest,
RequestsType,
TextGenerationResult,
TextGenerator,
)
Expand Down Expand Up @@ -276,9 +277,7 @@ def generate_multi_query(

def generate(
self,
requests: Sequence[
Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]
],
requests: RequestsType,
cache: KVCacheInfo,
) -> List[TextGenerationResult]:
if len(requests) == 0:
Expand Down

0 comments on commit c36d47c

Please sign in to comment.