Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Handle optional outputs for Dropout and MaxPool #4143

Merged
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1de9ee0
first version of dropout and maxpool impl, added ignoring optinal out…
Feb 1, 2021
e681f32
more tests, impl refactor
Feb 2, 2021
b20a4af
Added tests to dropout in opsets<12
Feb 2, 2021
3ad9eca
added tests for MaxPool
Feb 2, 2021
35dbe66
update xfail list
Feb 2, 2021
d7e4390
move dropout impl to cpp
Feb 2, 2021
9da00f5
fixed is_test bug
Feb 2, 2021
8ef8b5b
added dropout in opset 7
Feb 2, 2021
a767c5d
typo
Feb 2, 2021
f814085
added no const ratio test
Feb 5, 2021
1db819f
Merge remote-tracking branch 'upstream/master' into mbencer/HandleOnn…
Feb 5, 2021
3621d50
remove checking legacy attribute
Feb 8, 2021
7626b23
Merge remote-tracking branch 'upstream/master' into mbencer/HandleOnn…
Feb 8, 2021
ab41cd3
removed not needed code
Feb 8, 2021
053df15
enable default mask path
Feb 8, 2021
cc9a07d
Ignore ratio in training mode
Feb 8, 2021
64a0a3f
update test backend list
Feb 8, 2021
733b3a7
fixed constant bool network, setting precission to output blobs
Feb 10, 2021
7a29227
Merge remote-tracking branch 'upstream/master' into mbencer/HandleOnn…
Feb 10, 2021
869864e
ignore not used test values
Feb 10, 2021
566f7f5
Merge remote-tracking branch 'upstream/master' into mbencer/HandleOnn…
Feb 10, 2021
38f3bb4
Merge remote-tracking branch 'upstream/master' into mbencer/HandleOnn…
Feb 10, 2021
a0478fd
removed check constant->get_output_size()
Feb 10, 2021
0d8c1e9
dropout review remarks
Feb 11, 2021
c316bcc
Merge remote-tracking branch 'upstream/master' into mbencer/HandleOnn…
Feb 11, 2021
c762729
Merge remote-tracking branch 'upstream/master' into mbencer/HandleOnn…
Feb 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ bool fuse_type_to_constant(std::shared_ptr<Node> & node, element::Type to, const
}

