Skip to content

Commit

Permalink
[GPU] Force sdpa use onednn path for prefill and cl path for generati…
Browse files Browse the repository at this point in the history
…on. (#27387)

### Details:
- *[GPU] Force SDPA use oneDNN path for prefill and clDNN path for
generation on ARL-H platform*
 - *...*

### Tickets:
 - *CVS-158461*

---------

Co-authored-by: Sergey Shlyapnikov <[email protected]>
  • Loading branch information
ceciliapeng2011 and sshlyapn authored Dec 4, 2024
1 parent b2bfd85 commit 571e98d
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,14 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
auto& kernel_selector = kernel_selector_t::Instance();
auto kernel_impl = kernel_selector.GetImplementation(_kernels_data[default_sdpa].kernelName);
kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[default_sdpa]);
if (_kernels_data.size() == 2) {
if (_kernels_data.size() >= 2) {
auto bt_kernel_impl = kernel_selector.GetImplementation(_kernels_data[indirect_sdpa].kernelName);
bt_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[indirect_sdpa]);
}
if (_kernels_data.size() == 3) {
auto bt_kernel_impl = kernel_selector.GetImplementation(_kernels_data[2].kernelName);
bt_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[2]);
}
}
}

Expand All @@ -58,13 +62,15 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
// buffers number and its' sizes (since update_dispatch_data is called for both kernels too), and
// do not double memory allocations during reallocate_if_needed() function call
std::vector<layout> layouts;
if (_kernels_data.size() > 0 && !_kernels_data[0].internalBufferSizes.empty()) {
auto dtype = from_data_type(_kernels_data[0].internalBufferDataType);
const auto bpp = data_type_traits::size_of(dtype);
for (auto size : _kernels_data[0].internalBufferSizes) {
layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
{1, 1, 1, (tensor::value_type)(size / bpp)}};
layouts.push_back(inbuf_layout);
for (size_t i = 0; i < _kernels_data.size(); i++) {
if (!_kernels_data[i].internalBufferSizes.empty()) {
auto dtype = from_data_type(_kernels_data[i].internalBufferDataType);
const auto bpp = data_type_traits::size_of(dtype);
for (auto size : _kernels_data[i].internalBufferSizes) {
layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
{1, 1, 1, (tensor::value_type)(size / bpp)}};
layouts.push_back(inbuf_layout);
}
}
}

Expand Down Expand Up @@ -176,11 +182,37 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
return !is_prefill;
}

bool need_sdpa_opt_load(const scaled_dot_product_attention_inst& instance) const {
if (_kernels_data.size() < 2)
return false;

if (instance.has_indirect_inputs() && _kernels_data.size() < 3)
return false;

const auto& query_layout = instance.get_impl_params()->get_input_layout(0);

auto get_reordered_dimension = [](const ov::PartialShape& pshape, const std::vector<int64_t>& order, size_t idx) -> const ov::Dimension& {
if (order.empty())
return pshape[idx];

return pshape[order[idx]];
};

const auto& desc = instance.get_impl_params()->typed_desc<scaled_dot_product_attention>();
const auto dim_L = get_reordered_dimension(query_layout.get_partial_shape(), desc->input_q_transpose_order, 2 /* y */);

bool is_generate = dim_L.get_length() == 1; // L
return is_generate;
}

event::ptr execute_impl(const std::vector<event::ptr>& events, scaled_dot_product_attention_inst& instance) override {
if (need_indirect_load(instance))
if (need_indirect_load(instance)) {
return execute_stage(events, instance, indirect_sdpa);
else
} else if (need_sdpa_opt_load(instance)) {
return execute_stage(events, instance, _kernels_data.size() -1 /* the last */);
} else {
return execute_stage(events, instance, default_sdpa);
}
}

static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param) {
Expand Down Expand Up @@ -317,6 +349,12 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
kernels_data.push_back(kernel_selector.get_best_kernel(indirect_kernel_params));
}

const auto& gfx_ver = impl_param.get_program().get_engine().get_device_info().gfx_ver;
if (gfx_ver.major == 12 && gfx_ver.minor == 74) { // ARL only
sdpa_kernel_params.should_use_sdpa_opt = true;
kernels_data.push_back(kernel_selector.get_best_kernel(sdpa_kernel_params));
}

return cldnn::make_unique<scaled_dot_product_attention_impl>(kernels_data);
}

Expand All @@ -328,13 +366,16 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
update_shapes(*_kernels_data[default_sdpa].params, impl_param);
(_kernels_data[default_sdpa].update_dispatch_data_func)(*_kernels_data[default_sdpa].params, _kernels_data[default_sdpa]);

if (_kernels_data.size() == 2) {
if (_kernels_data.size() >= 2) {
if (_kernels_data[indirect_sdpa].params == nullptr) {
_kernels_data[indirect_sdpa].params = std::make_shared<kernel_params_t>(get_kernel_params(impl_param, true));
}
update_shapes(*_kernels_data[indirect_sdpa].params, impl_param);
(_kernels_data[indirect_sdpa].update_dispatch_data_func)(*_kernels_data[indirect_sdpa].params, _kernels_data[indirect_sdpa]);
}
if (_kernels_data.size() == 3) {
(_kernels_data[2].update_dispatch_data_func)(*_kernels_data[default_sdpa].params, _kernels_data[2]);
}
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct sdpa_params : public base_params {
DataTensor value_cache_comp_zp;

sdpa_configuration conf;
bool should_use_sdpa_opt = false;
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ bool SDPAKernelMicro::Validate(const Params& p) const {

const sdpa_params& params = static_cast<const sdpa_params&>(p);

if (params.should_use_sdpa_opt)
return false;

if (params.conf.is_paged_attention)
return false;

Expand Down

0 comments on commit 571e98d

Please sign in to comment.