Skip to content

Commit

Permalink
Fix reset output memory issue (#27853)
Browse files Browse the repository at this point in the history
### Details:
 - *Fix accuracy issue for reset output memory of optimized prim*
 - original PRs:
     - #27695
     - #27439
     - #27517


### Tickets:
 - *154591*
  • Loading branch information
ahnyoung-paul authored Dec 3, 2024
1 parent 37a4e59 commit 662a2e6
Show file tree
Hide file tree
Showing 11 changed files with 477 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct kernel_impl_params final {
std::shared_ptr<const primitive> desc;
size_t unique_id;
bool _can_be_optimized = false;
bool _runtime_skippable = false;
std::vector<layout> input_layouts;
std::vector<layout> output_layouts;
std::vector<tensor> input_offsets;
Expand Down Expand Up @@ -145,6 +146,10 @@ struct kernel_impl_params final {
return _can_be_optimized;
}

bool runtime_skippable() const {
return _runtime_skippable;
}

template <class PType>
std::shared_ptr<const PType> typed_desc() const { return std::static_pointer_cast<const PType>(desc); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#include "non_zero_inst.h"
#include "non_max_suppression_inst.h"
#include "unique_inst.hpp"
#include "scatter_elements_update_inst.h"
#include "scatter_update_inst.h"
#include "scatter_nd_update_inst.h"
#include "program_helpers.h"

using namespace cldnn;
Expand Down Expand Up @@ -201,5 +204,56 @@ void mark_runtime_skippable_nodes::run(program& p) {
GPU_DEBUG_TRACE_DETAIL << "[mark_runtime_skippable_nodes] : " << node.id() << " can_be_optimized" << std::endl;
}
});

program_helpers::do_for_types<scatter_elements_update>(*node, [](scatter_elements_update_node & node){
auto impl_params = node.get_kernel_impl_params();

if ((node.is_output() && node.get_dependency(0).is_input())
|| node.has_fused_primitives()
|| (impl_params->get_input_layout(0).format != impl_params->get_output_layout().format)
|| (impl_params->get_input_layout(0).data_type != impl_params->get_output_layout().data_type))
return;

if (node.is_dynamic()) {
node.can_be_optimized(true);
// Set runtime skippable only when the node is set as can_be_optimized finally.
node.set_runtime_skippable(true);
GPU_DEBUG_TRACE_DETAIL << "[mark_runtime_skippable_nodes] : " << node.id() << " can_be_optimized" << std::endl;
}
});

program_helpers::do_for_types<scatter_update>(*node, [](scatter_update_node & node){
auto impl_params = node.get_kernel_impl_params();

if ((node.is_output() && node.get_dependency(0).is_input())
|| node.has_fused_primitives()
|| (impl_params->get_input_layout(0).format != impl_params->get_output_layout().format)
|| (impl_params->get_input_layout(0).data_type != impl_params->get_output_layout().data_type))
return;

if (node.is_dynamic()) {
node.can_be_optimized(true);
// Set runtime skippable only when the node is set as can_be_optimized finally.
node.set_runtime_skippable(true);
GPU_DEBUG_TRACE_DETAIL << "[mark_runtime_skippable_nodes] : " << node.id() << " can_be_optimized" << std::endl;
}
});

program_helpers::do_for_types<scatter_nd_update>(*node, [](scatter_nd_update_node & node){
auto impl_params = node.get_kernel_impl_params();

if ((node.is_output() && node.get_dependency(0).is_input())
|| node.has_fused_primitives()
|| (impl_params->get_input_layout(0).format != impl_params->get_output_layout().format)
|| (impl_params->get_input_layout(0).data_type != impl_params->get_output_layout().data_type))
return;

if (node.is_dynamic()) {
node.can_be_optimized(true);
// Set runtime skippable only when the node is set as can_be_optimized finally.
node.set_runtime_skippable(true);
GPU_DEBUG_TRACE_DETAIL << "[mark_runtime_skippable_nodes] : " << node.id() << " can_be_optimized" << std::endl;
}
});
}
}
10 changes: 5 additions & 5 deletions src/plugins/intel_gpu/src/graph/impls/ocl/primitive_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
#include "permute_inst.h"
#include "strided_slice_inst.h"
#include "broadcast_inst.h"
#include "scatter_update_inst.h"
#include "scatter_elements_update_inst.h"
#include "scatter_nd_update_inst.h"