new_const->validate_and_infer_types();
if (constant->get_output_target_inputs(0).size() == consumers.size()) {
if (constant->get_output_size() == consumers.size()) {
mbencer marked this conversation as resolved.
Show resolved Hide resolved
new_const->set_friendly_name(constant->get_friendly_name());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,10 @@ TEST(TransformationTests, ConvertPrecision_Variables) {
template <typename From, typename To>
void constant_convert_test(element::Type_t type_from, element::Type_t type_to, From value, To expected) {
std::shared_ptr<ngraph::Function> f(nullptr);
std::string expected_friendly_name;
{
auto c = opset4::Constant::create(type_from, Shape{}, {value});
expected_friendly_name = c->get_friendly_name();
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});

pass::Manager manager;
Expand All @@ -576,6 +578,7 @@ void constant_convert_test(element::Type_t type_from, element::Type_t type_to, F
auto ops = f->get_ordered_ops();
auto c = std::dynamic_pointer_cast<opset4::Constant>(ops[0]);
ASSERT_NE(c, nullptr);
ASSERT_EQ(c->get_friendly_name(), expected_friendly_name);

auto actual = c->cast_vector<To>()[0];
ASSERT_EQ(expected, actual);
Expand Down Expand Up @@ -622,3 +625,8 @@ TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32MaxToI32) {
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32ToI32) {
constant_convert_test(element::Type_t::u32, element::Type_t::i32, 42, 42);
}

TEST(TransformationTests, ConvertPrecision_ConstantConversion_BoolToU8) {
constant_convert_test(element::Type_t::boolean, element::Type_t::u8, true, 1);
constant_convert_test(element::Type_t::boolean, element::Type_t::u8, false, 0);
}
6 changes: 5 additions & 1 deletion ngraph/frontend/onnx_import/src/core/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,11 @@ namespace ngraph
OutputVector results;
for (const auto& output : m_graph_proto->output())
{
results.emplace_back(get_ng_node_from_cache(output.name()));
const auto& ng_output = get_ng_node_from_cache(output.name());
if (!ngraph::op::is_null(ng_output)) // ignore optional outputs
{
results.emplace_back(ng_output);
}
}
return results;
}
Expand Down
110 changes: 110 additions & 0 deletions ngraph/frontend/onnx_import/src/op/dropout.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include <memory>

#include "core/null_node.hpp"
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "op/dropout.hpp"

namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace
{
OutputVector build_dropout(const Node& node, bool training_mode)
{
CHECK_VALID_NODE(
node, !training_mode, "Training mode is not supported for Dropout op");

const auto input_data = node.get_ng_inputs().at(0);
const bool return_mask = node.get_outputs_size() > 1;

if (return_mask)
{
const auto mask = std::make_shared<default_opset::Broadcast>(
default_opset::Constant::create(
ngraph::element::boolean, Shape{}, {true}),
mitruska marked this conversation as resolved.
Show resolved Hide resolved
std::make_shared<default_opset::ShapeOf>(input_data));
return {input_data, mask};
}
else
{
return {input_data};
}
}
}

namespace set_12
{
OutputVector dropout(const Node& node)
{
const auto ng_inputs = node.get_ng_inputs();
// seed attribute and ratio input are ignored because traning mode is not
// supported anyway
bool training_mode = false; // default value
if (ng_inputs.size() > 2)
{
if (!ngraph::op::is_null(ng_inputs.at(2)))
{
mbencer marked this conversation as resolved.
Show resolved Hide resolved
CHECK_VALID_NODE(
node,
ngraph::op::is_constant(ng_inputs.at(2).get_node_shared_ptr()),
"Not constant (or omitted) training_mode input is not supported.");
training_mode = as_type_ptr<default_opset::Constant>(
ng_inputs.at(2).get_node_shared_ptr())
->cast_vector<bool>()[0];
mbencer marked this conversation as resolved.
Show resolved Hide resolved
}
}
return build_dropout(node, training_mode);
}
} // namespace set_12

namespace set_7
{
OutputVector dropout(const Node& node)
{
// "is_test" attribute was removed
// ratio attribute is ignored because traning mode is not supported
const bool training_mode = false;

return build_dropout(node, training_mode);
}
} // namespace set_7

namespace set_1
{
OutputVector dropout(const Node& node)
{
// legacy consumed_inputs attribute ignored
// ratio attribute is ignored because traning mode is not supported
const bool training_mode = !node.get_attribute_value<int64_t>("is_test", 0);

return build_dropout(node, training_mode);
}
} // namespace set_1

} // namespace op

} // namespace onnx_import

} // namespace ngraph
22 changes: 11 additions & 11 deletions ngraph/frontend/onnx_import/src/op/dropout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@

#pragma once

#include <memory>

#include "core/null_node.hpp"
#include "ngraph/node.hpp"
#include "onnx_import/core/node.hpp"

