Skip to content

Commit

Permalink
wrote unit tests draft
Browse files Browse the repository at this point in the history
  • Loading branch information
evkotov committed Mar 6, 2023
1 parent ae63a5a commit 5dfba63
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include <openvino/cc/ngraph/itt.hpp>
#include <ngraph/rt_info.hpp>

#include "transformations/gather_sinking_transpose_reshape.hpp"

Expand Down Expand Up @@ -45,6 +46,9 @@ NodePair SinkForward(NodePtr transpose, NodePtr reshape) {

ov::replace_node(reshape, gather);

ov::copy_runtime_info({reshape}, {gather, gather_indices, gather_axis, reshape_new});
gather->set_friendly_name(reshape->get_friendly_name());

return std::make_pair(reshape_new, gather);
}

Expand Down Expand Up @@ -82,11 +86,17 @@ NodePair SinkBackward(NodePtr transpose, std::shared_ptr<Constant> transpose_con

ov::replace_node(transpose, reshape_new);

ov::copy_runtime_info({transpose}, {gather, gather_indices, gather_axis, reshape_new, reshape_const_new});
reshape_new->set_friendly_name(transpose->get_friendly_name());

return std::make_pair(transpose, reshape_new);
}

bool IsFlatten2D(const Output<Node>& output) {
std::shared_ptr<ov::Node> reshape_node = output.get_node_shared_ptr();
if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() ||
reshape_node->get_input_partial_shape(0).rank().is_dynamic())
return false;
const Shape& input_shape = reshape_node->get_input_shape(0);
const Shape& output_shape = reshape_node->get_output_shape(0);
return (input_shape.size() == 3 &&
Expand All @@ -97,6 +107,9 @@ bool IsFlatten2D(const Output<Node>& output) {

bool IsUnflatten2D(const Output<Node>& output) {
std::shared_ptr<ov::Node> reshape_node = output.get_node_shared_ptr();
if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() ||
reshape_node->get_input_partial_shape(0).rank().is_dynamic())
return false;
const Shape& input_shape = reshape_node->get_input_shape(0);
const Shape& output_shape = reshape_node->get_output_shape(0);
return (input_shape.size() == 2 &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,31 @@ void CompareOutput(std::shared_ptr<ov::Model> function, std::shared_ptr<ov::Mode
}
}
}

template <typename T>
std::shared_ptr<T> FindNode(std::shared_ptr<Model> model) {
for (auto op : model->get_ops()) {
auto node = as_type_ptr<T>(op);
if (node)
return node;
}
return {};
}

void PrintConstant(std::shared_ptr<Node> node) {
auto constant = as_type_ptr<Constant>(node);
if (!constant)
return;
auto value = constant->cast_vector<int>();
std::cout << "{ ";
for (int i = 0; i < value.size(); ++i) {
if (i)
std::cout << ", ";
std::cout << value[i];
}
std::cout << " }" << std::endl;
}

} // namespace

TEST(GatherSinkingTransposeReshape, ForwardSinking) {
Expand All @@ -98,15 +123,47 @@ TEST(GatherSinkingTransposeReshape, ForwardSinking) {
function = std::make_shared<Model>(OutputVector{result}, ParameterVector{input_params});
}

std::shared_ptr<Model> reference_function = function->clone();
std::shared_ptr<Model> orig_function = function->clone();
ov::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::VisualizeTree>("./0before.png");
//manager.register_pass<ngraph::pass::VisualizeTree>("./0before.png");
manager.register_pass<ov::intel_gna::pass::GatherSinkingTransposeReshapeForward>();
manager.register_pass<ngraph::pass::VisualizeTree>("./1after.png");
//manager.register_pass<ngraph::pass::VisualizeTree>("./1after.png");
manager.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));

CompareOutput(function, orig_function);

CompareOutput(function, reference_function);
std::shared_ptr<Model> reference_function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 3, 80});
auto tanh0 = std::make_shared<Tanh>(input_params);

auto reshape_const = std::make_shared<Constant>(element::i64, Shape{2}, std::vector<int>{1, -1});
auto reshape = std::make_shared<Reshape>(tanh0, reshape_const, false);

auto generate_indices = []() -> std::vector<int64_t> {
std::vector<int64_t> indices;
for (int i = 0; i < 80; ++i) {
indices.push_back(i);
indices.push_back(i + 80);
indices.push_back(i + 160);
}
return indices;
};
auto gather_indices = generate_indices();
auto gather_indices_const = std::make_shared<Constant>(element::i64, Shape{gather_indices.size()}, gather_indices);
auto gather_axis_const = std::make_shared<Constant>(element::i64, Shape{}, 1);
auto gather = std::make_shared<Gather>(reshape, gather_indices_const, gather_axis_const);

auto tanh1 = std::make_shared<Tanh>(gather);
const auto result = std::make_shared<Result>(tanh1);
reference_function = std::make_shared<Model>(OutputVector{result}, ParameterVector{input_params});
}

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid);
}

TEST(GatherSinkingTransposeReshape, BackwardSinking) {
Expand All @@ -126,16 +183,167 @@ TEST(GatherSinkingTransposeReshape, BackwardSinking) {
function = std::make_shared<Model>(OutputVector{result}, ParameterVector{input_params});
}

std::shared_ptr<Model> reference_function = function->clone();
std::shared_ptr<Model> orig_function = function->clone();
ov::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::VisualizeTree>("./0before.png");
//manager.register_pass<ngraph::pass::VisualizeTree>("./0before.png");
manager.register_pass<ov::intel_gna::pass::GatherSinkingTransposeReshapeBackward>();
manager.register_pass<ngraph::pass::VisualizeTree>("./1after.png");
//manager.register_pass<ngraph::pass::VisualizeTree>("./1after.png");
manager.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));

