From 8fd882c89db8477ac62d607177983581b2e6a31c Mon Sep 17 00:00:00 2001 From: Andrii Staikov Date: Thu, 19 Dec 2024 14:01:10 +0100 Subject: [PATCH] [TRANSFORMATIONS] Derive 'scale' from hidden_dim directly in SDPAToPA (#28091) Currently 'scale' is obtained using a ShapeOf expression as the hidden_dim may be dynamic in some cases and not propagated, so we can't use it directly to create a 'scale' Constant. Check if hidden_dim is static and use it to calculate 'scale' directly omitting the ShapeOf expression. Ticket: * [CVS-158394](https://jira.devtools.intel.com/browse/CVS-158394) Signed-off-by: Andrii Staikov --- .../state_management_pattern.cpp | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp index 28e7cd90019b34..b55c3d73316120 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp @@ -310,20 +310,28 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par auto v_reshape = std::make_shared(v_target_layout, v0::Constant::create(element::i64, Shape{2}, {0, -1}), true); - auto hidden_shape = std::make_shared(real_q); - auto hidden_dim = std::make_shared(hidden_shape, - v0::Constant::create(element::i64, Shape{}, {-1}), - v0::Constant::create(element::i64, Shape{}, {0})); std::shared_ptr scale; if (pattern_map.count(scale_input)) { scale = pattern_map.at(scale_input).get_node_shared_ptr(); } else { - // most likely `scale` below will always be a constant in real inference, but dynamic dimension - // propagation may not always derive it as a constant. That's why a sub-graph computing `scale` is built - // instead of just a constant node representing one of the dimensions. - scale = std::make_shared( - v0::Constant::create(element::f32, Shape{}, {1}), - std::make_shared(std::make_shared(hidden_dim, element::f32))); + auto real_q_ps = real_q.get_partial_shape(); + + bool rank_is_static = real_q_ps.rank().is_static(); + if (rank_is_static && real_q_ps[real_q_ps.rank().get_length() - 1].is_static()) { + auto hidden_dim_len = static_cast(real_q_ps[real_q_ps.rank().get_length() - 1].get_length()); + scale = v0::Constant::create(element::f32, Shape{}, {1.0 / std::sqrt(hidden_dim_len)}); + } else { + // most likely `scale` below will always be a constant in real inference, but dynamic dimension + // propagation may not always derive it as a constant. That's why a sub-graph computing `scale` is built + // instead of just a constant node representing one of the dimensions. + auto hidden_shape = std::make_shared(real_q); + auto hidden_dim = std::make_shared(hidden_shape, + v0::Constant::create(element::i64, Shape{}, {-1}), + v0::Constant::create(element::i64, Shape{}, {0})); + scale = std::make_shared( + v0::Constant::create(element::f32, Shape{}, {1}), + std::make_shared(std::make_shared(hidden_dim, element::f32))); + } } std::shared_ptr alibi_slopes;