namespace ngraph
Expand All @@ -28,15 +24,19 @@ namespace ngraph
{
namespace op
{
namespace set_12
{
OutputVector dropout(const Node& node);
} // namespace set_12

namespace set_7
{
OutputVector dropout(const Node& node);
} // namespace set_7

namespace set_1
{
inline OutputVector dropout(const Node& node)
{
// First value is actual output of Dropout,
// the second one is just a placeholder for optional trailing output.
return {node.get_ng_inputs().at(0).get_node_shared_ptr(),
std::make_shared<NullNode>()};
}
OutputVector dropout(const Node& node);
} // namespace set_1

} // namespace op
Expand Down
6 changes: 6 additions & 0 deletions ngraph/frontend/onnx_import/src/op/max_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <memory>

#include "core/null_node.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/max_pool.hpp"
#include "op/max_pool.hpp"
#include "utils/pooling_factory.hpp"
Expand All @@ -31,6 +32,11 @@ namespace ngraph
{
OutputVector max_pool(const Node& node)
{
if (node.get_outputs_size() > 1)
{
NGRAPH_WARN
<< "Indices output is not supported for MaxPooling and was ignored";
}
auto max_pool = pooling::PoolingFactory(node).make_max_pool();
max_pool.emplace_back(std::make_shared<NullNode>()); // Indices (optional)
return max_pool;
Expand Down
2 changes: 2 additions & 0 deletions ngraph/frontend/onnx_import/src/ops_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ namespace ngraph
REGISTER_OPERATOR("Div", 1, div);
REGISTER_OPERATOR("Div", 7, div);
REGISTER_OPERATOR("Dropout", 1, dropout);
REGISTER_OPERATOR("Dropout", 7, dropout);
REGISTER_OPERATOR("Dropout", 12, dropout);
REGISTER_OPERATOR("Elu", 1, elu);
REGISTER_OPERATOR("Equal", 1, equal);
REGISTER_OPERATOR("Erf", 1, erf);
Expand Down
5 changes: 2 additions & 3 deletions ngraph/python/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True):
xfail_issue_38699 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
"ai.onnx.preview.training.Gradient")
xfail_issue_38701 = xfail_test(reason="RuntimeError: unsupported element type: STRING")
xfail_issue_38705 = xfail_test(reason="IndexError: deque::_M_range_check: __n (which is 0)"
">= this->size() (which is 0)")
xfail_issue_38706 = xfail_test(reason="RuntimeError: output_3.0 has zero dimension which is not allowed")
xfail_issue_38707 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
"SoftmaxCrossEntropyLoss")
Expand Down Expand Up @@ -152,7 +150,7 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True):
"ai.onnx.preview.training.Adagrad")
xfail_issue_38736 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
"NegativeLogLikelihoodLoss")
xfail_issue_45177 = xfail_test(reason="RuntimeError: axes has zero dimension which is not allowed")
xfail_issue_48052 = xfail_test(reason="Dropout op is not supported in traning mode")
xfail_issue_45180 = xfail_test(reason="RuntimeError: Unsupported dynamic op: ReduceSum")
xfail_issue_44839 = xfail_test(reason="Huge computation missmatch")
xfail_issue_44848 = xfail_test(reason="E Unsupported dynamic op: Range")
Expand All @@ -176,6 +174,7 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True):
xfail_issue_47330 = xfail_test(reason="RuntimeError: Eltwise node with name `[name]` doesn't support "
"FP64 precision.")
xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot")
xfail_issue_33593 = xfail_test(reason="Current implementation of MaxPool doesn't support indices output")

# Model MSFT issues:
xfail_issue_37957 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
Expand Down
44 changes: 34 additions & 10 deletions ngraph/python/tests/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from typing import Dict, List, Union

import numpy as np
from openvino.inference_engine import IECore, IENetwork, Blob
from openvino.inference_engine import IECore, IENetwork, Blob, DataPtr

from ngraph.exceptions import UserInputError
from ngraph.impl import Function, Node, PartialShape
from ngraph.impl import Function, Node, PartialShape, Type
from ngraph.opset1.ops import result
from ngraph.utils.types import NumericData, get_shape, get_dtype

Expand Down Expand Up @@ -55,6 +55,18 @@ def _convert_inputs(cnn_network: IENetwork) -> None:
pass


def apply_ng_type(output: DataPtr, ng_type: Type):
ng_ie_supported_type_map = {
Type.boolean.get_type_name(): "BOOL",
Type.f32.get_type_name(): "FP32",
Type.i8.get_type_name(): "I8",
Type.i32.get_type_name(): "I32",
Type.u8.get_type_name(): "U8",
}
if ng_type.get_type_name() in ng_ie_supported_type_map:
output.precision = ng_ie_supported_type_map[ng_type.get_type_name()]


class Runtime(object):
"""Represents an nGraph runtime environment."""

