Skip to content

Commit

Permalink
🦒Nonzero-adjustment (openvinotoolkit#5863)
Browse files Browse the repository at this point in the history
  • Loading branch information
tsocha authored and rnugmanx committed Aug 26, 2021
1 parent 4e0d9bd commit 6ea33b6
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/ops/condition/NonZero_3.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ The output tensor has shape `[rank(input), num_non_zero]`. For example, for the

**Types**

* *T*: any numeric type.
* *T*: any type.

* *T_IND*: `int64` or `int32`.

Expand Down
6 changes: 1 addition & 5 deletions ngraph/core/src/op/non_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,7 @@ void op::v3::NonZero::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v3_NonZero_validate_and_infer_types);
const PartialShape& input_shape = get_input_partial_shape(0);
const auto input_et = get_input_element_type(0);

NODE_VALIDATION_CHECK(this,
input_et.is_integral_number() || input_et.is_real(),
"NonZero input data type needs to be a numeric type. Got: ",
input_et);
NODE_VALIDATION_CHECK(this,
m_output_type == element::i64 || m_output_type == element::i32,
"Output type must be i32 or i64");
Expand Down Expand Up @@ -154,6 +149,7 @@ namespace nonzero

switch (input->get_element_type())
{
NGRAPH_TYPE_CASE(evaluate_nonzero, boolean, input, output);
NGRAPH_TYPE_CASE(evaluate_nonzero, i8, input, output);
NGRAPH_TYPE_CASE(evaluate_nonzero, i16, input, output);
NGRAPH_TYPE_CASE(evaluate_nonzero, i32, input, output);
Expand Down
3 changes: 1 addition & 2 deletions ngraph/frontend/onnx_import/src/op/compress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ namespace ngraph
OutputVector compress(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
auto condition = std::make_shared<default_opset::Convert>(
node.get_ng_inputs().at(1), element::u8);
auto condition = node.get_ng_inputs().at(1);

int64_t axis = 0;
if (node.has_attribute("axis"))
Expand Down
4 changes: 0 additions & 4 deletions ngraph/frontend/onnx_import/src/op/non_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ namespace ngraph
OutputVector non_zero(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
if (data.get_element_type() == element::boolean)
{
data = std::make_shared<default_opset::Convert>(data, element::u8);
}
return {std::make_shared<default_opset::NonZero>(data, element::i64)};
}

Expand Down

0 comments on commit 6ea33b6

Please sign in to comment.