Skip to content

Commit

Permalink
Parallel sampling eviction (#157)
Browse files Browse the repository at this point in the history
* add new model for evaluating logits over multiple queries using KV cache

* add test

* clean

* Only the number of past tokens is needed

* fix build

* fix

* correctly handle num_past_tokens > sliding_window case

* wip

* blac

* wip

* wip

* remove cancel call back in eviction

* Create MultiQueryDecodeRequest

* only the number of past tokens is needed

* wip

* wip

* wip

* fix

* wip

* wip

* wip

* wip

* working?

* remove dbg print

* multi gpu works

* fixed sliding window logic

* remove dbug print

* clean and fix

* mypy

* generate signature update

* more

* fix mypy

* fix

* fix

* mypy fix

* refactor

* fix

* rename

* Disallow preempting when a request has generated more than max_num_batched_tokens
  • Loading branch information
masahi authored Feb 2, 2024
1 parent 4ebb5a3 commit ed0e52f
Show file tree
Hide file tree
Showing 7 changed files with 498 additions and 185 deletions.
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
class RawLogprobsInfo:
current_token_id: int
current_logprob: float
top_token_ids: Optional[np.array]
top_logprobs: Optional[np.array]
top_token_ids: Optional[np.ndarray]
top_logprobs: Optional[np.ndarray]

RawLogprobsInfos = List[Optional[RawLogprobsInfo]]

Expand Down
119 changes: 94 additions & 25 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from .model_module import (
DecodeRequest,
PrefillRequest,
EvalMultiQueryRequest,
EvictedTokens,
ConversationTemplate,
KVCacheManager,
ModelModule,
Expand Down Expand Up @@ -226,26 +228,70 @@ def update_sequence(

def get_requests_to_process(
current_states: list[RequestState], cache_manager: KVCacheManager
) -> Tuple[list[Union[PrefillRequest, DecodeRequest]], bool, int]:
requests: list[Union[PrefillRequest, DecodeRequest]] = []
) -> Tuple[
list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], bool, int
]:
requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]] = []
# TODO: consider having hybrid batch if the underlying attention kernel supports
# mixing prefill and decode.
is_prompt_batch = any(not state.is_prefilled for state in current_states)

token_counts = 0

is_evicted_parallel_sampling_request = (
lambda state: not state.is_prefilled
and state.num_sequences > 1
and any(
len(gen_seq.generated_token_ids) > 0
for gen_seq in state.generation_sequences
)
)

