Skip to content

Commit

Permalink
MulticlassNms/MatrixNms: ngraph python api (openvinotoolkit#6573)
Browse files Browse the repository at this point in the history
* nms python api

* fix python code style

* fix python code style

* apply review comments

* apply review comments
  • Loading branch information
luo-cheng2021 authored Jul 14, 2021
1 parent 3710c0e commit 2378593
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 2 deletions.
1 change: 1 addition & 0 deletions ngraph/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"ngraph.opset5",
"ngraph.opset6",
"ngraph.opset7",
"ngraph.opset8",
"ngraph.utils",
"ngraph.impl",
"ngraph.impl.op",
Expand Down
2 changes: 2 additions & 0 deletions ngraph/python/src/ngraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,13 @@
from ngraph.opset8 import lstm_cell
from ngraph.opset8 import lstm_sequence
from ngraph.opset8 import matmul
from ngraph.opset8 import matrix_nms
from ngraph.opset8 import max_pool
from ngraph.opset8 import maximum
from ngraph.opset8 import minimum
from ngraph.opset8 import mish
from ngraph.opset8 import mod
from ngraph.opset8 import multiclass_nms
from ngraph.opset8 import multiply
from ngraph.opset8 import mvn
from ngraph.opset8 import negative
Expand Down
2 changes: 2 additions & 0 deletions ngraph/python/src/ngraph/opset8/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@
from ngraph.opset4.ops import lstm_cell
from ngraph.opset1.ops import lstm_sequence
from ngraph.opset1.ops import matmul
from ngraph.opset8.ops import matrix_nms
from ngraph.opset1.ops import max_pool
from ngraph.opset1.ops import maximum
from ngraph.opset1.ops import minimum
from ngraph.opset4.ops import mish
from ngraph.opset1.ops import mod
from ngraph.opset8.ops import multiclass_nms
from ngraph.opset1.ops import multiply
from ngraph.opset6.ops import mvn
from ngraph.opset1.ops import negative
Expand Down
117 changes: 117 additions & 0 deletions ngraph/python/src/ngraph/opset8/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,120 @@ def deformable_convolution(
"bilinear_interpolation_pad": bilinear_interpolation_pad
},
)


@nameable_op
def multiclass_nms(
boxes: NodeInput,
scores: NodeInput,
sort_result_type: str = "none",
sort_result_across_batch: bool = False,
output_type: str = "i64",
iou_threshold: float = 0.0,
score_threshold: float = 0.0,
nms_top_k: int = -1,
keep_top_k: int = -1,
background_class: int = -1,
nms_eta: float = 1.0,
normalized: bool = True
) -> Node:
"""Return a node which performs MulticlassNms.
@param boxes: Tensor with box coordinates.
@param scores: Tensor with box scores.
@param sort_result_type: Specifies order of output elements, possible values:
'class': sort selected boxes by class id (ascending)
'score': sort selected boxes by score (descending)
'none': do not guarantee the order.
@param sort_result_across_batch: Specifies whenever it is necessary to sort selected boxes
across batches or not
@param output_type: Specifies the output tensor type, possible values:
'i64', 'i32'
@param iou_threshold: Specifies intersection over union threshold
@param score_threshold: Specifies minimum score to consider box for the processing
@param nms_top_k: Specifies maximum number of boxes to be selected per class, -1 meaning
to keep all boxes
@param keep_top_k: Specifies maximum number of boxes to be selected per batch element, -1
meaning to keep all boxes
@param background_class: Specifies the background class id, -1 meaning to keep all classes
@param nms_eta: Specifies eta parameter for adpative NMS, in close range [0, 1.0]
@param normalized: Specifies whether boxes are normalized or not
@return: The new node which performs MuticlassNms
"""
inputs = as_nodes(boxes, scores)

attributes = {
"sort_result_type": sort_result_type,
"sort_result_across_batch": sort_result_across_batch,
"output_type": output_type,
"iou_threshold": iou_threshold,
"score_threshold": score_threshold,
"nms_top_k": nms_top_k,
"keep_top_k": keep_top_k,
"background_class": background_class,
"nms_eta": nms_eta,
"normalized": normalized
}

return _get_node_factory_opset8().create("MulticlassNms", inputs, attributes)


@nameable_op
def matrix_nms(
boxes: NodeInput,
scores: NodeInput,
sort_result_type: str = "none",
sort_result_across_batch: bool = False,
output_type: str = "i64",
score_threshold: float = 0.0,
nms_top_k: int = -1,
keep_top_k: int = -1,
background_class: int = -1,
decay_function: str = "linear",
gaussian_sigma: float = 2.0,
post_threshold: float = 0.0,
normalized: bool = True
) -> Node:
"""Return a node which performs MatrixNms.
@param boxes: Tensor with box coordinates.
@param scores: Tensor with box scores.
@param sort_result_type: Specifies order of output elements, possible values:
'class': sort selected boxes by class id (ascending)
'score': sort selected boxes by score (descending)
'none': do not guarantee the order.
@param sort_result_across_batch: Specifies whenever it is necessary to sort selected boxes
across batches or not
@param output_type: Specifies the output tensor type, possible values:
'i64', 'i32'
@param score_threshold: Specifies minimum score to consider box for the processing
@param nms_top_k: Specifies maximum number of boxes to be selected per class, -1 meaning
to keep all boxes
@param keep_top_k: Specifies maximum number of boxes to be selected per batch element, -1
meaning to keep all boxes
@param background_class: Specifies the background class id, -1 meaning to keep all classes
@param decay_function: Specifies decay function used to decay scores, possible values:
'gaussian', 'linear'
@param gaussian_sigma: Specifies gaussian_sigma parameter for gaussian decay_function
@param post_threshold: Specifies threshold to filter out boxes with low confidence score
after decaying
@param normalized: Specifies whether boxes are normalized or not
@return: The new node which performs MatrixNms
"""
inputs = as_nodes(boxes, scores)