CompareOutput(function, orig_function);

std::shared_ptr<Model> reference_function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 240});
auto tanh0 = std::make_shared<Tanh>(input_params);

auto generate_indices = []() -> std::vector<int64_t> {
std::vector<int64_t> indices;
for (int i = 0; i < 80; ++i) {
indices.push_back(i);
indices.push_back(i + 80);
indices.push_back(i + 160);
}
return indices;
};
auto gather_indices = generate_indices();
auto gather_indices_const = std::make_shared<Constant>(element::i64, Shape{gather_indices.size()}, gather_indices);
auto gather_axis_const = std::make_shared<Constant>(element::i64, Shape{}, 1);
auto gather = std::make_shared<Gather>(tanh0, gather_indices_const, gather_axis_const);

auto reshape_const = std::make_shared<Constant>(element::i64, Shape{3}, std::vector<int>{1, 80, 3});
auto reshape = std::make_shared<Reshape>(gather, reshape_const, false);

auto tanh1 = std::make_shared<Tanh>(reshape);
const auto result = std::make_shared<Result>(tanh1);
reference_function = std::make_shared<Model>(OutputVector{result}, ParameterVector{input_params});
}

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid) << result.message;
}

TEST(GatherSinkingTransposeReshape, ForwardSinkingNoSink1) {
std::shared_ptr<Model> function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 3, 80});
auto tanh0 = std::make_shared<Tanh>(input_params);

auto transpose_order = std::make_shared<Constant>(element::u64, Shape{3}, Shape{0, 2, 1});
auto transpose = std::make_shared<Transpose>(tanh0, transpose_order);

auto reshape_const = std::make_shared<Constant>(element::i64, Shape{4}, std::vector<int>{1, 3, 80, 1});
auto reshape = std::make_shared<Reshape>(transpose, reshape_const, false);

auto tanh1 = std::make_shared<Tanh>(reshape);
const auto result = std::make_shared<Result>(tanh1);
function = std::make_shared<Model>(OutputVector{result}, ParameterVector{input_params});
}

std::shared_ptr<Model> orig_function = function->clone();
ov::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ov::intel_gna::pass::GatherSinkingTransposeReshapeForward>();
manager.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));

CompareOutput(function, reference_function);
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, orig_function);
ASSERT_TRUE(result.valid);
}

TEST(GatherSinkingTransposeReshape, ForwardSinkingNoSink2) {
std::shared_ptr<Model> function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 4, 80});
auto tanh0 = std::make_shared<Tanh>(input_params);

auto transpose_order = std::make_shared<Constant>(element::u64, Shape{3}, Shape{0, 2, 1});
auto transpose = std::make_shared<Transpose>(tanh0, transpose_order);

auto reshape_const = std::make_shared<Constant>(element::i64, Shape{4}, std::vector<int>{1, 2, 80, 2});
auto reshape = std::make_shared<Reshape>(transpose, reshape_const, false);

auto tanh1 = std::make_shared<Tanh>(reshape);
const auto result = std::make_shared<Result>(tanh1);
function = std::make_shared<Model>(OutputVector{result}, ParameterVector{input_params});
}

std::shared_ptr<Model> orig_function = function->clone();
ov::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ov::intel_gna::pass::GatherSinkingTransposeReshapeForward>();
manager.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, orig_function);
ASSERT_TRUE(result.valid);
}

TEST(GatherSinkingTransposeReshape, BackwardSinkingNoSink1) {
std::shared_ptr<Model> function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 240});
auto tanh0 = std::make_shared<Tanh>(input_params);

auto reshape_const = std::make_shared<Constant>(element::i64, Shape{4}, std::vector<int>{1, 3, 80, 1});
auto reshape = std::make_shared<Reshape>(tanh0, reshape_const, false);

auto transpose_order = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 1, 3});
auto transpose = std::make_shared<Transpose>(reshape, transpose_order);

auto tanh1 = std::make_shared<Tanh>(transpose);
const auto result = std::make_shared<Result>(tanh1);
function = std::make_shared<Model>(OutputVector{result}, ParameterVector{input_params});
}

std::shared_ptr<Model> orig_function = function->clone();
ov::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
//manager.register_pass<ngraph::pass::VisualizeTree>("./0before.png");
manager.register_pass<ov::intel_gna::pass::GatherSinkingTransposeReshapeBackward>();
//manager.register_pass<ngraph::pass::VisualizeTree>("./1after.png");
manager.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, orig_function);
ASSERT_TRUE(result.valid) << result.message;
}

TEST(GatherSinkingTransposeReshape, BackwardSinkingNoSink2) {
std::shared_ptr<Model> function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 320});
auto tanh0 = std::make_shared<Tanh>(input_params);

auto reshape_const = std::make_shared<Constant>(element::i64, Shape{4}, std::vector<int>{1, 2, 80, 2});
auto reshape = std::make_shared<Reshape>(tanh0, reshape_const, false);

auto transpose_order = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 1, 3});
auto transpose = std::make_shared<Transpose>(reshape, transpose_order);

auto tanh1 = std::make_shared<Tanh>(transpose);
const auto result = std::make_shared<Result>(tanh1);
function = std::make_shared<Model>(OutputVector{result}, ParameterVector{input_params});
}

std::shared_ptr<Model> orig_function = function->clone();
ov::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
//manager.register_pass<ngraph::pass::VisualizeTree>("./0before.png");
manager.register_pass<ov::intel_gna::pass::GatherSinkingTransposeReshapeBackward>();
//manager.register_pass<ngraph::pass::VisualizeTree>("./1after.png");
manager.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, orig_function);
ASSERT_TRUE(result.valid) << result.message;
}

} // namespace testing

0 comments on commit 5dfba63

Please sign in to comment.