diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index ee6b715adaef0..78e8d8ecd6d41 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -175,7 +175,10 @@ __device__ void paged_attention_kernel( // dot product with the query. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - const int physical_block_number = block_table[block_idx]; + // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by large numbers + // (e.g., kv_block_stride). + const int64_t physical_block_number = static_cast(block_table[block_idx]); // Load a key to registers. // Each thread in a thread group has a different part of the key. @@ -285,7 +288,10 @@ __device__ void paged_attention_kernel( scalar_t zero_value; zero(zero_value); for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - const int physical_block_number = block_table[block_idx]; + // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by large numbers + // (e.g., kv_block_stride). + const int64_t physical_block_number = static_cast(block_table[block_idx]); const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 4c80682812298..3ad52b1681c0c 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -55,26 +55,26 @@ template __global__ void copy_blocks_kernel( int64_t* key_cache_ptrs, int64_t* value_cache_ptrs, - const int* __restrict__ block_mapping, + const int64_t* __restrict__ block_mapping, const int numel_per_block) { const int layer_idx = blockIdx.x; const int pair_idx = blockIdx.y; scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); - int src_block_number = block_mapping[2 * pair_idx]; - int dst_block_number = block_mapping[2 * pair_idx + 1]; + int64_t src_block_number = block_mapping[2 * pair_idx]; + int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; - const int src_block_offset = src_block_number * numel_per_block; - const int dst_block_offset = dst_block_number * numel_per_block; + const int64_t src_block_offset = src_block_number * numel_per_block; + const int64_t dst_block_offset = dst_block_number * numel_per_block; for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { - int src_offset = src_block_offset + i; - int dst_offset = dst_block_offset + i; + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; key_cache[dst_offset] = key_cache[src_offset]; } for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { - int src_offset = src_block_offset + i; - int dst_offset = dst_block_offset + i; + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; value_cache[dst_offset] = value_cache[src_offset]; } } @@ -102,15 +102,15 @@ void copy_blocks( value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); } // Create block mapping array. - std::vector block_mapping_vec; + std::vector block_mapping_vec; for (const auto& pair : block_mapping) { - int src_block_number = pair.first; - for (int dst_block_number : pair.second) { + int64_t src_block_number = pair.first; + for (int64_t dst_block_number : pair.second) { block_mapping_vec.push_back(src_block_number); block_mapping_vec.push_back(dst_block_number); } } - int* block_mapping_array = block_mapping_vec.data(); + int64_t* block_mapping_array = block_mapping_vec.data(); int num_pairs = block_mapping_vec.size() / 2; // Move the data structures to the GPU. @@ -120,7 +120,7 @@ void copy_blocks( torch::Tensor value_cache_ptrs_tensor = torch::from_blob( value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::Tensor block_mapping_tensor = torch::from_blob( - block_mapping_array, {2 * num_pairs}, torch::kInt).to(cache_device); + block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device); // Launch the kernel. const int numel_per_block = key_caches[0][0].numel(); @@ -132,7 +132,7 @@ void copy_blocks( vllm::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), value_cache_ptrs_tensor.data_ptr(), - block_mapping_tensor.data_ptr(), + block_mapping_tensor.data_ptr(), numel_per_block); })); } @@ -141,46 +141,46 @@ namespace vllm { template __global__ void reshape_and_cache_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int* __restrict__ slot_mapping, // [num_tokens] + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, const int x) { - const int token_idx = blockIdx.x; - const int slot_idx = slot_mapping[token_idx]; + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { // Padding token that should be ignored. return; } - const int block_idx = slot_idx / block_size; - const int block_offset = slot_idx % block_size; + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; const int n = num_heads * head_size; for (int i = threadIdx.x; i < n; i += blockDim.x) { - const int src_key_idx = token_idx * key_stride + i; - const int src_value_idx = token_idx * value_stride + i; + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; const int x_idx = head_offset / x; const int x_offset = head_offset % x; - const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int tgt_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; + const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + + x_idx * block_size * x + + block_offset * x + + x_offset; + const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + + head_offset * block_size + + block_offset; key_cache[tgt_key_idx] = key[src_key_idx]; value_cache[tgt_value_idx] = value[src_value_idx]; } @@ -216,7 +216,7 @@ void reshape_and_cache( value.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), - slot_mapping.data_ptr(), + slot_mapping.data_ptr(), key_stride, value_stride, num_heads, diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 31d78dd1bcf90..7c4a84d4c7d84 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -13,7 +13,7 @@ # This will change depending on the compute capability. # - 512 as a buffer MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -NUM_BLOCKS = 128 # Arbitrary values for testing +NUM_BLOCKS = 40000 # Arbitrary values for testing PARTITION_SIZE = 512 DTYPES = [torch.half, torch.bfloat16, torch.float] diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index b72dfbd6688e3..e15e7ba91bcb0 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -6,13 +6,13 @@ from vllm import cache_ops DTYPES = [torch.half, torch.bfloat16, torch.float] -NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing -NUM_LAYERS = [5] # Arbitrary values for testing +NUM_TOKENS = [83] # Arbitrary values for testing +NUM_LAYERS = [1] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing HEAD_SIZES = [64, 80, 96, 112, 128, 256] BLOCK_SIZES = [8, 16, 32] -NUM_BLOCKS = [1024] # Arbitrary values for testing -NUM_MAPPINGS = [32, 256] # Arbitrary values for testing +NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing +NUM_MAPPINGS = [256] # Arbitrary values for testing SEEDS = [0] @@ -69,9 +69,9 @@ def test_copy_blocks( for src, dsts in block_mapping.items(): for dst in dsts: for cloned_key_cache in cloned_key_caches: - cloned_key_cache[dst] = cloned_key_cache[src] + cloned_key_cache[dst].copy_(cloned_key_cache[src]) for cloned_value_cache in cloned_value_caches: - cloned_value_cache[dst] = cloned_value_cache[src] + cloned_value_cache[dst].copy_(cloned_value_cache[src]) # Compare the results. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): @@ -106,7 +106,7 @@ def test_reshape_and_cache( # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cuda") + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device="cuda") qkv = torch.randn(num_tokens, 3, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index fd6faecccbfb2..d598a86cf0c1c 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -301,7 +301,7 @@ def _prepare_inputs( dtype=torch.long, device="cuda") slot_mapping_tensor = torch.tensor(padded_slot_mapping, - dtype=torch.int, + dtype=torch.long, device="cuda") context_lens_tensor = torch.tensor(context_lens, dtype=torch.int,