Skip to content

Commit

Permalink
Fix typing in OV node_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 28, 2023
1 parent bf9c611 commit 1b5ff64
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
from nncf.openvino.graph.metatypes.groups import OPERATIONS_WITH_BIAS
from nncf.openvino.graph.metatypes.groups import OPERATIONS_WITH_WEIGHTS
Expand Down Expand Up @@ -161,14 +162,15 @@ def get_ov_model_reduce_node_name(output_name: str, reduce_node_name: str, port_


def get_inplace_reduce_op(
op: Type[ov.Node], reduce_node_name: str, reduction_axes: Optional[Tuple[int, ...]], use_abs: bool
op: Type[ov.Node], reduce_node_name: str, reduction_axes: Optional[ReductionAxes], use_abs: bool
) -> InplaceInsertionFnType:
"""
Returns inplace insertion function that adds reduce node to a passed node.
:param op: OpenVINO reduction operation type to insert.
:param reduce_node_name: Reduce node name.
:param reduction_axes: Target reduction axes for the reduction node.
Reduce along all axes in case reduction_axes are None.
:param use_abs: Wheather reduce absolute values of input tensors or not.
:returns: Inplace insertion function to use in ModelTransformer.
"""
Expand Down Expand Up @@ -200,34 +202,40 @@ def get_reduce_op(node: ov.Node, output_port_id: int) -> ov.Node:
return get_reduce_op


def get_inplace_min_op(node_name: str, reduction_axes: Tuple[int, ...]) -> InplaceInsertionFnType:
def get_inplace_min_op(node_name: str, reduction_axes: Optional[ReductionAxes]) -> InplaceInsertionFnType:
"""
Returns inplace min function that adds reduce min node to a passed node.
:param node_name: Min reduce node name.
:param reduction_axes: Target reduction axes for the reduction node.
Reduce along all axes in case reduction_axes are None.
:returns: Inplace insertion function to use in ModelTransformer.
"""
return get_inplace_reduce_op(opset.reduce_min, node_name, reduction_axes, False)


def get_inplace_max_op(node_name: str, reduction_axes: Tuple[int, ...], use_abs_max: bool) -> InplaceInsertionFnType:
def get_inplace_max_op(
node_name: str, reduction_axes: Optional[ReductionAxes], use_abs_max: bool
) -> InplaceInsertionFnType:
"""
Returns inplace max function that adds reduce max node to a passed node.
:param node_name: Max reduce node name.
:param reduction_axes: Target reduction axes for the reduction node.
Reduce along all axes in case reduction_axes are None.
:param use_abs: Wheather reduce absolute values of input tensors or not.
:returns: Inplace insertion function to use in ModelTransformer.
"""
return get_inplace_reduce_op(opset.reduce_max, node_name, reduction_axes, use_abs_max)


def get_inplace_mean_op(node_name: str, reduction_axes: Tuple[int, ...]) -> InplaceInsertionFnType:
def get_inplace_mean_op(node_name: str, reduction_axes: Optional[ReductionAxes]) -> InplaceInsertionFnType:
"""
Returns inplace mean function that adds reduce mean node to a passed node.
:param node_name: Mean reduce node name.
:param reduction_axes: Target reduction axes for the reduction node.
Reduce along all axes in case reduction_axes are None.
:returns: Inplace insertion function to use in ModelTransformer.
"""
return get_inplace_reduce_op(opset.reduce_mean, node_name, reduction_axes, False)
Expand Down Expand Up @@ -302,7 +310,7 @@ def get_reduce_op(node: ov.Node, output_port_id: int) -> ov.Node:
return get_reduce_op


def get_partial_shape_safe(node, port_id) -> int:
def get_partial_shape_safe(node, port_id) -> Tuple[int, ...]:
partial_shape = node.get_output_partial_shape(port_id)
if partial_shape.rank.is_dynamic or not partial_shape.all_non_negative:
raise RuntimeError(
Expand Down Expand Up @@ -373,7 +381,7 @@ def get_matmul_channel_axes(weights_port_id: int, ndims: int, transpose: bool) -
return channel_axes


def get_channel_agnostic_reduction_axes(channel_axes: List[int], shape: List[int]) -> Tuple[int]:
def get_channel_agnostic_reduction_axes(channel_axes: List[int], shape: List[int]) -> Optional[ReductionAxes]:
"""
Returns filtered reduction axes without axes that corresponds channels.
Expand Down

0 comments on commit 1b5ff64

Please sign in to comment.