Skip to content

Commit

Permalink
cleanup tests
Browse files Browse the repository at this point in the history
  • Loading branch information
evkotov committed Mar 6, 2023
1 parent 5dfba63 commit 030dab3
Showing 1 changed file with 4 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,92 +20,6 @@ using namespace ov::opset9;

namespace testing {

namespace {
std::shared_ptr<void> GenerateFloatInput(size_t size, float initial_value, float delta) {
float * array = new float[size];
float value = initial_value;
for (size_t i = 0; i < size; ++i) {
array[i] = value;
value += delta;
}
return std::shared_ptr<void>(array);
}
template <typename T>
void AssertEq(const T& first, const T& second, const std::string& first_name, const std::string& second_name) {
if (first == second)
return;
std::ostringstream ss;
ss << "[EMUTEX ASSERT] " << first_name << " (" << first << ") != " << second_name << "(" << second << ")";
throw std::runtime_error(ss.str());
}
#define EMUTEX_DEBUG_ASSERT_EQ(first, second) AssertEq(first, second, #first, #second);
template <typename T, typename T1>
void AssertEqPrecision(const T& first, const T& second, const T1& delta, const std::string& first_name, const std::string& second_name) {
if (std::abs(first - second) <= delta)
return;
std::ostringstream ss;
ss << "[EMUTEX ASSERT] " << first_name << " (" << first << ") != " << second_name << " (" << second << ") with precision " << delta;
throw std::runtime_error(ss.str());
}
#define EMUTEX_DEBUG_ASSERT_EQ_PRECISION(first, second, delta) AssertEqPrecision(first, second, delta, #first, #second);
void CompareOutput(std::shared_ptr<ov::Model> function, std::shared_ptr<ov::Model> function_ref) {
auto function_input = function->input(0).get_node_shared_ptr();
auto function_ref_input = function_ref->input(0).get_node_shared_ptr();
const auto& function_input_shape = function_input->get_output_shape(0);
const auto& function_ref_input_shape = function_ref_input->get_output_shape(0);
bool rc = std::equal(function_input_shape.begin(), function_input_shape.end(), function_ref_input_shape.begin());
if (!rc)
throw std::runtime_error("function_input_shape != function_ref_input_shape");
const size_t n_outputs = function->outputs().size();
ov::TensorVector result(n_outputs), result_ref(n_outputs);
const size_t input_shape_product = std::accumulate(function_input_shape.begin(), function_input_shape.end(), 1, std::multiplies<size_t>());
auto inputs = GenerateFloatInput(input_shape_product, 0.0, 0.1);
ov::Tensor input{ov::element::f32, function_input_shape, inputs.get()};
rc = function->evaluate(result, ov::TensorVector{input});
if (!rc)
throw std::runtime_error("function->evaluate");
rc = function_ref->evaluate(result_ref, ov::TensorVector{input});
if (!rc)
throw std::runtime_error("function_ref->evaluate");
EMUTEX_DEBUG_ASSERT_EQ(result.size(), result_ref.size());
for (size_t output_idx = 0; output_idx < n_outputs; ++output_idx) {
EMUTEX_DEBUG_ASSERT_EQ(result[output_idx].get_element_type(), result_ref[output_idx].get_element_type());
EMUTEX_DEBUG_ASSERT_EQ(result[output_idx].get_shape(), result_ref[output_idx].get_shape());
EMUTEX_DEBUG_ASSERT_EQ(result[output_idx].get_size(), result_ref[output_idx].get_size());
const float * result_data = result[output_idx].data<float>();
const float * expected_result = result_ref[output_idx].data<float>();
for (size_t i = 0; i < result[output_idx].get_size(); ++i) {
EMUTEX_DEBUG_ASSERT_EQ_PRECISION(result_data[i], expected_result[i], 0.000001);
}
}
}

template <typename T>
std::shared_ptr<T> FindNode(std::shared_ptr<Model> model) {
for (auto op : model->get_ops()) {
auto node = as_type_ptr<T>(op);
if (node)
return node;
}
return {};
}

void PrintConstant(std::shared_ptr<Node> node) {
auto constant = as_type_ptr<Constant>(node);
if (!constant)
return;
auto value = constant->cast_vector<int>();
std::cout << "{ ";
for (int i = 0; i < value.size(); ++i) {
if (i)
std::cout << ", ";
std::cout << value[i];
}
std::cout << " }" << std::endl;
}

} // namespace

TEST(GatherSinkingTransposeReshape, ForwardSinking) {
std::shared_ptr<Model> function;
{
Expand All @@ -126,14 +40,10 @@ TEST(GatherSinkingTransposeReshape, ForwardSinking) {
std::shared_ptr<Model> orig_function = function->clone();
ov::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
//manager.register_pass<ngraph::pass::VisualizeTree>("./0before.png");
manager.register_pass<ov::intel_gna::pass::GatherSinkingTransposeReshapeForward>();
//manager.register_pass<ngraph::pass::VisualizeTree>("./1after.png");
manager.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));

CompareOutput(function, orig_function);

std::shared_ptr<Model> reference_function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 3, 80});
Expand Down Expand Up @@ -186,14 +96,10 @@ TEST(GatherSinkingTransposeReshape, BackwardSinking) {
std::shared_ptr<Model> orig_function = function->clone();
ov::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
//manager.register_pass<ngraph::pass::VisualizeTree>("./0before.png");
manager.register_pass<ov::intel_gna::pass::GatherSinkingTransposeReshapeBackward>();
//manager.register_pass<ngraph::pass::VisualizeTree>("./1after.png");
manager.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));

