Skip to content

Commit

Permalink
Add paged attention unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Dec 26, 2024
1 parent 704a663 commit 06dab62
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,13 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {

execute_stage(events, instance, res_events, Stage::KV_CACHE_UPDATE, is_mixed_mode);

std::vector<event::ptr> dep_events(res_events.begin(), res_events.end());
if (stage == PagedAttentionStage::PREFILL) {
std::vector<event::ptr> dep_events(res_events.begin(), res_events.end());
execute_stage(dep_events, instance, res_events, Stage::SDPA, is_mixed_mode);
}

if (stage == PagedAttentionStage::GENERATE || stage == PagedAttentionStage::MIXED || has_scores_output) {
std::vector<event::ptr> dep_events(res_events.begin(), res_events.end());
execute_stage(dep_events, instance, res_events, Stage::PA_SDPA, is_mixed_mode);
}

Expand Down
62 changes: 41 additions & 21 deletions src/plugins/intel_gpu/src/graph/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,37 +105,67 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) {
}

void paged_attention_inst::on_execute() {
auto stage = get_paged_attention_stage(*_impl_params);
const auto& desc = _impl_params->typed_desc<paged_attention>();
const bool has_scores_output = desc->has_scores_output();
const auto stage = get_paged_attention_stage(*_impl_params);

if (stage == PagedAttentionStage::UNKNOWN ||
stage == PagedAttentionStage::GENERATE)
if ((stage == PagedAttentionStage::UNKNOWN) ||
(stage == PagedAttentionStage::GENERATE && !has_scores_output))
return;

auto& stream = get_network().get_stream();
const auto past_lens_mem = past_lens_memory_ptr();
const auto subsequence_begins_mem = subsequence_begins_memory_ptr();
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, stream);
mem_lock<int32_t, mem_lock_type::read> subsequence_begins_mem_lock(subsequence_begins_mem, stream);
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> subsequence_offsets_lock = nullptr;

if (has_scores_output) {
const size_t subsequence_offsets_idx = 4;

OPENVINO_ASSERT(_intermediates_memory.size() > subsequence_offsets_idx,
"[GPU] Unexpected number of intermediates buffers for Paged Attention for scores output calculation");

auto subsequence_offsets_mem = _intermediates_memory[subsequence_offsets_idx];
subsequence_offsets_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(subsequence_offsets_mem, stream));
}

if (stage == PagedAttentionStage::GENERATE) {
// For the generate stage it's not necessary to configure any other intermediate
// buffers. Simply calculate the offsets and exit
size_t subsequence_offsets_acc = 0;
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
const auto past_len = past_lens_mem_lock[i];
const auto seq_start = subsequence_begins_mem_lock[i];
const auto seq_end = subsequence_begins_mem_lock[i + 1];
const auto seq_length = seq_end - seq_start;

if (subsequence_offsets_lock) {
subsequence_offsets_lock->operator[](i) = static_cast<int32_t>(subsequence_offsets_acc);
subsequence_offsets_acc += seq_length + past_len;
}
}

return;
}

OPENVINO_ASSERT(_intermediates_memory.size() >= 3, "Unexpected number of intermediates buffers for Paged Attention at prefill stage");

const auto blocks_indexes_start_idx = 0;
const auto blocks_indexes_end_idx = 1;
const auto blocked_gws_subseq_mapping_idx = 2;

const auto past_lens_mem = past_lens_memory_ptr();
auto subsequence_begins_mem = subsequence_begins_memory_ptr();
auto blocks_indexes_start_mem = _intermediates_memory[blocks_indexes_start_idx];
auto blocks_indexes_end_mem = _intermediates_memory[blocks_indexes_end_idx];
auto blocked_gws_subseq_mapping_mem = _intermediates_memory[blocked_gws_subseq_mapping_idx];

OPENVINO_ASSERT(subsequence_begins_mem->get_layout().data_type == data_types::i32);

auto& stream = get_network().get_stream();
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, stream);
mem_lock<int32_t, mem_lock_type::read> subsequence_begins_mem_lock(subsequence_begins_mem, stream);
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_start_lock(blocks_indexes_start_mem, stream);
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_end_lock(blocks_indexes_end_mem, stream);
mem_lock<int32_t, mem_lock_type::write> blocked_gws_subseq_mapping_mem_lock(blocked_gws_subseq_mapping_mem, stream);
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> subsequence_offsets_lock = nullptr;
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> sequential_gws_subseq_mapping_lock = nullptr;

const auto& desc = _impl_params->typed_desc<paged_attention>();
const bool has_scores_output = desc->has_scores_output();
if (stage == PagedAttentionStage::MIXED) {
const size_t sequential_gws_subseq_mapping_idx = has_scores_output ? 8 : 6;

Expand All @@ -146,16 +176,6 @@ void paged_attention_inst::on_execute() {
sequential_gws_subseq_mapping_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(sequential_gws_subseq_mapping_mem, stream));
}

if (has_scores_output) {
const size_t subsequence_offsets_idx = 4;

OPENVINO_ASSERT(_intermediates_memory.size() > subsequence_offsets_idx,
"[GPU] Unexpected number of intermediates buffers for Paged Attention for scores output calculation");

auto subsequence_offsets_mem = _intermediates_memory[subsequence_offsets_idx];
subsequence_offsets_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(subsequence_offsets_mem, stream));
}

size_t index = 0;
size_t subsequence_offsets_acc = 0;
const auto target_seq_len_block_size = 16; // TODO: Get block size from the impl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,9 @@ KERNEL(pa_sdpa_scores_calculation)(
slm_exp_sums[head_idx] = adjusted_exp_sum;
global_exp_sum += adjusted_exp_sum;
}

global_exp_sum = sub_group_reduce_add(global_exp_sum);

slm_global_exp_sum[head_idx] = global_exp_sum;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1250,9 +1250,6 @@ KERNEL(sdpa_opt)(
partition_idx;
exp_sums[exp_sums_output_offset] = exp_sum_new;
max_logits[exp_sums_output_offset] = qk_max_new;
const uint output_offset = subsequence_idx * NUM_HEADS * aligned_max_context_len +
num_heads_dim * aligned_max_context_len +
partition_idx * SEQ_LEN_PARTITION_SIZE + sgid * TARGET_SEQ_LEN_BLOCK_SIZE;
}

const uint output_offset = subsequence_idx * NUM_HEADS * aligned_max_context_len +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ void KVCacheUpdateKernelRef::GetUpdateDispatchDataFunc(KernelData& kd) const {

const auto indexes_dt = Datatype::INT32;
const auto target_seq_len_block_size = 16;
const auto target_seq_len = prim_params.conf.paged_attention_aligned_seq_len;
const auto target_seq_len = std::max(prim_params.conf.paged_attention_aligned_seq_len, static_cast<int64_t>(1));
const auto indexes_buf_size = CeilDiv(target_seq_len, target_seq_len_block_size) * BytesPerElement(indexes_dt);

kd.internalBufferSizes.clear();
Expand Down

0 comments on commit 06dab62

Please sign in to comment.