From 5371819c66c2deb059c4dd8ffb8c0bb04c6d8975 Mon Sep 17 00:00:00 2001 From: Evgeny Kotov Date: Thu, 24 Jun 2021 11:31:57 +0300 Subject: [PATCH] add tests Variant #4 Convolution -> Activation -> MaxPool --- .../gna_reorder_activation_and_pooling.cpp | 236 +++++++++++++++++- 1 file changed, 234 insertions(+), 2 deletions(-) diff --git a/inference-engine/tests/unit/gna/ngraph/transformations/gna_reorder_activation_and_pooling.cpp b/inference-engine/tests/unit/gna/ngraph/transformations/gna_reorder_activation_and_pooling.cpp index 0c46a57ed61654..4f0ad8c3b4a1aa 100644 --- a/inference-engine/tests/unit/gna/ngraph/transformations/gna_reorder_activation_and_pooling.cpp +++ b/inference-engine/tests/unit/gna/ngraph/transformations/gna_reorder_activation_and_pooling.cpp @@ -57,7 +57,9 @@ std::shared_ptr createReferenceFunctionVariant1(Args&& ... arg { auto input_params_convolution = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 64, 64}); - + + auto input_params_add = std::make_shared(ngraph::element::f32, + ngraph::Shape{1, 3, 64, 64}); auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3, 3, 1, 1}, {1}); @@ -83,7 +85,7 @@ std::shared_ptr createReferenceFunctionVariant1(Args&& ... arg auto result = std::make_shared(activation); return std::make_shared(ngraph::ResultVector{result}, - ngraph::ParameterVector{input_params_convolution}); + ngraph::ParameterVector{input_params_convolution, input_params_add}); } TEST(TransformationTests, ReorderActivationAndPoolingTestVariant1ActivationRelu) { @@ -440,4 +442,234 @@ TEST(TransformationTests, ReorderActivationAndPoolingTestVariant3) { ASSERT_TRUE(result.valid); } +// Variant #4 Convolution -> Activation -> MaxPool + +template +std::shared_ptr createFunctionVariant4(Args&& ... args) +{ + auto input_params_convolution = std::make_shared(ngraph::element::f32, + ngraph::Shape{1, 3, 64, 64}); + + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, + ngraph::Shape{3, 3, 1, 1}, {1}); + auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, + ngraph::Shape{3, 1, 1}, {1}); + auto convolution_operation = std::make_shared(input_params_convolution, + weights, + ngraph::Strides{1, 1}, + ngraph::CoordinateDiff{0, 0}, + ngraph::CoordinateDiff{0, 0}, + ngraph::Strides{1, 1}); + + auto activation = std::make_shared(convolution_operation, std::forward(args) ... ); + + auto max_pool_operation = std::make_shared(activation, + ngraph::Strides{1, 1}, + ngraph::Shape{1, 1}, + ngraph::Shape{1, 1}, + ngraph::Shape{1, 1}); + + auto result = std::make_shared(max_pool_operation); + return std::make_shared(ngraph::ResultVector{result}, + ngraph::ParameterVector{input_params_convolution}); +} + +template +std::shared_ptr createReferenceFunctionVariant4(Args&& ... args) +{ + auto input_params_convolution = std::make_shared(ngraph::element::f32, + ngraph::Shape{1, 3, 64, 64}); + + auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, + ngraph::Shape{3, 3, 1, 1}, {1}); + auto bias = ngraph::opset1::Constant::create(ngraph::element::f32, + ngraph::Shape{3, 1, 1}, {1}); + auto convolution_operation = std::make_shared(input_params_convolution, + weights, + ngraph::Strides{1, 1}, + ngraph::CoordinateDiff{0, 0}, + ngraph::CoordinateDiff{0, 0}, + ngraph::Strides{1, 1}); + + auto max_pool_operation = std::make_shared(convolution_operation, + ngraph::Strides{1, 1}, + ngraph::Shape{1, 1}, + ngraph::Shape{1, 1}, + ngraph::Shape{1, 1}); + + auto activation = std::make_shared(max_pool_operation, std::forward(args) ... ); + + auto result = std::make_shared(activation); + return std::make_shared(ngraph::ResultVector{result}, + ngraph::ParameterVector{input_params_convolution}); +} + +TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationRelu) { + std::shared_ptr func(nullptr), reference_func(nullptr); + + { + func = createFunctionVariant4(); + + ngraph::pass::Manager m; + m.register_pass(); + + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + reference_func = createReferenceFunctionVariant4(); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} + +TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationSigmoid) { + std::shared_ptr func(nullptr), reference_func(nullptr); + + { + func = createFunctionVariant4(); + + ngraph::pass::Manager m; + m.register_pass(); + + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + reference_func = createReferenceFunctionVariant4(); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} + +TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationTanh) { + std::shared_ptr func(nullptr), reference_func(nullptr); + + { + func = createFunctionVariant4(); + + ngraph::pass::Manager m; + m.register_pass(); + + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + reference_func = createReferenceFunctionVariant4(); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} + +TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationAbs) { + std::shared_ptr func(nullptr), reference_func(nullptr); + + { + func = createFunctionVariant4(); + + ngraph::pass::Manager m; + m.register_pass(); + + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + reference_func = createReferenceFunctionVariant4(); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} + +TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationLog) { + std::shared_ptr func(nullptr), reference_func(nullptr); + + { + func = createFunctionVariant4(); + + ngraph::pass::Manager m; + m.register_pass(); + + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + reference_func = createReferenceFunctionVariant4(); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} + +TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationExp) { + std::shared_ptr func(nullptr), reference_func(nullptr); + + { + func = createFunctionVariant4(); + + ngraph::pass::Manager m; + m.register_pass(); + + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + reference_func = createReferenceFunctionVariant4(); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} + +TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationSign) { + std::shared_ptr func(nullptr), reference_func(nullptr); + + { + func = createFunctionVariant4(); + + ngraph::pass::Manager m; + m.register_pass(); + + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + reference_func = createReferenceFunctionVariant4(); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} + +TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationClamp) { + std::shared_ptr func(nullptr), reference_func(nullptr); + + { + func = createFunctionVariant4(0.1, 0.2); + + ngraph::pass::Manager m; + m.register_pass(); + + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + reference_func = createReferenceFunctionVariant4(0.1, 0.2); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} + } // namespace testing