attributes = {
"sort_result_type": sort_result_type,
"sort_result_across_batch": sort_result_across_batch,
"output_type": output_type,
"score_threshold": score_threshold,
"nms_top_k": nms_top_k,
"keep_top_k": keep_top_k,
"background_class": background_class,
"decay_function": decay_function,
"gaussian_sigma": gaussian_sigma,
"post_threshold": post_threshold,
"normalized": normalized
}

return _get_node_factory_opset8().create("MatrixNms", inputs, attributes)
2 changes: 1 addition & 1 deletion ngraph/python/src/pyngraph/node_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ namespace
return it->second();
}

const ngraph::OpSet& m_opset = ngraph::get_opset7();
const ngraph::OpSet& m_opset = ngraph::get_opset8();
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>> m_variables;
};
} // namespace
Expand Down
52 changes: 51 additions & 1 deletion ngraph/python/tests/test_ngraph/test_create_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import pytest
from _pyngraph import PartialShape
from _pyngraph import PartialShape, Dimension

import ngraph as ng
import ngraph.opset1 as ng_opset1
Expand Down Expand Up @@ -1846,3 +1846,53 @@ def test_rnn_sequence_operator_forward(dtype):

assert node.get_type_name() == "RNNSequence"
assert node.get_output_size() == 2


def test_multiclass_nms():
boxes_data = np.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.1, 1.0, 1.1,
0.0, -0.1, 1.0, 0.9, 0.0, 10.0, 1.0, 11.0,
0.0, 10.1, 1.0, 11.1, 0.0, 100.0, 1.0, 101.0], dtype="float32")
boxes_data = boxes_data.reshape([1, 6, 4])
box = ng.constant(boxes_data, dtype=np.float)
scores_data = np.array([0.9, 0.75, 0.6, 0.95, 0.5, 0.3,
0.95, 0.75, 0.6, 0.80, 0.5, 0.3], dtype="float32")
scores_data = scores_data.reshape([1, 2, 6])
score = ng.constant(scores_data, dtype=np.float)

nms_node = ng.multiclass_nms(box, score, output_type="i32", nms_top_k=3,
iou_threshold=0.5, score_threshold=0.0, sort_result_type="classid",
nms_eta=1.0)

assert nms_node.get_type_name() == "MulticlassNms"
assert nms_node.get_output_size() == 3
assert nms_node.outputs()[0].get_partial_shape() == PartialShape([Dimension(0, 6), Dimension(6)])
assert nms_node.outputs()[1].get_partial_shape() == PartialShape([Dimension(0, 6), Dimension(1)])
assert list(nms_node.outputs()[2].get_shape()) == [1, ]
assert nms_node.get_output_element_type(0) == Type.f32
assert nms_node.get_output_element_type(1) == Type.i32
assert nms_node.get_output_element_type(2) == Type.i32


def test_matrix_nms():
boxes_data = np.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.1, 1.0, 1.1,
0.0, -0.1, 1.0, 0.9, 0.0, 10.0, 1.0, 11.0,
0.0, 10.1, 1.0, 11.1, 0.0, 100.0, 1.0, 101.0], dtype="float32")
boxes_data = boxes_data.reshape([1, 6, 4])
box = ng.constant(boxes_data, dtype=np.float)
scores_data = np.array([0.9, 0.75, 0.6, 0.95, 0.5, 0.3,
0.95, 0.75, 0.6, 0.80, 0.5, 0.3], dtype="float32")
scores_data = scores_data.reshape([1, 2, 6])
score = ng.constant(scores_data, dtype=np.float)

nms_node = ng.matrix_nms(box, score, output_type="i32", nms_top_k=3,
score_threshold=0.0, sort_result_type="score", background_class=0,
decay_function="linear", gaussian_sigma=2.0, post_threshold=0.0)

assert nms_node.get_type_name() == "MatrixNms"
assert nms_node.get_output_size() == 3
assert nms_node.outputs()[0].get_partial_shape() == PartialShape([Dimension(0, 6), Dimension(6)])
assert nms_node.outputs()[1].get_partial_shape() == PartialShape([Dimension(0, 6), Dimension(1)])
assert list(nms_node.outputs()[2].get_shape()) == [1, ]
assert nms_node.get_output_element_type(0) == Type.f32
assert nms_node.get_output_element_type(1) == Type.i32
assert nms_node.get_output_element_type(2) == Type.i32
2 changes: 2 additions & 0 deletions ngraph/test/util/engine/ie_engines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ std::set<NodeTypeInfo> test::IE_Engine::get_ie_ops() const
ie_ops.insert(opset6.begin(), opset6.end());
const auto& opset7 = get_opset7().get_type_info_set();
ie_ops.insert(opset7.begin(), opset7.end());
const auto& opset8 = get_opset8().get_type_info_set();
ie_ops.insert(opset8.begin(), opset8.end());
return ie_ops;
}

Expand Down

0 comments on commit 2378593

Please sign in to comment.