Skip to content

Commit

Permalink
[GPU] Fix remaininig issue to calculate present layout's padding for …
Browse files Browse the repository at this point in the history
…KVCache (#25706)

### Details:
- Follow up remaining issue from
#25682
- Fix issue where kvcache was optimized out even if calculated present
layout's padding was negative

### Tickets:
 - 146876
  • Loading branch information
andrew-k-park authored Jul 24, 2024
1 parent ff54835 commit eeb8fe9
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1208,14 +1208,11 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
GPU_DEBUG_TRACE_DETAIL << "[do runtime kv_cache opt] " << id() << " initial present_layout : " << present_layout.to_string() << std::endl;
GPU_DEBUG_TRACE_DETAIL << "[do runtime kv_cache opt] " << id() << " initial past_layout : " << past_layout.to_string() << std::endl;
auto max_pad = kv_cache_inst::get_max_pad(past_layout, _deps[0].first->_max_output_layout_count[0], sequence_axis_legacy, "past_layout");

if (max_pad > 0) {
const auto new_seq_len = static_cast<int64_t>(new_layout.get_shape()[sequence_axis]);
if (max_pad - new_seq_len >= 0) {
kv_cache_inst::update_pad(present_layout, max_pad - new_seq_len, sequence_axis_legacy);
GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id() << " Updated present_layout's pad : "
<< present_layout.to_string() << std::endl;
}
const auto new_seq_len = static_cast<int64_t>(new_layout.get_shape()[sequence_axis]);
// In chatbot scenario, when chat history must be stored in kvcache, new_seq_len may not be 1 even if max_pad is greater than 0
if (max_pad - new_seq_len >= 0) {
kv_cache_inst::update_pad(present_layout, max_pad - new_seq_len, sequence_axis_legacy);
GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id() << " Updated present_layout's pad : " << present_layout.to_string() << std::endl;
auto& variable = get_network().get_variable(desc->variable_info.variable_id);
variable.set_layout(present_layout);
GPU_DEBUG_TRACE_DETAIL << "[do_runtime_in_place_kv_cache] " << id() << "Updated variable with present_layout"
Expand Down

0 comments on commit eeb8fe9

Please sign in to comment.