Skip to content

Commit

Permalink
[GPU] Remove paddings on set_state() call (openvinotoolkit#23828)
Browse files Browse the repository at this point in the history
### Details:
 - Fix accuracy issue due to wrong pad size after set_state call
  • Loading branch information
vladimir-paramuzov authored and bbielawx committed Apr 12, 2024
1 parent b317f0c commit f6072ce
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/plugin/variable_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ void VariableState::set_layout(const cldnn::layout& new_layout) {

void VariableState::set_state(const ov::SoPtr<ov::ITensor>& state) {
m_layout.set_partial_shape(state->get_shape());
size_t rank = state->get_shape().size();
m_layout.data_padding = cldnn::padding(std::vector<int32_t>(rank, 0), std::vector<int32_t>(rank, 0), 0, m_layout.data_padding.get_dynamic_pad_dims());
update_device_buffer();
convert_and_copy(state._ptr.get(), m_memory, m_context->get_engine().get_service_stream());
set();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ class KVCacheTests: public ::testing::Test {
int64_t concat_axis = 2,
ov::element::Type model_element_type = ov::element::f16,
size_t num_iter = 10,
size_t num_groups = 1) {
size_t num_groups = 1,
bool set_state_on_each_iter = false) {
#if defined(ANDROID)
GTEST_SKIP();
#endif
Expand Down Expand Up @@ -437,6 +438,14 @@ class KVCacheTests: public ::testing::Test {
infer_request.infer();

compare_tensors({ ref_results[1] }, {matmul_out});

if (set_state_on_each_iter) {
auto state = infer_request.query_state()[0].get_state();
compare_tensors({ ref_kv_cache }, {state});
infer_request.query_state()[0].set_state(state);
auto state_1 = infer_request.query_state()[0].get_state();
compare_tensors({ ref_kv_cache }, {state_1});
}
}

auto state = infer_request.query_state()[0].get_state();
Expand Down Expand Up @@ -494,4 +503,8 @@ TEST_F(KVCacheTests, smoke_multipleIterations_stateful_same_shape_after_reset) {
this->test_smoke_multipleIterations_stateful(false, false, false, 1, 2, ov::element::f16, 0);
}

TEST_F(KVCacheTests, smoke_multipleIterations_stateful_with_set_state) {
this->test_smoke_multipleIterations_stateful(false, true, true, 1, 2, ov::element::f16, 5, 1, true);
}

} // namespace

0 comments on commit f6072ce

Please sign in to comment.