From 01dc53ac2deb437c644b63a633e0a7780ddf848c Mon Sep 17 00:00:00 2001 From: Paul Youngsoo Ahn Date: Wed, 3 Jul 2024 23:10:36 -0700 Subject: [PATCH] [GPU] update shape for fused prims (#25363) ### Details: - *update shape for user when current node is the input of fused prim of user even if the user is updated shape by other node* - *...* ### Tickets: - *145756* --- .../intel_gpu/src/graph/primitive_inst.cpp | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index 1e72002b56334e..d002972d19344b 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -586,19 +586,42 @@ event::ptr primitive_inst::realloc_if_needed() { user_insts.size(), " and ", user_insts_origin.size()); } for (auto user : user_insts) { + auto is_fused_prim_of_user = [&](primitive_id id) -> bool { + for (auto& p : user->get_node().get_fused_primitives()) { + if (p.has_outer_dep()) { + const auto start_idx = p.outer_dep_start_idx; + // exclude fused_node from total_num_deps + const auto end_idx = p.outer_dep_start_idx + p.total_num_deps -1; + for (size_t idx = start_idx; idx < end_idx; idx++) { + if (user->get_node().get_dependency(idx).id() == id) { + return true; + } + } + } + } + return false; + }; // Since fake alignment is applicable for input tensor as well, make sure we allocate enough memory // to prevent reading beyond the allocated memory bounds - if (user->get_node().is_type() && user->is_dynamic() && user->_deps[0].first == this) { - GPU_DEBUG_TRACE_DETAIL << "Check fc user " << user->id() << "'s fake alignment-ed input size" << std::endl; - user->update_shape(); - user->update_shape_done_by_other = true; - - auto fc_impl_params = *user->_impl_params; - auto fc_input_layout = user->get_node().type()->get_fake_aligned_params(fc_impl_params).input_layouts[0]; - if (fc_input_layout.bytes_count() > updated_layout.bytes_count()) { - GPU_DEBUG_TRACE_DETAIL << id() << ": increase output layout allocation size from " << actual_layout.to_short_string() << " -> " - << fc_input_layout.to_short_string() << " to meet the input buffer alignment requirements for FC\n"; - updated_layout = fc_input_layout; + if (user->get_node().is_type() && user->is_dynamic()) { + if (user->_deps[0].first == this) { + GPU_DEBUG_TRACE_DETAIL << "Check fc user " << user->id() << "'s fake alignment-ed input size" << std::endl; + user->update_shape(); + user->update_shape_done_by_other = true; + + auto fc_impl_params = *user->_impl_params; + auto fc_input_layout = user->get_node().type()->get_fake_aligned_params(fc_impl_params).input_layouts[0]; + if (fc_input_layout.bytes_count() > updated_layout.bytes_count()) { + GPU_DEBUG_TRACE_DETAIL << id() << ": increase output layout allocation size from " << actual_layout.to_short_string() << " -> " + << fc_input_layout.to_short_string() << " to meet the input buffer alignment requirements for FC\n"; + updated_layout = fc_input_layout; + } + } else if (is_fused_prim_of_user(id()) && user->update_shape_done_by_other) { + // Since the output layout of fused prim in user is determined after user's update_shape + // Rerun update_shape w/ new output layout of fused prim + user->update_shape_done_by_other = false; + user->update_shape(); + user->update_shape_done_by_other = true; } } }