Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
evkotov committed Mar 6, 2023
1 parent dab693e commit b1cd65d
Showing 1 changed file with 25 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,49 +22,46 @@ namespace {
using NodePtr = std::shared_ptr<ov::Node>;
using NodePair = std::pair<NodePtr, NodePtr>;

std::vector<std::vector<size_t>> slice_by_increment_order(const Shape& transpose_order) {
if (transpose_order.empty())
return {};

std::vector<std::vector<size_t>> partition;
std::vector<size_t> sub_order;

for (size_t i = 0; i < transpose_order.size(); ++i) {
if (!i || transpose_order[i] == transpose_order[i - 1] + 1) {
sub_order.push_back(transpose_order[i]);
continue;
}
if (!sub_order.empty()) {
partition.push_back(sub_order);
}
sub_order.clear();
sub_order.push_back(transpose_order[i]);
}
if (!sub_order.empty()) {
partition.push_back(sub_order);
size_t FindEndOfSlice(const Shape& transpose_order, size_t start_idx) {
size_t slice_end = start_idx;
for (size_t i = start_idx + 1; i < transpose_order.size(); ++i) {
if (transpose_order[i] != transpose_order[slice_end] + 1)
break;
slice_end = i;
}

return partition;
return slice_end;
}

std::vector<int64_t> CreateGatherIndices(const Shape& transpose_input_shape,
const Shape& reshape_output_shape,
const Shape& transpose_order) {
const auto partition = slice_by_increment_order(transpose_order);
if (partition.size() != 3)
return {};
const int64_t transpose_part_0 = std::accumulate(partition[2].begin(), partition[2].end(), 1,

const size_t slice_0_end = FindEndOfSlice(transpose_order, 0);
const size_t slice_1_start = slice_0_end + 1;
const size_t slice_1_end = FindEndOfSlice(transpose_order, slice_1_start);
const size_t slice_2_start = slice_1_end + 1;

if (slice_0_end >= transpose_input_shape.size() ||
slice_1_start >= transpose_input_shape.size() ||
slice_1_end >= transpose_input_shape.size() ||
slice_2_start >= transpose_input_shape.size()) {
return {};
}

const int64_t transpose_part_0 = std::accumulate(transpose_order.begin() + slice_1_start,
transpose_order.begin() + slice_1_end + 1, 1,
[&transpose_input_shape](int64_t result, int64_t order_value) {
return result *= transpose_input_shape[order_value];
});
const int64_t transpose_part_1 = std::accumulate(partition[1].begin(), partition[1].end(), 1,
const int64_t transpose_part_1 = std::accumulate(transpose_order.begin() + slice_2_start,
transpose_order.end(), 1,
[&transpose_input_shape](int64_t result, int64_t order_value) {
return result *= transpose_input_shape[order_value];
});

std::vector<int64_t> gather_indices_value(reshape_output_shape.back());
for (size_t i = 0; i < gather_indices_value.size(); ++i) {
gather_indices_value[i] = transpose_part_1 * (i % transpose_part_0) + i / transpose_part_0;
gather_indices_value[i] = transpose_part_0 * (i % transpose_part_1) + i / transpose_part_1;
}

return gather_indices_value;
Expand Down

0 comments on commit b1cd65d

Please sign in to comment.