#include <vector>
#include <list>
Expand Down Expand Up @@ -89,11 +92,8 @@ struct typed_primitive_impl_ocl : public typed_primitive_impl<PType> {
// concat buffer fusing for dynamic shape is adaptively applied at runtime. So we need to build dynamic impl at build time.
if (impl_param.can_be_optimized() &&
!((impl_param.is_type<concatenation>() ||
impl_param.is_type<gather>() ||
impl_param.is_type<permute>() ||
impl_param.is_type<strided_slice>() ||
impl_param.is_type<broadcast>() ||
impl_param.is_type<crop>()) && impl_param.is_dynamic())) {
impl_param.is_type<crop>() ||
impl_param.runtime_skippable()) && impl_param.is_dynamic())) {
return make_unique<ImplType>(kernel_selector::kernel_data{});
}
auto kernel_params = ImplType::get_kernel_params(ImplType::static_canonicalize_shapes(impl_param));
Expand Down
5 changes: 4 additions & 1 deletion src/plugins/intel_gpu/src/graph/include/primitive_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ class primitive_inst {
void do_runtime_in_place_concat();
void do_runtime_in_place_kv_cache();
void do_runtime_in_place_crop();
void do_runtime_skip_scatter_update();
void configure_shape_of_dependencies();

memory::ptr fused_memory(size_t dep_id) const {
Expand Down Expand Up @@ -422,7 +423,7 @@ class primitive_inst {
bool use_async_compilation();
// if primitive_inst doesn't replace impl to new impl(static impl with opt kerenl or dynamic impl), return false
bool update_impl(bool use_async_compilation);
event::ptr realloc_if_needed();
event::ptr realloc_if_needed(bool prev_execution_skipped = false);

cldnn::network::ptr get_unfused_subgraph();

Expand Down Expand Up @@ -476,6 +477,8 @@ class primitive_inst {
return false;
}

void clear_output_memory();

// This could be implemented via single map std::unordered_map<instrumentation::perf_counter_key, std::tuple<int64_t, size_t>>
// but the overhead on using perf_counter_key as map key is too big, thus we use hash as map key
// and store mapping onto original perf_clounter_key for further data analysis and dumps
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/include/program_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ struct program_node {
get_unique_id(), in_layouts, out_layouts, get_fused_primitives()));
params->memory_deps = get_const_memory_deps();
params->_can_be_optimized = this->optimized;
params->_runtime_skippable = this->runtime_skippable;
params->in_port_to_shape_info_offset = get_input_port_to_shape_info_offset_map();
params->out_port_to_shape_info_offset = get_output_port_to_shape_info_offset_map();
auto deps = get_dependencies();
Expand Down
111 changes: 101 additions & 10 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
#include "shape_of_inst.h"
#include "softmax_inst.h"
#include "strided_slice_inst.h"
#include "scatter_elements_update_inst.h"
#include "scatter_nd_update_inst.h"
#include "scatter_update_inst.h"
#include "gemm_inst.h"
#include "assign_inst.h"
#include "read_value_inst.h"
Expand Down Expand Up @@ -550,7 +553,12 @@ bool primitive_inst::all_dependencies_cpu_impl() const {
return check_all_deps_cpu(this);
}

event::ptr primitive_inst::realloc_if_needed() {
void primitive_inst::clear_output_memory() {
_outputs[0] = nullptr;
_max_output_layout_count[0] = 0;
}

event::ptr primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("realloc_if_needed: " + id()));
GPU_DEBUG_GET_INSTANCE(debug_config);
GPU_DEBUG_PROFILED_STAGE(instrumentation::pipeline_stage::memory_allocation);
Expand Down Expand Up @@ -740,15 +748,55 @@ event::ptr primitive_inst::realloc_if_needed() {

// Clear out memory if was previously reused, but now primitive can't be optimized
if (!_node->is_type<concatenation>() && (_node->is_runtime_skippable() || _node->is_type<crop>())) {
std::function<void(cldnn::primitive_inst*, cldnn::memory::ptr)> reset_user_output_memory
= [&](cldnn::primitive_inst* curr_inst, cldnn::memory::ptr target_mem_ptr) {
for (auto& user_inst : curr_inst->get_user_insts()) {
auto curr_output_memory_ptr = user_inst->output_memory_ptr(0);
if (user_inst->can_be_optimized()
&& (curr_output_memory_ptr
&& get_network().get_engine().is_the_same_buffer(*curr_output_memory_ptr, *target_mem_ptr))) {
user_inst->clear_output_memory();
reset_user_output_memory(user_inst, target_mem_ptr);
}
}
};
if (can_be_optimized()) {
_max_output_layout_count = _deps[0].first->_max_output_layout_count;
GPU_DEBUG_PROFILED_STAGE_MEMALLOC_INFO("can_be_optimized");
// If the inst is optimized out but it executed at the previous iteration,
// reset all output memory of users which was optimized out at the previous iteration.
// Ex.
// * iter0: node1(executed) -> node2(skipped) -> node3(skipped)
// * iter1: node1(skipped) -> node2(skipped) -> node3(executed)
if (_outputs[0] && dep_memory_ptr(0)
&& !_network.get_engine().is_the_same_buffer(dep_memory(0), output_memory(0))) {
reset_user_output_memory(this, dep_memory_ptr(0));
}
return ev;
} else if (_outputs[0] && dep_memory_ptr(0) &&
_network.get_engine().is_the_same_buffer(dep_memory(0), output_memory(0))) {
// Clear out memory if was previously reused, but now primitive can't be optimized
_outputs[0] = nullptr;
_max_output_layout_count[0] = 0;
if (mem_allocated()) {
get_network().get_memory_pool().release_memory(_outputs[0].get(),
get_node().get_unique_id(), id(), get_network_id());
_mem_allocated = false;
}
clear_output_memory();
// Check users recursively and if the users is can_be_optimized && runtime_skippable
// && output_memory of user is same as current input memory,
// then reset the users output memory too.
// Ex.
// * iter0: node1(skipped) -> node2(skipped) -> node3(skipped)
// * iter1: node1(executed) -> node2(skipped) -> node3(executed)
reset_user_output_memory(this, dep_memory_ptr(0));
} else {
// when this inst was not executed at the previous iteration,
// Reset output memory becuase current output memory is invalid.
if (prev_execution_skipped) {
if (_outputs[0]) {
reset_user_output_memory(this, _outputs[0]);
}
clear_output_memory();
}
}
}

Expand Down Expand Up @@ -1347,7 +1395,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
void primitive_inst::do_runtime_skip_gather() {
// Check pattern
if (!get_node().is_type<gather>()
|| !get_node().can_be_optimized()
|| !get_node().is_runtime_skippable()
|| _impl_params->has_fused_primitives()
|| _impl_params->get_input_layout(0).data_type != _impl_params->get_output_layout().data_type
|| get_node().get_dependency(1).is_constant() || get_node().get_dependency(1).is_type<data>())
Expand Down Expand Up @@ -1419,7 +1467,6 @@ void primitive_inst::do_runtime_skip_permute() {
// Check pattern
if (!get_node().is_type<permute>()
|| is_output()
|| !get_node().can_be_optimized()
|| !get_node().is_runtime_skippable()
|| _impl_params->has_fused_primitives()
|| _impl_params->get_input_layout(0).data_type != _impl_params->get_output_layout().data_type)
Expand Down Expand Up @@ -1459,7 +1506,7 @@ void primitive_inst::do_runtime_skip_permute() {
void primitive_inst::do_runtime_skip_strided_slice() {
OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("do_runtime_skip_strided_slice: " + id()));
// Check pattern
if (!get_node().is_type<strided_slice>() || !get_node().can_be_optimized())
if (!get_node().is_type<strided_slice>() || !get_node().is_runtime_skippable())
return;

GPU_DEBUG_TRACE_DETAIL << "[do_runtime_skip_strided_slice] " << id() << " : check optimizability" << std::endl;
Expand All @@ -1483,7 +1530,7 @@ void primitive_inst::do_runtime_skip_strided_slice() {
void primitive_inst::do_runtime_skip_broadcast() {
OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("do_runtime_skip_broadcast: " + id()));
// Check pattern
if (!get_node().is_type<broadcast>() || !get_node().can_be_optimized())
if (!get_node().is_type<broadcast>() || !get_node().is_runtime_skippable())
return;

GPU_DEBUG_TRACE_DETAIL << "[do_runtime_skip_broadcast] " << id() << " : check optimizability" << std::endl;
Expand Down Expand Up @@ -1586,6 +1633,44 @@ void primitive_inst::do_runtime_in_place_concat() {
GPU_DEBUG_TRACE_DETAIL << "[In place concat] " << concat_inst->id() << ": can_be_optimized " << std::endl;
}

void primitive_inst::do_runtime_skip_scatter_update() {
OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("do_runtime_skip_scatter_update: " + id()));
// Check pattern
if (!(get_node().is_type<scatter_update>()
|| get_node().is_type<scatter_elements_update>()
|| get_node().is_type<scatter_nd_update>())
|| !get_node().is_runtime_skippable())
return;

GPU_DEBUG_TRACE_DETAIL << "[do_runtime_skip_scatter_update] " << id() << " : check optimizability" << std::endl;
auto input_layout = _impl_params->get_input_layout(0);
auto output_layout = _impl_params->get_output_layout();
auto idx_layout = _impl_params->get_input_layout(1);
auto update_layout = _impl_params->get_input_layout(2);

if (idx_layout.count() > 0 && update_layout.count() > 0) {
// set shape_change to realloc memory for same input shapes
if (can_be_optimized()) {
set_shape_change();
}
set_can_be_optimized(false);
GPU_DEBUG_TRACE_DETAIL << "--- Cannot optimize because idx_layout (" << idx_layout.to_short_string()
<< ") and update_layout(" << update_layout.to_short_string() << ") are not zero" << std::endl;
return;
}

GPU_DEBUG_TRACE_DETAIL << "[do_runtime_skip_scatter_update] " << id() << " : can_be_optimized" << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Input layout : " << _impl_params->get_input_layout(0).to_short_string() << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Idx layout : " << _impl_params->get_input_layout(1).to_short_string() << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Update layout : " << _impl_params->get_input_layout(2).to_short_string() << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Output layout : " << _impl_params->get_output_layout().to_short_string() << std::endl;
// set shape_change to realloc memory for same input shapes
if (!can_be_optimized()) {
set_shape_change();
}
set_can_be_optimized(true);
}

void primitive_inst::do_runtime_in_place_crop() {
OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("do_runtime_in_place_crop: " + id()));
GPU_DEBUG_GET_INSTANCE(debug_config);
Expand Down Expand Up @@ -1670,6 +1755,11 @@ event::ptr primitive_inst::execute(const std::vector<event::ptr>& events) {
GPU_DEBUG_TRACE_DETAIL << "-----------------------------------------------------------------" << std::endl;
bool need_args_update = false;
_mem_changed = false;

// If it is optimized out or skipped for zero dimension at the previous iteration,
// Set this flag true to reset output memory in realloc_if_needed.
const bool prev_execution_skipped = can_be_optimized()
|| (_impl_params->output_layouts[0].is_static() && _impl_params->output_layouts[0].count() == 0);
const auto orig_outputs = _outputs;
std::vector<event::ptr> dependencies;
if ((is_dynamic() || _node->is_in_shape_of_subgraph()) && !has_inner_networks()) {
Expand Down Expand Up @@ -1715,6 +1805,7 @@ event::ptr primitive_inst::execute(const std::vector<event::ptr>& events) {
do_runtime_skip_permute();
do_runtime_skip_strided_slice();
do_runtime_skip_broadcast();
do_runtime_skip_scatter_update();
do_runtime_in_place_crop();

if (!is_valid_fusion()) {
Expand Down Expand Up @@ -1761,7 +1852,7 @@ event::ptr primitive_inst::execute(const std::vector<event::ptr>& events) {
auto ev = update_weights();
if (ev)
dependencies.push_back(ev);
auto ev_reset = realloc_if_needed();
auto ev_reset = realloc_if_needed(prev_execution_skipped);
if (ev_reset)
dependencies.push_back(ev_reset);

Expand All @@ -1775,7 +1866,7 @@ event::ptr primitive_inst::execute(const std::vector<event::ptr>& events) {
_impl->update(*this, *_impl_params);

need_args_update = true;
auto ev_reset = realloc_if_needed();
auto ev_reset = realloc_if_needed(prev_execution_skipped);
if (ev_reset)
dependencies.push_back(ev_reset);
}
Expand Down
14 changes: 8 additions & 6 deletions src/plugins/intel_gpu/src/graph/scatter_elements_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,18 @@ std::string scatter_elements_update_inst::to_string(scatter_elements_update_node
return primitive_description.str();
}

scatter_elements_update_inst::typed_primitive_inst(network& network, scatter_elements_update_node const& node) : parent(network, node) {}
void scatter_elements_update_inst::on_execute() {
auto input1_shape = _impl_params->input_layouts[1].get_partial_shape();
auto input2_shape = _impl_params->input_layouts[2].get_partial_shape();
scatter_elements_update_inst::typed_primitive_inst(network& network, scatter_elements_update_node const& node) : parent(network, node) {
update_output_memory();
}

if ((ov::shape_size(input1_shape.to_shape()) == 0) || (ov::shape_size(input2_shape.to_shape()) == 0))
update_output_memory();
void scatter_elements_update_inst::on_execute() {
update_output_memory();
}

void scatter_elements_update_inst::update_output_memory() {
if (!can_be_optimized() || _impl_params->is_dynamic())
return;

if (_outputs.size() > 0 && static_cast<bool>(_outputs[0])
&& _network.get_engine().is_the_same_buffer(output_memory(), input_memory()))
return;
Expand Down
14 changes: 7 additions & 7 deletions src/plugins/intel_gpu/src/graph/scatter_nd_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ std::string scatter_nd_update_inst::to_string(scatter_nd_update_node const& node
return primitive_description.str();
}

scatter_nd_update_inst::typed_primitive_inst(network& network, scatter_nd_update_node const& node) : parent(network, node) {}
scatter_nd_update_inst::typed_primitive_inst(network& network, scatter_nd_update_node const& node) : parent(network, node) {
update_output_memory();
}

void scatter_nd_update_inst::on_execute() {
auto input1_shape = _impl_params->input_layouts[1].get_partial_shape();
auto input2_shape = _impl_params->input_layouts[2].get_partial_shape();
auto same_layouts = _impl_params->input_layouts[0] == _impl_params->output_layouts[0];

if (same_layouts && ((ov::shape_size(input1_shape.to_shape()) == 0) || (ov::shape_size(input2_shape.to_shape()) == 0)))
update_output_memory();
update_output_memory();
}

void scatter_nd_update_inst::update_output_memory() {
if (!can_be_optimized() || _impl_params->is_dynamic())
return;

if (_outputs.size() > 0 && static_cast<bool>(_outputs[0])
&& _network.get_engine().is_the_same_buffer(output_memory(), input_memory()))
return;
Expand Down
Loading

0 comments on commit 662a2e6

Please sign in to comment.