Skip to content

Commit

Permalink
Follow-up code review
Browse files Browse the repository at this point in the history
  • Loading branch information
ahnyoung-paul committed Dec 29, 2023
1 parent e631e6a commit aee3c36
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 17 deletions.
6 changes: 2 additions & 4 deletions src/plugins/intel_gpu/src/graph/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,12 @@ void gather_inst::on_execute() {
void gather_inst::update_output_memory() {
if (!can_be_optimized())
return;
if (static_cast<bool>(_outputs[0]) && _network.get_engine().is_the_same_buffer(output_memory(), input_memory()))
return;

// need to run build_deps before checking same memory between input and output because deps are not setted yet in some case.
if (_node != nullptr)
build_deps();

if (static_cast<bool>(_outputs[0]) && _network.get_engine().is_the_same_buffer(output_memory(), input_memory()))
return;

GPU_DEBUG_TRACE_DETAIL << id() << " : update_output_memory with mem of input " << get_node().get_dependency(0).id()
<< " : " << input_memory_ptr()->buffer_ptr() << std::endl;
_outputs[0] = input_memory_ptr();
Expand Down
4 changes: 4 additions & 0 deletions src/plugins/intel_gpu/src/graph/network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,10 @@ void network::allocate_primitives() {
for (auto const& node : po) {
if (node->can_be_optimized() && !node->is_dynamic()) {
auto opt_inst = _primitives.at(node->id());
// build deps when prim_inst does not update dependencies yet.
if (!node->get_dependencies().empty() && opt_inst->dependencies().empty()) {
opt_inst->build_deps();
}
opt_inst->update_output_memory();
}
}
Expand Down
16 changes: 3 additions & 13 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,7 @@ event::ptr primitive_inst::realloc_if_needed() {

// Clear out memory if if was previously reused, but now primitive can't be optimized
if (_node->is_type<gather>() && !can_be_optimized() && _outputs[0]
&& (_impl_params->input_layouts[0].count() != 0 // check if layout is zero dimension before using dep_memory()
&& _network.get_engine().is_the_same_buffer(dep_memory(0), output_memory(0)))) {
&& dep_memory_ptr(0) && _network.get_engine().is_the_same_buffer(dep_memory(0), output_memory(0))) {
_outputs[0] = nullptr;
max_output_layout_size = 0;
}
Expand Down Expand Up @@ -920,6 +919,8 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
void primitive_inst::do_runtime_skip_gather() {
// Check pattern
if (!get_node().is_type<gather>()
|| !get_node().can_be_optimized()
|| (_impl_params->get_output_layout().count() == 0)
|| _impl_params->has_fused_primitives()
|| _impl_params->get_input_layout(0).data_type != _impl_params->get_output_layout().data_type
|| get_node().get_dependency(1).is_constant() || get_node().get_dependency(1).is_type<data>())
Expand All @@ -932,17 +933,6 @@ void primitive_inst::do_runtime_skip_gather() {
auto idx_shape = _impl_params->get_input_layout(1).get_shape();
auto idx_rank = idx_shape.size();

// If all input layout count is zero, then it can be optimized because all input will be skipped for empty output.
if (_impl_params->get_input_layout(0).count() == 0 && _impl_params->get_input_layout(1).count() == 0) {
GPU_DEBUG_TRACE_DETAIL << "[do_runtime_skip_gather] " << id() << " : can_be_optimized" << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Input layout : " << _impl_params->get_input_layout(0).to_short_string() << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Indices layout : " << _impl_params->get_input_layout(1).to_short_string() << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Gather axis : " << axis << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - Output layout : " << _impl_params->get_output_layout().to_short_string() << std::endl;
set_can_be_optimized(true);
return;
}

if (idx_rank != 1) {
GPU_DEBUG_TRACE_DETAIL << "-- Cannot optimize becuase of its indices rank " << idx_rank << std::endl;
return;
Expand Down

0 comments on commit aee3c36

Please sign in to comment.