Skip to content

Commit

Permalink
Fix integer overflows in attention & cache ops (vllm-project#1514)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Oct 31, 2023
1 parent 37a54d5 commit 5f7ae01
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 47 deletions.
10 changes: 8 additions & 2 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,10 @@ __device__ void paged_attention_kernel(
// dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
// because int32 can lead to overflow when this variable is multiplied by large numbers
// (e.g., kv_block_stride).
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);

// Load a key to registers.
// Each thread in a thread group has a different part of the key.
Expand Down Expand Up @@ -285,7 +288,10 @@ __device__ void paged_attention_kernel(
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
// because int32 can lead to overflow when this variable is multiplied by large numbers
// (e.g., kv_block_stride).
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec;
Expand Down
72 changes: 36 additions & 36 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,26 +55,26 @@ template<typename scalar_t>
__global__ void copy_blocks_kernel(
int64_t* key_cache_ptrs,
int64_t* value_cache_ptrs,
const int* __restrict__ block_mapping,
const int64_t* __restrict__ block_mapping,
const int numel_per_block) {
const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y;

scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
int src_block_number = block_mapping[2 * pair_idx];
int dst_block_number = block_mapping[2 * pair_idx + 1];
int64_t src_block_number = block_mapping[2 * pair_idx];
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];

const int src_block_offset = src_block_number * numel_per_block;
const int dst_block_offset = dst_block_number * numel_per_block;
const int64_t src_block_offset = src_block_number * numel_per_block;
const int64_t dst_block_offset = dst_block_number * numel_per_block;
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
int src_offset = src_block_offset + i;
int dst_offset = dst_block_offset + i;
int64_t src_offset = src_block_offset + i;
int64_t dst_offset = dst_block_offset + i;
key_cache[dst_offset] = key_cache[src_offset];
}
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
int src_offset = src_block_offset + i;
int dst_offset = dst_block_offset + i;
int64_t src_offset = src_block_offset + i;
int64_t dst_offset = dst_block_offset + i;
value_cache[dst_offset] = value_cache[src_offset];
}
}
Expand Down Expand Up @@ -102,15 +102,15 @@ void copy_blocks(
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
}
// Create block mapping array.
std::vector<int> block_mapping_vec;
std::vector<int64_t> block_mapping_vec;
for (const auto& pair : block_mapping) {
int src_block_number = pair.first;
for (int dst_block_number : pair.second) {
int64_t src_block_number = pair.first;
for (int64_t dst_block_number : pair.second) {
block_mapping_vec.push_back(src_block_number);
block_mapping_vec.push_back(dst_block_number);
}
}
int* block_mapping_array = block_mapping_vec.data();
int64_t* block_mapping_array = block_mapping_vec.data();
int num_pairs = block_mapping_vec.size() / 2;

// Move the data structures to the GPU.
Expand All @@ -120,7 +120,7 @@ void copy_blocks(
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
torch::Tensor block_mapping_tensor = torch::from_blob(
block_mapping_array, {2 * num_pairs}, torch::kInt).to(cache_device);
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);

// Launch the kernel.
const int numel_per_block = key_caches[0][0].numel();
Expand All @@ -132,7 +132,7 @@ void copy_blocks(
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping_tensor.data_ptr<int>(),
block_mapping_tensor.data_ptr<int64_t>(),
numel_per_block);
}));
}
Expand All @@ -141,46 +141,46 @@ namespace vllm {

template<typename scalar_t>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int* __restrict__ slot_mapping, // [num_tokens]
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size,
const int x) {
const int token_idx = blockIdx.x;
const int slot_idx = slot_mapping[token_idx];
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}

const int block_idx = slot_idx / block_size;
const int block_offset = slot_idx % block_size;
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;

const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int src_key_idx = token_idx * key_stride + i;
const int src_value_idx = token_idx * value_stride + i;
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;

const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;

const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x
+ block_offset * x
+ x_offset;
const int tgt_value_idx = block_idx * num_heads * head_size * block_size
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x
+ block_offset * x
+ x_offset;
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;
key_cache[tgt_key_idx] = key[src_key_idx];
value_cache[tgt_value_idx] = value[src_value_idx];
}
Expand Down Expand Up @@ -216,7 +216,7 @@ void reshape_and_cache(
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(),
slot_mapping.data_ptr<int64_t>(),
key_stride,
value_stride,
num_heads,
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 128 # Arbitrary values for testing
NUM_BLOCKS = 40000 # Arbitrary values for testing
PARTITION_SIZE = 512

DTYPES = [torch.half, torch.bfloat16, torch.float]
Expand Down
14 changes: 7 additions & 7 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from vllm import cache_ops

DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
NUM_LAYERS = [5] # Arbitrary values for testing
NUM_TOKENS = [83] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS = [1024] # Arbitrary values for testing
NUM_MAPPINGS = [32, 256] # Arbitrary values for testing
NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing
NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0]


Expand Down Expand Up @@ -69,9 +69,9 @@ def test_copy_blocks(
for src, dsts in block_mapping.items():
for dst in dsts:
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst] = cloned_key_cache[src]
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst] = cloned_value_cache[src]
cloned_value_cache[dst].copy_(cloned_value_cache[src])

# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_reshape_and_cache(
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cuda")
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device="cuda")

qkv = torch.randn(num_tokens,
3,
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _prepare_inputs(
dtype=torch.long,
device="cuda")
slot_mapping_tensor = torch.tensor(padded_slot_mapping,
dtype=torch.int,
dtype=torch.long,
device="cuda")
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
Expand Down

0 comments on commit 5f7ae01

Please sign in to comment.