Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][misc] remove logical block #5882

Merged
merged 2 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 1 addition & 81 deletions vllm/block.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,10 @@
"""Token blocks."""
import weakref
from collections import defaultdict
from typing import Dict, List
from typing import List

from vllm.utils import Device

_BLANK_TOKEN_ID = -1

DEFAULT_LAST_ACCESSED_TIME = -1

TokensBlock = List[int]


class BlockPool:
"""A pool of logical blocks.
When requests come, we create a lot of logical blocks;
when requests are done, we destroy a lot of logical blocks.
It turns out that creating and destroying logical blocks can be expensive,
especially for the `token_ids` field, which is a list of integers.
To avoid this overhead, we use a pool to manage the logical blocks.
When an old request is done and a new request comes, we can reuse the
logical blocks from the old request to feed the new request.
"""

def __init__(self) -> None:
# block size to list of token blocks
self.pool: Dict[int, List[TokensBlock]] = defaultdict(list)

def alloc_block(self, block_size: int) -> TokensBlock:
if block_size in self.pool and self.pool[block_size]:
return self.pool[block_size].pop()
return [_BLANK_TOKEN_ID] * block_size

def del_block(self, block: TokensBlock) -> None:
self.pool[len(block)].append(block)


_BLOCK_POOL = BlockPool()


class LogicalTokenBlock:
"""A block that stores a contiguous chunk of tokens from left to right.

Logical blocks are used to represent the states of the corresponding
physical blocks in the KV cache.
"""

def __init__(
self,
block_number: int,
block_size: int,
) -> None:
self.block_number = block_number
self.block_size = block_size

self.token_ids = _BLOCK_POOL.alloc_block(block_size)
# this finalizer is used to return the block to the pool when the object is deleted # noqa
# NOTE: don't use __del__ because it cannot guarantee the order of finalization, # noqa
# i.e. `self.token_ids` may be deleted before `self`, and we lose
# the opportunity to return the block to the pool
self._finalizer = weakref.finalize(self, _BLOCK_POOL.del_block,
self.token_ids)
self.num_tokens = 0

def is_empty(self) -> bool:
return self.num_tokens == 0

def get_num_empty_slots(self) -> int:
return self.block_size - self.num_tokens

def is_full(self) -> bool:
return self.num_tokens == self.block_size

def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots()
curr_idx = self.num_tokens
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
self.num_tokens += len(token_ids)

def get_token_ids(self) -> List[int]:
return self.token_ids[:self.num_tokens]

def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]


class PhysicalTokenBlock:
"""Represents the state of a block in the KV cache."""
Expand Down
19 changes: 9 additions & 10 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,7 @@ def __init__(
self.cross_block_tables: Dict[str, BlockTable] = {}

def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
return 0 if seq is None \
else len(seq.logical_token_blocks)
return 0 if seq is None else seq.n_blocks

def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
Expand Down Expand Up @@ -298,7 +297,7 @@ def _allocate_sequence(self, \
ref_count: int, \
is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks)
num_prompt_blocks = seq.n_blocks

block_table: BlockTable = []
for logical_idx in range(num_prompt_blocks):
Expand Down Expand Up @@ -367,7 +366,7 @@ def _promote_last_block(

# Compute a new hash for the block so that it can be shared by other
# Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
new_hash = seq.hash_of_block(seq.n_blocks - 1)

# if new_hash is already in the cached table, then free last_block
# and return the cached version
Expand Down Expand Up @@ -407,10 +406,10 @@ def _allocate_last_physical_block(
if not self.enable_caching:
return self.gpu_allocator.allocate()
block_hash: Optional[int] = None
n_blocks = seq.n_blocks
if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(
len(seq.logical_token_blocks) - 1)
block_hash = seq.hash_of_block(n_blocks - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1)

# num_hashed_tokens is used to compute future hashes
# (e.g. in the hashing function, it is used to ask the sequence for
Expand All @@ -429,12 +428,12 @@ def append_slots(
num_lookahead_slots: int = 0,
) -> List[Tuple[int, int]]:
"""Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks
n_blocks = seq.n_blocks
block_table = self.block_tables[seq.seq_id]
# If we need to allocate a new physical block
if len(block_table) < len(logical_blocks):
if len(block_table) < n_blocks:
# Currently this code only supports adding one physical block
assert len(block_table) == len(logical_blocks) - 1
assert len(block_table) == n_blocks - 1

if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window):
Expand Down
35 changes: 6 additions & 29 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Sequence and its related classes."""
import copy
import enum
import math
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import torch

from vllm.block import LogicalTokenBlock
from vllm.inputs import LLMInputs
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
Expand Down Expand Up @@ -236,9 +236,6 @@ def __init__(
self.output_logprobs: SampleLogprobs = []
self.output_text = ""

self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids.
self._append_tokens_to_blocks(self.prompt_token_ids)
self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None

Expand All @@ -248,6 +245,10 @@ def __init__(
# Input + output tokens
self.tokens: Optional[List[str]] = None

@property
def n_blocks(self) -> int:
return math.ceil(self.get_len() / self.block_size)

@property
def prompt(self) -> Optional[str]:
return self.inputs.get("prompt")
Expand Down Expand Up @@ -287,36 +288,12 @@ def reset_state_for_recompute(self):
"""Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute()

def _append_logical_block(self) -> None:
block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks),
block_size=self.block_size,
)
self.logical_token_blocks.append(block)

def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
cursor = 0
while cursor < len(token_ids):
if not self.logical_token_blocks:
self._append_logical_block()

last_block = self.logical_token_blocks[-1]
if last_block.is_full():
self._append_logical_block()
last_block = self.logical_token_blocks[-1]

num_empty_slots = last_block.get_num_empty_slots()
last_block.append_tokens(token_ids[cursor:cursor +
num_empty_slots])
cursor += num_empty_slots

def append_token_id(
self,
token_id: int,
logprobs: Dict[int, Logprob],
) -> None:
assert token_id in logprobs
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)

Expand Down Expand Up @@ -388,7 +365,7 @@ def is_prefill(self) -> bool:
def __repr__(self) -> str:
return (f"Sequence(seq_id={self.seq_id}, "
f"status={self.status.name}, "
f"num_blocks={len(self.logical_token_blocks)})")
f"num_blocks={self.n_blocks}, ")


@dataclass
Expand Down
Loading