Skip to content

Commit

Permalink
Shared onnx tests to API2.0 (#21726)
Browse files Browse the repository at this point in the history
  • Loading branch information
vurusovs authored Dec 19, 2023
1 parent 5b776e9 commit e6ab01c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@
#pragma once

#include <string>
#include "shared_test_classes/base/layer_test_utils.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"

namespace ONNXTestsDefinitions {

class QuantizedModelsTests : public testing::WithParamInterface<std::string>,
virtual public LayerTestsUtils::LayerTestsCommon {
virtual public ov::test::SubgraphBaseStaticTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<std::string>& obj);

protected:
void SetUp() override;
using LayerInputTypes = std::unordered_map<std::string, std::vector<ngraph::element::Type_t>>;
void runModel(const char* model, const LayerInputTypes& expected_layer_input_types, float thr);
using LayerInputTypes = std::unordered_map<std::string, std::vector<ov::element::Type>>;
void run_model(const char* model, const LayerInputTypes& expected_layer_input_types, float thr);
};

} // namespace ONNXTestsDefinitions
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@ static std::string getModelFullPath(const char* path) {
FileUtils::makePath<char>(ov::test::utils::getExecutableDirectory(), TEST_MODELS), path);
}

void QuantizedModelsTests::runModel(const char* model, const LayerInputTypes& expected_layer_input_types, float thr) {
threshold = thr;
auto ie = getCore();
auto network = ie->ReadNetwork(getModelFullPath(model));
function = network.getFunction();
Run();
auto runtime_function = executableNetwork.GetExecGraphInfo().getFunction();
void QuantizedModelsTests::run_model(const char* model, const LayerInputTypes& expected_layer_input_types, float thr) {
abs_threshold = thr;
function = core->read_model(getModelFullPath(model));
ov::test::SubgraphBaseStaticTest::run();
auto runtime_model = compiledModel.get_runtime_model();
int ops_found = 0;
for (const auto& node : runtime_function->get_ordered_ops()) {
for (const auto& node : runtime_model->get_ordered_ops()) {
const auto& name = node->get_friendly_name();
if (expected_layer_input_types.count(name)) {
ops_found++;
Expand All @@ -47,25 +45,21 @@ void QuantizedModelsTests::runModel(const char* model, const LayerInputTypes& ex
}

TEST_P(QuantizedModelsTests, MaxPoolQDQ) {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
runModel("max_pool_qdq.onnx", {{"890_original", {ngraph::element::u8}}}, 1e-5);
run_model("max_pool_qdq.onnx", {{"890_original", {ov::element::u8}}}, 1e-5);
}

TEST_P(QuantizedModelsTests, MaxPoolFQ) {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
runModel("max_pool_fq.onnx", {{"887_original", {ngraph::element::u8}}}, 1e-5);
run_model("max_pool_fq.onnx", {{"887_original", {ov::element::u8}}}, 1e-5);
}

TEST_P(QuantizedModelsTests, ConvolutionQDQ) {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
// activations have type uint8 and weights int8
runModel("convolution_qdq.onnx", {{"908_original", {ngraph::element::u8, ngraph::element::i8}}}, 1.5e-2);
run_model("convolution_qdq.onnx", {{"908_original", {ov::element::u8, ov::element::i8}}}, 1.5e-2);
}

TEST_P(QuantizedModelsTests, ConvolutionFQ) {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
// activations have type uint8 and weights int8
runModel("convolution_fq.onnx", {{"902_original", {ngraph::element::u8, ngraph::element::i8}}}, 1.5e-2);
run_model("convolution_fq.onnx", {{"902_original", {ov::element::u8, ov::element::i8}}}, 1.5e-2);
}

} // namespace ONNXTestsDefinitions

0 comments on commit e6ab01c

Please sign in to comment.