Skip to content

Commit

Permalink
Revert "Parallel sampling eviction (#157)"
Browse files Browse the repository at this point in the history
This reverts commit ed0e52f.
  • Loading branch information
sunggg authored Feb 2, 2024
1 parent ed0e52f commit d324181
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 498 deletions.
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
class RawLogprobsInfo:
current_token_id: int
current_logprob: float
top_token_ids: Optional[np.ndarray]
top_logprobs: Optional[np.ndarray]
top_token_ids: Optional[np.array]
top_logprobs: Optional[np.array]

RawLogprobsInfos = List[Optional[RawLogprobsInfo]]

Expand Down
119 changes: 25 additions & 94 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from .model_module import (
DecodeRequest,
PrefillRequest,
EvalMultiQueryRequest,
EvictedTokens,
ConversationTemplate,
KVCacheManager,
ModelModule,
Expand Down Expand Up @@ -228,70 +226,26 @@ def update_sequence(

def get_requests_to_process(
current_states: list[RequestState], cache_manager: KVCacheManager
) -> Tuple[
list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], bool, int
]:
requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]] = []
) -> Tuple[list[Union[PrefillRequest, DecodeRequest]], bool, int]:
requests: list[Union[PrefillRequest, DecodeRequest]] = []
# TODO: consider having hybrid batch if the underlying attention kernel supports
# mixing prefill and decode.
is_prompt_batch = any(not state.is_prefilled for state in current_states)

token_counts = 0

is_evicted_parallel_sampling_request = (
lambda state: not state.is_prefilled
and state.num_sequences > 1
and any(
len(gen_seq.generated_token_ids) > 0
for gen_seq in state.generation_sequences
)
)

