Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MulticlassNms/MatrixNms: ngraph python api #6573

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ngraph/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"ngraph.opset5",
"ngraph.opset6",
"ngraph.opset7",
"ngraph.opset8",
"ngraph.utils",
"ngraph.impl",
"ngraph.impl.op",
Expand Down
2 changes: 2 additions & 0 deletions ngraph/python/src/ngraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,13 @@
from ngraph.opset8 import lstm_cell
from ngraph.opset8 import lstm_sequence
from ngraph.opset8 import matmul
from ngraph.opset8 import matrix_nms
from ngraph.opset8 import max_pool
from ngraph.opset8 import maximum
from ngraph.opset8 import minimum
from ngraph.opset8 import mish
from ngraph.opset8 import mod
from ngraph.opset8 import multiclass_nms
from ngraph.opset8 import multiply
from ngraph.opset8 import mvn
from ngraph.opset8 import negative
Expand Down
2 changes: 2 additions & 0 deletions ngraph/python/src/ngraph/opset8/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@
from ngraph.opset4.ops import lstm_cell
from ngraph.opset1.ops import lstm_sequence
from ngraph.opset1.ops import matmul
from ngraph.opset8.ops import matrix_nms
from ngraph.opset1.ops import max_pool
from ngraph.opset1.ops import maximum
from ngraph.opset1.ops import minimum
from ngraph.opset4.ops import mish
from ngraph.opset1.ops import mod
from ngraph.opset8.ops import multiclass_nms
from ngraph.opset1.ops import multiply
from ngraph.opset6.ops import mvn
from ngraph.opset1.ops import negative
Expand Down
117 changes: 117 additions & 0 deletions ngraph/python/src/ngraph/opset8/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,120 @@ def deformable_convolution(
"bilinear_interpolation_pad": bilinear_interpolation_pad
},
)


@nameable_op
def multiclass_nms(
boxes: NodeInput,
scores: NodeInput,
sort_result_type: str = "none",
sort_result_across_batch: bool = False,
output_type: str = "i64",
iou_threshold: float = 0.0,
score_threshold: float = 0.0,
nms_top_k: int = -1,
keep_top_k: int = -1,
background_class: int = -1,
nms_eta: float = 1.0,
normalized: bool = True
) -> Node:
"""Return a node which performs MulticlassNms.

@param boxes: Tensor with box coordinates.
@param scores: Tensor with box scores.
@param sort_result_type: Specifies order of output elements, possible values:
'class': sort selected boxes by class id (ascending)
'score': sort selected boxes by score (descending)
'none': do not guarantee the order.
@param sort_result_across_batch: Specifies whenever it is necessary to sort selected boxes
across batches or not
@param output_type: Specifies the output tensor type, possible values:
'i64', 'i32'
@param iou_threshold: Specifies intersection over union threshold
@param score_threshold: Specifies minimum score to consider box for the processing
@param nms_top_k: Specifies maximum number of boxes to be selected per class, -1 meaning
to keep all boxes
@param keep_top_k: Specifies maximum number of boxes to be selected per batch element, -1
meaning to keep all boxes
@param background_class: Specifies the background class id, -1 meaning to keep all classes
@param nms_eta: Specifies eta parameter for adpative NMS, in close range [0, 1.0]
@param normalized: Specifies whether boxes are normalized or not
@return: The new node which performs MuticlassNms
"""
inputs = as_nodes(boxes, scores)

attributes = {
"sort_result_type": sort_result_type,
"sort_result_across_batch": sort_result_across_batch,
"output_type": output_type,
"iou_threshold": iou_threshold,
"score_threshold": score_threshold,
"nms_top_k": nms_top_k,
"keep_top_k": keep_top_k,
"background_class": background_class,
"nms_eta": nms_eta,
"normalized": normalized
}

return _get_node_factory_opset8().create("MulticlassNms", inputs, attributes)


@nameable_op
def matrix_nms(
boxes: NodeInput,
scores: NodeInput,
sort_result_type: str = "none",
luo-cheng2021 marked this conversation as resolved.
Show resolved Hide resolved
sort_result_across_batch: bool = False,
output_type: str = "i64",
score_threshold: float = 0.0,
nms_top_k: int = -1,
keep_top_k: int = -1,
background_class: int = -1,
decay_function: str = "linear",
gaussian_sigma: float = 2.0,
post_threshold: float = 0.0,
normalized: bool = True
) -> Node:
"""Return a node which performs MatrixNms.

@param boxes: Tensor with box coordinates.
@param scores: Tensor with box scores.
@param sort_result_type: Specifies order of output elements, possible values:
'class': sort selected boxes by class id (ascending)
'score': sort selected boxes by score (descending)
'none': do not guarantee the order.
@param sort_result_across_batch: Specifies whenever it is necessary to sort selected boxes
across batches or not
@param output_type: Specifies the output tensor type, possible values:
'i64', 'i32'
@param score_threshold: Specifies minimum score to consider box for the processing
@param nms_top_k: Specifies maximum number of boxes to be selected per class, -1 meaning
to keep all boxes
@param keep_top_k: Specifies maximum number of boxes to be selected per batch element, -1
meaning to keep all boxes
@param background_class: Specifies the background class id, -1 meaning to keep all classes
@param decay_function: Specifies decay function used to decay scores, possible values:
'gaussian', 'linear'
@param gaussian_sigma: Specifies gaussian_sigma parameter for gaussian decay_function
@param post_threshold: Specifies threshold to filter out boxes with low confidence score
after decaying
@param normalized: Specifies whether boxes are normalized or not
@return: The new node which performs MatrixNms
"""
inputs = as_nodes(boxes, scores)

