diff --git a/docs/ops/condition/NonZero_3.md b/docs/ops/condition/NonZero_3.md index acf75ae0886684..44bd96690ddba4 100644 --- a/docs/ops/condition/NonZero_3.md +++ b/docs/ops/condition/NonZero_3.md @@ -29,7 +29,7 @@ The output tensor has shape `[rank(input), num_non_zero]`. For example, for the **Types** -* *T*: any numeric type. +* *T*: any type. * *T_IND*: `int64` or `int32`. diff --git a/ngraph/core/src/op/non_zero.cpp b/ngraph/core/src/op/non_zero.cpp index 19e52f77fe98db..1e11aad1b2aa59 100644 --- a/ngraph/core/src/op/non_zero.cpp +++ b/ngraph/core/src/op/non_zero.cpp @@ -48,12 +48,7 @@ void op::v3::NonZero::validate_and_infer_types() { NGRAPH_OP_SCOPE(v3_NonZero_validate_and_infer_types); const PartialShape& input_shape = get_input_partial_shape(0); - const auto input_et = get_input_element_type(0); - NODE_VALIDATION_CHECK(this, - input_et.is_integral_number() || input_et.is_real(), - "NonZero input data type needs to be a numeric type. Got: ", - input_et); NODE_VALIDATION_CHECK(this, m_output_type == element::i64 || m_output_type == element::i32, "Output type must be i32 or i64"); @@ -154,6 +149,7 @@ namespace nonzero switch (input->get_element_type()) { + NGRAPH_TYPE_CASE(evaluate_nonzero, boolean, input, output); NGRAPH_TYPE_CASE(evaluate_nonzero, i8, input, output); NGRAPH_TYPE_CASE(evaluate_nonzero, i16, input, output); NGRAPH_TYPE_CASE(evaluate_nonzero, i32, input, output); diff --git a/ngraph/frontend/onnx_import/src/op/compress.cpp b/ngraph/frontend/onnx_import/src/op/compress.cpp index f7658a5e7aafb2..0cb7b77f442fee 100644 --- a/ngraph/frontend/onnx_import/src/op/compress.cpp +++ b/ngraph/frontend/onnx_import/src/op/compress.cpp @@ -19,8 +19,7 @@ namespace ngraph OutputVector compress(const Node& node) { auto data = node.get_ng_inputs().at(0); - auto condition = std::make_shared( - node.get_ng_inputs().at(1), element::u8); + auto condition = node.get_ng_inputs().at(1); int64_t axis = 0; if (node.has_attribute("axis")) diff --git a/ngraph/frontend/onnx_import/src/op/non_zero.cpp b/ngraph/frontend/onnx_import/src/op/non_zero.cpp index 5f580111e0d9aa..1798d12ba1f40a 100644 --- a/ngraph/frontend/onnx_import/src/op/non_zero.cpp +++ b/ngraph/frontend/onnx_import/src/op/non_zero.cpp @@ -18,10 +18,6 @@ namespace ngraph OutputVector non_zero(const Node& node) { auto data = node.get_ng_inputs().at(0); - if (data.get_element_type() == element::boolean) - { - data = std::make_shared(data, element::u8); - } return {std::make_shared(data, element::i64)}; }