Expand Down Expand Up @@ -103,18 +115,30 @@ def __repr__(self) -> str:
params_string = ", ".join([param.name for param in self.parameters])
return "<Computation: {}({})>".format(self.function.get_name(), params_string)

def __get_ie_output_blob_buffer(self, output_blobs: Dict[str, Blob], ng_result: result) -> np.ndarray:
def __get_ie_output_blob_name(self, outputs: Dict, ng_result: result) -> str:
if len(self.results) == 1:
return next(iter(output_blobs.values())).buffer
return next(iter(outputs.keys()))
else:
prev_layer = ng_result.input(0).get_source_output()
out_name = prev_layer.get_node().get_friendly_name()
if prev_layer.get_node().get_output_size() != 1:
out_name += "." + str(prev_layer.get_index())
return output_blobs[out_name].buffer
return out_name

def __get_ie_output_blob_buffer(self, output_blobs: Dict[str, Blob], ng_result: result) -> np.ndarray:
out_name = self.__get_ie_output_blob_name(output_blobs, ng_result)
return output_blobs[out_name].buffer

def __call__(self, *input_values: NumericData) -> List[NumericData]:
"""Run computation on input values and return result."""
# Input validation
if len(input_values) < len(self.parameters):
raise UserInputError(
"Expected %s params, received not enough %s values.", len(self.parameters), len(input_values)
)
# ignore not needed input values
input_values = input_values[:len(self.parameters)]

input_values = [np.array(input_value) for input_value in input_values]
input_shapes = [get_shape(input_value) for input_value in input_values]

Expand All @@ -131,13 +155,13 @@ def __call__(self, *input_values: NumericData) -> List[NumericData]:
else:
cnn_network = self.network_cache[str(input_shapes)]

# set output blobs precission based on nG results
for ng_result in self.results:
ie_out_name = self.__get_ie_output_blob_name(cnn_network.outputs, ng_result)
apply_ng_type(cnn_network.outputs[ie_out_name], ng_result.get_output_element_type(0))

executable_network = self.runtime.backend.load_network(cnn_network, self.runtime.backend_name)

# Input validation
if len(input_values) != len(self.parameters):
raise UserInputError(
"Expected %s parameters, received %s.", len(self.parameters), len(input_values)
)
for parameter, input in zip(self.parameters, input_values):
parameter_shape = parameter.get_output_partial_shape(0)
input_shape = PartialShape(input.shape)
Expand Down
12 changes: 6 additions & 6 deletions ngraph/python/tests/test_ngraph/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ def test_simple_computation_on_ndarrays(dtype):

value_a = np.array([[1, 2], [3, 4]], dtype=dtype)
value_b = np.array([[5, 6], [7, 8]], dtype=dtype)
value_c = np.array([[9, 10], [11, 12]], dtype=dtype)
value_c = np.array([[2, 3], [4, 5]], dtype=dtype)
result = computation(value_a, value_b, value_c)
assert np.allclose(result, np.array([[54, 80], [110, 144]], dtype=dtype))
assert np.allclose(result, np.array([[12, 24], [40, 60]], dtype=dtype))

value_a = np.array([[13, 14], [15, 16]], dtype=dtype)
value_b = np.array([[17, 18], [19, 20]], dtype=dtype)
value_c = np.array([[21, 22], [23, 24]], dtype=dtype)
value_a = np.array([[9, 10], [11, 12]], dtype=dtype)
value_b = np.array([[13, 14], [15, 16]], dtype=dtype)
value_c = np.array([[5, 4], [3, 2]], dtype=dtype)
result = computation(value_a, value_b, value_c)
assert np.allclose(result, np.array([[630, 704], [782, 864]], dtype=dtype))
assert np.allclose(result, np.array([[110, 96], [78, 56]], dtype=dtype))


def test_serialization():
Expand Down
1 change: 0 additions & 1 deletion ngraph/python/tests/test_ngraph/test_ops_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def test_reshape_v1():
assert np.allclose(result, expected)


@xfail_issue_40957
def test_shape_of():
input_tensor = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)

Expand Down
Loading