attributes = {
"sort_result_type": sort_result_type,
"sort_result_across_batch": sort_result_across_batch,
"output_type": output_type,
"score_threshold": score_threshold,
"nms_top_k": nms_top_k,
"keep_top_k": keep_top_k,
"background_class": background_class,
"decay_function": decay_function,
"gaussian_sigma": gaussian_sigma,
"post_threshold": post_threshold,
"normalized": normalized
}

return _get_node_factory_opset8().create("MatrixNms", inputs, attributes)
2 changes: 1 addition & 1 deletion ngraph/python/src/pyngraph/node_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ namespace
return it->second();
}

const ngraph::OpSet& m_opset = ngraph::get_opset7();
const ngraph::OpSet& m_opset = ngraph::get_opset8();
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>> m_variables;
};
} // namespace
Expand Down
52 changes: 51 additions & 1 deletion ngraph/python/tests/test_ngraph/test_create_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import pytest
from _pyngraph import PartialShape
from _pyngraph import PartialShape, Dimension

import ngraph as ng
import ngraph.opset1 as ng_opset1
Expand Down Expand Up @@ -1846,3 +1846,53 @@ def test_rnn_sequence_operator_forward(dtype):

assert node.get_type_name() == "RNNSequence"
assert node.get_output_size() == 2


def test_multiclass_nms():
boxes_data = np.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.1, 1.0, 1.1,
0.0, -0.1, 1.0, 0.9, 0.0, 10.0, 1.0, 11.0,
0.0, 10.1, 1.0, 11.1, 0.0, 100.0, 1.0, 101.0], dtype="float32")
boxes_data = boxes_data.reshape([1, 6, 4])
box = ng.constant(boxes_data, dtype=np.float)
scores_data = np.array([0.9, 0.75, 0.6, 0.95, 0.5, 0.3,
0.95, 0.75, 0.6, 0.80, 0.5, 0.3], dtype="float32")
scores_data = scores_data.reshape([1, 2, 6])
score = ng.constant(scores_data, dtype=np.float)

nms_node = ng.multiclass_nms(box, score, output_type="i32", nms_top_k=3,
iou_threshold=0.5, score_threshold=0.0, sort_result_type="classid",
nms_eta=1.0)

assert nms_node.get_type_name() == "MulticlassNms"
assert nms_node.get_output_size() == 3
assert nms_node.outputs()[0].get_partial_shape() == PartialShape([Dimension(0, 6), Dimension(6)])
assert nms_node.outputs()[1].get_partial_shape() == PartialShape([Dimension(0, 6), Dimension(1)])
assert list(nms_node.outputs()[2].get_shape()) == [1, ]
assert nms_node.get_output_element_type(0) == Type.f32
assert nms_node.get_output_element_type(1) == Type.i32
assert nms_node.get_output_element_type(2) == Type.i32


def test_matrix_nms():
boxes_data = np.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.1, 1.0, 1.1,
0.0, -0.1, 1.0, 0.9, 0.0, 10.0, 1.0, 11.0,
0.0, 10.1, 1.0, 11.1, 0.0, 100.0, 1.0, 101.0], dtype="float32")
boxes_data = boxes_data.reshape([1, 6, 4])
box = ng.constant(boxes_data, dtype=np.float)
scores_data = np.array([0.9, 0.75, 0.6, 0.95, 0.5, 0.3,
0.95, 0.75, 0.6, 0.80, 0.5, 0.3], dtype="float32")
scores_data = scores_data.reshape([1, 2, 6])
score = ng.constant(scores_data, dtype=np.float)

nms_node = ng.matrix_nms(box, score, output_type="i32", nms_top_k=3,
score_threshold=0.0, sort_result_type="score", background_class=0,
decay_function="linear", gaussian_sigma=2.0, post_threshold=0.0)

assert nms_node.get_type_name() == "MatrixNms"
assert nms_node.get_output_size() == 3
assert nms_node.outputs()[0].get_partial_shape() == PartialShape([Dimension(0, 6), Dimension(6)])
assert nms_node.outputs()[1].get_partial_shape() == PartialShape([Dimension(0, 6), Dimension(1)])
assert list(nms_node.outputs()[2].get_shape()) == [1, ]
assert nms_node.get_output_element_type(0) == Type.f32
assert nms_node.get_output_element_type(1) == Type.i32
assert nms_node.get_output_element_type(2) == Type.i32
2 changes: 2 additions & 0 deletions ngraph/test/util/engine/ie_engines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ std::set<NodeTypeInfo> test::IE_Engine::get_ie_ops() const
ie_ops.insert(opset6.begin(), opset6.end());
const auto& opset7 = get_opset7().get_type_info_set();
ie_ops.insert(opset7.begin(), opset7.end());
const auto& opset8 = get_opset8().get_type_info_set();
ie_ops.insert(opset8.begin(), opset8.end());
return ie_ops;
}

Expand Down