Skip to content

Commit

Permalink
fall back to full prefill when any of the KV cache receive fails.
Browse files Browse the repository at this point in the history
Signed-off-by: Kuntai Du <[email protected]>
  • Loading branch information
KuntaiDu committed Nov 20, 2024
1 parent bae609a commit 1780820
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 288 deletions.
11 changes: 5 additions & 6 deletions benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ set -ex

kill_gpu_processes() {
# kill all processes on GPU.
pkill -f pt_main_thread
pkill -f python3
pgrep pt_main_thread | xargs kill -9
pgrep pt_main_thread | xargs -r kill -9
pgrep python3 | xargs -r kill -9
for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done
sleep 1
}
Expand Down Expand Up @@ -64,7 +63,7 @@ launch_disagg_prefill() {
# disagg prefill
CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--model $model \
--port 8100 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
Expand All @@ -75,7 +74,7 @@ launch_disagg_prefill() {
--kv-buffer-size 5e9 &
CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--model $model \
--port 8200 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
Expand All @@ -93,7 +92,7 @@ launch_disagg_prefill() {

benchmark() {
results_folder="./results"
model="meta-llama/Meta-Llama-3.1-70B-Instruct"
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset_name="sonnet"
dataset_path="../sonnet_4x.txt"
num_prompts=100
Expand Down
80 changes: 69 additions & 11 deletions vllm/distributed/kv_transfer/kv_connector/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
"""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch

from vllm.sequence import IntermediateTensors

if TYPE_CHECKING:
from vllm.config import KVTransferConfig
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
Expand Down Expand Up @@ -119,18 +121,74 @@ def close(self) -> None:
raise NotImplementedError

@abstractmethod
def build_partial_prefill_input(
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
input_tokens_list: List[torch.Tensor],
num_computed_tokens_list: List[int],
start_pos_list: List[int],
slot_mapping_flat: torch.Tensor,
device: torch.device,
) -> "ModelInputForGPUWithSamplingMetadata":
"""Rebuild the model input based on how many KV caches are received
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
"""
Send KV caches and hidden states to the connector.
This method processes the input tokens, KV caches, and
hidden/intermediate states for a given model and sends the data to the
decode instance.
Args:
model_executable (torch.nn.Module): The model executable containing
start and end layer information.
model_input (ModelInputForGPUWithSamplingMetadata): The input
metadata from vLLM.
kv_caches (List[torch.Tensor]): List of KV caches (keys and values)
for each layer.
hidden_or_intermediate_states (Union[torch.Tensor,
IntermediateTensors]):
The hidden or intermediate states associated with the tokens.
Returns:
None
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""

raise NotImplementedError

@abstractmethod
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
"""
Receive KV caches and hidden states from the connector.
This method attempts to retrieve KV caches and hidden states for input
tokens. If all required KV caches and hidden states are received, it
will bypass model input, else it will fall back to normal vLLM model
forwarding.
Args:
model_executable (torch.nn.Module):
The model executable from vLLM modelrunner.
model_input (ModelInputForGPUWithSamplingMetadata):
The model input from vLLM modelrunner.
kv_caches (List[torch.Tensor]):
List of KV caches for each layer.
Returns:
- hidden_or_intermediate_states (torch.Tensor or
IntermediateTensors):
Concatenated hidden states if all required data is retrieved,
otherwise `None`.
- bypass_model_exec (bool):
Indicates whether the model execution can be skipped (True) or
needs to be redone (False).
- model_input (ModelInputForGPUWithSamplingMetadata):
Optionally adjusted input metadata for re-execution when
`bypass_model_exec=False`.
"""

raise NotImplementedError
Loading

0 comments on commit 1780820

Please sign in to comment.