Skip to content

Commit

Permalink
add tests Variant #4 Convolution -> Activation -> MaxPool
Browse files Browse the repository at this point in the history
  • Loading branch information
evkotov committed Jun 24, 2021
1 parent 260c10c commit 8f4530b
Showing 1 changed file with 234 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ std::shared_ptr<ngraph::Function> createReferenceFunctionVariant1(Args&& ... arg
{
auto input_params_convolution = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32,
ngraph::Shape{1, 3, 64, 64});


auto input_params_add = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32,
ngraph::Shape{1, 3, 64, 64});

auto weights = ngraph::opset1::Constant::create(ngraph::element::f32,
ngraph::Shape{3, 3, 1, 1}, {1});
Expand All @@ -83,7 +85,7 @@ std::shared_ptr<ngraph::Function> createReferenceFunctionVariant1(Args&& ... arg

auto result = std::make_shared<ngraph::opset7::Result>(activation);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params_convolution});
ngraph::ParameterVector{input_params_convolution, input_params_add});
}

TEST(TransformationTests, ReorderActivationAndPoolingTestVariant1ActivationRelu) {
Expand Down Expand Up @@ -440,4 +442,234 @@ TEST(TransformationTests, ReorderActivationAndPoolingTestVariant3) {
ASSERT_TRUE(result.valid);
}

// Variant #4 Convolution -> Activation -> MaxPool

template <typename ActivationT, typename ... Args>
std::shared_ptr<ngraph::Function> createFunctionVariant4(Args&& ... args)
{
auto input_params_convolution = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32,
ngraph::Shape{1, 3, 64, 64});

auto weights = ngraph::opset1::Constant::create(ngraph::element::f32,
ngraph::Shape{3, 3, 1, 1}, {1});
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32,
ngraph::Shape{3, 1, 1}, {1});
auto convolution_operation = std::make_shared<ngraph::opset7::Convolution>(input_params_convolution,
weights,
ngraph::Strides{1, 1},
ngraph::CoordinateDiff{0, 0},
ngraph::CoordinateDiff{0, 0},
ngraph::Strides{1, 1});

auto activation = std::make_shared<ActivationT>(convolution_operation, std::forward<Args>(args) ... );

auto max_pool_operation = std::make_shared<ngraph::opset7::MaxPool>(activation,
ngraph::Strides{1, 1},
ngraph::Shape{1, 1},
ngraph::Shape{1, 1},
ngraph::Shape{1, 1});

auto result = std::make_shared<ngraph::opset7::Result>(max_pool_operation);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params_convolution});
}

template <typename ActivationT, typename ... Args>
std::shared_ptr<ngraph::Function> createReferenceFunctionVariant4(Args&& ... args)
{
auto input_params_convolution = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32,
ngraph::Shape{1, 3, 64, 64});

auto weights = ngraph::opset1::Constant::create(ngraph::element::f32,
ngraph::Shape{3, 3, 1, 1}, {1});
auto bias = ngraph::opset1::Constant::create(ngraph::element::f32,
ngraph::Shape{3, 1, 1}, {1});
auto convolution_operation = std::make_shared<ngraph::opset7::Convolution>(input_params_convolution,
weights,
ngraph::Strides{1, 1},
ngraph::CoordinateDiff{0, 0},
ngraph::CoordinateDiff{0, 0},
ngraph::Strides{1, 1});

auto max_pool_operation = std::make_shared<ngraph::opset7::MaxPool>(convolution_operation,
ngraph::Strides{1, 1},
ngraph::Shape{1, 1},
ngraph::Shape{1, 1},
ngraph::Shape{1, 1});

auto activation = std::make_shared<ActivationT>(max_pool_operation, std::forward<Args>(args) ... );

auto result = std::make_shared<ngraph::opset7::Result>(activation);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params_convolution});
}

TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationRelu) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);

{
func = createFunctionVariant4<ngraph::opset7::Relu>();

ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();

m.register_pass<GNAPluginNS::ReorderActivationAndPooling>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = createReferenceFunctionVariant4<ngraph::opset7::Relu>();

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}

TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationSigmoid) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);

{
func = createFunctionVariant4<ngraph::opset7::Sigmoid>();

ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();

m.register_pass<GNAPluginNS::ReorderActivationAndPooling>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = createReferenceFunctionVariant4<ngraph::opset7::Sigmoid>();

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}

TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationTanh) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);

{
func = createFunctionVariant4<ngraph::opset7::Tanh>();

ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();

m.register_pass<GNAPluginNS::ReorderActivationAndPooling>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = createReferenceFunctionVariant4<ngraph::opset7::Tanh>();

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}

TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationAbs) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);

{
func = createFunctionVariant4<ngraph::opset7::Abs>();

ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();

m.register_pass<GNAPluginNS::ReorderActivationAndPooling>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = createReferenceFunctionVariant4<ngraph::opset7::Abs>();

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}

TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationLog) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);

{
func = createFunctionVariant4<ngraph::opset7::Log>();

ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();

m.register_pass<GNAPluginNS::ReorderActivationAndPooling>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = createReferenceFunctionVariant4<ngraph::opset7::Log>();

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}

TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationExp) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);

{
func = createFunctionVariant4<ngraph::opset7::Exp>();

ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();

m.register_pass<GNAPluginNS::ReorderActivationAndPooling>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = createReferenceFunctionVariant4<ngraph::opset7::Exp>();

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}

TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationSign) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);

{
func = createFunctionVariant4<ngraph::opset7::Sign>();

ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();

m.register_pass<GNAPluginNS::ReorderActivationAndPooling>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = createReferenceFunctionVariant4<ngraph::opset7::Sign>();

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}

TEST(TransformationTests, ReorderActivationAndPoolingTestVariant4ActivationClamp) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);

{
func = createFunctionVariant4<ngraph::opset7::Clamp>(0.1, 0.2);

ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();

m.register_pass<GNAPluginNS::ReorderActivationAndPooling>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = createReferenceFunctionVariant4<ngraph::opset7::Clamp>(0.1, 0.2);

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}

} // namespace testing

0 comments on commit 8f4530b

Please sign in to comment.