if is_prompt_batch:
for state in current_states:
if is_evicted_parallel_sampling_request(state):
requests.append(
PrefillRequest(
request_id=state.request_id,
token_ids=state.prompt_token_ids,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
)

token_counts += len(state.prompt_token_ids)

for gen_seq in state.generation_sequences:
requests.append(
EvalMultiQueryRequest(
sequence_id=gen_seq.seq_id,
num_past_tokens=state.prompt_len,
queries=EvictedTokens(gen_seq.generated_token_ids),
sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
len(gen_seq.generated_token_ids) + 1,
)

# TODO(masahi): How to account for token counts in EvalMultiQueryRequest in
# Prometheus metric?
elif not state.is_prefilled:
token_ids = state.prompt_token_ids
# generated_token_ids is added for the case where the request is
# recovering from cache eviction.

if (
state.num_sequences == 1
and state.generation_sequences[0].generated_token_ids
):
token_ids += state.generation_sequences[0].generated_token_ids

if not state.is_prefilled:
requests.append(
# generated_token_ids is added for the case where the request is
# recovering from cache eviction.
# TODO(masahi): This needs an update when we support evicting
# a parallel-sampling request.
PrefillRequest(
request_id=state.request_id,
token_ids=token_ids,
token_ids=state.prompt_token_ids
+ state.generation_sequences[0].generated_token_ids,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
Expand Down Expand Up @@ -438,28 +392,16 @@ def evict_request(self, cancell_callback: Callable[[RequestId], None]) -> int:
candidate_victims = parallel_sample_requests

request_to_remove = min(candidate_victims, key=lambda s: s.num_total_tokens)
victim_state = self.current_batch[request_to_remove.request_id]

if victim_state.num_sequences != 1:
prev_generated_token_counts = sum(
[
len(gen_seq.generated_token_ids)
for gen_seq in victim_state.generation_sequences
]

# TODO(masahi): Properly support evicting a multi-sequence request
if self.current_batch[request_to_remove.request_id].num_sequences != 1:
cancell_callback(request_to_remove.request_id)
self.remove_request_from_batch(request_to_remove.request_id)
LOG.warn(
"Preempting a multi-sequence request is currently not supported,"
f" cancelling request '{request_to_remove.request_id}'",
)
# We could allow evicting and restoring a parallel-sampling request whose prev_generated_token_counts
# is > max_num_batched_tokens, by making the model split a list of EvalMultiQuery requests into parts,
# so that an inference on each part can be done with the max_num_batched_tokens budget.
# But this introduces an undesirable coupling between the engine and the model.
if prev_generated_token_counts >= self.max_num_batched_tokens:
cancell_callback(request_to_remove.request_id)
self.remove_request_from_batch(request_to_remove.request_id)
LOG.warn(
f"Cancelling a parallel-sampling request '{request_to_remove.request_id}'"
f"since it has generated more than {self.max_num_batched_tokens} tokens in total"
"and currently we do not support preempting such request.",
)
continue
continue

self.remove_request_from_batch(request_to_remove.request_id)
request_to_remove.is_prefilled = False
Expand Down Expand Up @@ -504,27 +446,14 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
gen_seq.next_start_position = (
num_new_batched_tokens
) = num_tokens = self.max_num_batched_tokens

num_kv_slots_needed = min(num_tokens, self.model_context_window_size)
else:
prev_generated_token_counts = sum(
[
len(gen_seq.generated_token_ids)
for gen_seq in state.generation_sequences
]
# Evicting and recovering multi-sequence requests is not supported for now.
assert all(
gen_seq.next_start_position == state.prompt_len
for gen_seq in state.generation_sequences
)

# Restoring an evicted parallel-sampling request with sliding-window attention is
# difficult to reason about, so we use crude upper bounds below for now.
num_tokens = state.prompt_len
num_kv_slots_needed = state.prompt_len + prev_generated_token_counts
# Restoring an evicted parallel-sampling request is done by separate
# Prefill and MultiQuery requests. The maximum below is an upper bound on the
# batch size increase due to this request.
# TODO(masahi): Prefill and EvalMultiQuery requests are handled separately by the model.
# So comparing the sum of their batched token counts against max_num_batched_tokens
# is not optimal.
num_new_batched_tokens += max(state.prompt_len, prev_generated_token_counts)
num_new_batched_tokens += num_tokens

if num_new_batched_tokens > self.max_num_batched_tokens:
LOG.debug(
Expand All @@ -536,6 +465,7 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
# We make sure that the KV cache will have enough free space for this request to proceed
# decoding for at least self.max_decode_steps steps.
# See the comment in check_prompt_too_long for the optimization involving the window size.
num_kv_slots_needed = min(num_tokens, self.model_context_window_size)
if (self.cache_manager.get_free_space() - num_kv_slots_needed) / (
len(self.current_batch) + 1
) < self.max_decode_steps * state.num_sequences:
Expand All @@ -547,6 +477,7 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
return None

self.queue.popleft()
# TODO parallel sampling: Need update here when evicting multi-sequence requests is supported.
self.cache_manager.allocate(state.request_id, num_tokens, state.num_sequences)
self.current_batch[state.request_id] = state

Expand Down
30 changes: 2 additions & 28 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,37 +35,11 @@ class PrefillRequest:
class DecodeRequest:
sequence_id: SequenceId
prompt_token_counts: int
# Decoded tokens for this sequence
# All tokens for this request, including prompt
token_ids: List[int]
sampling_params: SamplingParams


@dataclass
class DraftTokens:
token_ids: List[int]

@property
def num_tokens(self):
return len(self.token_ids)


@dataclass
class EvictedTokens:
token_ids: List[int]

@property
def num_tokens(self):
return len(self.token_ids)


@dataclass
class EvalMultiQueryRequest:
sequence_id: SequenceId
num_past_tokens: int
queries: Union[DraftTokens, EvictedTokens]
sampling_params: SamplingParams


@dataclass
class TextGenerationResult:
"""
Expand Down Expand Up @@ -151,7 +125,7 @@ class TextGenerator(Protocol):

def generate(
self,
requests: Sequence[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]],
requests: Sequence[Union[PrefillRequest, DecodeRequest]],
kv_cache,
) -> List[TextGenerationResult]:
"""
Expand Down
Loading

0 comments on commit d324181

Please sign in to comment.