CompareOutput(function, orig_function);

std::shared_ptr<Model> reference_function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 240});
Expand Down Expand Up @@ -226,7 +132,7 @@ TEST(GatherSinkingTransposeReshape, BackwardSinking) {
ASSERT_TRUE(result.valid) << result.message;
}

TEST(GatherSinkingTransposeReshape, ForwardSinkingNoSink1) {
TEST(GatherSinkingTransposeReshape, ForwardSinkingNoSinkOnes) {
std::shared_ptr<Model> function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 3, 80});
Expand Down Expand Up @@ -255,7 +161,7 @@ TEST(GatherSinkingTransposeReshape, ForwardSinkingNoSink1) {
ASSERT_TRUE(result.valid);
}

TEST(GatherSinkingTransposeReshape, ForwardSinkingNoSink2) {
TEST(GatherSinkingTransposeReshape, ForwardSinkingNoSinkNot2d) {
std::shared_ptr<Model> function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 4, 80});
Expand Down Expand Up @@ -284,7 +190,7 @@ TEST(GatherSinkingTransposeReshape, ForwardSinkingNoSink2) {
ASSERT_TRUE(result.valid);
}

TEST(GatherSinkingTransposeReshape, BackwardSinkingNoSink1) {
TEST(GatherSinkingTransposeReshape, BackwardSinkingNoSinkOnes) {
std::shared_ptr<Model> function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 240});
Expand All @@ -304,9 +210,7 @@ TEST(GatherSinkingTransposeReshape, BackwardSinkingNoSink1) {
std::shared_ptr<Model> orig_function = function->clone();
ov::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
//manager.register_pass<ngraph::pass::VisualizeTree>("./0before.png");
manager.register_pass<ov::intel_gna::pass::GatherSinkingTransposeReshapeBackward>();
//manager.register_pass<ngraph::pass::VisualizeTree>("./1after.png");
manager.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));

Expand All @@ -315,7 +219,7 @@ TEST(GatherSinkingTransposeReshape, BackwardSinkingNoSink1) {
ASSERT_TRUE(result.valid) << result.message;
}

TEST(GatherSinkingTransposeReshape, BackwardSinkingNoSink2) {
TEST(GatherSinkingTransposeReshape, BackwardSinkingNoSinkNot2d) {
std::shared_ptr<Model> function;
{
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 320});
Expand All @@ -335,9 +239,7 @@ TEST(GatherSinkingTransposeReshape, BackwardSinkingNoSink2) {
std::shared_ptr<Model> orig_function = function->clone();
ov::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
//manager.register_pass<ngraph::pass::VisualizeTree>("./0before.png");
manager.register_pass<ov::intel_gna::pass::GatherSinkingTransposeReshapeBackward>();
//manager.register_pass<ngraph::pass::VisualizeTree>("./1after.png");
manager.run_passes(function);
ASSERT_NO_THROW(check_rt_info(function));

Expand Down

0 comments on commit 030dab3

Please sign in to comment.