From d821ec47ce0a17856c1127fcc776a8fa57ff5142 Mon Sep 17 00:00:00 2001 From: Mateusz Bencer Date: Mon, 15 Feb 2021 11:14:20 +0100 Subject: [PATCH] [ONNX] Handle optional outputs for Dropout and MaxPool (#4143) --- .../src/transformations/convert_precision.cpp | 4 +- .../transformations/convert_precision.cpp | 8 ++ .../frontend/onnx_import/src/core/graph.cpp | 6 +- .../frontend/onnx_import/src/op/dropout.cpp | 107 +++++++++++++++ .../frontend/onnx_import/src/op/dropout.hpp | 22 +-- .../frontend/onnx_import/src/op/max_pool.cpp | 6 + .../frontend/onnx_import/src/ops_bridge.cpp | 2 + ngraph/python/tests/__init__.py | 5 +- ngraph/python/tests/runtime.py | 44 ++++-- ngraph/python/tests/test_ngraph/test_basic.py | 12 +- .../tests/test_ngraph/test_ops_reshape.py | 1 - ngraph/python/tests/test_onnx/test_backend.py | 36 ++--- ...pout12_no_training_no_return_mask.prototxt | 56 ++++++++ ...dropout12_no_training_return_mask.prototxt | 86 ++++++++++++ ...ropout12_no_traning_no_const_rato.prototxt | 101 ++++++++++++++ ...dropout12_not_const_training_mode.prototxt | 122 +++++++++++++++++ .../onnx/dropout12_training_mode.prototxt | 128 ++++++++++++++++++ ...opout1_no_training_no_return_mask.prototxt | 56 ++++++++ .../dropout1_no_training_return_mask.prototxt | 92 +++++++++++++ .../onnx/dropout7_no_return_mask.prototxt | 51 +++++++ .../max_pool_with_indices_output.prototxt | 94 +++++++++++++ ngraph/test/onnx/onnx_import.in.cpp | 115 ++++++++++++++++ .../test/onnx/onnx_import_dyn_shapes.in.cpp | 25 ++++ 23 files changed, 1120 insertions(+), 59 deletions(-) create mode 100644 ngraph/frontend/onnx_import/src/op/dropout.cpp create mode 100644 ngraph/test/models/onnx/dropout12_no_training_no_return_mask.prototxt create mode 100644 ngraph/test/models/onnx/dropout12_no_training_return_mask.prototxt create mode 100644 ngraph/test/models/onnx/dropout12_no_traning_no_const_rato.prototxt create mode 100644 ngraph/test/models/onnx/dropout12_not_const_training_mode.prototxt create mode 100644 ngraph/test/models/onnx/dropout12_training_mode.prototxt create mode 100644 ngraph/test/models/onnx/dropout1_no_training_no_return_mask.prototxt create mode 100644 ngraph/test/models/onnx/dropout1_no_training_return_mask.prototxt create mode 100644 ngraph/test/models/onnx/dropout7_no_return_mask.prototxt create mode 100644 ngraph/test/models/onnx/dynamic_shapes/max_pool_with_indices_output.prototxt diff --git a/inference-engine/src/transformations/src/transformations/convert_precision.cpp b/inference-engine/src/transformations/src/transformations/convert_precision.cpp index 0492bd4a9759d6..460ffcd028779d 100644 --- a/inference-engine/src/transformations/src/transformations/convert_precision.cpp +++ b/inference-engine/src/transformations/src/transformations/convert_precision.cpp @@ -404,9 +404,7 @@ bool fuse_type_to_constant(std::shared_ptr & node, element::Type to, const } new_const->validate_and_infer_types(); - if (constant->get_output_target_inputs(0).size() == consumers.size()) { - new_const->set_friendly_name(constant->get_friendly_name()); - } + new_const->set_friendly_name(constant->get_friendly_name()); } return false; } diff --git a/inference-engine/tests/functional/inference_engine/transformations/convert_precision.cpp b/inference-engine/tests/functional/inference_engine/transformations/convert_precision.cpp index d67d67d29ae932..35d63367718028 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/convert_precision.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/convert_precision.cpp @@ -565,8 +565,10 @@ TEST(TransformationTests, ConvertPrecision_Variables) { template void constant_convert_test(element::Type_t type_from, element::Type_t type_to, From value, To expected) { std::shared_ptr f(nullptr); + std::string expected_friendly_name; { auto c = opset4::Constant::create(type_from, Shape{}, {value}); + expected_friendly_name = c->get_friendly_name(); f = std::make_shared(NodeVector{c}, ParameterVector{}); pass::Manager manager; @@ -576,6 +578,7 @@ void constant_convert_test(element::Type_t type_from, element::Type_t type_to, F auto ops = f->get_ordered_ops(); auto c = std::dynamic_pointer_cast(ops[0]); ASSERT_NE(c, nullptr); + ASSERT_EQ(c->get_friendly_name(), expected_friendly_name); auto actual = c->cast_vector()[0]; ASSERT_EQ(expected, actual); @@ -622,3 +625,8 @@ TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32MaxToI32) { TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32ToI32) { constant_convert_test(element::Type_t::u32, element::Type_t::i32, 42, 42); } + +TEST(TransformationTests, ConvertPrecision_ConstantConversion_BoolToU8) { + constant_convert_test(element::Type_t::boolean, element::Type_t::u8, true, 1); + constant_convert_test(element::Type_t::boolean, element::Type_t::u8, false, 0); +} diff --git a/ngraph/frontend/onnx_import/src/core/graph.cpp b/ngraph/frontend/onnx_import/src/core/graph.cpp index a9543e440e51ad..2d2f1e77ade374 100644 --- a/ngraph/frontend/onnx_import/src/core/graph.cpp +++ b/ngraph/frontend/onnx_import/src/core/graph.cpp @@ -223,7 +223,11 @@ namespace ngraph OutputVector results; for (const auto& output : m_graph_proto->output()) { - results.emplace_back(get_ng_node_from_cache(output.name())); + const auto& ng_output = get_ng_node_from_cache(output.name()); + if (!ngraph::op::is_null(ng_output)) // ignore optional outputs + { + results.emplace_back(ng_output); + } } return results; } diff --git a/ngraph/frontend/onnx_import/src/op/dropout.cpp b/ngraph/frontend/onnx_import/src/op/dropout.cpp new file mode 100644 index 00000000000000..fd4f8f93f133ef --- /dev/null +++ b/ngraph/frontend/onnx_import/src/op/dropout.cpp @@ -0,0 +1,107 @@ +//***************************************************************************** +// Copyright 2017-2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include + +#include "core/null_node.hpp" +#include "default_opset.hpp" +#include "exceptions.hpp" +#include "ngraph/log.hpp" +#include "ngraph/node.hpp" +#include "op/dropout.hpp" + +namespace ngraph +{ + namespace onnx_import + { + namespace op + { + namespace + { + OutputVector build_dropout(const Node& node, bool training_mode) + { + CHECK_VALID_NODE( + node, !training_mode, "Training mode is not supported for Dropout op"); + + const auto input_data = node.get_ng_inputs().at(0); + const bool return_mask = node.get_outputs_size() > 1; + + if (return_mask) + { + const auto mask = std::make_shared( + default_opset::Constant::create( + ngraph::element::boolean, Shape{}, {true}), + std::make_shared(input_data)); + return {input_data, mask}; + } + else + { + return {input_data}; + } + } + } + + namespace set_12 + { + OutputVector dropout(const Node& node) + { + const auto ng_inputs = node.get_ng_inputs(); + // seed attribute and ratio input are ignored because traning mode is not + // supported anyway + bool training_mode = false; // default value + if (ng_inputs.size() > 2 && !ngraph::op::is_null(ng_inputs.at(2))) + { + CHECK_VALID_NODE( + node, + ngraph::op::is_constant(ng_inputs.at(2).get_node_shared_ptr()), + "Non-constant training_mode input is not supported."); + training_mode = as_type_ptr( + ng_inputs.at(2).get_node_shared_ptr()) + ->cast_vector()[0]; + } + return build_dropout(node, training_mode); + } + } // namespace set_12 + + namespace set_7 + { + OutputVector dropout(const Node& node) + { + // "is_test" attribute was removed + // ratio attribute is ignored because traning mode is not supported + const bool training_mode = false; + + return build_dropout(node, training_mode); + } + } // namespace set_7 + + namespace set_1 + { + OutputVector dropout(const Node& node) + { + // legacy consumed_inputs attribute ignored + // ratio attribute is ignored because traning mode is not supported + const bool training_mode = !node.get_attribute_value("is_test", 0); + + return build_dropout(node, training_mode); + } + } // namespace set_1 + + } // namespace op + + } // namespace onnx_import + +} // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/op/dropout.hpp b/ngraph/frontend/onnx_import/src/op/dropout.hpp index 5b6ee83b731d5e..a1a083cb71912f 100644 --- a/ngraph/frontend/onnx_import/src/op/dropout.hpp +++ b/ngraph/frontend/onnx_import/src/op/dropout.hpp @@ -16,10 +16,6 @@ #pragma once -#include - -#include "core/null_node.hpp" -#include "ngraph/node.hpp" #include "onnx_import/core/node.hpp" namespace ngraph @@ -28,15 +24,19 @@ namespace ngraph { namespace op { + namespace set_12 + { + OutputVector dropout(const Node& node); + } // namespace set_12 + + namespace set_7 + { + OutputVector dropout(const Node& node); + } // namespace set_7 + namespace set_1 { - inline OutputVector dropout(const Node& node) - { - // First value is actual output of Dropout, - // the second one is just a placeholder for optional trailing output. - return {node.get_ng_inputs().at(0).get_node_shared_ptr(), - std::make_shared()}; - } + OutputVector dropout(const Node& node); } // namespace set_1 } // namespace op diff --git a/ngraph/frontend/onnx_import/src/op/max_pool.cpp b/ngraph/frontend/onnx_import/src/op/max_pool.cpp index c53c4cea6bb207..3533c30b577dbc 100644 --- a/ngraph/frontend/onnx_import/src/op/max_pool.cpp +++ b/ngraph/frontend/onnx_import/src/op/max_pool.cpp @@ -17,6 +17,7 @@ #include #include "core/null_node.hpp" +#include "ngraph/log.hpp" #include "ngraph/op/max_pool.hpp" #include "op/max_pool.hpp" #include "utils/pooling_factory.hpp" @@ -31,6 +32,11 @@ namespace ngraph { OutputVector max_pool(const Node& node) { + if (node.get_outputs_size() > 1) + { + NGRAPH_WARN + << "Indices output is not supported for MaxPooling and was ignored"; + } auto max_pool = pooling::PoolingFactory(node).make_max_pool(); max_pool.emplace_back(std::make_shared()); // Indices (optional) return max_pool; diff --git a/ngraph/frontend/onnx_import/src/ops_bridge.cpp b/ngraph/frontend/onnx_import/src/ops_bridge.cpp index d276ac4ed39e6d..5a580b82eeb37d 100644 --- a/ngraph/frontend/onnx_import/src/ops_bridge.cpp +++ b/ngraph/frontend/onnx_import/src/ops_bridge.cpp @@ -343,6 +343,8 @@ namespace ngraph REGISTER_OPERATOR("Div", 1, div); REGISTER_OPERATOR("Div", 7, div); REGISTER_OPERATOR("Dropout", 1, dropout); + REGISTER_OPERATOR("Dropout", 7, dropout); + REGISTER_OPERATOR("Dropout", 12, dropout); REGISTER_OPERATOR("Elu", 1, elu); REGISTER_OPERATOR("Equal", 1, equal); REGISTER_OPERATOR("Erf", 1, erf); diff --git a/ngraph/python/tests/__init__.py b/ngraph/python/tests/__init__.py index 72d32983a51bd6..410ac724d48ab0 100644 --- a/ngraph/python/tests/__init__.py +++ b/ngraph/python/tests/__init__.py @@ -107,8 +107,6 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True): xfail_issue_38699 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:" "ai.onnx.preview.training.Gradient") xfail_issue_38701 = xfail_test(reason="RuntimeError: unsupported element type: STRING") -xfail_issue_38705 = xfail_test(reason="IndexError: deque::_M_range_check: __n (which is 0)" - ">= this->size() (which is 0)") xfail_issue_38706 = xfail_test(reason="RuntimeError: output_3.0 has zero dimension which is not allowed") xfail_issue_38707 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:" "SoftmaxCrossEntropyLoss") @@ -152,7 +150,7 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True): "ai.onnx.preview.training.Adagrad") xfail_issue_38736 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:" "NegativeLogLikelihoodLoss") -xfail_issue_45177 = xfail_test(reason="RuntimeError: axes has zero dimension which is not allowed") +xfail_issue_48052 = xfail_test(reason="Dropout op is not supported in traning mode") xfail_issue_45180 = xfail_test(reason="RuntimeError: Unsupported dynamic op: ReduceSum") xfail_issue_44839 = xfail_test(reason="Huge computation missmatch") xfail_issue_44848 = xfail_test(reason="E Unsupported dynamic op: Range") @@ -176,6 +174,7 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True): xfail_issue_47330 = xfail_test(reason="RuntimeError: Eltwise node with name `[name]` doesn't support " "FP64 precision.") xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot") +xfail_issue_33593 = xfail_test(reason="Current implementation of MaxPool doesn't support indices output") # Model MSFT issues: xfail_issue_37957 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:" diff --git a/ngraph/python/tests/runtime.py b/ngraph/python/tests/runtime.py index 5397d7874b34f2..aadbe0c96d9342 100644 --- a/ngraph/python/tests/runtime.py +++ b/ngraph/python/tests/runtime.py @@ -18,10 +18,10 @@ from typing import Dict, List, Union import numpy as np -from openvino.inference_engine import IECore, IENetwork, Blob +from openvino.inference_engine import IECore, IENetwork, Blob, DataPtr from ngraph.exceptions import UserInputError -from ngraph.impl import Function, Node, PartialShape +from ngraph.impl import Function, Node, PartialShape, Type from ngraph.opset1.ops import result from ngraph.utils.types import NumericData, get_shape, get_dtype @@ -55,6 +55,18 @@ def _convert_inputs(cnn_network: IENetwork) -> None: pass +def apply_ng_type(output: DataPtr, ng_type: Type): + ng_ie_supported_type_map = { + Type.boolean.get_type_name(): "BOOL", + Type.f32.get_type_name(): "FP32", + Type.i8.get_type_name(): "I8", + Type.i32.get_type_name(): "I32", + Type.u8.get_type_name(): "U8", + } + if ng_type.get_type_name() in ng_ie_supported_type_map: + output.precision = ng_ie_supported_type_map[ng_type.get_type_name()] + + class Runtime(object): """Represents an nGraph runtime environment.""" @@ -103,18 +115,30 @@ def __repr__(self) -> str: params_string = ", ".join([param.name for param in self.parameters]) return "".format(self.function.get_name(), params_string) - def __get_ie_output_blob_buffer(self, output_blobs: Dict[str, Blob], ng_result: result) -> np.ndarray: + def __get_ie_output_blob_name(self, outputs: Dict, ng_result: result) -> str: if len(self.results) == 1: - return next(iter(output_blobs.values())).buffer + return next(iter(outputs.keys())) else: prev_layer = ng_result.input(0).get_source_output() out_name = prev_layer.get_node().get_friendly_name() if prev_layer.get_node().get_output_size() != 1: out_name += "." + str(prev_layer.get_index()) - return output_blobs[out_name].buffer + return out_name + + def __get_ie_output_blob_buffer(self, output_blobs: Dict[str, Blob], ng_result: result) -> np.ndarray: + out_name = self.__get_ie_output_blob_name(output_blobs, ng_result) + return output_blobs[out_name].buffer def __call__(self, *input_values: NumericData) -> List[NumericData]: """Run computation on input values and return result.""" + # Input validation + if len(input_values) < len(self.parameters): + raise UserInputError( + "Expected %s params, received not enough %s values.", len(self.parameters), len(input_values) + ) + # ignore not needed input values + input_values = input_values[:len(self.parameters)] + input_values = [np.array(input_value) for input_value in input_values] input_shapes = [get_shape(input_value) for input_value in input_values] @@ -131,13 +155,13 @@ def __call__(self, *input_values: NumericData) -> List[NumericData]: else: cnn_network = self.network_cache[str(input_shapes)] + # set output blobs precission based on nG results + for ng_result in self.results: + ie_out_name = self.__get_ie_output_blob_name(cnn_network.outputs, ng_result) + apply_ng_type(cnn_network.outputs[ie_out_name], ng_result.get_output_element_type(0)) + executable_network = self.runtime.backend.load_network(cnn_network, self.runtime.backend_name) - # Input validation - if len(input_values) != len(self.parameters): - raise UserInputError( - "Expected %s parameters, received %s.", len(self.parameters), len(input_values) - ) for parameter, input in zip(self.parameters, input_values): parameter_shape = parameter.get_output_partial_shape(0) input_shape = PartialShape(input.shape) diff --git a/ngraph/python/tests/test_ngraph/test_basic.py b/ngraph/python/tests/test_ngraph/test_basic.py index 16240bccdbe439..3c234061edc80f 100644 --- a/ngraph/python/tests/test_ngraph/test_basic.py +++ b/ngraph/python/tests/test_ngraph/test_basic.py @@ -80,15 +80,15 @@ def test_simple_computation_on_ndarrays(dtype): value_a = np.array([[1, 2], [3, 4]], dtype=dtype) value_b = np.array([[5, 6], [7, 8]], dtype=dtype) - value_c = np.array([[9, 10], [11, 12]], dtype=dtype) + value_c = np.array([[2, 3], [4, 5]], dtype=dtype) result = computation(value_a, value_b, value_c) - assert np.allclose(result, np.array([[54, 80], [110, 144]], dtype=dtype)) + assert np.allclose(result, np.array([[12, 24], [40, 60]], dtype=dtype)) - value_a = np.array([[13, 14], [15, 16]], dtype=dtype) - value_b = np.array([[17, 18], [19, 20]], dtype=dtype) - value_c = np.array([[21, 22], [23, 24]], dtype=dtype) + value_a = np.array([[9, 10], [11, 12]], dtype=dtype) + value_b = np.array([[13, 14], [15, 16]], dtype=dtype) + value_c = np.array([[5, 4], [3, 2]], dtype=dtype) result = computation(value_a, value_b, value_c) - assert np.allclose(result, np.array([[630, 704], [782, 864]], dtype=dtype)) + assert np.allclose(result, np.array([[110, 96], [78, 56]], dtype=dtype)) def test_serialization(): diff --git a/ngraph/python/tests/test_ngraph/test_ops_reshape.py b/ngraph/python/tests/test_ngraph/test_ops_reshape.py index f0aa63bd4160a7..b74658fcd0b086 100644 --- a/ngraph/python/tests/test_ngraph/test_ops_reshape.py +++ b/ngraph/python/tests/test_ngraph/test_ops_reshape.py @@ -223,7 +223,6 @@ def test_reshape_v1(): assert np.allclose(result, expected) -@xfail_issue_40957 def test_shape_of(): input_tensor = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32) diff --git a/ngraph/python/tests/test_onnx/test_backend.py b/ngraph/python/tests/test_onnx/test_backend.py index 6d4072da03248c..090339387c9c0f 100644 --- a/ngraph/python/tests/test_onnx/test_backend.py +++ b/ngraph/python/tests/test_onnx/test_backend.py @@ -44,7 +44,6 @@ xfail_issue_38701, xfail_issue_33595, xfail_issue_33651, - xfail_issue_38705, xfail_issue_38706, xfail_issue_38736, xfail_issue_38707, @@ -69,7 +68,6 @@ xfail_issue_38732, xfail_issue_38734, xfail_issue_38735, - xfail_issue_45177, xfail_issue_45180, xfail_issue_43742, xfail_issue_44839, @@ -89,6 +87,8 @@ xfail_issue_47317, xfail_issue_47323, xfail_issue_47330, + xfail_issue_48052, + xfail_issue_33593, xfail_issue_47337) @@ -198,14 +198,6 @@ def expect_fail(test_case_path, xfail): # type: (str) -> None "OnnxBackendNodeModelTest.test_constant_cpu", "OnnxBackendNodeModelTest.test_eyelike_populate_off_main_diagonal_cpu", "OnnxBackendNodeModelTest.test_eyelike_without_dtype_cpu", - "OnnxBackendNodeModelTest.test_shape_cpu", - "OnnxBackendNodeModelTest.test_shape_example_cpu", - "OnnxBackendNodeModelTest.test_size_cpu", - "OnnxBackendNodeModelTest.test_size_example_cpu", - "OnnxBackendNodeModelTest.test_dropout_default_ratio_cpu", - "OnnxBackendNodeModelTest.test_training_dropout_default_cpu", - "OnnxBackendNodeModelTest.test_training_dropout_zero_ratio_cpu", - "OnnxBackendNodeModelTest.test_training_dropout_cpu", "OnnxBackendNodeModelTest.test_eyelike_with_dtype_cpu"), (xfail_issue_35915, "OnnxBackendNodeModelTest.test_min_uint8_cpu"), @@ -287,14 +279,6 @@ def expect_fail(test_case_path, xfail): # type: (str) -> None "OnnxBackendNodeModelTest.test_tfidfvectorizer_tf_only_bigrams_skip0_cpu", "OnnxBackendNodeModelTest.test_tfidfvectorizer_tf_batch_uniandbigrams_skip5_cpu", "OnnxBackendNodeModelTest.test_tfidfvectorizer_tf_onlybigrams_skip5_cpu"), - (xfail_issue_38705, - "OnnxBackendNodeModelTest.test_training_dropout_mask_cpu", - "OnnxBackendNodeModelTest.test_training_dropout_default_mask_cpu", - "OnnxBackendNodeModelTest.test_training_dropout_zero_ratio_mask_cpu", - "OnnxBackendNodeModelTest.test_maxpool_with_argmax_2d_precomputed_strides_cpu", - "OnnxBackendNodeModelTest.test_maxpool_with_argmax_2d_precomputed_pads_cpu", - "OnnxBackendNodeModelTest.test_dropout_default_mask_cpu", - "OnnxBackendNodeModelTest.test_dropout_default_mask_ratio_cpu"), (xfail_issue_38706, "OnnxBackendNodeModelTest.test_split_zero_size_splits_cpu"), (xfail_issue_38736, @@ -618,12 +602,13 @@ def expect_fail(test_case_path, xfail): # type: (str) -> None (xfail_issue_38735, "OnnxBackendNodeModelTest.test_adagrad_multiple_cpu", "OnnxBackendNodeModelTest.test_adagrad_cpu"), - (xfail_issue_45177, - "OnnxBackendNodeModelTest.test_reduce_sum_default_axes_keepdims_example_cpu", - "OnnxBackendNodeModelTest.test_reduce_sum_default_axes_keepdims_random_cpu", - "OnnxBackendNodeModelTest.test_reduce_sum_empty_axes_input_noop_example_cpu", - "OnnxBackendNodeModelTest.test_reduce_sum_empty_axes_input_noop_random_cpu", - "OnnxBackendNodeModelTest.test_reduce_sum_negative_axes_keepdims_random_cpu"), + (xfail_issue_48052, + "OnnxBackendNodeModelTest.test_training_dropout_cpu", + "OnnxBackendNodeModelTest.test_training_dropout_mask_cpu", + "OnnxBackendNodeModelTest.test_training_dropout_default_cpu", + "OnnxBackendNodeModelTest.test_training_dropout_zero_ratio_cpu", + "OnnxBackendNodeModelTest.test_training_dropout_default_mask_cpu", + "OnnxBackendNodeModelTest.test_training_dropout_zero_ratio_mask_cpu"), (xfail_issue_45180, "OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_example_cpu", "OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_random_cpu", @@ -682,6 +667,9 @@ def expect_fail(test_case_path, xfail): # type: (str) -> None "OnnxBackendNodeModelTest.test_squeeze_negative_axes_cpu",), (xfail_issue_44976, "OnnxBackendNodeModelTest.test_quantizelinear_axis_cpu",), + (xfail_issue_33593, + "OnnxBackendNodeModelTest.test_maxpool_with_argmax_2d_precomputed_strides_cpu", + "OnnxBackendNodeModelTest.test_maxpool_with_argmax_2d_precomputed_pads_cpu",) ] for test_group in tests_expected_to_fail: diff --git a/ngraph/test/models/onnx/dropout12_no_training_no_return_mask.prototxt b/ngraph/test/models/onnx/dropout12_no_training_no_return_mask.prototxt new file mode 100644 index 00000000000000..e28fb21b4e7bb5 --- /dev/null +++ b/ngraph/test/models/onnx/dropout12_no_training_no_return_mask.prototxt @@ -0,0 +1,56 @@ +ir_version: 7 +producer_name: "backend-test" +graph { + node { + input: "x" + output: "y" + op_type: "Dropout" + attribute { + name: "seed" + i: 0 + type: INT + } + } + name: "test_dropout_default_mask" + input { + name: "x" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } +} +opset_import { + version: 12 +} diff --git a/ngraph/test/models/onnx/dropout12_no_training_return_mask.prototxt b/ngraph/test/models/onnx/dropout12_no_training_return_mask.prototxt new file mode 100644 index 00000000000000..51046ebb8f4636 --- /dev/null +++ b/ngraph/test/models/onnx/dropout12_no_training_return_mask.prototxt @@ -0,0 +1,86 @@ +ir_version: 7 +producer_name: "backend-test" +graph { + node { + input: "x" + output: "y" + output: "z" + op_type: "Dropout" + attribute { + name: "seed" + i: 0 + type: INT + } + } + node { + input: "z" + op_type: "Cast" + output: "z_out" + attribute { + name: "to" + i: 6 + type: INT + } + } + name: "test_dropout_default_mask" + input { + name: "x" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "z_out" + type { + tensor_type { + elem_type: 9 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } +} +opset_import { + version: 12 +} diff --git a/ngraph/test/models/onnx/dropout12_no_traning_no_const_rato.prototxt b/ngraph/test/models/onnx/dropout12_no_traning_no_const_rato.prototxt new file mode 100644 index 00000000000000..a286df855933cc --- /dev/null +++ b/ngraph/test/models/onnx/dropout12_no_traning_no_const_rato.prototxt @@ -0,0 +1,101 @@ +ir_version: 7 +producer_name: "onnx-importer-test" +graph { + node { + output: "N" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 1 + float_data: 1.0 + name: "const_tensor_N" + } + type: TENSOR + } + } + node { + output: "T" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 9 + int32_data: 0 + } + type: TENSOR + } + } + node { + input: "X" + output: "A" + op_type: "Relu" + } + node { + input: "A" + input: "N" + output: "B" + op_type: "Pow" + } + node { + input: "B" + input: "R" + input: "T" + output: "C" + op_type: "Dropout" + } + node { + input: "C" + output: "Y" + op_type: "Relu" + } + name: "test-model" + input { + name: "X" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "R" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } + output { + name: "Y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + domain: "" + version: 12 +} diff --git a/ngraph/test/models/onnx/dropout12_not_const_training_mode.prototxt b/ngraph/test/models/onnx/dropout12_not_const_training_mode.prototxt new file mode 100644 index 00000000000000..780f8f3d7246c8 --- /dev/null +++ b/ngraph/test/models/onnx/dropout12_not_const_training_mode.prototxt @@ -0,0 +1,122 @@ +ir_version: 7 +producer_name: "backend-test" +graph { + node { + input: "x" + input: "ratio" + input: "training_mode" + output: "y" + output: "z" + op_type: "Dropout" + attribute { + name: "seed" + i: 0 + type: INT + } + } + name: "test_dropout_default_mask" + input { + name: "x" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + input { + name: "ratio" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + input { + name: "training_mode" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "z" + type { + tensor_type { + elem_type: 9 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } +initializer { + dims: 1 + data_type: 1 + float_data: 3 + name: "ratio" +} +} +opset_import { + version: 12 +} diff --git a/ngraph/test/models/onnx/dropout12_training_mode.prototxt b/ngraph/test/models/onnx/dropout12_training_mode.prototxt new file mode 100644 index 00000000000000..518a0e1af4fff7 --- /dev/null +++ b/ngraph/test/models/onnx/dropout12_training_mode.prototxt @@ -0,0 +1,128 @@ +ir_version: 7 +producer_name: "backend-test" +graph { + node { + input: "x" + input: "ratio" + input: "training_mode" + output: "y" + output: "z" + op_type: "Dropout" + attribute { + name: "seed" + i: 0 + type: INT + } + } + name: "test_dropout_default_mask" + input { + name: "x" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + input { + name: "ratio" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + input { + name: "training_mode" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "z" + type { + tensor_type { + elem_type: 9 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } +initializer { + dims: 1 + data_type: 1 + float_data: 3 + name: "ratio" +} +initializer { + dims: 1 + data_type: 9 + int32_data: 00000001 + name: "training_mode" +} +} +opset_import { + version: 12 +} diff --git a/ngraph/test/models/onnx/dropout1_no_training_no_return_mask.prototxt b/ngraph/test/models/onnx/dropout1_no_training_no_return_mask.prototxt new file mode 100644 index 00000000000000..9c106663844e9b --- /dev/null +++ b/ngraph/test/models/onnx/dropout1_no_training_no_return_mask.prototxt @@ -0,0 +1,56 @@ +ir_version: 7 +producer_name: "backend-test" +graph { + node { + input: "x" + output: "y" + op_type: "Dropout" + attribute { + name: "is_test" + i: 1 + type: INT + } + } + name: "test_dropout_default_mask" + input { + name: "x" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } +} +opset_import { + version: 1 +} diff --git a/ngraph/test/models/onnx/dropout1_no_training_return_mask.prototxt b/ngraph/test/models/onnx/dropout1_no_training_return_mask.prototxt new file mode 100644 index 00000000000000..abc400dcdd175c --- /dev/null +++ b/ngraph/test/models/onnx/dropout1_no_training_return_mask.prototxt @@ -0,0 +1,92 @@ + +ir_version: 7 +producer_name: "backend-test" +graph { + node { + input: "x" + output: "y" + output: "z" + op_type: "Dropout" + attribute { + name: "is_test" + i: 1 + type: INT + } + attribute { + name: "ratio" + f: 0.1 + type: FLOAT + } + } + node { + input: "z" + op_type: "Cast" + output: "z_out" + attribute { + name: "to" + i: 6 + type: INT + } + } + name: "test_dropout_default_mask" + input { + name: "x" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "z_out" + type { + tensor_type { + elem_type: 9 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } +} +opset_import { + version: 6 +} diff --git a/ngraph/test/models/onnx/dropout7_no_return_mask.prototxt b/ngraph/test/models/onnx/dropout7_no_return_mask.prototxt new file mode 100644 index 00000000000000..ced7fbca21ea13 --- /dev/null +++ b/ngraph/test/models/onnx/dropout7_no_return_mask.prototxt @@ -0,0 +1,51 @@ +ir_version: 7 +producer_name: "backend-test" +graph { + node { + input: "x" + output: "y" + op_type: "Dropout" + } + name: "test_dropout_default_mask" + input { + name: "x" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } +} +opset_import { + version: 7 +} diff --git a/ngraph/test/models/onnx/dynamic_shapes/max_pool_with_indices_output.prototxt b/ngraph/test/models/onnx/dynamic_shapes/max_pool_with_indices_output.prototxt new file mode 100644 index 00000000000000..6105d792f4582f --- /dev/null +++ b/ngraph/test/models/onnx/dynamic_shapes/max_pool_with_indices_output.prototxt @@ -0,0 +1,94 @@ +ir_version: 3 +producer_name: "backend-test" +graph { + node { + input: "x" + output: "y" + output: "z" + op_type: "MaxPool" + attribute { + name: "kernel_shape" + ints: 5 + ints: 5 + type: INTS + } + attribute { + name: "pads" + ints: 2 + ints: 2 + ints: 2 + ints: 2 + type: INTS + } + } + name: "test_maxpool_with_argmax_2d_precomputed_pads" + input { + name: "x" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "z" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 5 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/ngraph/test/onnx/onnx_import.in.cpp b/ngraph/test/onnx/onnx_import.in.cpp index a03e49ffa2aaec..de8abb7b56040b 100644 --- a/ngraph/test/onnx/onnx_import.in.cpp +++ b/ngraph/test/onnx/onnx_import.in.cpp @@ -3834,3 +3834,118 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_mvn_v6) 1.2906139, 1.1860244, -0.92945826, 0.0721334, -0.38174, -1.7799333}); test_case.run(); } + +NGRAPH_TEST(${BACKEND_NAME}, onnx_dropout1_no_training_no_return_mask) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/dropout1_no_training_no_return_mask.prototxt")); + + auto test_case = test::TestCase(function); + const std::vector data(3 * 4 * 5, 2.0f); + test_case.add_input(data); + test_case.add_expected_output(Shape{3, 4, 5}, data); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_dropout1_no_training_return_mask) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/dropout1_no_training_return_mask.prototxt")); + + auto test_case = test::TestCase(function); + const std::vector data(3 * 4 * 5, 2.0f); + test_case.add_input(data); + test_case.add_expected_output(Shape{3, 4, 5}, data); + test_case.add_expected_output( + Shape{3, 4, 5}, std::vector(3 * 4 * 5, 1)); // // bool converted to i32 + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_dropout7_no_return_mask) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/dropout7_no_return_mask.prototxt")); + + auto test_case = test::TestCase(function); + const std::vector data(3 * 4 * 5, 2.0f); + test_case.add_input(data); + test_case.add_expected_output(Shape{3, 4, 5}, data); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_dropout12_no_training_no_return_mask) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/dropout12_no_training_no_return_mask.prototxt")); + + auto test_case = test::TestCase(function); + const std::vector data(3 * 4 * 5, 2.0f); + test_case.add_input(data); + test_case.add_expected_output(Shape{3, 4, 5}, data); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_dropout12_no_training_return_mask) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/dropout12_no_training_return_mask.prototxt")); + + auto test_case = test::TestCase(function); + const std::vector data(3 * 4 * 5, 2.0f); + test_case.add_input(data); + test_case.add_expected_output(Shape{3, 4, 5}, data); + test_case.add_expected_output( + Shape{3, 4, 5}, std::vector(3 * 4 * 5, 1)); // bool converted to i32 + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_dropout12_no_traning_no_const_rato) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/dropout12_no_traning_no_const_rato.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_input({1, 2, 3, 4}); + // test_case.add_input(Shape{}, {0.5}); // ratio input is ignored + + test_case.add_expected_output(Shape{1, 4}, {1., 2., 3., 4.}); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_dropout12_training_mode) +{ + try + { + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/dropout12_training_mode.prototxt")); + FAIL() << "Expected exception was not thrown"; + } + catch (const ngraph::ngraph_error& e) + { + EXPECT_HAS_SUBSTRING(e.what(), + std::string("Training mode is not supported for Dropout op")); + } + catch (...) + { + FAIL() << "Expected ngraph_error exception was not thrown"; + } +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_dropout12_not_const_training_mode) +{ + try + { + auto function = onnx_import::import_onnx_model(file_util::path_join( + SERIALIZED_ZOO, "onnx/dropout12_not_const_training_mode.prototxt")); + FAIL() << "Expected exception was not thrown"; + } + catch (const ngraph::ngraph_error& e) + { + EXPECT_HAS_SUBSTRING(e.what(), + std::string("Non-constant training_mode input is not supported.")); + } + catch (...) + { + FAIL() << "Expected ngraph_error exception was not thrown"; + } +} diff --git a/ngraph/test/onnx/onnx_import_dyn_shapes.in.cpp b/ngraph/test/onnx/onnx_import_dyn_shapes.in.cpp index c93026cde92090..10a9882e2f679d 100644 --- a/ngraph/test/onnx/onnx_import_dyn_shapes.in.cpp +++ b/ngraph/test/onnx/onnx_import_dyn_shapes.in.cpp @@ -330,6 +330,31 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_dyn_shapes_max_pool_dyn_shape) test_case.run(); } +NGRAPH_TEST(${BACKEND_NAME}, onnx_dyn_shapes_max_pool_with_indices_output) +{ + const auto function = onnx_import::import_onnx_model(file_util::path_join( + SERIALIZED_ZOO, "onnx/dynamic_shapes/max_pool_with_indices_output.prototxt")); + + auto test_case = test::TestCase(function); + + const Shape shape{1, 1, 5, 5}; + std::vector input_values(shape_size(shape)); + std::iota(input_values.begin(), input_values.end(), 1.f); + + test_case.add_input(shape, input_values); + + std::vector expected_values{13.f, 14.f, 15.f, 15.f, 15.f, 18.f, 19.f, 20.f, 20.f, + 20.f, 23.f, 24.f, 25.f, 25.f, 25.f, 23.f, 24.f, 25.f, + 25.f, 25.f, 23.f, 24.f, 25.f, 25.f, 25.f}; + test_case.add_expected_output(Shape{1, 1, 5, 5}, expected_values); + + // indices output is not supported and is ingored in current implementation + // std::vector expected_indices{12, 13, 14, 14, 14, 17, 18, 19, 19, 19, 22, 23, 24, 24, + // 24, 22, 23, 24, 24, 24, 22, 23, 24, 24, 24}; + // test_case.add_expected_output(Shape{1, 1, 5, 5}, expected_indices); + test_case.run(); +} + NGRAPH_TEST(${BACKEND_NAME}, onnx_dyn_shapes_global_avg_pool_dyn_shape) { const auto function = onnx_import::import_onnx_model(file_util::path_join(