Skip to content

Commit

Permalink
add builder to template tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pelszkow committed Jul 19, 2021
1 parent 3653844 commit a16ab8b
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 29 deletions.
34 changes: 5 additions & 29 deletions docs/template_plugin/tests/functional/op_reference/acosh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,17 @@

using namespace ngraph;

namespace reference_tests {
namespace {

struct Tensor {
Tensor() = default;

Tensor(const ngraph::Shape& shape, ngraph::element::Type type, const InferenceEngine::Blob::Ptr& data): shape {shape}, type {type}, data {data} {}

template <typename T>
Tensor(const ngraph::Shape& shape, ngraph::element::Type type, const std::vector<T>& data_elements)
: Tensor {shape, type, CreateBlob(type, data_elements)} {}

ngraph::Shape shape;
ngraph::element::Type type;
InferenceEngine::Blob::Ptr data;
};

struct AcoshParams {
Tensor input;
Tensor expected;
};

struct Builder {
AcoshParams params;

operator AcoshParams() const {
return params;
}

#define ADD_SET_PARAM(set_p) \
Builder& set_p(decltype(params.set_p) t) { \
params.set_p = std::move(t); \
return *this; \
}
ADD_SET_PARAM(input);
ADD_SET_PARAM(expected);
#undef ADD_SET_PARAM
struct Builder : ParamsBuilder<AcoshParams> {
REFERENCE_TESTS_ADD_SET_PARAM(Builder, input);
REFERENCE_TESTS_ADD_SET_PARAM(Builder, expected);
};

class ReferenceAcoshLayerTest : public testing::TestWithParam<AcoshParams>, public CommonReferenceTest {
Expand Down Expand Up @@ -103,3 +78,4 @@ INSTANTIATE_TEST_SUITE_P(
.input({{8}, element::u64, std::vector<uint64_t> {1, 2, 3, 4, 5, 10, 100, 1000}})
.expected({{8}, element::u64, std::vector<uint64_t> {0, 1, 2, 2, 2, 3, 5, 8}})),
ReferenceAcoshLayerTest::getTestCaseName);
} // namespace reference_tests
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

using namespace InferenceEngine;

namespace reference_tests {

CommonReferenceTest::CommonReferenceTest(): targetDevice("TEMPLATE") {
core = PluginCache::get().ie(targetDevice);
}
Expand Down Expand Up @@ -171,3 +173,5 @@ void CommonReferenceTest::ValidateBlobs(const InferenceEngine::Blob::Ptr& refBlo
FAIL() << "Comparator for " << precision << " precision isn't supported";
}
}

} // namespace reference_tests
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
#include <ie_core.hpp>
#include <ie_ngraph_utils.hpp>
#include <ngraph/ngraph.hpp>
#include <ngraph/shape.hpp>
#include <ngraph/type/element_type.hpp>
#include <shared_test_classes/base/layer_test_utils.hpp>

namespace reference_tests {

class CommonReferenceTest {
public:
CommonReferenceTest();
Expand Down Expand Up @@ -51,3 +55,55 @@ InferenceEngine::Blob::Ptr CreateBlob(const ngraph::element::Type& element_type,
return blob;
}

///
/// Class which should help to build data for single input
///
struct Tensor {
Tensor() = default;

Tensor(const ngraph::Shape& shape, ngraph::element::Type type, const InferenceEngine::Blob::Ptr& data): shape {shape}, type {type}, data {data} {}

template <typename T>
Tensor(const ngraph::Shape& shape, ngraph::element::Type type, const std::vector<T>& data_elements)
: Tensor {shape, type, CreateBlob(type, data_elements)} {}

ngraph::Shape shape;
ngraph::element::Type type;
InferenceEngine::Blob::Ptr data;
};

///
/// Class which should help in building test inputs.
///
/// e.g.:
/// struct Params {
/// Tensor i,o;
/// int mul;
/// };
/// struct TestParamsBuilder : ParamsBuilder<Params>
/// REFERENCE_TESTS_ADD_SET_PARAM(TestParamsBuilder, i);
/// REFERENCE_TESTS_ADD_SET_PARAM(TestParamsBuilder, o);
/// REFERENCE_TESTS_ADD_SET_PARAM(TestParamsBuilder, mul);
/// };
///
/// const Params p = TestParamsBuilder{}
/// .i(Tensor{{0}, i32, {1}})
/// .o(Tensor{{0}, i32, {1}})
/// .mul(10);
template <typename Params>
class ParamsBuilder {
protected:
Params params;

public:
operator Params() const {
return params;
}
};
#define REFERENCE_TESTS_ADD_SET_PARAM(builder_type, parma_to_set) \
builder_type& parma_to_set(decltype(params.parma_to_set) t) { \
params.parma_to_set = std::move(t); \
return *this; \
}

} // namespace reference_tests
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "base_reference_test.hpp"

using namespace reference_tests;
using namespace ngraph;
using namespace InferenceEngine;

Expand Down

0 comments on commit a16ab8b

Please sign in to comment.