Skip to content

Commit

Permalink
calculate concat/slicing using output layout instead of layout in mem…
Browse files Browse the repository at this point in the history
…ory (#21951)
  • Loading branch information
ahnyoung-paul authored Jan 4, 2024
1 parent 241aed4 commit 1646625
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions src/plugins/intel_gpu/src/graph/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ loop_inst::concatenated_memory_mapping::ptr loop_inst::create_concat_memory_map(
<< " to " << updated_sliced_layout.to_string() << std::endl;
sliced_layout.set_partial_shape(updated_sliced_layout);
out_mem_ptr = engine.allocate_memory(sliced_layout);
prim->set_output_layout(sliced_layout, internal_id.idx);
}

// When num_iterations is -1, allocate first sliced_mem and allocate sliced memory if additional sliced mem is required
Expand Down Expand Up @@ -783,9 +784,11 @@ void loop_inst::concatenated_memory_mapping::slice_mem(const int64_t num_iterati
") should be same with sliced_mems.size(", sliced_mems.size(), ")");
OPENVINO_ASSERT(concatenated_mem != nullptr, "concatenated_mem should not be nullptr");

auto elem_size = ov::element::Type(concatenated_mem->get_layout().data_type).size();
auto concat_mem_shape = concatenated_mem->get_layout().get_shape();
auto sliced_mem_shape = sliced_mems.front()->get_layout().get_shape();
auto concat_layout = concat_data_prim->get_output_layout(io_prim_map.external_id.idx);
auto sliced_layout = sliced_data_prim->get_output_layout(io_prim_map.internal_id.idx);
auto concat_mem_shape = concat_layout.get_shape();
auto sliced_mem_shape = sliced_layout.get_shape();
auto elem_size = ov::element::Type(concat_layout.data_type).size();
const auto stride = io_prim_map.stride;
const auto axis = io_prim_map.axis;
const auto step = std::abs(stride);
Expand All @@ -801,10 +804,8 @@ void loop_inst::concatenated_memory_mapping::slice_mem(const int64_t num_iterati
}

char* concat_data = reinterpret_cast<char*>(concatenated_mem->lock(stream, cldnn::mem_lock_type::read));

auto concate_layout = concatenated_mem->get_layout();
auto dims = concat_mem_shape.size();
if (!format::is_default_format(concate_layout.format) || dims == 1 || concate_layout.data_padding) {
if (!format::is_default_format(concat_layout.format) || dims == 1 || concat_layout.data_padding) {
// BE CAREFUL: ov::reference::split is extremely slow.
// If we encounter any case where this code path is executed, we need to optimize it
ov::reference::split(concat_data, concat_mem_shape, elem_size, axis, num_iters, pointers_to_data.data());
Expand Down Expand Up @@ -864,9 +865,11 @@ void loop_inst::concatenated_memory_mapping::concat_mem(const int64_t curent_ite
") should be less than the number of sliced_mems(", sliced_mems.size(), ")");
OPENVINO_ASSERT(concatenated_mem != nullptr, "concatenated_mem should not be nullptr");

auto elem_size = ov::element::Type(concatenated_mem->get_layout().data_type).size();
auto concat_mem_shape = concatenated_mem->get_layout().get_shape();
auto sliced_mem_shape = sliced_mems.front()->get_layout().get_shape();
auto concat_layout = concat_data_prim->get_output_layout(io_prim_map.external_id.idx);
auto sliced_layout = sliced_data_prim->get_output_layout(io_prim_map.internal_id.idx);
auto concat_mem_shape = concat_layout.get_shape();
auto sliced_mem_shape = sliced_layout.get_shape();
auto elem_size = ov::element::Type(concat_layout.data_type).size();
const auto stride = io_prim_map.stride;
const auto axis = io_prim_map.axis;
const auto step = std::abs(stride);
Expand Down

0 comments on commit 1646625

Please sign in to comment.