Skip to content

Commit

Permalink
rewrite backward
Browse files Browse the repository at this point in the history
  • Loading branch information
evkotov committed Mar 6, 2023
1 parent d127fe6 commit 1cba14e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace {
using NodePtr = std::shared_ptr<ov::Node>;
using NodePair = std::pair<NodePtr, NodePtr>;

std::vector<std::vector<size_t>> partition_by_increment_order(const Shape& transpose_order) {
std::vector<std::vector<size_t>> slice_by_increment_order(const Shape& transpose_order) {
if (transpose_order.empty())
return {};

Expand All @@ -47,10 +47,10 @@ std::vector<std::vector<size_t>> partition_by_increment_order(const Shape& trans
return partition;
}

std::vector<int64_t> CreateForwardSinkingGatherIndices(const Shape& transpose_input_shape,
std::vector<int64_t> CreateGatherIndices(const Shape& transpose_input_shape,
const Shape& reshape_output_shape,
const Shape& transpose_order) {
const auto partition = partition_by_increment_order(transpose_order);
const auto partition = slice_by_increment_order(transpose_order);
if (partition.size() != 3)
return {};
const int64_t transpose_part_0 = std::accumulate(partition[2].begin(), partition[2].end(), 1,
Expand All @@ -71,7 +71,7 @@ std::vector<int64_t> CreateForwardSinkingGatherIndices(const Shape& transpose_in
}

NodePair SinkForward(NodePtr transpose, std::shared_ptr<Constant> transpose_constant, NodePtr reshape) {
const auto gather_indices_value = CreateForwardSinkingGatherIndices(transpose->get_input_shape(0),
const auto gather_indices_value = CreateGatherIndices(transpose->get_input_shape(0),
reshape->get_output_shape(0),
transpose_constant->get_axis_vector_val());
const int64_t gather_axis_value = reshape->get_output_shape(0).size() - 1;
Expand Down Expand Up @@ -99,27 +99,18 @@ Shape TransposeShape(const Shape& shape, AxisVector transpose_axis) {
}

NodePair SinkBackward(NodePtr transpose, std::shared_ptr<Constant> transpose_constant, NodePtr reshape) {
const Shape& pattern_input_shape = reshape->get_input_shape(0);
const Shape& pattern_output_shape = transpose->get_output_shape(0);
auto compare_shapes = [](const Shape& first, const Shape& second) { return first.size() < second.size(); };
const Shape& max_shape = std::max(pattern_input_shape, pattern_output_shape, compare_shapes);
const Shape& min_shape = std::min(pattern_input_shape, pattern_output_shape, compare_shapes);
const int64_t gather_axis_value = reshape->get_input_shape(0).size() - 1;

const int64_t gather_axis_value = min_shape.size() - 1;

const Shape transposed_max_shape = TransposeShape(max_shape, transpose_constant->get_axis_vector_val());
const Shape transposed_shape_part(transposed_max_shape.end() - 2, transposed_max_shape.end());

std::vector<int64_t> gather_indices_value(min_shape.back());
for (size_t i = 0; i < gather_indices_value.size(); ++i) {
gather_indices_value[i] = transposed_shape_part[1] * (i % transposed_shape_part[0]) + i / transposed_shape_part[0];
}
const auto gather_indices_value = CreateGatherIndices(TransposeShape(transpose->get_output_shape(0),
transpose_constant->get_axis_vector_val()),
reshape->get_input_shape(0),
transpose_constant->get_axis_vector_val());

auto gather_axis = std::make_shared<Constant>(element::i64, Shape{}, gather_axis_value);
auto gather_indices = std::make_shared<Constant>(element::i64, Shape{gather_indices_value.size()}, gather_indices_value);
auto gather = std::make_shared<Gather>(reshape->input_value(0), gather_indices, gather_axis);

auto reshape_const_new = std::make_shared<Constant>(element::i64, Shape{max_shape.size()}, max_shape);
auto reshape_const_new = std::make_shared<Constant>(element::i64, Shape{transpose->get_output_shape(0).size()}, transpose->get_output_shape(0));
auto reshape_new = std::make_shared<Reshape>(gather, reshape_const_new, false);

ov::replace_node(transpose, reshape_new);
Expand All @@ -129,7 +120,7 @@ NodePair SinkBackward(NodePtr transpose, std::shared_ptr<Constant> transpose_con

return std::make_pair(transpose, reshape_new);
}

#if 0
bool IsFlatten2D(const Output<Node>& output) {
std::shared_ptr<ov::Node> reshape_node = output.get_node_shared_ptr();
if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() ||
Expand All @@ -155,7 +146,7 @@ bool IsUnflatten2D(const Output<Node>& output) {
output_shape[0] == input_shape[0] &&
input_shape[1] == output_shape[1] * output_shape[2]);
}

#endif
} // namespace

// working with situation when we transpose dims that are flatten/unflatten
Expand Down Expand Up @@ -189,7 +180,7 @@ GatherSinkingTransposeReshapeForward::GatherSinkingTransposeReshapeForward() {
GatherSinkingTransposeReshapeBackward::GatherSinkingTransposeReshapeBackward() {
MATCHER_SCOPE(GatherSinkingTransposeReshapeBackward);

auto reshape_label = wrap_type<Reshape>({any_input(), any_input()}, IsUnflatten2D/*check if it is sinkable */);
auto reshape_label = wrap_type<Reshape>({any_input(), any_input()}/*, IsUnflatten2D*//*check if it is sinkable */); // TODO: IsUnflatten2D
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({reshape_label, transpose_const_label});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ TEST(GatherSinkingTransposeReshape, BackwardSinking) {
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid) << result.message;
}
#if 0

TEST(GatherSinkingTransposeReshape, BackwardSinking3D) {
std::shared_ptr<Model> function;
{
Expand Down Expand Up @@ -283,10 +283,10 @@ TEST(GatherSinkingTransposeReshape, BackwardSinking3D) {

auto generate_indices = []() -> std::vector<int64_t> {
std::vector<int64_t> indices;
for (int i = 0; i < 80; ++i) { // FIXME
indices.push_back(i);
indices.push_back(i + 80);
indices.push_back(i + 160);
for (int j = 0; j < 4; ++j) {
for (int i = 0; i < 14; ++i) {
indices.push_back(j + 4 * i);
}
}
return indices;
};
Expand All @@ -306,8 +306,9 @@ TEST(GatherSinkingTransposeReshape, BackwardSinking3D) {
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid) << result.message;
CompareOutput(function, reference_function); // DEBUG
}
#endif

#if 0
TEST(GatherSinkingTransposeReshape, ForwardSinkingNoSinkOnes) {
std::shared_ptr<Model> function;
Expand Down

0 comments on commit 1cba14e

Please sign in to comment.