if is_prompt_batch:
for state in current_states:
if not state.is_prefilled:
if is_evicted_parallel_sampling_request(state):
requests.append(
PrefillRequest(
request_id=state.request_id,
token_ids=state.prompt_token_ids,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
)

token_counts += len(state.prompt_token_ids)

for gen_seq in state.generation_sequences:
requests.append(
EvalMultiQueryRequest(
sequence_id=gen_seq.seq_id,
num_past_tokens=state.prompt_len,
queries=EvictedTokens(gen_seq.generated_token_ids),
sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
len(gen_seq.generated_token_ids) + 1,
)

# TODO(masahi): How to account for token counts in EvalMultiQueryRequest in
# Prometheus metric?
elif not state.is_prefilled:
token_ids = state.prompt_token_ids
# generated_token_ids is added for the case where the request is
# recovering from cache eviction.

if (
state.num_sequences == 1
and state.generation_sequences[0].generated_token_ids
):
token_ids += state.generation_sequences[0].generated_token_ids

requests.append(
# generated_token_ids is added for the case where the request is
# recovering from cache eviction.
# TODO(masahi): This needs an update when we support evicting
# a parallel-sampling request.
PrefillRequest(
request_id=state.request_id,
token_ids=state.prompt_token_ids
+ state.generation_sequences[0].generated_token_ids,
token_ids=token_ids,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
Expand Down Expand Up @@ -392,16 +438,28 @@ def evict_request(self, cancell_callback: Callable[[RequestId], None]) -> int:
candidate_victims = parallel_sample_requests

request_to_remove = min(candidate_victims, key=lambda s: s.num_total_tokens)

# TODO(masahi): Properly support evicting a multi-sequence request
if self.current_batch[request_to_remove.request_id].num_sequences != 1:
cancell_callback(request_to_remove.request_id)
self.remove_request_from_batch(request_to_remove.request_id)
LOG.warn(
"Preempting a multi-sequence request is currently not supported,"
f" cancelling request '{request_to_remove.request_id}'",
victim_state = self.current_batch[request_to_remove.request_id]

if victim_state.num_sequences != 1:
prev_generated_token_counts = sum(
[
len(gen_seq.generated_token_ids)
for gen_seq in victim_state.generation_sequences
]
)
continue
# We could allow evicting and restoring a parallel-sampling request whose prev_generated_token_counts
# is > max_num_batched_tokens, by making the model split a list of EvalMultiQuery requests into parts,
# so that an inference on each part can be done with the max_num_batched_tokens budget.
# But this introduces an undesirable coupling between the engine and the model.
if prev_generated_token_counts >= self.max_num_batched_tokens:
cancell_callback(request_to_remove.request_id)
self.remove_request_from_batch(request_to_remove.request_id)
LOG.warn(
f"Cancelling a parallel-sampling request '{request_to_remove.request_id}'"
f"since it has generated more than {self.max_num_batched_tokens} tokens in total"
"and currently we do not support preempting such request.",
)
continue

self.remove_request_from_batch(request_to_remove.request_id)
request_to_remove.is_prefilled = False
Expand Down Expand Up @@ -446,14 +504,27 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
gen_seq.next_start_position = (
num_new_batched_tokens
) = num_tokens = self.max_num_batched_tokens

num_kv_slots_needed = min(num_tokens, self.model_context_window_size)
else:
# Evicting and recovering multi-sequence requests is not supported for now.
assert all(
gen_seq.next_start_position == state.prompt_len
for gen_seq in state.generation_sequences
prev_generated_token_counts = sum(
[
len(gen_seq.generated_token_ids)
for gen_seq in state.generation_sequences
]
)

# Restoring an evicted parallel-sampling request with sliding-window attention is
# difficult to reason about, so we use crude upper bounds below for now.
num_tokens = state.prompt_len
num_new_batched_tokens += num_tokens
num_kv_slots_needed = state.prompt_len + prev_generated_token_counts
# Restoring an evicted parallel-sampling request is done by separate
# Prefill and MultiQuery requests. The maximum below is an upper bound on the
# batch size increase due to this request.
# TODO(masahi): Prefill and EvalMultiQuery requests are handled separately by the model.
# So comparing the sum of their batched token counts against max_num_batched_tokens
# is not optimal.
num_new_batched_tokens += max(state.prompt_len, prev_generated_token_counts)

if num_new_batched_tokens > self.max_num_batched_tokens:
LOG.debug(
Expand All @@ -465,7 +536,6 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
# We make sure that the KV cache will have enough free space for this request to proceed
# decoding for at least self.max_decode_steps steps.
# See the comment in check_prompt_too_long for the optimization involving the window size.
num_kv_slots_needed = min(num_tokens, self.model_context_window_size)
if (self.cache_manager.get_free_space() - num_kv_slots_needed) / (
len(self.current_batch) + 1
) < self.max_decode_steps * state.num_sequences:
Expand All @@ -477,7 +547,6 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
return None

self.queue.popleft()
# TODO parallel sampling: Need update here when evicting multi-sequence requests is supported.
self.cache_manager.allocate(state.request_id, num_tokens, state.num_sequences)
self.current_batch[state.request_id] = state

Expand Down
30 changes: 28 additions & 2 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,37 @@ class PrefillRequest:
class DecodeRequest:
sequence_id: SequenceId
prompt_token_counts: int
# All tokens for this request, including prompt
# Decoded tokens for this sequence
token_ids: List[int]
sampling_params: SamplingParams


@dataclass
class DraftTokens:
token_ids: List[int]

@property
def num_tokens(self):
return len(self.token_ids)


@dataclass
class EvictedTokens:
token_ids: List[int]

@property
def num_tokens(self):
return len(self.token_ids)


@dataclass
class EvalMultiQueryRequest:
sequence_id: SequenceId
num_past_tokens: int
queries: Union[DraftTokens, EvictedTokens]
sampling_params: SamplingParams


@dataclass
class TextGenerationResult:
"""
Expand Down Expand Up @@ -125,7 +151,7 @@ class TextGenerator(Protocol):

def generate(
self,
requests: Sequence[Union[PrefillRequest, DecodeRequest]],
requests: Sequence[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]],
kv_cache,
) -> List[TextGenerationResult]:
"""
Expand Down
Loading

0 comments on commit ed0e52f

Please sign in to comment.