From b1cd65d75ff061f9e4ad69ea174fe37e321073d2 Mon Sep 17 00:00:00 2001 From: Evgeny Kotov Date: Wed, 22 Feb 2023 12:23:14 +0100 Subject: [PATCH] refactor --- .../gather_sinking_transpose_reshape.cpp | 53 +++++++++---------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/src/plugins/intel_gna/src/transformations/gather_sinking_transpose_reshape.cpp b/src/plugins/intel_gna/src/transformations/gather_sinking_transpose_reshape.cpp index b277f2f24b0f35..54680bbd68ce49 100644 --- a/src/plugins/intel_gna/src/transformations/gather_sinking_transpose_reshape.cpp +++ b/src/plugins/intel_gna/src/transformations/gather_sinking_transpose_reshape.cpp @@ -22,49 +22,46 @@ namespace { using NodePtr = std::shared_ptr; using NodePair = std::pair; -std::vector> slice_by_increment_order(const Shape& transpose_order) { - if (transpose_order.empty()) - return {}; - - std::vector> partition; - std::vector sub_order; - - for (size_t i = 0; i < transpose_order.size(); ++i) { - if (!i || transpose_order[i] == transpose_order[i - 1] + 1) { - sub_order.push_back(transpose_order[i]); - continue; - } - if (!sub_order.empty()) { - partition.push_back(sub_order); - } - sub_order.clear(); - sub_order.push_back(transpose_order[i]); - } - if (!sub_order.empty()) { - partition.push_back(sub_order); +size_t FindEndOfSlice(const Shape& transpose_order, size_t start_idx) { + size_t slice_end = start_idx; + for (size_t i = start_idx + 1; i < transpose_order.size(); ++i) { + if (transpose_order[i] != transpose_order[slice_end] + 1) + break; + slice_end = i; } - - return partition; + return slice_end; } std::vector CreateGatherIndices(const Shape& transpose_input_shape, const Shape& reshape_output_shape, const Shape& transpose_order) { - const auto partition = slice_by_increment_order(transpose_order); - if (partition.size() != 3) - return {}; - const int64_t transpose_part_0 = std::accumulate(partition[2].begin(), partition[2].end(), 1, + + const size_t slice_0_end = FindEndOfSlice(transpose_order, 0); + const size_t slice_1_start = slice_0_end + 1; + const size_t slice_1_end = FindEndOfSlice(transpose_order, slice_1_start); + const size_t slice_2_start = slice_1_end + 1; + + if (slice_0_end >= transpose_input_shape.size() || + slice_1_start >= transpose_input_shape.size() || + slice_1_end >= transpose_input_shape.size() || + slice_2_start >= transpose_input_shape.size()) { + return {}; + } + + const int64_t transpose_part_0 = std::accumulate(transpose_order.begin() + slice_1_start, + transpose_order.begin() + slice_1_end + 1, 1, [&transpose_input_shape](int64_t result, int64_t order_value) { return result *= transpose_input_shape[order_value]; }); - const int64_t transpose_part_1 = std::accumulate(partition[1].begin(), partition[1].end(), 1, + const int64_t transpose_part_1 = std::accumulate(transpose_order.begin() + slice_2_start, + transpose_order.end(), 1, [&transpose_input_shape](int64_t result, int64_t order_value) { return result *= transpose_input_shape[order_value]; }); std::vector gather_indices_value(reshape_output_shape.back()); for (size_t i = 0; i < gather_indices_value.size(); ++i) { - gather_indices_value[i] = transpose_part_1 * (i % transpose_part_0) + i / transpose_part_0; + gather_indices_value[i] = transpose_part_0 * (i % transpose_part_1) + i / transpose_part_1; } return gather_indices_value;