Skip to content

Commit

Permalink
add unit tests; fix
Browse files Browse the repository at this point in the history
  • Loading branch information
evkotov committed Mar 6, 2023
1 parent 668da96 commit 2d5383d
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ std::vector<int64_t> GetNormalizedGatherIndices(const std::shared_ptr<Constant>&
return NormalizeGatherIndices(indices->cast_vector<int64_t>());
}

std::vector<int64_t> ApplyGatherPermutation(const std::vector<int64_t>& input_gather_indices, const std::vector<int64_t>& output_gather_indices) {
std::vector<int64_t> CombineGatherPermutations(const std::vector<int64_t>& input_gather_indices, const std::vector<int64_t>& output_gather_indices) {
if (input_gather_indices.size() != output_gather_indices.size())
return {};
std::vector<int64_t> result(input_gather_indices.size());
for (size_t i = 0; i < result.size(); ++i) {
result[output_gather_indices[i]] = input_gather_indices[i];
result[i] = input_gather_indices[output_gather_indices[i]];
}

return result;
Expand All @@ -107,9 +107,9 @@ bool IsPointlessPermutation(const std::vector<int64_t>& indices) {
std::shared_ptr<Gather> FuseGatherNodes(TransformationInfo& info) {
const std::vector<int64_t> input_gather_indices = GetNormalizedGatherIndices(info.input_indices_const);
const std::vector<int64_t> output_gather_indices = GetNormalizedGatherIndices(info.output_indices_const);
const std::vector<int64_t> result_gather_indices = ApplyGatherPermutation(input_gather_indices, output_gather_indices);
const std::vector<int64_t> result_gather_indices = CombineGatherPermutations(input_gather_indices, output_gather_indices);
if (IsPointlessPermutation(result_gather_indices)) {
info.input_gather->input_value(0).replace(info.output_gather->output(0));
ov::replace_output_update_name(info.output_gather->output(0), info.input_gather->input_value(0));
return {};
}

Expand All @@ -120,7 +120,7 @@ std::shared_ptr<Gather> FuseGatherNodes(TransformationInfo& info) {
auto new_axis_const = info.output_axis_const->clone_with_new_inputs({});
auto new_gather = std::make_shared<Gather>(info.input_gather->input_value(0), new_indices_const, new_axis_const);

info.input_gather->input(0).replace_source_output(new_gather->output(0));
ov::replace_node(info.output_gather, new_gather);
copy_runtime_info(info.input_gather, {new_gather, new_indices_const, new_axis_const});
new_gather->set_friendly_name(info.output_gather->get_friendly_name());

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/gather_sinking_fuse.hpp"

#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/manager.hpp>
#include <transformations/init_node_info.hpp>

#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"

using namespace ov;
using namespace ov::opset9;

using NodePtr = std::shared_ptr<ov::Node>;

namespace {

std::shared_ptr<Gather> MakeGather(NodePtr input_node, const std::vector<size_t>& indices, size_t axis) {
const ov::Shape& input_shape = input_node->get_output_shape(0);
auto gather_indexes_node = Constant::create(ngraph::element::i64, ov::Shape{indices.size()}, indices);

auto gather_axis_node = Constant::create(ngraph::element::i64, ngraph::Shape{}, {axis});

return std::make_shared<Gather>(input_node, gather_indexes_node, gather_axis_node);
}

} // namespace

TEST(GatherSinkingFuse, Remove) {
std::shared_ptr<Model> function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 3, 80});
auto tanh0 = std::make_shared<Tanh>(input_params);

auto input_gather = MakeGather(tanh0, std::vector<size_t>{2, 0 ,1}, /* axis */ 1);
auto output_gather = MakeGather(input_gather, std::vector<size_t>{1, 2, 0}, /* axis */ 1);

auto tanh1 = std::make_shared<Tanh>(output_gather);
const auto result = std::make_shared<Result>(tanh1);
function = std::make_shared<Model>(OutputVector{result}, ParameterVector{input_params});
}

std::shared_ptr<Model> orig_function = function->clone();
ov::pass::Manager manager;
manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<ov::intel_gna::pass::GatherSinkingFuse>();
manager.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));

std::shared_ptr<Model> reference_function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 3, 80});
auto tanh0 = std::make_shared<Tanh>(input_params);

auto tanh1 = std::make_shared<Tanh>(tanh0);
const auto result = std::make_shared<Result>(tanh1);
reference_function = std::make_shared<Model>(OutputVector{result}, ParameterVector{input_params});
}

const FunctionsComparator func_comparator =
FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid) << result.message;
}

TEST(GatherSinkingFuse, DifferentAxis) {
std::shared_ptr<Model> function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 3, 80});
auto tanh0 = std::make_shared<Tanh>(input_params);

auto input_gather = MakeGather(tanh0, std::vector<size_t>{2, 0 ,1}, /* axis */ 1);
auto output_gather = MakeGather(input_gather, std::vector<size_t>{1, 2, 0}, /* axis */ 2);

auto tanh1 = std::make_shared<Tanh>(output_gather);
const auto result = std::make_shared<Result>(tanh1);
function = std::make_shared<Model>(OutputVector{result}, ParameterVector{input_params});
}

std::shared_ptr<Model> orig_function = function->clone();
ov::pass::Manager manager;
manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<ov::intel_gna::pass::GatherSinkingFuse>();
manager.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));

std::shared_ptr<Model> reference_function = function->clone();

const FunctionsComparator func_comparator =
FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid) << result.message;
}

TEST(GatherSinkingFuse, Combine) {
std::shared_ptr<Model> function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 3, 80});
auto tanh0 = std::make_shared<Tanh>(input_params);

auto input_gather = MakeGather(tanh0, std::vector<size_t>{2, 0 ,1}, /* axis */ 1);
auto output_gather = MakeGather(input_gather, std::vector<size_t>{1, 0, 2}, /* axis */ 1);

auto tanh1 = std::make_shared<Tanh>(output_gather);
const auto result = std::make_shared<Result>(tanh1);
function = std::make_shared<Model>(OutputVector{result}, ParameterVector{input_params});
}

std::shared_ptr<Model> orig_function = function->clone();
ov::pass::Manager manager;
manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<ov::intel_gna::pass::GatherSinkingFuse>();
manager.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));

std::shared_ptr<Model> reference_function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 3, 80});
auto tanh0 = std::make_shared<Tanh>(input_params);

auto gather = MakeGather(tanh0, std::vector<size_t>{0, 2, 1}, /* axis */ 1);

auto tanh1 = std::make_shared<Tanh>(gather);
const auto result = std::make_shared<Result>(tanh1);
reference_function = std::make_shared<Model>(OutputVector{result}, ParameterVector{input_params});
}

const FunctionsComparator func_comparator =
FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid) << result.message;
}

0 comments on commit 2d5383d

Please sign in to comment.