Skip to content

Commit

Permalink
[GPU] Update PagedAttention output shape, add dynamic paddings suppor…
Browse files Browse the repository at this point in the history
…t for mixed kernel mode execution
  • Loading branch information
sshlyapn committed Jan 24, 2025
1 parent cfbc998 commit 1e9f429
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 25 deletions.
3 changes: 2 additions & 1 deletion src/plugins/intel_gpu/src/graph/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*no
template<typename ShapeType>
std::vector<layout> paged_attention_inst::calc_output_layouts(paged_attention_node const& /*node*/, kernel_impl_params const& impl_param) {
auto data_layout = impl_param.get_input_layout(0);
data_layout.data_padding = padding();

const auto& key_cache_ps = impl_param.get_input_layout(3).get_partial_shape();
bool valid_block_size = key_cache_ps[3].is_dynamic() || key_cache_ps[3].get_length() == paged_attention::block_size;
Expand All @@ -71,7 +72,7 @@ std::vector<layout> paged_attention_inst::calc_output_layouts(paged_attention_no
total_size += past_lens_mem_lock[i];
}

total_size += static_cast<long int>(impl_param.get_input_layout(0).get_shape()[0]);
total_size += static_cast<long int>(data_layout.get_shape()[0]);

output_layouts.push_back(layout{ov::PartialShape{total_size}, output_dt, format::bfyx});
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ KERNEL(pa_sdpa_opt)(
{
#if STORE_QUERY_TO_SLM
const uint query_idx_local = sgid * SUBGROUP_SIZE + sglid;
const uint query_idx = seq_idx * HEAD_SIZE * HEADS_NUM +
const uint query_idx = INPUT0_OFFSET +
seq_idx * (HEAD_SIZE * HEADS_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) +
head_num_idx * HEAD_SIZE +
query_idx_local;

Expand All @@ -137,7 +138,8 @@ KERNEL(pa_sdpa_opt)(
#else
INPUT0_TYPE q_val[HEAD_SIZE / SUBGROUP_SIZE];
unroll_for (uint i = 0; i < HEAD_SIZE / SUBGROUP_SIZE; i++) {
const uint query_idx = seq_idx * HEAD_SIZE * HEADS_NUM +
const uint query_idx = INPUT0_OFFSET +
seq_idx * (HEAD_SIZE * HEADS_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) +
head_num_idx * HEAD_SIZE +
i * SUBGROUP_SIZE;
q_val[i] = BLOCK_READN(INPUT0_TYPE, 1, query, query_idx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,35 @@ struct PagedAttentionTest : public ::testing::TestWithParam<T> {
rotation_deltas_layout.set_partial_shape(ov::PartialShape{ -1, -1 });
rotation_trig_lut_layout.set_partial_shape(ov::PartialShape{ -1, p.head_size });

if (p.dynamic_paddings) {
const auto padding_axis = 1;
const auto pad_before = p.head_size;
const auto pad_after = p.head_size * 2;

query_layout.data_padding._dynamic_dims_mask[padding_axis] = 1;

auto query_data_layout = query_mem->get_layout();
auto padded_query_data_layout = query_data_layout;
padded_query_data_layout.data_padding._lower_size[padding_axis] = pad_before;
padded_query_data_layout.data_padding._upper_size[padding_axis] = pad_after;

auto new_query_memory = get_test_engine().allocate_memory(padded_query_data_layout, false);

mem_lock<ov::float16> query_mem_lock(query_mem, get_test_stream());
mem_lock<ov::float16> new_query_mem_lock(new_query_memory, get_test_stream());

auto query_data_shape = query_data_layout.get_shape();
for (size_t b = 0; b < query_data_shape[0]; b++) {
for (size_t f = 0; f < query_data_shape[1]; f++) {
auto input_offset = query_data_layout.get_linear_offset(cldnn::tensor(b, f, 0, 0, 0, 0));
auto output_offset = padded_query_data_layout.get_linear_offset(cldnn::tensor(b, f, 0, 0, 0, 0));

new_query_mem_lock[output_offset] = query_mem_lock[input_offset];
}
}
query_mem = new_query_memory;
}

std::vector<input_info> pa_inputs = {
input_info("query"),
input_info("key"),
Expand Down Expand Up @@ -857,6 +886,7 @@ struct paged_attention_test_params {
int num_heads;
int head_size;
int block_size;
bool dynamic_paddings;
bool scores_output;
CacheRotationDescriptor rotation_config;
};
Expand All @@ -873,31 +903,34 @@ const auto DISABLE_SCORES = false;
const auto PER_BLOCK_ROTATION = CacheRotationDescriptor{ true, true };
const auto PER_TOKEN_ROTATION = CacheRotationDescriptor{ true, false };
const auto DISABLE_ROTATION = CacheRotationDescriptor{ false, false };
const auto STATIC_INPUT_PAD = false;
const auto DYNAMIC_INPUT_PAD = true;

INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector<paged_attention_test_params>{
/* with scores output */
paged_attention_test_params{ {{10, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token
paged_attention_test_params{ {{36, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token
paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token long
paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token
paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token
paged_attention_test_params{ {{1, 10}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token
paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
/* without scores output */
paged_attention_test_params{ {{10, 0}}, 2, 64, 16, DISABLE_SCORES, DISABLE_ROTATION }, // 1st token
paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, DISABLE_SCORES, DISABLE_ROTATION }, // 1st token long
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, DISABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token
paged_attention_test_params{ {{10, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token
paged_attention_test_params{ {{36, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token
paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token long
paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token
paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token
paged_attention_test_params{ {{1, 10}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token
paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
/* without scores output, dynamic input query paddings */
paged_attention_test_params{ {{10, 0}}, 2, 64, 16, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // 1st token
paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // 1st token long
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token
paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
/* with scores, per_block rotation */
paged_attention_test_params{ {{10, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token
paged_attention_test_params{ {{36, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token
paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token long
paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token
paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token
paged_attention_test_params{ {{1, 10}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token + 2nd token
paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
paged_attention_test_params{ {{10, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token
paged_attention_test_params{ {{36, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token
paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token long
paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token
paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token
paged_attention_test_params{ {{1, 10}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token + 2nd token
paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
/* with scores, per_token rotation */
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, ENABLE_SCORES, PER_TOKEN_ROTATION }, // 2nd token + 2nd token
paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, ENABLE_SCORES, PER_TOKEN_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION }, // 2nd token + 2nd token
paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
}));

0 comments on commit 1e9f429

Please sign in to comment.