Skip to content

Commit

Permalink
[Misc] Remove cache stream and cache events (#3461)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Mar 20, 2024
1 parent 4ad521d commit 5ee1449
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 32 deletions.
77 changes: 77 additions & 0 deletions tests/worker/test_swap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch

from vllm.engine.arg_utils import EngineArgs
from vllm.worker.worker import Worker
from vllm.utils import get_distributed_init_method, get_ip, get_open_port


def test_swap() -> None:
# Configure the engine.
engine_args = EngineArgs(model="facebook/opt-125m",
dtype="half",
load_format="dummy")
(model_config, cache_config, parallel_config, scheduler_config,
device_config, _) = engine_args.create_engine_configs()
cache_config.num_gpu_blocks = 100
cache_config.num_cpu_blocks = 100

# Create the worker.
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
worker = Worker(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=True,
)

# Initialize the worker.
worker.init_model()
worker.load_model()
worker.init_cache_engine(cache_config)
worker.warm_up_model()

# Randomly initialize the cache.
gpu_cache = worker.cache_engine.gpu_cache
cpu_cache = worker.cache_engine.cpu_cache
num_layers = len(gpu_cache)
for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i]
gpu_key_cache.random_()
gpu_value_cache.random_()
cpu_key_cache, cpu_value_cache = cpu_cache[i]
cpu_key_cache.random_()
cpu_value_cache.random_()

allclose = lambda a, b: torch.allclose(
a.cuda(), b.cuda(), rtol=0.0, atol=0.0)

# Test swap out.
blocks_to_swap_out = {3: 72, 56: 35, 84: 34}
worker.execute_model(seq_group_metadata_list=[],
blocks_to_swap_in={},
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy={})
for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i]
for src, dst in blocks_to_swap_out.items():
assert allclose(gpu_key_cache[src], cpu_key_cache[dst])
assert allclose(gpu_value_cache[src], cpu_value_cache[dst])

# Test swap in.
blocks_to_swap_in = {19: 45, 67: 23, 12: 78, 40: 99, 1: 71}
worker.execute_model(seq_group_metadata_list=[],
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out={},
blocks_to_copy={})
for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i]
for src, dst in blocks_to_swap_in.items():
assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
assert allclose(gpu_value_cache[dst], cpu_value_cache[src])
26 changes: 8 additions & 18 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self.num_gpu_blocks = cache_config.num_gpu_blocks
self.num_cpu_blocks = cache_config.num_cpu_blocks

# Skip initializing CUDA stream and buffer for Neuron backend.
# Skip initializing KV cache for Neuron backend.
if is_neuron():
return

Expand All @@ -51,12 +51,6 @@ def __init__(
self.gpu_cache = self.allocate_gpu_cache()
self.cpu_cache = self.allocate_cpu_cache()

# Initialize the stream for caching operations.
self.cache_stream = torch.cuda.Stream()
assert self.cache_stream != torch.cuda.current_stream()
# Initialize the events for stream synchronization.
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]

def get_key_block_shape(self) -> Tuple[int, int, int, int]:
element_size = torch.tensor([], dtype=self.dtype).element_size()
x = 16 // element_size
Expand Down Expand Up @@ -126,17 +120,13 @@ def _swap(
) -> None:
from vllm._C import cache_ops

with torch.cuda.stream(self.cache_stream):
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
cache_ops.swap_blocks(src_value_cache, dst_value_cache,
src_to_dst)
event = self.events[i]
event.record(stream=self.cache_stream)
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)

def swap_in(self, src_to_dst: Dict[int, int]) -> None:
self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
Expand Down
15 changes: 1 addition & 14 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def __init__(
# self.init_cache_engine().
self.cache_config = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None

def init_model(self, cupy_port: Optional[int] = None) -> None:
Expand Down Expand Up @@ -148,7 +147,6 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache
self.model_runner.set_block_size(self.cache_engine.block_size)

Expand All @@ -166,24 +164,13 @@ def cache_swap(
blocks_to_copy: Dict[int, List[int]],
) -> None:
# Issue cache operations.
issued_cache_op = False
# TODO(woosuk): Profile swapping overhead and optimize if needed.
if blocks_to_swap_in:
self.cache_engine.swap_in(blocks_to_swap_in)
issued_cache_op = True
if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out)
issued_cache_op = True
if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy)
issued_cache_op = True

cache_events = self.cache_events if issued_cache_op else None

# Wait for cache operations to finish.
# TODO(woosuk): Profile swapping overhead and optimize if needed.
if cache_events is not None:
for event in cache_events:
event.wait()

@torch.inference_mode()
def execute_model(
Expand Down

0 comments on commit 5ee1449

Please sign in to comment.