From 8847180dafdb6a8256b89fe7ebc0ce1c796347ca Mon Sep 17 00:00:00 2001 From: Patryk Elszkowski Date: Mon, 19 Jul 2021 11:24:57 +0200 Subject: [PATCH] add builder to template tests --- .../tests/functional/op_reference/acosh.cpp | 34 ++--------- .../op_reference/base_reference_test.cpp | 4 ++ .../op_reference/base_reference_test.hpp | 56 +++++++++++++++++++ .../tests/functional/op_reference/convert.cpp | 1 + 4 files changed, 66 insertions(+), 29 deletions(-) diff --git a/docs/template_plugin/tests/functional/op_reference/acosh.cpp b/docs/template_plugin/tests/functional/op_reference/acosh.cpp index a9ca9b53a8b965..e854c98b7e0f7a 100644 --- a/docs/template_plugin/tests/functional/op_reference/acosh.cpp +++ b/docs/template_plugin/tests/functional/op_reference/acosh.cpp @@ -14,42 +14,17 @@ using namespace ngraph; +namespace reference_tests { namespace { -struct Tensor { - Tensor() = default; - - Tensor(const ngraph::Shape& shape, ngraph::element::Type type, const InferenceEngine::Blob::Ptr& data): shape {shape}, type {type}, data {data} {} - - template - Tensor(const ngraph::Shape& shape, ngraph::element::Type type, const std::vector& data_elements) - : Tensor {shape, type, CreateBlob(type, data_elements)} {} - - ngraph::Shape shape; - ngraph::element::Type type; - InferenceEngine::Blob::Ptr data; -}; - struct AcoshParams { Tensor input; Tensor expected; }; -struct Builder { - AcoshParams params; - - operator AcoshParams() const { - return params; - } - -#define ADD_SET_PARAM(set_p) \ - Builder& set_p(decltype(params.set_p) t) { \ - params.set_p = std::move(t); \ - return *this; \ - } - ADD_SET_PARAM(input); - ADD_SET_PARAM(expected); -#undef ADD_SET_PARAM +struct Builder : ParamsBuilder { + REFERENCE_TESTS_ADD_SET_PARAM(Builder, input); + REFERENCE_TESTS_ADD_SET_PARAM(Builder, expected); }; class ReferenceAcoshLayerTest : public testing::TestWithParam, public CommonReferenceTest { @@ -103,3 +78,4 @@ INSTANTIATE_TEST_SUITE_P( .input({{8}, element::u64, std::vector {1, 2, 3, 4, 5, 10, 100, 1000}}) .expected({{8}, element::u64, std::vector {0, 1, 2, 2, 2, 3, 5, 8}})), ReferenceAcoshLayerTest::getTestCaseName); +} // namespace reference_tests diff --git a/docs/template_plugin/tests/functional/op_reference/base_reference_test.cpp b/docs/template_plugin/tests/functional/op_reference/base_reference_test.cpp index 51af4d2ea1a221..f2d2cf68aa39a2 100644 --- a/docs/template_plugin/tests/functional/op_reference/base_reference_test.cpp +++ b/docs/template_plugin/tests/functional/op_reference/base_reference_test.cpp @@ -9,6 +9,8 @@ using namespace InferenceEngine; +namespace reference_tests { + CommonReferenceTest::CommonReferenceTest(): targetDevice("TEMPLATE") { core = PluginCache::get().ie(targetDevice); } @@ -171,3 +173,5 @@ void CommonReferenceTest::ValidateBlobs(const InferenceEngine::Blob::Ptr& refBlo FAIL() << "Comparator for " << precision << " precision isn't supported"; } } + +} // namespace reference_tests diff --git a/docs/template_plugin/tests/functional/op_reference/base_reference_test.hpp b/docs/template_plugin/tests/functional/op_reference/base_reference_test.hpp index 6e3fd942a9e722..a9827eb07b8e35 100644 --- a/docs/template_plugin/tests/functional/op_reference/base_reference_test.hpp +++ b/docs/template_plugin/tests/functional/op_reference/base_reference_test.hpp @@ -5,8 +5,12 @@ #include #include #include +#include +#include #include +namespace reference_tests { + class CommonReferenceTest { public: CommonReferenceTest(); @@ -51,3 +55,55 @@ InferenceEngine::Blob::Ptr CreateBlob(const ngraph::element::Type& element_type, return blob; } +/// +/// Class which should help to build data for single input +/// +struct Tensor { + Tensor() = default; + + Tensor(const ngraph::Shape& shape, ngraph::element::Type type, const InferenceEngine::Blob::Ptr& data): shape {shape}, type {type}, data {data} {} + + template + Tensor(const ngraph::Shape& shape, ngraph::element::Type type, const std::vector& data_elements) + : Tensor {shape, type, CreateBlob(type, data_elements)} {} + + ngraph::Shape shape; + ngraph::element::Type type; + InferenceEngine::Blob::Ptr data; +}; + +/// +/// Class which should helps build test parameters. +/// +/// e.g.: +/// struct Params { +/// Tensor i,o; +/// int mul; +/// }; +/// struct TestParamsBuilder : ParamsBuilder +/// REFERENCE_TESTS_ADD_SET_PARAM(TestParamsBuilder, i); +/// REFERENCE_TESTS_ADD_SET_PARAM(TestParamsBuilder, o); +/// REFERENCE_TESTS_ADD_SET_PARAM(TestParamsBuilder, mul); +/// }; +/// +/// const Params p = TestParamsBuilder{} +/// .i(Tensor{{0}, i32, {1}}) +/// .o(Tensor{{0}, i32, {1}}) +/// .mul(10); +template +class ParamsBuilder { +protected: + Params params; + +public: + operator Params() const { + return params; + } +}; +#define REFERENCE_TESTS_ADD_SET_PARAM(builder_type, parma_to_set) \ + builder_type& parma_to_set(decltype(params.parma_to_set) t) { \ + params.parma_to_set = std::move(t); \ + return *this; \ + } + +} // namespace reference_tests diff --git a/docs/template_plugin/tests/functional/op_reference/convert.cpp b/docs/template_plugin/tests/functional/op_reference/convert.cpp index fb32fda4cbbfd8..b8e6f5846f7408 100644 --- a/docs/template_plugin/tests/functional/op_reference/convert.cpp +++ b/docs/template_plugin/tests/functional/op_reference/convert.cpp @@ -12,6 +12,7 @@ #include "base_reference_test.hpp" +using namespace reference_tests; using namespace ngraph; using namespace InferenceEngine;