Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Remove cache stream and cache events #3461

Merged
merged 3 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions tests/worker/test_swap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch

from vllm.engine.arg_utils import EngineArgs
from vllm.worker.worker import Worker
from vllm.utils import get_distributed_init_method, get_ip, get_open_port


def test_swap() -> None:
# Configure the engine.
engine_args = EngineArgs(model="facebook/opt-125m",
dtype="half",
load_format="dummy")
(model_config, cache_config, parallel_config, scheduler_config,
device_config, _) = engine_args.create_engine_configs()
cache_config.num_gpu_blocks = 100
cache_config.num_cpu_blocks = 100

# Create the worker.
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
worker = Worker(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=True,
)

# Initialize the worker.
worker.init_model()
worker.load_model()
worker.init_cache_engine(cache_config)
worker.warm_up_model()

# Randomly initialize the cache.
gpu_cache = worker.cache_engine.gpu_cache
cpu_cache = worker.cache_engine.cpu_cache
num_layers = len(gpu_cache)
for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i]
gpu_key_cache.random_()
gpu_value_cache.random_()
cpu_key_cache, cpu_value_cache = cpu_cache[i]
cpu_key_cache.random_()
cpu_value_cache.random_()

allclose = lambda a, b: torch.allclose(
a.cuda(), b.cuda(), rtol=0.0, atol=0.0)

# Test swap out.
blocks_to_swap_out = {3: 72, 56: 35, 84: 34}
worker.execute_model(seq_group_metadata_list=[],
blocks_to_swap_in={},
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy={})
for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i]
for src, dst in blocks_to_swap_out.items():
assert allclose(gpu_key_cache[src], cpu_key_cache[dst])
assert allclose(gpu_value_cache[src], cpu_value_cache[dst])

# Test swap in.
blocks_to_swap_in = {19: 45, 67: 23, 12: 78, 40: 99, 1: 71}
worker.execute_model(seq_group_metadata_list=[],
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out={},
blocks_to_copy={})
for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i]
for src, dst in blocks_to_swap_in.items():
assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
assert allclose(gpu_value_cache[dst], cpu_value_cache[src])
26 changes: 8 additions & 18 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self.num_gpu_blocks = cache_config.num_gpu_blocks
self.num_cpu_blocks = cache_config.num_cpu_blocks

# Skip initializing CUDA stream and buffer for Neuron backend.
# Skip initializing KV cache for Neuron backend.
if is_neuron():
return

Expand All @@ -51,12 +51,6 @@ def __init__(
self.gpu_cache = self.allocate_gpu_cache()
self.cpu_cache = self.allocate_cpu_cache()

# Initialize the stream for caching operations.
self.cache_stream = torch.cuda.Stream()
assert self.cache_stream != torch.cuda.current_stream()
# Initialize the events for stream synchronization.
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]

def get_key_block_shape(self) -> Tuple[int, int, int, int]:
element_size = torch.tensor([], dtype=self.dtype).element_size()
x = 16 // element_size
Expand Down Expand Up @@ -126,17 +120,13 @@ def _swap(
) -> None:
from vllm._C import cache_ops

with torch.cuda.stream(self.cache_stream):
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
cache_ops.swap_blocks(src_value_cache, dst_value_cache,
src_to_dst)
event = self.events[i]
event.record(stream=self.cache_stream)
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)

def swap_in(self, src_to_dst: Dict[int, int]) -> None:
self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
Expand Down
15 changes: 1 addition & 14 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def __init__(
# self.init_cache_engine().
self.cache_config = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None

def init_model(self, cupy_port: Optional[int] = None) -> None:
Expand Down Expand Up @@ -148,7 +147,6 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache
self.model_runner.set_block_size(self.cache_engine.block_size)

Expand All @@ -166,24 +164,13 @@ def cache_swap(
blocks_to_copy: Dict[int, List[int]],
) -> None:
# Issue cache operations.
issued_cache_op = False
# TODO(woosuk): Profile swapping overhead and optimize if needed.
if blocks_to_swap_in:
self.cache_engine.swap_in(blocks_to_swap_in)
issued_cache_op = True
if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out)
issued_cache_op = True
if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy)
issued_cache_op = True

cache_events = self.cache_events if issued_cache_op else None

# Wait for cache operations to finish.
# TODO(woosuk): Profile swapping overhead and optimize if needed.
if cache_events is not None:
for event in cache_events:
event.wait()

@torch.inference_mode()
def execute_model(
Expand Down
Loading