From c614cfee5861e5715a023fa501e432d4acf910fe Mon Sep 17 00:00:00 2001 From: ifsheldon <39153080+ifsheldon@users.noreply.github.com> Date: Wed, 20 Mar 2024 01:54:59 +0800 Subject: [PATCH 01/13] Update dockerfile with ModelScope support (#3429) --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 6a56a33cfe7ac..1f254c76fe5af 100644 --- a/Dockerfile +++ b/Dockerfile @@ -122,7 +122,7 @@ RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,ta FROM vllm-base AS vllm-openai # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate hf_transfer + pip install accelerate hf_transfer modelscope COPY --from=build /workspace/vllm/*.so /workspace/vllm/ COPY vllm vllm From 2a60c9bd174c4eaba790ecb36d13fa4c145d99f4 Mon Sep 17 00:00:00 2001 From: Jim Burtoft <39492751+jimburtoft@users.noreply.github.com> Date: Tue, 19 Mar 2024 16:21:35 -0400 Subject: [PATCH 02/13] [Doc] minor fix to neuron-installation.rst (#3505) --- docs/source/getting_started/neuron-installation.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/getting_started/neuron-installation.rst b/docs/source/getting_started/neuron-installation.rst index 0aff1037d8a29..62bf779c339d5 100644 --- a/docs/source/getting_started/neuron-installation.rst +++ b/docs/source/getting_started/neuron-installation.rst @@ -128,6 +128,7 @@ Once neuronx-cc and transformers-neuronx packages are installed, we will be able .. code-block:: console + $ git clone https://github.com/vllm-project/vllm.git $ cd vllm $ pip install -U -r requirements-neuron.txt $ pip install . From cc63d03fbb93f2b984d38e1f5626f523c1f9f1a4 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Tue, 19 Mar 2024 13:22:58 -0700 Subject: [PATCH 03/13] Revert "[Core] Cache some utils" (#3507) --- vllm/utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 729a4332af967..d4a8c962c3bfc 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -4,7 +4,6 @@ import subprocess import uuid import gc -from functools import cache from platform import uname from typing import List, Tuple, Union from packaging.version import parse, Version @@ -121,7 +120,6 @@ def is_hip() -> bool: return torch.version.hip is not None -@cache def is_neuron() -> bool: try: import transformers_neuronx @@ -130,7 +128,6 @@ def is_neuron() -> bool: return transformers_neuronx is not None -@cache def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" # NOTE: This import statement should be executed lazily since @@ -154,7 +151,6 @@ def random_uuid() -> str: return str(uuid.uuid4().hex) -@cache def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 return "microsoft" in " ".join(uname()).lower() @@ -229,7 +225,6 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None: os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) -@cache def get_nvcc_cuda_version() -> Optional[Version]: cuda_home = os.environ.get('CUDA_HOME') if not cuda_home: From 63e8b28a990ef1584975c642b1ee5ae8a65b3183 Mon Sep 17 00:00:00 2001 From: Jim Burtoft <39492751+jimburtoft@users.noreply.github.com> Date: Tue, 19 Mar 2024 16:32:30 -0400 Subject: [PATCH 04/13] [Doc] minor fix of spelling in amd-installation.rst (#3506) --- docs/source/getting_started/amd-installation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 5d9fdf4056709..3d736bf7120ec 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -100,7 +100,7 @@ You can build and install vLLM from source: Build a docker image from `Dockerfile.rocm`, and launch a docker container. -The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later versions. It provides flexibility to customize the build of docker image using the following arguments: +The `Dockerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later versions. It provides flexibility to customize the build of docker image using the following arguments: * `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1` * `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942` From 20478c4d3abcd0aa8a1d9ace9c76ea3a2e04cb5e Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Tue, 19 Mar 2024 14:34:15 -0700 Subject: [PATCH 05/13] Use lru_cache for some environment detection utils (#3508) --- vllm/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index d4a8c962c3bfc..7c73062e809f3 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -11,7 +11,7 @@ import psutil import torch import asyncio -from functools import partial +from functools import partial, lru_cache from typing import ( Awaitable, Callable, @@ -120,6 +120,7 @@ def is_hip() -> bool: return torch.version.hip is not None +@lru_cache(maxsize=None) def is_neuron() -> bool: try: import transformers_neuronx @@ -128,6 +129,7 @@ def is_neuron() -> bool: return transformers_neuronx is not None +@lru_cache(maxsize=None) def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" # NOTE: This import statement should be executed lazily since @@ -151,6 +153,7 @@ def random_uuid() -> str: return str(uuid.uuid4().hex) +@lru_cache(maxsize=None) def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 return "microsoft" in " ".join(uname()).lower() @@ -225,6 +228,7 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None: os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) +@lru_cache(maxsize=None) def get_nvcc_cuda_version() -> Optional[Version]: cuda_home = os.environ.get('CUDA_HOME') if not cuda_home: From 9474e89ba4ecae253b585eb6b3e1d85f4e108f01 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 20 Mar 2024 08:11:11 +0100 Subject: [PATCH 06/13] [PREFIX CACHING FOLLOW UP] A bunch of fixes to block allocator performance when automatic prefix caching is disabled (#3357) Co-authored-by: Zhuohan Li --- tests/core/test_block_manager.py | 14 +- tests/prefix_caching/test_prefix_caching.py | 12 +- vllm/core/block_manager.py | 189 +++++++++++++++----- vllm/core/evictor.py | 71 +------- 4 files changed, 165 insertions(+), 121 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 44ac05a1430b3..9473a33f0ee68 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -4,7 +4,7 @@ from vllm import SamplingParams from vllm.block import PhysicalTokenBlock -from vllm.core.block_manager import (BlockAllocator, BlockSpaceManager, +from vllm.core.block_manager import (UncachedBlockAllocator, BlockSpaceManager, AllocStatus) from vllm.utils import Device from vllm.sequence import Sequence, SequenceGroup, SequenceStatus, Logprob @@ -15,7 +15,8 @@ def test_block_allocator_allocate(): block_size = 4 num_cpu_blocks = 4 - cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) + cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size, + num_cpu_blocks) # Allocate all available cpu blocks. num_free = num_cpu_blocks @@ -24,7 +25,7 @@ def test_block_allocator_allocate(): block = cpu_allocator.allocate() num_free -= 1 - assert block.block_hash not in cpu_allocator.evictor + assert block not in cpu_allocator.free_blocks assert cpu_allocator.get_num_free_blocks() == num_free with pytest.raises(ValueError): @@ -34,14 +35,15 @@ def test_block_allocator_allocate(): def test_block_allocator_free(): block_size = 4 num_cpu_blocks = 4 - cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) + cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size, + num_cpu_blocks) # Allocate all available cpu blocks. blocks: List[PhysicalTokenBlock] = [] for _ in range(num_cpu_blocks): block = cpu_allocator.allocate() blocks.append(block) - assert block.block_hash not in cpu_allocator.evictor + assert block not in cpu_allocator.free_blocks # Free all allocated cpu blocks. num_free = 0 @@ -49,7 +51,7 @@ def test_block_allocator_free(): for block in blocks: cpu_allocator.free(block) num_free += 1 - assert block.block_hash in cpu_allocator.evictor + assert block in cpu_allocator.free_blocks assert cpu_allocator.get_num_free_blocks() == num_free with pytest.raises(ValueError): diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index c83551c36ef10..cb61aac3975a8 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -4,7 +4,7 @@ """ import pytest -from vllm.core.block_manager import BlockAllocator +from vllm.core.block_manager import CachedBlockAllocator from vllm.utils import Device @@ -15,10 +15,7 @@ def test_block_allocator( num_blocks: int, ): block_hash = 1 - block_allocator = BlockAllocator(Device.CPU, - block_size, - num_blocks, - enable_caching=True) + block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks) # Allocate two PysicalTokenBlocks with the same hash and check # that they are the same PhysicalTokenBlock @@ -45,10 +42,7 @@ def test_block_allocator( @pytest.mark.parametrize("num_blocks", [16]) def test_eviction(num_blocks: int, ): block_size = 16 - block_allocator = BlockAllocator(Device.CPU, - block_size, - num_blocks, - enable_caching=True) + block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks) blocks = [] for i in range(num_blocks): diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 8b089a5650f48..ad9b557fd9a83 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -3,6 +3,7 @@ from itertools import count, takewhile from os.path import commonprefix from typing import Dict, List, Optional, Set, Tuple +from abc import ABC, abstractmethod from vllm.block import BlockTable, PhysicalTokenBlock from vllm.sequence import Sequence, SequenceGroup, SequenceStatus @@ -10,7 +11,7 @@ from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor -class BlockAllocator: +class BlockAllocatorBase(ABC): """Manages free physical token blocks for a device. The allocator maintains a list of free blocks and allocates a block when @@ -18,23 +19,57 @@ class BlockAllocator: the reference count becomes zero, the block is added back to the free list. """ + @abstractmethod def __init__(self, device: Device, block_size: int, num_blocks: int, - eviction_policy: EvictionPolicy = EvictionPolicy.LRU, - enable_caching: bool = False) -> None: + eviction_policy: EvictionPolicy = EvictionPolicy.LRU): + pass + + @abstractmethod + def allocate(self, + block_hash: Optional[int] = None, + num_hashed_tokens: int = 0) -> PhysicalTokenBlock: + pass + + @abstractmethod + def free(self, block: PhysicalTokenBlock) -> None: + pass + + @abstractmethod + def get_num_free_blocks(self) -> int: + pass + + @abstractmethod + def contains_block(self, block_hash: int) -> bool: + pass + + @abstractmethod + def update_hash(self, block_hash: int, block: PhysicalTokenBlock): + pass + + +class CachedBlockAllocator(BlockAllocatorBase): + """Manages free physical token blocks for a device. + + The allocator maintains a list of free blocks and allocates a block when + requested. When a block is freed, its reference count is decremented. If + the reference count becomes zero, the block is added back to the free list. + """ + + def __init__(self, + device: Device, + block_size: int, + num_blocks: int, + eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None: self.device = device self.block_size = block_size self.num_blocks = num_blocks - self.enable_caching = enable_caching self.current_num_blocks = 0 self.cached_blocks: Dict[int, PhysicalTokenBlock] = {} - # Switch over to FIFO eviction when caching is disabled - if not self.enable_caching: - eviction_policy = EvictionPolicy.FIFO self.evictor: Evictor = make_evictor(eviction_policy) self.default_hash_ctr = count() @@ -57,13 +92,6 @@ def allocate_block(self, block_hash: int, def allocate(self, block_hash: Optional[int] = None, num_hashed_tokens: int = 0) -> PhysicalTokenBlock: - # If caching is disabled, just allocate a new block and return it - if not self.enable_caching: - block = self.allocate_block(next(self.default_hash_ctr), - num_hashed_tokens) - block.ref_count += 1 - return block - if block_hash is None: block_hash = next(self.default_hash_ctr) if block_hash in self.evictor: @@ -90,9 +118,8 @@ def free(self, block: PhysicalTokenBlock) -> None: assert block.block_hash not in self.evictor self.evictor.add(block) - # If caching is enabled, remove the block from the cached_blocks - if self.enable_caching: - del self.cached_blocks[block.block_hash] + # Remove the block from the cached_blocks + del self.cached_blocks[block.block_hash] def get_num_free_blocks(self) -> int: return (self.num_blocks - self.current_num_blocks + @@ -102,14 +129,68 @@ def contains_block(self, block_hash: int) -> bool: return block_hash in self.cached_blocks or block_hash in self.evictor def update_hash(self, block_hash: int, block: PhysicalTokenBlock): - # If caching is enabled, update the hash of block and the - # cached_blocks dictionary. - if self.enable_caching: - assert not self.contains_block(block_hash) - old_hash = block.block_hash - block.block_hash = block_hash - del self.cached_blocks[old_hash] - self.cached_blocks[block_hash] = block + # Update the hash of block and the cached_blocks dictionary. + assert not self.contains_block(block_hash) + old_hash = block.block_hash + block.block_hash = block_hash + del self.cached_blocks[old_hash] + self.cached_blocks[block_hash] = block + + +class UncachedBlockAllocator(BlockAllocatorBase): + """Manages free physical token blocks for a device. + + The allocator maintains a list of free blocks and allocates a block when + requested. When a block is freed, its reference count is decremented. If + the reference count becomes zero, the block is added back to the free list. + """ + + def __init__( + self, + device: Device, + block_size: int, + num_blocks: int, + ) -> None: + self.device = device + self.block_size = block_size + self.num_blocks = num_blocks + + # Initialize the free blocks. + self.free_blocks: BlockTable = [] + for i in range(num_blocks): + block = PhysicalTokenBlock(device=device, + block_number=i, + block_size=block_size, + block_hash=-1, + num_hashed_tokens=0) + self.free_blocks.append(block) + + def allocate(self, + block_hash: Optional[int] = None, + num_hashed_tokens: int = 0) -> PhysicalTokenBlock: + if not self.free_blocks: + raise ValueError("Out of memory! No free blocks are available.") + block = self.free_blocks.pop() + block.ref_count = 1 + return block + + def free(self, block: PhysicalTokenBlock) -> None: + if block.ref_count == 0: + raise ValueError(f"Double free! {block} is already freed.") + block.ref_count -= 1 + if block.ref_count == 0: + self.free_blocks.append(block) + + def get_num_free_blocks(self) -> int: + return len(self.free_blocks) + + def contains_block(self, block_hash: int) -> bool: + raise NotImplementedError( + "Invalid codepath for uncached block allocator.") + + def update_hash(self, block_hash: int, block: PhysicalTokenBlock): + raise NotImplementedError( + "Invalid codepath for uncached block allocator.") class AllocStatus(enum.Enum): @@ -142,6 +223,10 @@ def __init__( self.num_total_gpu_blocks = num_gpu_blocks self.num_total_cpu_blocks = num_cpu_blocks + if enable_caching and sliding_window is not None: + raise NotImplementedError( + "Sliding window is not allowed with prefix caching enabled!") + self.block_sliding_window = None if sliding_window is not None: assert sliding_window % block_size == 0, (sliding_window, @@ -154,14 +239,17 @@ def __init__( self.enable_caching = enable_caching self.watermark_blocks = int(watermark * num_gpu_blocks) - self.gpu_allocator = BlockAllocator(Device.GPU, - block_size, - num_gpu_blocks, - enable_caching=enable_caching) - self.cpu_allocator = BlockAllocator(Device.CPU, - block_size, - num_cpu_blocks, - enable_caching=enable_caching) + + if self.enable_caching: + self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size, + num_gpu_blocks) + self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size, + num_cpu_blocks) + else: + self.gpu_allocator = UncachedBlockAllocator( + Device.GPU, block_size, num_gpu_blocks) + self.cpu_allocator = UncachedBlockAllocator( + Device.CPU, block_size, num_cpu_blocks) # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} @@ -198,10 +286,16 @@ def allocate(self, seq_group: SequenceGroup) -> None: if (self.block_sliding_window is not None and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] - else: + # Set the reference counts of the token blocks. + block.ref_count = seq_group.num_seqs() + elif self.enable_caching: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) + else: + block = self.gpu_allocator.allocate() + # Set the reference counts of the token blocks. + block.ref_count = seq_group.num_seqs() block_table.append(block) # Assign the block table for each sequence. @@ -220,8 +314,10 @@ def _promote_last_block( seq: Sequence, last_block: PhysicalTokenBlock, ) -> PhysicalTokenBlock: - # Compute a new hash for the block so that it can be shared by - # other Sequences + assert self.enable_caching + + # Compute a new hash for the block so that it can be shared by other + # Sequences new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) # if new_hash is already in the cached table, then free last_block @@ -254,6 +350,8 @@ def _allocate_last_physical_block( self, seq: Sequence, ) -> PhysicalTokenBlock: + if not self.enable_caching: + return self.gpu_allocator.allocate() block_hash: Optional[int] = None if (self._is_last_block_full(seq)): block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) @@ -293,10 +391,12 @@ def append_slot( assert last_block.device == Device.GPU if last_block.ref_count == 1: # Not shared with other sequences. Appendable. - # If the last block is now complete, promote it to a full block so - # that it can be shared - new_block = self._maybe_promote_last_block(seq, last_block) - block_table[-1] = new_block + if self.enable_caching: + # If the last block is now complete, we may reuse an old block + # to save memory. + maybe_new_block = self._maybe_promote_last_block( + seq, last_block) + block_table[-1] = maybe_new_block return None else: # The last block is shared with other sequences. @@ -440,9 +540,12 @@ def access_all_blocks_in_seq( seq: Sequence, access_time: float, ) -> None: - block_table = self.block_tables[seq.seq_id] - for block in block_table: - block.last_accessed = access_time + if self.enable_caching: + # Update the last accessed time of all the blocks accessed + # in this step. + block_table = self.block_tables[seq.seq_id] + for block in block_table: + block.last_accessed = access_time def compute_full_blocks_in_seq(self, seq: Sequence): if seq.seq_id not in self.block_tables: diff --git a/vllm/core/evictor.py b/vllm/core/evictor.py index 1d81f5a97d71c..9f401cba3fbea 100644 --- a/vllm/core/evictor.py +++ b/vllm/core/evictor.py @@ -1,5 +1,5 @@ import enum -from typing import Dict, List, Optional +from typing import Dict from abc import ABC, abstractmethod, abstractproperty from vllm.block import PhysicalTokenBlock @@ -10,7 +10,6 @@ class EvictionPolicy(enum.Enum): Evictor subclass. """ LRU = enum.auto() - FIFO = enum.auto() class Evictor(ABC): @@ -66,37 +65,18 @@ def __contains__(self, block_hash: int) -> bool: # TODO: The performance of this evict function can be optimized further. def evict(self) -> PhysicalTokenBlock: - free_blocks: List[PhysicalTokenBlock] = list(self.free_table.values()) - if len(free_blocks) == 0: + if len(self.free_table) == 0: raise ValueError("No usable cache memory left") + free_blocks = self.free_table.values() - # Find lowest timestamp - lowest_timestamp = free_blocks[0].last_accessed - for block in free_blocks: - if block.last_accessed < lowest_timestamp: - lowest_timestamp = block.last_accessed + # Get evicted block + evicted_block: PhysicalTokenBlock = next(iter(free_blocks)) - # Find all blocks with the lowest timestamp - least_recent: List[PhysicalTokenBlock] = [] for block in free_blocks: - if block.last_accessed == lowest_timestamp: - least_recent.append(block) - - # Find highest prefix count per block - highest_num_hashed_tokens = 0 - for block in least_recent: - if block.num_hashed_tokens > highest_num_hashed_tokens: - highest_num_hashed_tokens = block.num_hashed_tokens - - evicted_block: Optional[PhysicalTokenBlock] = None - - # Find the first block with the lowest timestamp - for block in least_recent: - if block.num_hashed_tokens == highest_num_hashed_tokens: + if (block.last_accessed < evicted_block.last_accessed + or block.last_accessed == evicted_block.last_accessed and + block.num_hashed_tokens > evicted_block.num_hashed_tokens): evicted_block = block - break - - assert evicted_block is not None del self.free_table[evicted_block.block_hash] @@ -119,43 +99,8 @@ def num_blocks(self) -> int: return len(self.free_table) -class RandomEvictor(Evictor): - """Evicts in a first-in-first-out order""" - - def __init__(self): - self.free_table: Dict[int, PhysicalTokenBlock] = {} - - def __contains__(self, block_hash: int) -> bool: - return block_hash in self.free_table - - def evict(self) -> PhysicalTokenBlock: - if len(self.free_table) == 0: - raise ValueError("No usable cache memory left") - evicted_block = next(iter(self.free_table.values())) - evicted_block.computed = False - del self.free_table[evicted_block.block_hash] - return evicted_block - - def add(self, block: PhysicalTokenBlock): - self.free_table[block.block_hash] = block - - def remove(self, block_hash: int) -> PhysicalTokenBlock: - if block_hash not in self.free_table: - raise ValueError( - "Attempting to remove block that's not in the evictor") - block: PhysicalTokenBlock = self.free_table[block_hash] - del self.free_table[block_hash] - return block - - @property - def num_blocks(self) -> int: - return len(self.free_table) - - def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: if eviction_policy == EvictionPolicy.LRU: return LRUEvictor() - elif eviction_policy == EvictionPolicy.FIFO: - return RandomEvictor() else: raise ValueError(f"Unknown cache eviction policy: {eviction_policy}") From 4ad521d8b51145a55c1be6b8e451f76423cc2d87 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 20 Mar 2024 00:36:09 -0700 Subject: [PATCH 07/13] [Core] Add generic typing to `LRUCache` (#3511) --- vllm/lora/models.py | 6 +++--- .../tokenizer_group/base_tokenizer_group.py | 19 ++++++++++++------ .../tokenizer_group/tokenizer_group.py | 6 ++---- vllm/utils.py | 20 ++++++++++--------- 4 files changed, 29 insertions(+), 22 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 238da256b7cdc..6fe07b69b3203 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,7 +4,7 @@ import math import os import re -from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type) +from typing import (Callable, Dict, Hashable, List, Optional, Tuple, Type) import safetensors.torch import torch @@ -535,14 +535,14 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: replacement_loras) -class LoRALRUCache(LRUCache): +class LoRALRUCache(LRUCache[LoRAModel]): def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], None]): super().__init__(capacity) self.deactivate_lora_fn = deactivate_lora_fn - def _on_remove(self, key: Hashable, value: Any): + def _on_remove(self, key: Hashable, value: LoRAModel): logger.debug(f"Removing LoRA. int id: {key}") self.deactivate_lora_fn(key) return super()._on_remove(key, value) diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py index 99518a606fabe..3cce96e06d1a0 100644 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -22,27 +22,34 @@ def get_max_input_len(self, pass @abstractmethod - def encode(self, prompt: str, request_id: Optional[str], - lora_request: Optional[LoRARequest]) -> List[int]: + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: """Encode a prompt using the tokenizer group.""" pass @abstractmethod - async def encode_async(self, prompt: str, request_id: Optional[str], - lora_request: Optional[LoRARequest]) -> List[int]: + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: """Encode a prompt using the tokenizer group.""" pass @abstractmethod def get_lora_tokenizer( self, - lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": """Get a tokenizer for a LoRA request.""" pass @abstractmethod async def get_lora_tokenizer_async( self, - lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": """Get a tokenizer for a LoRA request.""" pass diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 3af1334cb5ede..ec20d0fb713a4 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -21,10 +21,8 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, self.enable_lora = enable_lora self.max_input_length = max_input_length self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) - if enable_lora: - self.lora_tokenizers = LRUCache(capacity=max_num_seqs) - else: - self.lora_tokenizers = None + self.lora_tokenizers = LRUCache[PreTrainedTokenizer]( + capacity=max_num_seqs) if enable_lora else None def ping(self) -> bool: """Check if the tokenizer group is alive.""" diff --git a/vllm/utils.py b/vllm/utils.py index 7c73062e809f3..8fa372b5f7f09 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,7 +5,7 @@ import uuid import gc from platform import uname -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Generic from packaging.version import parse, Version import psutil @@ -53,10 +53,10 @@ def reset(self) -> None: self.counter = 0 -class LRUCache: +class LRUCache(Generic[T]): def __init__(self, capacity: int): - self.cache = OrderedDict() + self.cache = OrderedDict[Hashable, T]() self.capacity = capacity def __contains__(self, key: Hashable) -> bool: @@ -65,10 +65,10 @@ def __contains__(self, key: Hashable) -> bool: def __len__(self) -> int: return len(self.cache) - def __getitem__(self, key: Hashable) -> Any: + def __getitem__(self, key: Hashable) -> T: return self.get(key) - def __setitem__(self, key: Hashable, value: Any) -> None: + def __setitem__(self, key: Hashable, value: T) -> None: self.put(key, value) def __delitem__(self, key: Hashable) -> None: @@ -77,7 +77,9 @@ def __delitem__(self, key: Hashable) -> None: def touch(self, key: Hashable) -> None: self.cache.move_to_end(key) - def get(self, key: Hashable, default_value: Optional[Any] = None) -> int: + def get(self, + key: Hashable, + default_value: Optional[T] = None) -> Optional[T]: if key in self.cache: value = self.cache[key] self.cache.move_to_end(key) @@ -85,12 +87,12 @@ def get(self, key: Hashable, default_value: Optional[Any] = None) -> int: value = default_value return value - def put(self, key: Hashable, value: Any) -> None: + def put(self, key: Hashable, value: T) -> None: self.cache[key] = value self.cache.move_to_end(key) self._remove_old_if_needed() - def _on_remove(self, key: Hashable, value: Any): + def _on_remove(self, key: Hashable, value: T): pass def remove_oldest(self): @@ -103,7 +105,7 @@ def _remove_old_if_needed(self) -> None: while len(self.cache) > self.capacity: self.remove_oldest() - def pop(self, key: int, default_value: Optional[Any] = None) -> Any: + def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T: run_on_remove = key in self.cache value = self.cache.pop(key, default_value) if run_on_remove: From 5ee14494e4c78769fa10af8b58c3e7808053da0d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 20 Mar 2024 00:38:53 -0700 Subject: [PATCH 08/13] [Misc] Remove cache stream and cache events (#3461) --- tests/worker/test_swap.py | 77 +++++++++++++++++++++++++++++++++++++ vllm/worker/cache_engine.py | 26 ++++--------- vllm/worker/worker.py | 15 +------- 3 files changed, 86 insertions(+), 32 deletions(-) create mode 100644 tests/worker/test_swap.py diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py new file mode 100644 index 0000000000000..35630a06a900f --- /dev/null +++ b/tests/worker/test_swap.py @@ -0,0 +1,77 @@ +import torch + +from vllm.engine.arg_utils import EngineArgs +from vllm.worker.worker import Worker +from vllm.utils import get_distributed_init_method, get_ip, get_open_port + + +def test_swap() -> None: + # Configure the engine. + engine_args = EngineArgs(model="facebook/opt-125m", + dtype="half", + load_format="dummy") + (model_config, cache_config, parallel_config, scheduler_config, + device_config, _) = engine_args.create_engine_configs() + cache_config.num_gpu_blocks = 100 + cache_config.num_cpu_blocks = 100 + + # Create the worker. + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + worker = Worker( + model_config=model_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + is_driver_worker=True, + ) + + # Initialize the worker. + worker.init_model() + worker.load_model() + worker.init_cache_engine(cache_config) + worker.warm_up_model() + + # Randomly initialize the cache. + gpu_cache = worker.cache_engine.gpu_cache + cpu_cache = worker.cache_engine.cpu_cache + num_layers = len(gpu_cache) + for i in range(num_layers): + gpu_key_cache, gpu_value_cache = gpu_cache[i] + gpu_key_cache.random_() + gpu_value_cache.random_() + cpu_key_cache, cpu_value_cache = cpu_cache[i] + cpu_key_cache.random_() + cpu_value_cache.random_() + + allclose = lambda a, b: torch.allclose( + a.cuda(), b.cuda(), rtol=0.0, atol=0.0) + + # Test swap out. + blocks_to_swap_out = {3: 72, 56: 35, 84: 34} + worker.execute_model(seq_group_metadata_list=[], + blocks_to_swap_in={}, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy={}) + for i in range(num_layers): + gpu_key_cache, gpu_value_cache = gpu_cache[i] + cpu_key_cache, cpu_value_cache = cpu_cache[i] + for src, dst in blocks_to_swap_out.items(): + assert allclose(gpu_key_cache[src], cpu_key_cache[dst]) + assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) + + # Test swap in. + blocks_to_swap_in = {19: 45, 67: 23, 12: 78, 40: 99, 1: 71} + worker.execute_model(seq_group_metadata_list=[], + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out={}, + blocks_to_copy={}) + for i in range(num_layers): + gpu_key_cache, gpu_value_cache = gpu_cache[i] + cpu_key_cache, cpu_value_cache = cpu_cache[i] + for src, dst in blocks_to_swap_in.items(): + assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) + assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 880299783935c..1782fe7e57177 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -38,7 +38,7 @@ def __init__( self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_cpu_blocks = cache_config.num_cpu_blocks - # Skip initializing CUDA stream and buffer for Neuron backend. + # Skip initializing KV cache for Neuron backend. if is_neuron(): return @@ -51,12 +51,6 @@ def __init__( self.gpu_cache = self.allocate_gpu_cache() self.cpu_cache = self.allocate_cpu_cache() - # Initialize the stream for caching operations. - self.cache_stream = torch.cuda.Stream() - assert self.cache_stream != torch.cuda.current_stream() - # Initialize the events for stream synchronization. - self.events = [torch.cuda.Event() for _ in range(self.num_layers)] - def get_key_block_shape(self) -> Tuple[int, int, int, int]: element_size = torch.tensor([], dtype=self.dtype).element_size() x = 16 // element_size @@ -126,17 +120,13 @@ def _swap( ) -> None: from vllm._C import cache_ops - with torch.cuda.stream(self.cache_stream): - for i in range(self.num_layers): - src_key_cache, src_value_cache = src[i] - dst_key_cache, dst_value_cache = dst[i] - # Copy the key blocks. - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - # Copy the value blocks. - cache_ops.swap_blocks(src_value_cache, dst_value_cache, - src_to_dst) - event = self.events[i] - event.record(stream=self.cache_stream) + for i in range(self.num_layers): + src_key_cache, src_value_cache = src[i] + dst_key_cache, dst_value_cache = dst[i] + # Copy the key blocks. + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + # Copy the value blocks. + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) def swap_in(self, src_to_dst: Dict[int, int]) -> None: self._swap(self.cpu_cache, self.gpu_cache, src_to_dst) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 0dcd4018afa5f..81beb5ce4d8d4 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -65,7 +65,6 @@ def __init__( # self.init_cache_engine(). self.cache_config = None self.cache_engine = None - self.cache_events = None self.gpu_cache = None def init_model(self, cupy_port: Optional[int] = None) -> None: @@ -148,7 +147,6 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: self.cache_config = cache_config self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) - self.cache_events = self.cache_engine.events self.gpu_cache = self.cache_engine.gpu_cache self.model_runner.set_block_size(self.cache_engine.block_size) @@ -166,24 +164,13 @@ def cache_swap( blocks_to_copy: Dict[int, List[int]], ) -> None: # Issue cache operations. - issued_cache_op = False + # TODO(woosuk): Profile swapping overhead and optimize if needed. if blocks_to_swap_in: self.cache_engine.swap_in(blocks_to_swap_in) - issued_cache_op = True if blocks_to_swap_out: self.cache_engine.swap_out(blocks_to_swap_out) - issued_cache_op = True if blocks_to_copy: self.cache_engine.copy(blocks_to_copy) - issued_cache_op = True - - cache_events = self.cache_events if issued_cache_op else None - - # Wait for cache operations to finish. - # TODO(woosuk): Profile swapping overhead and optimize if needed. - if cache_events is not None: - for event in cache_events: - event.wait() @torch.inference_mode() def execute_model( From 84eaa68425807a490f363d2e5ddf9bee3d362b0d Mon Sep 17 00:00:00 2001 From: "Allen.Dou" Date: Thu, 21 Mar 2024 00:28:29 +0800 Subject: [PATCH 09/13] Abort when nvcc command is not found in the PATH (#3527) --- CMakeLists.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 29a531d44a9d5..150fcebeb8878 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,6 +49,12 @@ endif() # append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") +# Ensure the 'nvcc' command is in the PATH +find_program(NVCC_EXECUTABLE nvcc) +if (NOT NVCC_EXECUTABLE) + message(FATAL_ERROR "nvcc not found") +endif() + # # Import torch cmake configuration. # Torch also imports CUDA (and partially HIP) languages with some customizations, From ba8ae1d84f66dd804a97182350fee6ffcadf0faf Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Wed, 20 Mar 2024 13:06:56 -0400 Subject: [PATCH 10/13] Check for _is_cuda() in compute_num_jobs (#3481) --- setup.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 88787334be21a..67575a0e04bf0 100644 --- a/setup.py +++ b/setup.py @@ -61,12 +61,12 @@ def compute_num_jobs(self): except AttributeError: num_jobs = os.cpu_count() - nvcc_cuda_version = get_nvcc_cuda_version() - if nvcc_cuda_version >= Version("11.2"): - nvcc_threads = int(os.getenv("NVCC_THREADS", 8)) - num_jobs = max(1, round(num_jobs / (nvcc_threads / 4))) - else: - nvcc_threads = None + nvcc_threads = None + if _is_cuda(): + nvcc_cuda_version = get_nvcc_cuda_version() + if nvcc_cuda_version >= Version("11.2"): + nvcc_threads = int(os.getenv("NVCC_THREADS", 8)) + num_jobs = max(1, round(num_jobs / (nvcc_threads / 4))) return num_jobs, nvcc_threads From 80e254834de9c3c34eaca02d8880e952b3daf344 Mon Sep 17 00:00:00 2001 From: James Whedbee Date: Wed, 20 Mar 2024 16:05:03 -0500 Subject: [PATCH 11/13] [Bugfix] Fix ROCm support in CMakeLists.txt (#3534) --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 150fcebeb8878..66842e6845edd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,7 +51,7 @@ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") # Ensure the 'nvcc' command is in the PATH find_program(NVCC_EXECUTABLE nvcc) -if (NOT NVCC_EXECUTABLE) +if (CUDA_FOUND AND NOT NVCC_EXECUTABLE) message(FATAL_ERROR "nvcc not found") endif() From 426ec4ec6711b4180538cd56b9f6b856e5276a1f Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 20 Mar 2024 14:45:08 -0700 Subject: [PATCH 12/13] [1/n] Triton sampling kernel (#3186) Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- tests/kernels/test_rand.py | 51 +++ tests/kernels/test_sampler.py | 196 ++++++++++ tests/samplers/test_sampler.py | 6 +- vllm/model_executor/layers/ops/__init__.py | 0 vllm/model_executor/layers/ops/rand.py | 157 ++++++++ vllm/model_executor/layers/ops/sample.py | 405 +++++++++++++++++++++ vllm/model_executor/layers/sampler.py | 109 +++++- vllm/model_executor/sampling_metadata.py | 129 ++++++- vllm/sequence.py | 3 + vllm/worker/model_runner.py | 40 +- 10 files changed, 1072 insertions(+), 24 deletions(-) create mode 100644 tests/kernels/test_rand.py create mode 100644 tests/kernels/test_sampler.py create mode 100644 vllm/model_executor/layers/ops/__init__.py create mode 100644 vllm/model_executor/layers/ops/rand.py create mode 100644 vllm/model_executor/layers/ops/sample.py diff --git a/tests/kernels/test_rand.py b/tests/kernels/test_rand.py new file mode 100644 index 0000000000000..3b9d0d732acf5 --- /dev/null +++ b/tests/kernels/test_rand.py @@ -0,0 +1,51 @@ +import torch +import pytest +import random + +from vllm.model_executor.layers.ops.rand import seeded_uniform +from vllm.model_executor.utils import set_random_seed + + +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("use_3d", [True, False]) +def test_seeded_uniform(dtype: torch.dtype, use_3d: bool): + device = "cuda" + for seed in range(512): + set_random_seed(seed) + rows = random.randint(1, 512) + cols = random.randint(1, 64000) + if use_3d: + third_dim = random.randint(2, 10) + dims = [rows, third_dim, cols] + else: + dims = [rows, cols] + seeds = torch.randint(torch.iinfo(torch.long).min, + torch.iinfo(torch.long).max, (rows, ), + device=device) + + # Test that the same seed produces the same output + out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) + out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) + torch.testing.assert_close(out, out2) + # del to save memory + del out2 + + out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) + torch.testing.assert_close(out, out3) + # del to save memory + del out3 + + # Initialize out tensor with garbage to ensure that it is overwritten + out_with_tensor = seeded_uniform( + *dims, + out=torch.full( + (*dims, ), + -1, + dtype=dtype, + device=device, + ), + seeds=seeds, + dtype=dtype, + ) + torch.testing.assert_close(out, out_with_tensor) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py new file mode 100644 index 0000000000000..5f8c51fb074f4 --- /dev/null +++ b/tests/kernels/test_sampler.py @@ -0,0 +1,196 @@ +import gc + +import torch +import pytest +import triton +import triton.language as tl + +from vllm.model_executor.layers.ops.sample import ( + _uniform_to_exponential, sample, get_num_triton_sampler_splits, + MAX_TRITON_N_COLS) +from vllm.model_executor.utils import set_random_seed +from vllm.model_executor.sampling_metadata import SamplingTensors + +SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size +MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100 + + +@pytest.fixture(autouse=True) +def _cleanup(): + yield + gc.collect() + torch.cuda.empty_cache() + + +@triton.jit +def _uniform_to_exponential_kernel(input, output, n: tl.constexpr): + idx = tl.arange(0, n) + x = tl.load(input + idx) + y = _uniform_to_exponential(x) + tl.store(output + idx, y) + + +def test_uniform_to_exponential(): + """Test that we can convert uniform to exponential without div by 0.""" + input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps], + dtype=torch.float32, + device="cuda") + output = torch.zeros(input.shape, dtype=torch.float32, device="cuda") + _uniform_to_exponential_kernel[(1, )](input, output, 2) + assert torch.all(torch.isfinite(output)) + assert torch.all(output > 0) + assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output)) + + +@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) +@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) +@pytest.mark.parametrize("modify_greedy_probs", [True, False]) +@pytest.mark.parametrize("seed", [1337]) +@pytest.mark.parametrize("vocab_size", + [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) +@pytest.mark.parametrize("save_logprobs", [True, False]) +def test_sample_decoding_only(random_sampling, max_best_of, + modify_greedy_probs, seed, vocab_size, + save_logprobs): + set_random_seed(seed) + bs = 8 + probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") + for i in range(bs): + probs[i, i * (vocab_size // bs)] = 1.0 + logprobs = torch.rand_like(probs) + sample_indices = torch.arange(bs, dtype=torch.long, device="cuda") + n_splits = get_num_triton_sampler_splits(probs.shape[1]) + if random_sampling == "mixed": + random_sampling_mask = (torch.rand( + (1, bs), device="cuda") < 0.5).expand(n_splits, bs) + elif random_sampling: + random_sampling_mask = torch.ones((n_splits, bs), + dtype=torch.bool, + device="cuda") + else: + random_sampling_mask = torch.zeros((n_splits, bs), + dtype=torch.bool, + device="cuda") + + seeds = torch.randint(1, + torch.iinfo(torch.long).max, (n_splits, bs), + device="cuda").mul_(random_sampling_mask) + sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( + probs=probs, + logprobs=logprobs, + sample_indices=sample_indices, + seeds=seeds, + max_best_of=max_best_of, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + _save_modified_probs=True) + assert sampled_tokens.shape == (bs, max_best_of) + for i in range(bs): + assert torch.all(sampled_tokens[i] == i * (vocab_size // bs)) + request_uses_random_sampling = random_sampling_mask[0, i] + if modify_greedy_probs and not request_uses_random_sampling: + # If we are modifying greedy probs and the request is greedy, + # we want to make sure the probs tensor is modified in place + assert torch.allclose( + probs[i][sampled_tokens[i]], + torch.full_like(probs[i][sampled_tokens[i]], 1.0)) + assert torch.sum(probs[i]) == 1.0 + assert torch.allclose( + sampled_modified_probs[i][0], + torch.full_like(sampled_modified_probs[i][0], 1.0)) + elif request_uses_random_sampling: + # If the request is random, we want to make sure + # sampled_modified_probs tensor has noise added + # (and thus is different from probs tensor) + assert not torch.allclose(sampled_modified_probs[i][0], + probs[i][sampled_tokens[i]]) + elif not request_uses_random_sampling: + # If the request is greedy and we are not modifying greedy probs, + # we want to make sure sampled_modified_probs tensor is the same as + # the probs tensor. + assert torch.allclose(sampled_modified_probs[i][0], + probs[i][sampled_tokens[i]]) + + if save_logprobs: + assert sampled_logprobs.shape == (bs, max_best_of) + for i in range(bs): + for best_of in range(max_best_of): + assert torch.all(sampled_logprobs[i] == logprobs[i][ + sampled_tokens[i, best_of]]) + else: + assert sampled_logprobs is None + + +@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) +@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) +@pytest.mark.parametrize("modify_greedy_probs", [True, False]) +@pytest.mark.parametrize("seed", [1337]) +@pytest.mark.parametrize("vocab_size", + [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) +def test_sample_prompt_logprobs(random_sampling, max_best_of, + modify_greedy_probs, seed, vocab_size): + set_random_seed(seed) + prompt_sizes = [16, 32, 64, 128] * 2 + samples = 8 + bs = samples + sum(prompt_sizes) + probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") + for i in range(bs): + probs[i, i * (vocab_size // bs)] = 1.0 + logprobs = torch.rand_like(probs) + sample_indices = torch.tensor(prompt_sizes, + dtype=torch.long, + device="cuda").cumsum_(0) + n_splits = get_num_triton_sampler_splits(probs.shape[1]) + if random_sampling == "mixed": + random_sampling_mask = torch.rand( + (n_splits, samples), device="cuda") < 0.5 + elif random_sampling: + random_sampling_mask = torch.ones((n_splits, samples), + dtype=torch.bool, + device="cuda") + else: + random_sampling_mask = torch.zeros((n_splits, samples), + dtype=torch.bool, + device="cuda") + + seeds = torch.randint(1, + torch.iinfo(torch.long).max, (n_splits, samples), + device="cuda").mul_(random_sampling_mask) + sampled_tokens, sampled_logprobs, _ = sample( + probs=probs, + logprobs=logprobs, + sample_indices=sample_indices, + seeds=seeds, + max_best_of=max_best_of, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=True) + assert sampled_tokens.shape == (samples, max_best_of) + assert sampled_logprobs.shape == (samples, max_best_of) + for i, t in enumerate(sample_indices): + assert torch.all(sampled_tokens[i] == t * (vocab_size // bs)) + for best_of in range(max_best_of): + assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]] + [sampled_tokens[i, best_of]]) + + +@pytest.mark.parametrize("seed", list(range(16))) +def test_get_sequence_seeds(seed): + """Ensure that we get a different child seed from base + seed + extra entropy""" + starting_seed = seed + seq_seed = None + extra_entropy = 1 + for i in range(512): + new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed, + i, + seeds_to_generate=1, + is_greedy=False)[0] + new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds( + starting_seed, + i, + extra_entropy, + seeds_to_generate=1, + is_greedy=False)[0] + assert new_seq_seed_extra_entropy != new_seq_seed + assert seq_seed != new_seq_seed + seq_seed = new_seq_seed diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 1bc8703d1a8e0..b0c6e1c09eebc 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -302,11 +302,11 @@ def test_sampler_logits_processors(seed: int, device: str): batch_size = random.randint(1, 256) input_tensor, _, sampler, model_runner = _prepare_test(batch_size) - # This sample logits processor gives infinite score to the i-th token, + # This sample logits processor gives maximum score to the i-th token, # where i is the length of the input sequence. # We therefore expect the output token sequence to be [0, 1, 2, ...] def pick_ith(token_ids, logits): - logits[len(token_ids)] = float("inf") + logits[len(token_ids)] = torch.finfo(logits.dtype).max return logits seq_group_metadata_list = [] @@ -385,7 +385,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): sample_probs = None - def mock_sample(probs, logprobs, sampling_metadata): + def mock_sample(probs, *args, **kwargs): nonlocal sample_probs sample_probs = probs return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] diff --git a/vllm/model_executor/layers/ops/__init__.py b/vllm/model_executor/layers/ops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/ops/rand.py b/vllm/model_executor/layers/ops/rand.py new file mode 100644 index 0000000000000..5b4b7a153351f --- /dev/null +++ b/vllm/model_executor/layers/ops/rand.py @@ -0,0 +1,157 @@ +import torch +import triton +import triton.language as tl + +from typing import Optional, Union + + +def seeded_uniform( + *size, + seeds: torch.Tensor, + out: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str]] = None, + pin_memory: Optional[bool] = False, +) -> torch.Tensor: + """Similar to torch.rand, but allows for seeds to be set per row. + + seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d. + If it is 3d, the additional seeds needed will be derived automatically + in a deterministic fashion: + [ + row 0: [columns_with_seed_0], [columns_with_seed0^1], ... + ] + """ + n_dims = len(size) + + if n_dims > 3: + raise ValueError("seeded_uniform only supports up to 3D tensors") + + if out is None: + out = torch.empty(*size, + dtype=dtype, + device=device, + pin_memory=pin_memory) + elif out.shape != size: + raise ValueError("shape of out and size must be the same") + + if n_dims == 3: + n_rows, n_3d, n_cols = out.shape + stride_row = out.stride(0) + stride_3d = out.stride(1) + elif n_dims == 2: + n_rows, n_cols = out.shape + n_3d = 1 + stride_row = out.stride(0) + stride_3d = 1 + else: + n_cols = out.shape[0] + n_rows = 1 + n_3d = 1 + stride_row = 1 + stride_3d = 1 + + if seeds.ndim != 1: + raise ValueError("seeds must be a 1D tensor") + + if seeds.numel() != n_rows: + raise ValueError( + "seeds must have the same number of elements as out has rows") + + # The philox PRNG Triton uses generates 4 random numbers at once. + # Therefore, the most efficient use of it is to divide the + # block size by 4, and then save the generated random numbers to + # each of the 4 slices of the tensor. + full_block_size = triton.next_power_of_2(n_cols) + philox_block_size = max(full_block_size // 4, 1) + n_slices = full_block_size // philox_block_size + num_warps = 4 + # Manual tuning. This seems to give best performance on A100 for + # simple kernels like this. + if philox_block_size >= 8192: + num_warps = 32 + elif philox_block_size >= 4096: + num_warps = 16 + elif philox_block_size >= 2048: + num_warps = 8 + + _seeded_uniform_triton[(n_rows, n_3d)]( + out, + seeds, + stride_row, + stride_3d, + seeds.stride(0), + n_rows, + n_3d, + n_cols, + n_slices=n_slices, + num_warps=num_warps, + block_size=philox_block_size, + ) + return out + + +@triton.jit +def _seeded_uniform_triton( + out_ptr: torch.Tensor, + seed_ptr: torch.Tensor, + out_row_stride: int, + out_3d_stride: int, + seed_row_stride: int, + n_rows: int, + n_3d: int, + n_cols: int, + n_slices: tl.constexpr, + block_size: tl.constexpr, +): + """ + Generate a random float32 number in [0, 1) for each element in the output + tensor. The random numbers in a row generated using the seed for that row. + + Args: + out_ptr: The output tensor. + seed_ptr: The per-row seeds to use for random number generation. + out_row_stride: The stride between rows of the output tensor. + out_3d_stride: The stride between 3D slices of the output tensor. + seed_row_stride: The stride between rows of the seed tensor. + n_rows: The number of rows in the output tensor. + n_3d: The size of second dimension of the output tensor, + if output tensor is 3D. + n_cols: The number of columns in the output tensor. + n_slices: The number of philox outputs to use. + """ + tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4") + + # Get the row index. + row_idx = tl.program_id(axis=0) + three_d_idx = tl.program_id(axis=1) + + philox_offsets = tl.arange(0, block_size) + # Get the seed for the current element. + seed = tl.load(seed_ptr + row_idx * seed_row_stride) + if three_d_idx > 0: + seed ^= three_d_idx + # Generate random numbers in [0, 1). + out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets) + + output_row_start_ptr = (out_ptr + row_idx * out_row_stride + + three_d_idx * out_3d_stride) + out1_offsets = philox_offsets + tl.store(output_row_start_ptr + out1_offsets, + out1, + mask=out1_offsets < n_cols) + if n_slices > 1: + out2_offsets = tl.arange(block_size, block_size * 2) + tl.store(output_row_start_ptr + out2_offsets, + out2, + mask=out2_offsets < n_cols) + if n_slices > 2: + out3_offsets = tl.arange(block_size * 2, block_size * 3) + tl.store(output_row_start_ptr + out3_offsets, + out3, + mask=out3_offsets < n_cols) + if n_slices > 3: + out4_offsets = tl.arange(block_size * 3, block_size * 4) + tl.store(output_row_start_ptr + out4_offsets, + out4, + mask=out4_offsets < n_cols) diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py new file mode 100644 index 0000000000000..0077317282204 --- /dev/null +++ b/vllm/model_executor/layers/ops/sample.py @@ -0,0 +1,405 @@ +import math +from typing import Tuple, Optional + +import torch +import triton +import triton.language as tl + +from vllm.model_executor.layers.ops.rand import seeded_uniform + +_EPS = 1e-6 + +# This is a hardcoded limit in Triton (max block size). +MAX_TRITON_N_COLS = 131072 + + +def get_num_triton_sampler_splits(n_cols: int) -> int: + """Get the number of splits to use for Triton sampling. + + Triton has a limit on the number of columns it can handle, so we need to + split the tensor and call the kernel multiple times if it's too large. + """ + return math.ceil(n_cols / MAX_TRITON_N_COLS) + + +def _multi_split_sample( + probs: torch.Tensor, + seeds: torch.Tensor, + n_splits: int, + sampled_tokens_size: Tuple[int, int], + sampled_logprobs_size: Tuple[int, int], + sample_indices: torch.Tensor, + *, + logprobs: Optional[torch.Tensor] = None, + modify_greedy_probs: bool = False, + save_logprobs: bool = False, +): + """Sample tokens where vocab size is split into multiple parts + (too large for Triton otherwise).""" + assert seeds.ndim == 2 and seeds.shape[0] == n_splits + split_probs = probs.tensor_split(n_splits, 1) + split_logprobs = logprobs.tensor_split(n_splits, 1) + sampled_tokens_tmp = [ + torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device) + for _ in range(n_splits) + ] + sampled_logprobs_tmp = [ + torch.empty(sampled_logprobs_size, + dtype=probs.dtype, + device=probs.device) for _ in range(n_splits) + ] + # We are purposefuly using sampled_tokens_size as we need to always + # save modified probs in this case. + sampled_modified_probs_tmp = [ + torch.empty(sampled_tokens_size, + dtype=probs.dtype, + device=probs.device) for _ in range(n_splits) + ] + for i in range(n_splits): + n_samples = sample_indices.shape[0] + n_cols = split_probs[i].shape[1] + n_best = sampled_tokens_tmp[i].shape[1] + uniform_noise = seeded_uniform(n_samples, + n_best, + n_cols, + seeds=seeds[i].flatten(), + device=split_probs[i].device, + dtype=split_probs[i].dtype) + # TODO(yard1): See if we can remove the contiguous() calls. + # Will need kernel support. + _sample( + split_probs[i].contiguous(), + split_logprobs[i].contiguous(), + sample_indices, + sampled_tokens_tmp[i], + sampled_logprobs_tmp[i], + sampled_modified_probs_tmp[i], + seeds[i], + uniform_noise, + modify_greedy_probs=False, + save_logprobs=save_logprobs, + save_modified_probs=True, + ) + if i > 0: + # Add offset to sampled tokens + sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1]) + sampled_tokens = torch.stack(sampled_tokens_tmp) + sampled_modified_probs = torch.stack(sampled_modified_probs_tmp) + # Reduce the results from the splits. + sampled_modified_probs, indices = torch.max(sampled_modified_probs, + dim=0, + keepdim=True) + sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0) + if save_logprobs: + sampled_logprobs = torch.stack(sampled_logprobs_tmp) + sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0) + else: + sampled_logprobs = None + sampled_modified_probs = sampled_modified_probs.squeeze(0) + + if modify_greedy_probs: + # We need to modify the greedy probs for the sampled tokens. + # We can't do this in the kernel as we need to know the + # sampled tokens. + probs.fill_(0.0) + probs.scatter_(1, sampled_tokens, 1.0) + + return (sampled_tokens, sampled_logprobs, sampled_modified_probs) + + +def sample( + probs: torch.Tensor, + seeds: torch.Tensor, + *, + max_best_of: int = 1, + sample_indices: Optional[torch.Tensor] = None, + logprobs: Optional[torch.Tensor] = None, + modify_greedy_probs: bool = False, + save_logprobs: bool = False, + _save_modified_probs: bool = False, # pylint: disable=invalid-name +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """Sample tokens from probs. with per-sequence seeds. + + Can sample from a subset of sequences through sample_indices. + + Args: + probs: Probabilities to sample from. + shape = [batch_size, vocab_size] + seeds: Per-sequence seed values. + shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)] + max_best_of: Number of samples to generate per sequence. + Sequence seed will be incremented by 1 each time. + sample_indices: Indices of sequences to sample from. + If not provided, will sample from all sequences. + shape = [n] + logprobs: Log-probabilities of the sampled tokens. + Only used for saving the logprobs if save_logprobs is True. + shape = [batch_size, vocab_size] + modify_greedy_probs: Whether to modify the greedy probabilities + for speculative sampling (sampled token = 1.0, + everything else = 0.0). + save_logprobs: Whether to save the log-probabilities of the + sampled tokens to a tensor. + _save_modified_probs: Whether to save the modified probabilities + (including gumbel noise) of the sampled tokens to a tensor. + DOES NOT include the modification done by modify_greedy_probs + (because we want to use the unmodified probs to pick the best + split in case of multi-split sampling). + This is exposed only for testing. + + Returns: + sampled_tokens: shape = [n, max_best_of] + sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None + sampled_modified_probs: shape = [n, max_best_of] + if save_modified_probs else None + """ + if sample_indices is None: + sample_indices = torch.arange(0, probs.shape[0], device=probs.device) + + sampled_tokens_size = (sample_indices.size(0), max_best_of) + if save_logprobs: + if logprobs is None: + raise ValueError( + "logprobs tensor must be provided if save_logprobs is True") + sampled_logprobs_size = sampled_tokens_size + else: + # Empty tensors to invoke the kernel + sampled_logprobs_size = (0, 0) + logprobs = probs + + if _save_modified_probs: + sampled_modified_probs_size = sampled_tokens_size + else: + # Empty tensors to invoke the kernel + sampled_modified_probs_size = (0, 0) + + # If the number of columns in probs is too large for Triton to handle, + # we split the tensor and sample from each split separately, and then + # do an argmax+gather to combine the results. + n_splits = get_num_triton_sampler_splits(probs.shape[1]) + if n_splits > 1: + (sampled_tokens, sampled_logprobs, + sampled_modified_probs) = _multi_split_sample( + probs, + seeds, + n_splits, + sampled_tokens_size, + sampled_logprobs_size, + sample_indices, + logprobs=logprobs, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs) + else: + sampled_tokens = torch.empty(sampled_tokens_size, + dtype=torch.long, + device=probs.device) + sampled_logprobs = torch.empty(sampled_logprobs_size, + dtype=probs.dtype, + device=probs.device) + sampled_modified_probs = torch.empty(sampled_modified_probs_size, + dtype=probs.dtype, + device=probs.device) + n_samples = sample_indices.shape[0] + n_cols = probs.shape[1] + uniform_noise = seeded_uniform(n_samples, + max_best_of, + n_cols, + seeds=seeds.flatten(), + device=probs.device, + dtype=probs.dtype) + + _sample( + probs, + logprobs, + sample_indices, + sampled_tokens, + sampled_logprobs, + sampled_modified_probs, + seeds, + uniform_noise, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + save_modified_probs=_save_modified_probs, + ) + return (sampled_tokens, sampled_logprobs if save_logprobs else None, + sampled_modified_probs if _save_modified_probs else None) + + +def _sample(probs: torch.Tensor, + logprobs: torch.Tensor, + sample_indices: torch.Tensor, + output_samples: torch.Tensor, + output_logprobs: torch.Tensor, + output_modified_probs: torch.Tensor, + seeds: torch.Tensor, + uniform_noise: torch.Tensor, + *, + modify_greedy_probs: bool = False, + save_logprobs: bool = True, + save_modified_probs: bool = False) -> torch.Tensor: + """Sample tokens from probs. + + Args: + probs [batch_size, vocab_size]: probs to sample from. + logprobs [batch_size, vocab_size]: logprobs (used when + save_logprobsis True). + sample_indices [n]: Indices of the samples to use for each row of probs. + output_samples [n, n_best]: Output tensor to store samples in. + output_logprobs [n, n_best]: Output tensor to store logprobs in. + output_modified_probs [n, n_best]: Output tensor to store + probs of chosen tokens in (modified with noise). + seeds [n]: Seeds to use for sampling. If the seed is 0, we use + greedy sampling. Note this is ONLY used for determining + whether to use random sampling or not. The actual random + noise should be passed as uniform_noise. + uniform_noise [batch_size, n_best, vocab_size]: Uniform + noise to use for random sampling (will be converted + to exponential gumbel noise by the kernel). + modify_greedy_probs: If True, we modify the probs tensor in-place + to encode the sampling method used for each row. This is used + in speculative decoding. Only applies in greedy decoding. + save_logprobs: If True, we save the logprobs of the sampled tokens + in the output_logprobs tensor. + save_modified_probs: If True, we save the modified probs (with noise) + of the sampled tokens in the output_modified_probs tensor. + DOES NOT include the modification done by modify_greedy_probs + (because we want to use the unmodified probs to pick the best + split in case of multi-split sampling). + """ + n_samples = sample_indices.shape[0] + n_cols = probs.shape[1] + n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1 + + # The block size is the smallest power of two greater than the number of + # columns in probs + block_size = triton.next_power_of_2(n_cols) + num_warps = 4 + # Manual tuning. This seems to give best performance on A100 for + # simple kernels like this. + if block_size >= 8192: + num_warps = 32 + elif block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + + # Enqueue kernel. The 1D launch grid is simple: we have one kernel + # instance per row of the probs matrix + _sample_triton[(n_samples, n_best)]( + sample_indices, + output_samples, + output_logprobs, + output_modified_probs, + probs, + logprobs, + seeds, + uniform_noise, + output_samples.stride(0), + probs.stride(0), + uniform_noise.stride(0), + uniform_noise.stride(1) if n_best > 1 else 1, + n_samples, + n_cols, + n_best, + num_warps=num_warps, + block_size=block_size, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + save_modified_probs=save_modified_probs, + ) + return output_samples, output_logprobs, output_modified_probs + + +@triton.jit +def _uniform_to_exponential(uniform_noise): + """Convert uniform samples to exponential samples.""" + # tl.rand returns values in [0, 1), so we clamp lower bound + # to _EPS to avoid log(0) and thus division by 0 later + lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype) + uniform_noise = tl.maximum(uniform_noise, lb) + # Use the inversion method to turn uniform samples + # into exponential samples + exponential_noise = -tl.log(uniform_noise) + return exponential_noise + + +@triton.jit +def _sample_triton( + sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor, + output_logprobs_ptr: torch.Tensor, + output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor, + logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor, + uniform_noise_ptr: torch.Tensor, output_row_stride: int, + probs_row_stride: int, uniform_noise_row_stride: int, + uniform_noise_best_stride: int, n_samples: int, n_cols: int, + n_best: int, block_size: tl.constexpr, + modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr, + save_modified_probs: tl.constexpr): + # The rows are independent, so we parallelize across those + sample_idx = tl.program_id(0) + best_idx = tl.program_id(1) + + # Load the row index from DRAM + row_idx = tl.load(sample_indices_ptr + sample_idx) + seed = tl.load(seeds_ptr + sample_idx) + uses_random_sampling = seed != 0 + + # The stride represents how much we need to increase the + # pointer to advance 1 row + row_start_ptr = probs_ptr + row_idx * probs_row_stride + + # The block size is the next power of two greater than n_cols, + # so we can fit each row in a single block + col_offsets = tl.arange(0, block_size) + + # Load the row into SRAM, using a mask since block_size may be > than n_cols + row = tl.load(row_start_ptr + col_offsets, + mask=col_offsets < n_cols, + other=float("-inf")) + + if uses_random_sampling: + uniform_noise_start_ptr = (uniform_noise_ptr + + sample_idx * uniform_noise_row_stride + + best_idx * uniform_noise_best_stride) + uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets, + mask=col_offsets < n_cols, + other=0.5) + exponential_noise = _uniform_to_exponential(uniform_noise) + row /= exponential_noise + + sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True) + # clamp sampled token to n_cols - 1 + # this should not be necessary, but we do it + # just in case + if sampled_token >= n_cols: + sampled_token = n_cols - 1 + # Write back output to DRAM + output_row_start_ptr = (output_ptr + sample_idx * output_row_stride + + best_idx) + tl.store(output_row_start_ptr, sampled_token) + + if modify_greedy_probs: # noqa + if not uses_random_sampling: + # Set the probability of the sampled token to 1, all other + # tokens to zero. This is used in speculative decoding where + # the sampling method must be encoded within the sampled + # probability distributions. + row = tl.where(col_offsets == sampled_token, 1.0, 0.0) + tl.store(row_start_ptr + col_offsets, + row, + mask=col_offsets < n_cols) + + if save_modified_probs: + output_row_start_ptr = (output_modified_probs_ptr + + sample_idx * output_row_stride + best_idx) + tl.store(output_row_start_ptr, sampled_value) + + if save_logprobs: + # Load the row into SRAM, using a mask since block_size + # may be > than n_cols + sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride + + sampled_token) + # Write back output to DRAM + output_row_start_ptr = (output_logprobs_ptr + + sample_idx * output_row_stride + best_idx) + tl.store(output_row_start_ptr, sampled_logprob) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 4377b845df628..1fab1e734e1d7 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -12,6 +12,7 @@ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) +from vllm.model_executor.layers.ops.sample import (sample as sample_triton) from vllm.utils import is_neuron @@ -114,7 +115,8 @@ def forward( logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. - sample_results = _sample(probs, logprobs, sampling_metadata) + sample_results = _sample(probs, logprobs, sampling_metadata, + sampling_tensors) # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) @@ -375,7 +377,7 @@ def _multinomial( return probs.div_(q).argmax(dim=1).view(-1, num_samples) -def _sample( +def _sample_with_torch( probs: torch.Tensor, logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -394,7 +396,7 @@ def _sample( # Counterintiutively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type] + sample_indices = categorized_sample_indices[sampling_type][:, 0] num_tokens = len(sample_indices) if num_tokens == 0: continue @@ -407,17 +409,19 @@ def _sample( greedy_samples = torch.argmax(logprobs[sample_indices.long()], dim=-1) elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - max_best_of = 1 + max_best_of_in_batch = 1 for seq_group, is_prompt in zip(seq_groups, is_prompts): if is_prompt: _, sampling_params = seq_group - max_best_of = max(max_best_of, sampling_params.best_of) + max_best_of_in_batch = max(max_best_of_in_batch, + sampling_params.best_of) seeded_args = {} if sampling_type == SamplingType.RANDOM else { "seq_groups": seq_groups, "generators": sampling_metadata.generators, } multinomial_samples[sampling_type] = _multinomial( - probs[sample_indices.long()], max_best_of, **seeded_args) + probs[sample_indices.long()], max_best_of_in_batch, + **seeded_args) elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] else: @@ -448,6 +452,99 @@ def _sample( return sample_results +def _sample_with_triton_kernel( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, +) -> List[Tuple[List[int], List[int]]]: + categorized_seq_group_ids = {t: [] for t in SamplingType} + categorized_sample_indices = sampling_metadata.categorized_sample_indices + for i, seq_group in enumerate(sampling_metadata.seq_groups): + _, sampling_params = seq_group + sampling_type = sampling_params.sampling_type + categorized_seq_group_ids[sampling_type].append(i) + + sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} + sample_metadata = {} + max_best_of_in_batch = 1 + + # Counterintiutively, having two loops here is actually faster. + # The first loop can run without waiting on GPU<->CPU sync. + for sampling_type in SamplingType: + sample_indices = categorized_sample_indices[sampling_type][:, 0] + sampled_token_indices = categorized_sample_indices[sampling_type][:, 1] + num_tokens = len(sample_indices) + if num_tokens == 0: + continue + seq_group_ids = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] + is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] + sample_metadata[sampling_type] = (seq_group_ids, seq_groups, + is_prompts, sample_indices, + sampled_token_indices) + if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, + SamplingType.RANDOM_SEED): + for seq_group, is_prompt in zip(seq_groups, is_prompts): + if is_prompt: + _, sampling_params = seq_group + max_best_of_in_batch = max(max_best_of_in_batch, + sampling_params.best_of) + elif sampling_type == SamplingType.BEAM: + beam_search_logprobs = logprobs[sample_indices] + else: + raise ValueError(f"Unsupported sampling type: {sampling_type}") + + sampled_tokens, _, _ = sample_triton( + probs=probs, + seeds=sampling_tensors.sampling_seeds, + max_best_of=max_best_of_in_batch, + sample_indices=sampling_tensors.sample_indices, + logprobs=logprobs, + # don't save logprobs because we have logic for that below + # TODO: use this instead of the CPU-based logic below + save_logprobs=False, + ) + + # GPU<->CPU sync happens in the loop below. + + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + (seq_group_ids, seq_groups, is_prompts, sample_indices, + sampled_token_indices) = sample_metadata[sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample( + seq_groups, sampled_tokens[sampled_token_indices][:, 0]) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + sample_results = _random_sample( + seq_groups, is_prompts, sampled_tokens[sampled_token_indices]) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, is_prompts, + sampling_metadata.seq_data, + beam_search_logprobs) + sample_results_dict.update(zip(seq_group_ids, sample_results)) + + sample_results = [ + sample_results_dict[i] + for i in range(len(sampling_metadata.seq_groups)) + ] + return sample_results + + +def _sample( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, +) -> List[Tuple[List[int], List[int]]]: + return _sample_with_torch(probs, logprobs, sampling_metadata) + + # TODO: Enable once Triton kernel & associated code is faster. + # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, + # sampling_tensors) + + def _get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index b23f0170a6ca5..7d08feb3fee1c 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -2,12 +2,16 @@ from typing import Dict, List, Optional, Tuple import torch +import random from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData from vllm.utils import in_wsl, is_neuron +from vllm.model_executor.layers.ops.sample import ( + get_num_triton_sampler_splits) _SAMPLING_EPS = 1e-5 +_SEED_0_REPLACEMENT = 3403598558 class SamplingMetadata: @@ -67,14 +71,28 @@ class SamplingTensors: presence_penalties: torch.Tensor frequency_penalties: torch.Tensor repetition_penalties: torch.Tensor + sampling_seeds: torch.Tensor + sample_indices: torch.Tensor + extra_seeds: Optional[torch.Tensor] prompt_tokens: torch.Tensor output_tokens: torch.Tensor @classmethod def from_sampling_metadata( - cls, sampling_metadata: "SamplingMetadata", vocab_size: int, - device: torch.device, - dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]: + cls, + sampling_metadata: "SamplingMetadata", + vocab_size: int, + device: torch.device, + dtype: torch.dtype, + *, + extra_seeds_to_generate: int = 0, + extra_entropy: Optional[Tuple[int, ...]] = None + ) -> Tuple["SamplingTensors", bool, bool, bool]: + """ + extra_seeds_to_generate: extra seeds to generate using the + user-defined seed for each sequence. + extra_entropy: extra entropy to use when generating seeds. + """ prompt_tokens: List[List[int]] = [] output_tokens: List[List[int]] = [] top_ks: List[int] = [] @@ -84,9 +102,18 @@ def from_sampling_metadata( presence_penalties: List[float] = [] frequency_penalties: List[float] = [] repetition_penalties: List[float] = [] + sampling_seeds: List[int] = [] + sample_indices: List[int] = [] + prompt_best_of: List[int] = [] do_penalties = False do_top_p_top_k = False do_min_p = False + + # We need one base seed per Triton slice. + seeds_to_generate = (extra_seeds_to_generate + + get_num_triton_sampler_splits(vocab_size)) + + sample_indices_start_idx = 0 for i, seq_group in enumerate(sampling_metadata.seq_groups): seq_ids, sampling_params = seq_group temperature = sampling_params.temperature @@ -95,6 +122,10 @@ def from_sampling_metadata( r = sampling_params.repetition_penalty top_p = sampling_params.top_p min_p = sampling_params.min_p + seed = sampling_params.seed + + is_greedy = sampling_params.sampling_type == SamplingType.GREEDY + # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) top_k = vocab_size if top_k == -1 else top_k @@ -112,6 +143,7 @@ def from_sampling_metadata( or abs(f) >= _SAMPLING_EPS or abs(r - 1.0) >= _SAMPLING_EPS): do_penalties = True + if (i < sampling_metadata.num_prompts and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get @@ -138,10 +170,34 @@ def from_sampling_metadata( frequency_penalties += [f] * len(seq_ids) repetition_penalties += [r] * len(seq_ids) + is_prompt = i < sampling_metadata.num_prompts + if is_prompt: + prompt_best_of.append(sampling_params.best_of) + prompt_len = sampling_metadata.prompt_lens[i] + + if sampling_params.prompt_logprobs is not None: + # NOTE: the sampling position is the last token + # in the prompt + sample_indices_start_idx += prompt_len - 1 + for seq_id in seq_ids: + seq_data = sampling_metadata.seq_data[seq_id] + extra_entropy = extra_entropy or () + seq_seeds = cls._get_sequence_seeds( + seed, + seq_data.get_len(), + *extra_entropy, + seq_id, + seeds_to_generate=seeds_to_generate, + is_greedy=is_greedy) + sampling_seeds.append(seq_seeds) + sample_indices.append(sample_indices_start_idx) + sample_indices_start_idx += 1 + sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, - frequency_penalties, repetition_penalties, prompt_tokens, - output_tokens, vocab_size, device, dtype) + frequency_penalties, repetition_penalties, sampling_seeds, + sample_indices, prompt_tokens, output_tokens, vocab_size, + extra_seeds_to_generate, device, dtype) return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) @classmethod @@ -150,9 +206,10 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], presence_penalties: List[float], frequency_penalties: List[float], repetition_penalties: List[float], + sampling_seeds: List[int], sample_indices: List[int], prompt_tokens: List[List[int]], output_tokens: List[List[int]], vocab_size: int, - device: torch.device, + extra_seeds_to_generate: int, device: torch.device, dtype: torch.dtype) -> "SamplingTensors": # Note that the performance will be very bad without # pinned memory. @@ -210,6 +267,12 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype=torch.int, pin_memory=pin_memory, ) + sample_indices_t = torch.tensor( + sample_indices, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) prompt_tensor = torch.tensor( prompt_padded_tokens, device="cpu", @@ -222,8 +285,28 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype=torch.long, pin_memory=pin_memory, ) + # need to transpose and make contiguous to + # copy the tensor correctly. + # [batch_size, n_seeds] -> [n_seeds, batch_size] + sampling_seeds_t = torch.tensor( + sampling_seeds, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ).T.contiguous() + # Because the memory is pinned, we can do non-blocking # transfer to device. + + # How many seeds the sample operation itself will need. + num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate + sampling_seeds_gpu = sampling_seeds_t.to(device=device, + non_blocking=True) + extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:] + if not extra_seeds_gpu.numel(): + extra_seeds_gpu = None + sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds] + return cls( temperatures=temperatures_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True), @@ -237,4 +320,38 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], non_blocking=True), prompt_tokens=prompt_tensor.to(device=device, non_blocking=True), output_tokens=output_tensor.to(device=device, non_blocking=True), + sampling_seeds=sampling_seeds_gpu, + sample_indices=sample_indices_t.to(device=device, + non_blocking=True), + extra_seeds=extra_seeds_gpu, ) + + @staticmethod + def _get_sequence_seeds( + seed: int, + *extra_entropy: int, + seeds_to_generate: int, + is_greedy: bool, + ): + """Get `seeds_to_generate` child seeds from `seed` and extra entropy.""" + if not is_greedy: + if seed is None: + randint_fn = random.randint + else: + generator = random.Random(str((seed, ) + extra_entropy)) + randint_fn = generator.randint + lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max + # If the user/random sets seed = 0 but request should + # have sampling, we need to change it to something + # else. We use a constant in that case. + # This way we don't need to create and load a bool + # matrix in the sampling kernel, which reduces CPU + # overhead and latency. + seq_seeds = [ + randint_fn(lo, hi) or _SEED_0_REPLACEMENT + for _ in range(seeds_to_generate) + ] + else: + # For the kernel, seed == 0 means greedy decoding. + seq_seeds = [0] * seeds_to_generate + return seq_seeds diff --git a/vllm/sequence.py b/vllm/sequence.py index 4a002edaf580f..ff96dd306791c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -242,6 +242,9 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.data.get_token_ids() + def get_prompt_token_ids(self) -> List[int]: + return self.data.get_prompt_token_ids() + def get_last_token_id(self) -> int: return self.data.get_last_token_id() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 27213887ed265..7e25311fa2268 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -408,6 +408,7 @@ def _prepare_sample( selected_token_start_idx = 0 categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices_start_idx = 0 + categorized_sampled_token_indices_start_idx = 0 pin_memory = not self.in_wsl and not self.device_config.is_neuron max_subquery_len = max(subquery_lens) if subquery_lens else 1 @@ -425,9 +426,12 @@ def _prepare_sample( categorized_sample_indices_start_idx += subquery_len - 1 categorized_sample_indices[ - sampling_params.sampling_type].append( - categorized_sample_indices_start_idx) + sampling_params.sampling_type].append([ + categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx + ]) categorized_sample_indices_start_idx += 1 + categorized_sampled_token_indices_start_idx += 1 if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( @@ -449,9 +453,17 @@ def _prepare_sample( categorized_sample_indices[ sampling_params.sampling_type].extend( - range(categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + num_seqs)) + zip( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + + num_seqs), + range( + categorized_sampled_token_indices_start_idx, + categorized_sampled_token_indices_start_idx + + num_seqs))) categorized_sample_indices_start_idx += num_seqs + categorized_sampled_token_indices_start_idx += num_seqs if sampling_params.seed is not None: generators.append(seq_group_metadata.state.generator) @@ -459,12 +471,14 @@ def _prepare_sample( selected_token_indices = _async_h2d(selected_token_indices, dtype=torch.long, target_device=self.device, - pin_memory=pin_memory) + pin_memory=not self.in_wsl) + categorized_sample_indices = { - t: _async_h2d(seq_ids, - dtype=torch.int, - target_device=self.device, - pin_memory=pin_memory) + t: _maybe_expand_dim( + _async_h2d(seq_ids, + dtype=torch.int, + target_device=self.device, + pin_memory=pin_memory), 2, 2) for t, seq_ids in categorized_sample_indices.items() } @@ -884,3 +898,11 @@ def _async_h2d( ) -> torch.Tensor: t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") return t.to(device=target_device, non_blocking=True) + + +def _maybe_expand_dim(tensor: torch.Tensor, + target_dims: int, + size: int = 1) -> torch.Tensor: + if tensor.ndim < target_dims: + tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim))) + return tensor From 6e435de766c7749b214b637ac58570a221006c95 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 21 Mar 2024 06:46:05 +0900 Subject: [PATCH 13/13] [1/n][Chunked Prefill] Refactor input query shapes (#3236) --- .buildkite/test-pipeline.yaml | 4 +- .../test_basic_correctness.py | 4 +- tests/core/test_scheduler.py | 18 +- tests/lora/test_worker.py | 2 +- tests/spec_decode/test_multi_step_worker.py | 4 +- tests/worker/test_model_runner.py | 161 +++++++++++- vllm/config.py | 3 - vllm/core/scheduler.py | 13 +- vllm/engine/arg_utils.py | 8 +- vllm/engine/llm_engine.py | 1 - vllm/model_executor/input_metadata.py | 82 +++++- vllm/model_executor/layers/activation.py | 4 +- .../layers/attention/attention.py | 3 +- .../layers/attention/backends/flash_attn.py | 46 +++- .../layers/attention/backends/xformers.py | 232 ++++++++++------- .../layers/attention/ops/paged_attn.py | 9 +- vllm/model_executor/layers/sampler.py | 1 - vllm/worker/model_runner.py | 239 +++++++++++------- 18 files changed, 575 insertions(+), 259 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6ae351130f203..17f4c33670821 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -47,7 +47,7 @@ steps: - pytest -v -s prefix_caching - label: Samplers Test - command: pytest -v -s samplers --forked + command: pytest -v -s samplers - label: Worker Test command: pytest -v -s worker @@ -56,7 +56,7 @@ steps: command: pytest -v -s spec_decode - label: LoRA Test %N - command: pytest -v -s lora --forked --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 - label: Metrics Test diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index fe67e0f2f4808..da0176306b4ee 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -13,6 +13,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("enforce_eager", [False, True]) def test_models( hf_runner, vllm_runner, @@ -20,12 +21,13 @@ def test_models( model: str, dtype: str, max_tokens: int, + enforce_eager: bool, ) -> None: hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype) + vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index ebfeb8ba04812..397101fa86104 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -10,7 +10,7 @@ def test_scheduler_add_seq_group(): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1, 256) + scheduler_config = SchedulerConfig(100, 64, 1) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -26,7 +26,7 @@ def test_scheduler_add_seq_group(): def test_scheduler_abort_seq_group(): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1, 256) + scheduler_config = SchedulerConfig(100, 64, 1) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -50,7 +50,7 @@ def test_scheduler_schedule_simple(): block_size = 4 num_seq_group = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len, 256) + scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -64,10 +64,10 @@ def test_scheduler_schedule_simple(): running.append(seq_group) # Schedule seq groups prompts. + num_tokens = block_size * num_seq_group seq_group_meta, out = scheduler.schedule() assert set(out.scheduled_seq_groups) == set(running) - assert out.num_batched_tokens == num_seq_group * seq_group.get_seqs( - )[0].get_len() + assert out.num_batched_tokens == num_tokens assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == num_seq_group @@ -84,7 +84,7 @@ def test_scheduler_schedule_simple(): def test_scheduler_schedule_preempt_abort(): block_size = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, 2, max_model_len, 256) + scheduler_config = SchedulerConfig(64, 2, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 cache_config.num_gpu_blocks = 2 @@ -99,7 +99,7 @@ def test_scheduler_schedule_preempt_abort(): # Schedule seq groups prompts. seq_group_meta, out = scheduler.schedule() assert out.scheduled_seq_groups == [seq_group_a, seq_group_b] - assert out.num_batched_tokens == seq_group_a.get_seqs()[0].get_len() * 2 + assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == 2 @@ -124,7 +124,7 @@ def test_scheduler_schedule_preempt_abort(): scheduler.abort_seq_group("1") seq_group_meta, out = scheduler.schedule() assert out.scheduled_seq_groups == [seq_group_b] - assert out.num_batched_tokens == seq_group_b.get_seqs()[0].get_len() + assert out.num_batched_tokens == 5 # 4 prompt + 1 generation. assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == 1 @@ -136,7 +136,7 @@ def test_scheduler_max_seqs(): num_seq_group = 4 max_seq_group = 2 max_model_len = 16 - scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len, 256) + scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 31a7c716afbf2..e4538de35169b 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -25,7 +25,7 @@ def test_worker_apply_lora(sql_lora_files): revision=None, ), parallel_config=ParallelConfig(1, 1, False), - scheduler_config=SchedulerConfig(32, 32, 32, 256), + scheduler_config=SchedulerConfig(32, 32, 32), device_config=DeviceConfig("cuda"), local_rank=0, rank=0, diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 45b43ec59ee8f..5f788549d44d0 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -92,8 +92,8 @@ def test_same_output_for_single_step(): num_gpu_blocks, seed, ) - multi_step_worker.model_runner = worker.model_runner - multi_step_worker.cache_engine = worker.cache_engine + # multi_step_worker.model_runner = worker.model_runner + # multi_step_worker.cache_engine = worker.cache_engine num_steps = 1 diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index f44895a728c7e..44b22c2bd8a21 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,8 +1,13 @@ import random import torch +from vllm.config import ModelConfig from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import ModelRunner +from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT + + +def get_aligned_size(batch_size: int, alignment: int): + return ((batch_size + alignment - 1) // alignment * alignment) def test_prepare_prompt(): @@ -12,6 +17,7 @@ def test_prepare_prompt(): batch_size = random.randint(1, 256) prompt_lens = [] seq_group_metadata_list = [] + block_tables = {0: [1]} for i in range(batch_size): # make sure all tokens fit into one block prompt_len = i % (model_runner.block_size - 1) + 1 @@ -23,26 +29,165 @@ def test_prepare_prompt(): is_prompt=True, seq_data={0: SequenceData(seq_data)}, sampling_params=SamplingParams(temperature=0), - block_tables={0: [1]}, + block_tables=block_tables, )) expected_selected_token_indices = [] selected_token_start_idx = 0 - max_seq_len = max(prompt_lens) for prompt_len in prompt_lens: expected_selected_token_indices.append(selected_token_start_idx + prompt_len - 1) - selected_token_start_idx += max_seq_len - input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = ( - model_runner._prepare_prompt(seq_group_metadata_list)) + selected_token_start_idx += prompt_len + (input_tokens, input_positions, input_metadata, return_prompt_lens, _, _, + _, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens + + # Verify input metadata is correct for prompts. + device = model_runner.device + assert input_metadata.is_prompt is True + assert torch.allclose(input_metadata.prompt_lens_tensor, + torch.tensor(prompt_lens, device=device)) + assert input_metadata.prompt_lens == prompt_lens + assert input_metadata.num_prompt_tokens == sum(prompt_lens) + assert input_metadata.num_generation_tokens == 0 + assert input_metadata.max_seq_len == max(prompt_lens) + + # Test subquery start locs. + start_idx = 0 + start_loc = [start_idx] + for prompt_len in prompt_lens: + start_idx += prompt_len + start_loc.append(start_idx) + assert torch.allclose( + input_metadata.subquery_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + + # Test seq start locs. Note that for normal prefill it is + # equivalent to subquery_start_loc. + start_idx = 0 + seq_start_loc = [start_idx] + for prompt_len in prompt_lens: + start_idx += prompt_len + seq_start_loc.append(start_idx) + + assert torch.allclose( + input_metadata.seq_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + assert input_metadata.max_context_len is None + assert torch.allclose( + input_metadata.context_lens, + torch.zeros(input_metadata.context_lens.shape[0], + dtype=torch.int, + device=device)) + + expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.allclose(input_metadata.block_tables, expected) + # Cuda graph should not be used for prerill. + assert input_metadata.use_cuda_graph is False + assert input_metadata.kv_cache_dtype == "auto" + + assert input_tokens.shape == (sum(prompt_lens), ) + assert input_positions.shape == (sum(prompt_lens), ) + torch.testing.assert_close(input_tokens, input_positions) + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - assert input_tokens.shape == (batch_size, max_seq_len) - assert input_positions.shape == (batch_size, max_seq_len) + assert input_tokens.shape == (sum(prompt_lens), ) + assert input_positions.shape == (sum(prompt_lens), ) + actual = sampling_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) + torch.testing.assert_close(input_tokens, input_positions) + + actual = sampling_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) + + +def test_prepare_decode_cuda_graph(): + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=False, + ) + model_runner = ModelRunner(model_config, None, None, None, None) + model_runner.set_block_size(16) + + batch_size = random.randint(1, 256) + prompt_lens = [] + seq_group_metadata_list = [] + for i in range(batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_lens.append(prompt_len) + seq_data = list(range(prompt_len)) + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: SequenceData(seq_data)}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + )) + + input_tokens, input_positions, input_metadata, _, _, _ = ( + model_runner._prepare_decode(seq_group_metadata_list)) + + # Verify input metadata is correct for prompts. + device = model_runner.device + assert input_metadata.is_prompt is False + assert input_metadata.prompt_lens is None + assert input_metadata.num_prompt_tokens == 0 + assert input_metadata.num_generation_tokens == (get_aligned_size( + len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT)) + assert input_metadata.max_seq_len is None + assert input_metadata.subquery_start_loc is None + assert input_metadata.seq_start_loc is None + assert input_metadata.max_context_len == max(prompt_lens) + assert torch.allclose( + input_metadata.context_lens[:len(prompt_lens)], + torch.tensor(prompt_lens, dtype=torch.int, device=device)) + + # block table's first index corresponds to each batch, meaning in + # decoding it is each token. + assert input_metadata.block_tables.shape[0] == len(input_tokens) + # Block table's second dim correspondsd to each token's block number. + # It is padded up to + assert input_metadata.block_tables.shape[1] == ( + model_runner.get_max_block_per_batch()) + # Cuda graph should not be used for prerill. + assert input_metadata.use_cuda_graph is True + assert input_metadata.kv_cache_dtype == "auto" + + assert input_tokens.shape == (get_aligned_size( + len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), ) + assert input_positions.shape == (get_aligned_size( + len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), ) torch.testing.assert_close(input_tokens, input_positions) + # Verify Sampling + expected_selected_token_indices = [] + selected_token_start_idx = 0 + for prompt_len in prompt_lens: + expected_selected_token_indices.append(selected_token_start_idx) + selected_token_start_idx += 1 + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, diff --git a/vllm/config.py b/vllm/config.py index 51ae66e2375ab..b769ecdce8808 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -535,7 +535,6 @@ class SchedulerConfig: iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). - max_paddings: Maximum number of paddings to be added to a batch. """ def __init__( @@ -543,7 +542,6 @@ def __init__( max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, - max_paddings: int, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -553,7 +551,6 @@ def __init__( self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len - self.max_paddings = max_paddings self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c3f93a2928df5..be55e8520a55f 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -173,12 +173,12 @@ def _schedule(self) -> SchedulerOutputs: curr_loras = set( seq_group.lora_int_id for seq_group in self.running) if self.lora_enabled else None - seq_lens: List[int] = [] # Optimization: We do not sort the waiting queue since the preempted # sequence groups are added to the front and the new sequence groups # are added to the back. leftover_waiting_sequences = deque() + num_batched_tokens = 0 while self.waiting: seq_group = self.waiting[0] waiting_seqs = seq_group.get_seqs( @@ -223,8 +223,7 @@ def _schedule(self) -> SchedulerOutputs: continue # If the number of batched tokens exceeds the limit, stop. - new_seq_lens = seq_lens + [num_prompt_tokens] - num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) + num_batched_tokens += num_prompt_tokens if (num_batched_tokens > self.scheduler_config.max_num_batched_tokens): break @@ -236,11 +235,6 @@ def _schedule(self) -> SchedulerOutputs: self.scheduler_config.max_num_seqs): break - num_paddings = num_batched_tokens - sum(new_seq_lens) - if num_paddings > self.scheduler_config.max_paddings: - break - seq_lens = new_seq_lens - if lora_int_id > 0: curr_loras.add(lora_int_id) self.waiting.popleft() @@ -255,8 +249,7 @@ def _schedule(self) -> SchedulerOutputs: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, prompt_run=True, - num_batched_tokens=len(seq_lens) * - max(seq_lens) if seq_lens else 0, + num_batched_tokens=num_batched_tokens, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3e146d2e6c0c4..94c80f4284067 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -31,7 +31,6 @@ class EngineArgs: gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 - max_paddings: int = 256 max_logprobs: int = 5 # OpenAI default value disable_log_stats: bool = False revision: Optional[str] = None @@ -213,10 +212,6 @@ def add_cli_args( type=int, default=EngineArgs.max_num_seqs, help='maximum number of sequences per iteration') - parser.add_argument('--max-paddings', - type=int, - default=EngineArgs.max_paddings, - help='maximum number of paddings in a batch') parser.add_argument( '--max-logprobs', type=int, @@ -347,8 +342,7 @@ def create_engine_configs( ), self.ray_workers_use_nsight) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, - model_config.max_model_len, - self.max_paddings) + model_config.max_model_len) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 71798ab7d17c0..2280481cca9cb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -561,7 +561,6 @@ def _process_model_outputs( # Log stats. if self.log_stats: self.stat_logger.log(self._get_stats(scheduler_outputs)) - return request_outputs def step(self) -> List[RequestOutput]: diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 01bba70ac10a8..35245865fb1b1 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,36 +1,92 @@ from dataclasses import dataclass, fields -from typing import Optional, Any, Dict +from typing import Optional, List, Any, Dict import torch +from xformers.ops.fmha.attn_bias import AttentionBias @dataclass class InputMetadata: """Metadata for input sequences. Used in PagedAttention. - Args: - prompt_lens: Lengths of prompts. - slot_mapping: The address to write the new KV to of each token. - max_context_len: The maximum context length. - context_lens: the length of attention context for each sequence. - block_tables: The block tables. (Seq id -> list of physical block) - kv_cache_dtype: Data type to store kv cache. + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. """ - + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. is_prompt: bool + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor - prompt_lens: Optional[torch.Tensor] - max_seq_len: Optional[int] - start_loc: Optional[torch.Tensor] + # (batch_size,). The prompt length per sequence. None if it is a decoding. + prompt_lens: Optional[List[int]] + # prompt_lens stored as a tensor. + prompt_lens_tensor: Optional[torch.Tensor] + # The number of prompt tokens. Doesn't include padding. + num_prompt_tokens: int + # The number of generation tokens. Doesn't include padding. + num_generation_tokens: int + """ + Definition of context_len, subquery_len, and seqlen. + |---------- N-1 iteration --------| + |---------------- N iteration ---------------------| + |- tokenA -|......................|-- newTokens ---| + |---------- context_len ----------| + |-------------------- seqlen ----------------------| + |- subquery_len -| + + WARNING: context_len has different definition depending on if it is + prefill vs decoding. When it is prefill, it doesn't include new + tokens. When it is for decoding, it includes a new token. + """ + + # Maximum subquery length in the batch. + max_subquery_len: Optional[int] + # Maximum context length in the batch. max_context_len: Optional[int] + # FIXME: It is for flash attn. + # Maximum sequence length in the batch. + max_seq_len: Optional[int] + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + subquery_start_loc: Optional[torch.Tensor] + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,). The length of context (tokens stored in KV cache) per + # sequence. WARNING: When it is a prefill request, it doesn't include new + # tokens. When it is for decoding, it includes a new token. context_lens: Optional[torch.Tensor] + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. block_tables: Optional[torch.Tensor] + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. use_cuda_graph: bool kv_cache_dtype: str def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. # will not appear in the __repr__ and __init__ - self.attn_bias = None + self.attn_bias: Optional[List[AttentionBias]] = None + + # Cuda graph is only used for decoding now. + if self.use_cuda_graph: + assert self.num_prompt_tokens == 0 def asdict_zerocopy(self) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 3eb73ee109f50..f569a5a49cbdf 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -20,8 +20,8 @@ class SiluAndMul(nn.Module): The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. Shapes: - x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) - return: (batch_size, seq_len, d) or (num_tokens, d) + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) """ def _forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 4b63b9eaf59a7..ae598b029a007 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -17,11 +17,12 @@ class Attention(nn.Module): This class takes query, key, and value tensors as input. The input tensors can either contain prompt tokens or generation tokens. + The class does the following: 1. Store the input key and value tensors in the KV cache. 2. Perform (multi-head/multi-query/grouped-query) attention. - 3. Return the output tensor. + 3. Output the output tensor. """ def __init__( diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index 58ccd461b993e..9ce5851f3650d 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -1,7 +1,7 @@ """Attention layer with Flash and PagedAttention.""" from typing import List, Optional -from flash_attn import flash_attn_func +from flash_attn import flash_attn_varlen_func import torch from vllm.model_executor.input_metadata import InputMetadata @@ -10,6 +10,21 @@ class FlashAttentionBackend: + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ def __init__( self, @@ -52,18 +67,18 @@ def forward( """Forward pass with FlashAttention and PagedAttention. Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] input_metadata: metadata for the inputs. Returns: - shape = [batch_size, seq_len, num_heads * head_size] + shape = [num_tokens, num_heads * head_size] """ - batch_size, seq_len, hidden_size = query.shape + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -82,13 +97,16 @@ def forward( if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): # normal attention - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) - output = flash_attn_func( - query, - key, - value, + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=input_metadata.seq_start_loc, + cu_seqlens_k=input_metadata.seq_start_loc, + max_seqlen_q=input_metadata.max_seq_len, + max_seqlen_k=input_metadata.max_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -118,4 +136,4 @@ def forward( ) # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + return output.view(num_tokens, hidden_size) diff --git a/vllm/model_executor/layers/attention/backends/xformers.py b/vllm/model_executor/layers/attention/backends/xformers.py index bad2a648b6703..f0ef9fac9aaa4 100644 --- a/vllm/model_executor/layers/attention/backends/xformers.py +++ b/vllm/model_executor/layers/attention/backends/xformers.py @@ -14,6 +14,21 @@ class XFormersBackend: + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens --------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ def __init__( self, @@ -55,19 +70,18 @@ def forward( """Forward pass with xFormers and PagedAttention. Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] input_metadata: metadata for the inputs. Returns: - shape = [batch_size, seq_len, num_heads * head_size] + shape = [num_tokens, num_heads * head_size] """ - batch_size, seq_len, hidden_size = query.shape - # Reshape the query, key, and value tensors. + num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) @@ -82,9 +96,10 @@ def forward( if input_metadata.is_prompt: # Prompt run. + # key_cache and value_cache are None when it is a profiling run. + # block tables are empty if the prompt has never been computed. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): - # normal attention if self.num_kv_heads != self.num_heads: # As of Nov 2023, xformers only supports MHA. For MQA/GQA, # project the key and value tensors to the desired number of @@ -103,61 +118,33 @@ def forward( self.num_queries_per_kv, value.shape[-1]) - # Set attention bias if not provided. This typically happens at - # the very attention layer of every iteration. - # FIXME(woosuk): This is a hack. - if input_metadata.attn_bias is None: - if self.alibi_slopes is None: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * batch_size) - if self.sliding_window is not None: - attn_bias = attn_bias.make_local_attention( - self.sliding_window) - input_metadata.attn_bias = attn_bias - else: - input_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, batch_size, - seq_len, query.dtype) - if self.use_ref_attention: - output = _ref_masked_attention( - query, - key, - value, - self.num_heads, - self.num_kv_heads, - self.head_size, - self.scale, - ) + print("ref attention used.") + output = torch.empty_like(query) + start = 0 + for _, prompt_len in enumerate(input_metadata.prompt_lens): + end = start + prompt_len + out = _ref_masked_attention( + query[None, start:end], + key[None, start:end], + value[None, start:end], + self.num_heads, + self.num_kv_heads, + self.head_size, + self.scale, + ) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out) + start += prompt_len + # Using view got RuntimeError: view size is not compatible # with input tensor's size and stride (at least one # dimension spans across two contiguous subspaces). # Use reshape instead. - return output.reshape(batch_size, seq_len, hidden_size) - - # TODO(woosuk): Too many view operations. Let's try to reduce - # them in the future for code readability. - if self.alibi_slopes is None: - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - else: - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) - - out = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if - (is_hip()) else None, - ) - output = out.view_as(query) + return output.reshape(num_tokens, hidden_size) + output = self._run_memory_efficient_xformer_forward( + query, key, value, input_metadata) else: # prefix-enabled attention output = PagedAttentionImpl.forward_prefix( @@ -182,41 +169,117 @@ def forward( ) # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + return output.view(-1, self.num_heads * self.head_size) + + def _run_memory_efficient_xformer_forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + """Attention for 1D query of multiple prompts. Multiple prompt + tokens are flattened in to `query` input. + + Args: + output: shape = [num_prompt_tokens, num_heads, head_size] + query: shape = [num_prompt_tokens, num_heads, head_size] + key: shape = [num_prompt_tokens, num_kv_heads, head_size] + value: shape = [num_prompt_tokens, num_kv_heads, head_size] + input_metadata: metadata for paged attention. + """ + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + if input_metadata.attn_bias is None: + if self.alibi_slopes is None: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + input_metadata.prompt_lens) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + input_metadata.attn_bias = [attn_bias] + else: + input_metadata.attn_bias = _make_alibi_bias( + self.alibi_slopes, self.num_kv_heads, query.dtype, + input_metadata) + + op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if ( + is_hip()) else None + # No alibi slopes. + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. + if self.alibi_slopes is None: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias[0], + p=0.0, + scale=self.scale, + op=op) + + return out.view_as(query) + + # Attention with alibi slopes. + # FIXME(woosuk): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + output = torch.empty_like(query) + start = 0 + for i, prompt_len in enumerate(input_metadata.prompt_lens): + end = start + prompt_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=input_metadata.attn_bias[i], + p=0.0, + scale=self.scale, + op=op) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out.squeeze(0)) + start += prompt_len + return output def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, - batch_size: int, - seq_len: int, dtype: torch.dtype, + input_metadata: InputMetadata, ) -> LowerTriangularMaskWithTensorBias: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - # When using custom attention bias, xformers requires the bias to - # be sliced from a tensor whose length is a multiple of 8. - padded_len = (seq_len + 7) // 8 * 8 - num_heads = alibi_slopes.shape[0] - bias = torch.empty( - batch_size, - num_heads, - seq_len, - padded_len, - device=alibi_slopes.device, - dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias.mul_(alibi_slopes[:, None, None]) - if num_heads != num_kv_heads: - bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) - attn_bias = LowerTriangularMaskWithTensorBias(bias) - return attn_bias + attn_biases = [] + for prompt_len in input_metadata.prompt_lens: + bias = torch.arange(prompt_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(prompt_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + # Calculate a matrix where each element represents ith element- jth + # element. + bias = bias[None, :] - bias[:, None] + + padded_len = (prompt_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + prompt_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :prompt_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) + + return attn_biases def _check_use_ref_attention() -> bool: @@ -239,7 +302,6 @@ def _ref_masked_attention( query = query.view(-1, num_heads, head_size) key = key.view(-1, num_kv_heads, head_size) value = value.view(-1, num_kv_heads, head_size) - seq_len, _, _ = query.shape attn_mask = torch.triu(torch.ones(seq_len, seq_len, diff --git a/vllm/model_executor/layers/attention/ops/paged_attn.py b/vllm/model_executor/layers/attention/ops/paged_attn.py index c5a9618c2395b..3105ba37b9832 100644 --- a/vllm/model_executor/layers/attention/ops/paged_attn.py +++ b/vllm/model_executor/layers/attention/ops/paged_attn.py @@ -128,11 +128,12 @@ def forward_prefix( output, key_cache, value_cache, - input_metadata.block_tables, # [BS, max_block_per_request] - input_metadata.start_loc, - input_metadata.prompt_lens, + input_metadata.block_tables, + # subquery_start_loc is (batch_size + 1,) + input_metadata.subquery_start_loc[:-1], + input_metadata.prompt_lens_tensor, input_metadata.context_lens, - input_metadata.max_seq_len, + input_metadata.max_subquery_len, alibi_slopes, ) return output diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1fab1e734e1d7..ac8336ca0f9ad 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -128,7 +128,6 @@ def _prune_hidden_states( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) return hidden_states.index_select(0, sampling_metadata.selected_token_indices) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7e25311fa2268..cfccbbb20adc5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -28,9 +28,12 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] _PAD_SLOT_ID = -1 LORA_WARMUP_RANK = 8 -# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. +_BATCH_SIZE_ALIGNMENT = 8 +# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # NOTE: _get_graph_batch_size needs to be updated if this list is changed. -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) +] class ModelRunner: @@ -107,8 +110,7 @@ def load_model(self) -> None: ), "Model does not have embedding_padding_modules" self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens + - self.scheduler_config.max_paddings, self.vocab_size, + self.scheduler_config.max_num_batched_tokens, self.vocab_size, self.lora_config, self.device, self.model.embedding_modules, self.model.embedding_padding_modules) self.model = self.lora_manager.create_lora_manager(self.model) @@ -116,10 +118,13 @@ def load_model(self) -> None: def set_block_size(self, block_size: int) -> None: self.block_size = block_size - max_num_blocks = (self.max_context_len_to_capture + block_size - - 1) // block_size self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + dtype=np.int32) + + def get_max_block_per_batch(self) -> int: + block_size = self.block_size + return (self.max_context_len_to_capture + block_size - 1) // block_size def _prepare_prompt( self, @@ -127,9 +132,9 @@ def _prepare_prompt( ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], List[int], List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() @@ -158,16 +163,18 @@ def _prepare_prompt( computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) + context_len = computed_len else: prefix_block_tables.append([]) + context_len = 0 # actual prompt lens - context_lens.append(computed_len) + context_lens.append(context_len) subquery_lens.append(prompt_len - computed_len) - input_tokens.append(prompt_tokens) + input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.append( + input_positions.extend( list(range(computed_len, computed_len + len(prompt_tokens)))) lora_id = seq_group_metadata.lora_int_id @@ -175,7 +182,7 @@ def _prepare_prompt( if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping.append([lora_id] * (prompt_len - computed_len)) + lora_index_mapping += [lora_id] * (prompt_len - computed_len) lora_prompt_mapping.extend( [lora_id] * (prompt_len - computed_len @@ -184,11 +191,10 @@ def _prepare_prompt( if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. - slot_mapping.append([_PAD_SLOT_ID] * prompt_len) + slot_mapping.extend([_PAD_SLOT_ID] * prompt_len) continue # Compute the slot mapping. - slot_mapping.append([]) block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, prompt_len - sliding_window). @@ -203,35 +209,30 @@ def _prepare_prompt( start_idx = max(0, prompt_len - self.sliding_window) for i in range(computed_len, prompt_len): if i < start_idx: - slot_mapping[-1].append(_PAD_SLOT_ID) + slot_mapping.append(_PAD_SLOT_ID) continue block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping[-1].append(slot) - - max_prompt_len = max(subquery_lens) - assert max_prompt_len > 0 - input_tokens = _make_tensor_with_pad(input_tokens, - max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - input_positions = _make_tensor_with_pad(input_positions, - max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - slot_mapping = _make_tensor_with_pad(slot_mapping, - max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) - lora_index_mapping = [ - _pad_to_max(mapping, max_prompt_len, pad=0) - for mapping in lora_index_mapping - ] + slot_mapping.append(slot) + + max_subquery_len = max(subquery_lens) + max_seq_len = max(prompt_lens) + num_prompt_tokens = len(input_tokens) + assert max_subquery_len > 0 + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + lora_index_mapping = lora_index_mapping + context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -244,22 +245,45 @@ def _prepare_prompt( dtype=torch.int, device=self.device, ) - start_loc_tensor = torch.arange(0, - len(prompt_lens) * max_prompt_len, - max_prompt_len, - dtype=torch.long, - device=self.device) + + # Query length can be shorter than key (i.e., prompt) when prefill + # is chunked or prefix cached. + subquery_lens_tensor = torch.tensor(subquery_lens, + dtype=torch.long, + device=self.device) + subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + prompt_lens_tensor = torch.tensor(prompt_lens, dtype=torch.long, device=self.device) + seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(subquery_lens_tensor, + dim=0, + dtype=subquery_start_loc.dtype, + out=subquery_start_loc[1:]) + + torch.cumsum(prompt_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) input_metadata = InputMetadata( is_prompt=True, slot_mapping=slot_mapping, - prompt_lens=prompt_lens_tensor, - max_seq_len=max_prompt_len, - start_loc=start_loc_tensor, + prompt_lens=prompt_lens, + prompt_lens_tensor=prompt_lens_tensor, + num_prompt_tokens=num_prompt_tokens, + num_generation_tokens=0, + max_subquery_len=max_subquery_len, max_context_len=None, + max_seq_len=max_seq_len, + subquery_start_loc=subquery_start_loc, + seq_start_loc=seq_start_loc, context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -275,9 +299,9 @@ def _prepare_decode( ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] context_lens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] @@ -296,11 +320,11 @@ def _prepare_decode( for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) + input_tokens.append(generation_token) seq_len = seq_data.get_len() position = seq_len - 1 - input_positions.append([position]) + input_positions.append(position) context_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) @@ -310,8 +334,8 @@ def _prepare_decode( block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping.append([slot]) - lora_index_mapping.append([lora_id]) + slot_mapping.append(slot) + lora_index_mapping.append(lora_id) lora_prompt_mapping.append(lora_id) if self.sliding_window is not None: @@ -320,6 +344,9 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + # vLLM uses cuda graph only for decoding requests. + # See `capture_model` API for more details. + # For decoding requests, batch_size == input_tokens. batch_size = len(input_tokens) max_context_len = max(context_lens) use_captured_graph = ( @@ -327,38 +354,37 @@ def _prepare_decode( and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_context_len <= self.max_context_len_to_capture) if use_captured_graph: - # Pad the input tokens, positions, and slot mapping to match the - # batch size of the captured graph. graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size for _ in range(graph_batch_size - batch_size): - input_tokens.append([]) - input_positions.append([]) - slot_mapping.append([]) + input_tokens.append(0) + input_positions.append(0) + slot_mapping.append(_PAD_SLOT_ID) context_lens.append(1) block_tables.append([]) + lora_index_mapping.append(0) batch_size = graph_batch_size - input_tokens = _make_tensor_with_pad(input_tokens, - max_len=1, - pad=0, - dtype=torch.long, - device=self.device) - input_positions = _make_tensor_with_pad(input_positions, - max_len=1, - pad=0, - dtype=torch.long, - device=self.device) - slot_mapping = _make_tensor_with_pad(slot_mapping, - max_len=1, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) if use_captured_graph: + # When using cuda-graph all these tensors should be + # padded. + assert context_lens.shape[0] == input_tokens.shape[0] + assert context_lens.shape[0] == input_positions.shape[0] + assert context_lens.shape[0] == slot_mapping.shape[0] + # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.graph_block_tables[:batch_size] @@ -377,17 +403,18 @@ def _prepare_decode( device=self.device, ) - lora_index_mapping = [ - _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping - ] - input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping, prompt_lens=None, - max_seq_len=None, - start_loc=None, + prompt_lens_tensor=None, + num_prompt_tokens=0, + num_generation_tokens=len(input_tokens), + max_subquery_len=None, max_context_len=max_context_len, + max_seq_len=None, + subquery_start_loc=None, + seq_start_loc=None, context_lens=context_lens, block_tables=block_tables, use_cuda_graph=use_captured_graph, @@ -411,7 +438,6 @@ def _prepare_sample( categorized_sampled_token_indices_start_idx = 0 pin_memory = not self.in_wsl and not self.device_config.is_neuron - max_subquery_len = max(subquery_lens) if subquery_lens else 1 for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params @@ -439,7 +465,7 @@ def _prepare_sample( selected_token_start_idx + subquery_len - 1)) selected_token_indices.append(selected_token_start_idx + subquery_len - 1) - selected_token_start_idx += max_subquery_len + selected_token_start_idx += subquery_len if sampling_params.seed is not None: seq_group_metadata.state.generator = torch.Generator( @@ -521,11 +547,8 @@ def prepare_input_tensors( subquery_lens) if self.lora_config: - flat_lora_index_mapping = [ - item for sublist in lora_index_mapping for item in sublist - ] lora_mapping = LoRAMapping( - flat_lora_index_mapping, + lora_index_mapping, lora_prompt_mapping, ) else: @@ -679,6 +702,18 @@ def list_loras(self) -> Set[int]: @torch.inference_mode() def capture_model(self, kv_caches: List[KVCache]) -> None: + """Cuda graph capture a model. + + Note that CUDA graph's performance gain is negligible if number + of batched tokens are larger than 200. And since CUDA graph + requires fixed sized tensors, supporting large/variable batch + size requires high GPU memory overhead. Thus, vLLM only captures + decoding requests. Mixed batch (chunked prefill + decoding) or + prefill requests are not captured. + + Since it is used for decoding-only, it assumes there's only 1 token + per sequence in the batch. + """ # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never # deleted before the CUDA graphs. self.cupy_nccl_backend = cupy_utils.get_nccl_backend() @@ -697,10 +732,9 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() - input_positions = torch.zeros(max_batch_size, 1, - dtype=torch.long).cuda() - slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() + input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() + input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() + slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() @@ -726,9 +760,14 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: is_prompt=False, slot_mapping=slot_mapping[:batch_size], prompt_lens=None, - max_seq_len=None, - start_loc=None, + prompt_lens_tensor=None, + num_prompt_tokens=0, + num_generation_tokens=batch_size, + max_subquery_len=None, max_context_len=self.max_context_len_to_capture, + max_seq_len=None, + subquery_start_loc=None, + seq_start_loc=None, context_lens=context_lens[:batch_size], block_tables=block_tables[:batch_size], use_cuda_graph=True, @@ -845,7 +884,6 @@ def forward( non_blocking=True) self.input_buffers["block_tables"].copy_(input_metadata.block_tables, non_blocking=True) - # Run the graph. self.graph.replay() @@ -877,17 +915,28 @@ def _make_tensor_with_pad( dtype: torch.dtype, device: Optional[Union[str, torch.device]], ) -> torch.Tensor: + """Make a padded tensor of a 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] return torch.tensor(padded_x, dtype=dtype, device=device) def _get_graph_batch_size(batch_size: int) -> int: + """Returns the padded batch size given actual batch size. + + Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, + 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... + """ if batch_size <= 2: return batch_size elif batch_size <= 4: return 4 else: - return (batch_size + 7) // 8 * 8 + return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // + _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) def _async_h2d(