diff --git a/.mypy.ini b/.mypy.ini index 42b6329f72c..015bbb5f890 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,5 +1,5 @@ [mypy] -files = nncf/common/sparsity +files = nncf/common/sparsity, nncf/common/graph follow_imports = silent strict = True diff --git a/nncf/common/graph/graph.py b/nncf/common/graph/graph.py index 1f9795fc721..13ea932d921 100644 --- a/nncf/common/graph/graph.py +++ b/nncf/common/graph/graph.py @@ -8,12 +8,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pathlib from collections import defaultdict from copy import deepcopy -from typing import Any, Callable, Dict, Generator, KeysView, List, Optional, Tuple, Type, ValuesView +from typing import Any, Callable, Dict, Generator, KeysView, List, Optional, Tuple, Type, ValuesView, cast -import networkx as nx -import networkx.algorithms.isomorphism as iso +import networkx as nx # type:ignore +import networkx.algorithms.isomorphism as iso # type:ignore +from networkx.classes.reportviews import OutEdgeView # type:ignore import nncf from nncf.common.graph.graph_matching import find_subgraphs_matching_pattern @@ -46,7 +48,7 @@ class NNCFNode: IS_INTEGER_INPUT_NODE_ATTR = "is_integer_input" IS_SHARED_ATTR = "is_shared" - def __init__(self, attributes: Dict[str, Any]): + def __init__(self, attributes: Dict[str, Any]) -> None: self._attributes = attributes @property @@ -55,23 +57,23 @@ def attributes(self) -> Dict[str, Any]: @property def node_id(self) -> int: - return self._attributes[NNCFNode.ID_NODE_ATTR] + return cast(int, self._attributes[NNCFNode.ID_NODE_ATTR]) @property def node_key(self) -> str: - return self._attributes[NNCFNode.KEY_NODE_ATTR] + return cast(str, self._attributes[NNCFNode.KEY_NODE_ATTR]) @property def node_name(self) -> NNCFNodeName: - return self._attributes[NNCFNode.NODE_NAME_ATTR] + return cast(NNCFNodeName, self._attributes[NNCFNode.NODE_NAME_ATTR]) @property def metatype(self) -> Type[OperatorMetatype]: - return self._attributes[NNCFNode.METATYPE_ATTR] + return cast(Type[OperatorMetatype], self._attributes[NNCFNode.METATYPE_ATTR]) @property def node_type(self) -> str: - return self._attributes[NNCFNode.NODE_TYPE_ATTR] + return cast(str, self._attributes[NNCFNode.NODE_TYPE_ATTR]) @property def layer_name(self) -> Optional[LayerName]: @@ -91,27 +93,27 @@ def layer_attributes(self, value: BaseLayerAttributes) -> None: @property def ignored_algorithms(self) -> List[str]: - return self._attributes[NNCFNode.IGNORED_ALGOS_ATTR] + return cast(List[str], self._attributes[NNCFNode.IGNORED_ALGOS_ATTR]) def is_in_iteration_scope(self) -> bool: - return self._attributes[NNCFNode.IS_IN_ITERATION_SCOPE_NODE_ATTR] + return cast(bool, self._attributes[NNCFNode.IS_IN_ITERATION_SCOPE_NODE_ATTR]) def is_integer_input(self) -> bool: - return self._attributes[NNCFNode.IS_INTEGER_INPUT_NODE_ATTR] + return cast(bool, self._attributes[NNCFNode.IS_INTEGER_INPUT_NODE_ATTR]) def is_shared(self) -> bool: - return self._attributes[NNCFNode.IS_SHARED_ATTR] + return cast(bool, self._attributes[NNCFNode.IS_SHARED_ATTR]) - def __repr__(self): + def __repr__(self) -> str: return str(self) - def __str__(self): + def __str__(self) -> str: return " ".join([str(self.node_id), self.node_name, self.node_type]) - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, NNCFNode) and self.attributes == other.attributes @@ -131,7 +133,7 @@ def __init__( tensor_shape: List[int], dtype: Dtype, parallel_input_port_ids: List[int], - ): + ) -> None: """ :param from_node: An NNCFNode that sources the directed edge. :param to_node: An NNCFNode that sinks the directed edge. @@ -144,14 +146,14 @@ def __init__( self.to_node = to_node self.input_port_id = input_port_id self.output_port_id = output_port_id - self.tensor_shape: Tuple[int] = tuple(tensor_shape) + self.tensor_shape: Tuple[int, ...] = tuple(tensor_shape) self.dtype = dtype self.parallel_input_port_ids = parallel_input_port_ids - def __str__(self): + def __str__(self) -> str: return f"{self.from_node}:{self.output_port_id} -> {self.tensor_shape} -> {self.to_node}:{self.input_port_id}" - def __hash__(self): + def __hash__(self) -> int: return hash( ( self.from_node, @@ -164,7 +166,7 @@ def __hash__(self): ) ) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, NNCFGraphEdge) and self.__dict__ == other.__dict__ @@ -190,9 +192,9 @@ class NNCFGraph: DTYPE_EDGE_ATTR = "dtype" PARALLEL_INPUT_PORT_IDS_ATTR = "parallel_input_ports" - def __init__(self): + def __init__(self) -> None: self._nx_graph = nx.DiGraph() - self._node_id_to_key_dict = {} + self._node_id_to_key_dict: Dict[int, str] = {} self._nodes: Dict[str, NNCFNode] = {} self._input_nncf_nodes: Dict[int, NNCFNode] = {} self._output_nncf_nodes: Dict[int, NNCFNode] = {} @@ -288,18 +290,20 @@ def get_all_simple_paths( end_node = self.get_node_by_name(end_node_name) start_node_key = self.get_node_key_by_id(start_node.node_id) end_node_key = self.get_node_key_by_id(end_node.node_id) - return nx.all_simple_paths(self._nx_graph, start_node_key, end_node_key) + return cast( + Generator[List[NNCFNodeName], None, None], nx.all_simple_paths(self._nx_graph, start_node_key, end_node_key) + ) @staticmethod def _get_edge_boundaries( match: List[str], graph: nx.DiGraph - ) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]: + ) -> Tuple[List[Tuple[str, str, Dict[str, Any]]], List[Tuple[str, str, Dict[str, Any]]]]: out_edge_boundary = list(nx.edge_boundary(graph, match, data=True)) complement = list(filter(lambda x: x not in match, graph.nodes.keys())) in_edge_boundary = list(nx.edge_boundary(graph, complement, data=True)) return sorted(in_edge_boundary), sorted(out_edge_boundary) # must be sorted for determinism - def get_node_key_by_id(self, node_id: id) -> str: + def get_node_key_by_id(self, node_id: int) -> str: """ Returns node key (node_name) by provided id. @@ -369,7 +373,7 @@ def _get_edges(self, from_node: NNCFNode, to_node: NNCFNode) -> List[NNCFGraphEd to_node=edge.to_node, input_port_id=input_port_id, output_port_id=edge.output_port_id, - tensor_shape=edge.tensor_shape, + tensor_shape=list(edge.tensor_shape), dtype=edge.dtype, parallel_input_port_ids=[], ) @@ -381,7 +385,7 @@ def traverse_graph( curr_node: NNCFNode, traverse_function: Callable[[NNCFNode, List[Any]], Tuple[bool, List[Any]]], traverse_forward: bool = True, - ): + ) -> List[Any]: """ Traverses graph up or down starting form `curr_node` node. @@ -390,7 +394,7 @@ def traverse_graph( :param traverse_forward: Flag specifying direction of traversal. :return: """ - output = [] + output: List[Any] = [] return self._traverse_graph_recursive_helper(curr_node, traverse_function, output, traverse_forward) def _traverse_graph_recursive_helper( @@ -399,7 +403,7 @@ def _traverse_graph_recursive_helper( traverse_function: Callable[[NNCFNode, List[Any]], Tuple[bool, List[Any]]], output: List[Any], traverse_forward: bool, - ): + ) -> List[Any]: is_finished, output = traverse_function(curr_node, output) get_nodes_fn = self.get_next_nodes if traverse_forward else self.get_previous_nodes if not is_finished: @@ -450,7 +454,7 @@ def add_nncf_node( if node_id_override is not None: node_id = node_id_override else: - node_ids = self.get_all_node_ids() + node_ids = list(self.get_all_node_ids()) if node_ids: node_id = max(self.get_all_node_ids()) + 1 else: @@ -508,7 +512,7 @@ def add_edge_between_nncf_nodes( output_port_id: int, dtype: Dtype, parallel_input_port_ids: Optional[List[int]] = None, - ): + ) -> None: """ Adds a directed edge between two `NNCFNode`s that are already present in the graph. The edge represents an activation tensor, produced or consumed by an operation (which is represented by a node) @@ -559,12 +563,12 @@ def topological_sort(self) -> List[NNCFNode]: ) ] - def dump_graph(self, path: str): - write_dot_graph(self.get_graph_for_structure_analysis(), path) + def dump_graph(self, path: str) -> None: + write_dot_graph(self.get_graph_for_structure_analysis(), pathlib.Path(path)) - def visualize_graph(self, path: str): + def visualize_graph(self, path: str) -> None: out_graph = self._get_graph_for_visualization() - write_dot_graph(out_graph, path) + write_dot_graph(out_graph, pathlib.Path(path)) def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph: """ @@ -633,7 +637,7 @@ def _get_graph_for_visualization(self) -> nx.DiGraph: mapping = {k: v["label"] for k, v in out_graph.nodes.items()} out_graph = nx.relabel_nodes(out_graph, mapping) for node in out_graph.nodes.values(): - node.pop("label") + node.pop("label") # type: ignore return out_graph @@ -647,14 +651,16 @@ def get_node_by_name(self, name: NNCFNodeName) -> NNCFNode: node_key = f"{node_ids[0]} {name}" return self._nodes[node_key] - def __eq__(self, other: "NNCFGraph"): + def __eq__(self, other: object) -> bool: nm = iso.categorical_node_match( [NNCFNode.ID_NODE_ATTR, NNCFNode.KEY_NODE_ATTR, NNCFNode.LAYER_ATTRIBUTES], [None, None, None] ) em = iso.categorical_edge_match( [NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR, NNCFGraph.INPUT_PORT_ID_EDGE_ATTR], [None, None] ) - return nx.is_isomorphic(self._nx_graph, other._nx_graph, node_match=nm, edge_match=em) + return isinstance(other, NNCFGraph) and bool( + nx.is_isomorphic(self._nx_graph, other._nx_graph, node_match=nm, edge_match=em) + ) def get_nx_graph_copy(self) -> nx.DiGraph: return deepcopy(self._nx_graph) @@ -697,13 +703,13 @@ def get_nncf_graph_pattern_io(self, match: List[str]) -> NNCFGraphPatternIO: return NNCFGraphPatternIO(input_nncf_edges, output_nncf_edges) - def get_nx_edge(self, node_u: NNCFNode, node_v: NNCFNode): + def get_nx_edge(self, node_u: NNCFNode, node_v: NNCFNode) -> OutEdgeView: nx_node_u = self._nx_graph.nodes[self._node_id_to_key_dict[node_u.node_id]] nx_node_v = self._nx_graph.nodes[self._node_id_to_key_dict[node_v.node_id]] return self._nx_graph.edges[nx_node_u["key"], nx_node_v["key"]] - def get_nodes_count(self): - return self._nx_graph.number_of_nodes() + def get_nodes_count(self) -> int: + return int(self._nx_graph.number_of_nodes()) def get_edge(self, from_node: NNCFNode, to_node: NNCFNode) -> NNCFGraphEdge: """ @@ -741,7 +747,7 @@ def remove_nodes_from(self, nodes: List[NNCFNode]) -> None: self._node_id_to_key_dict = {} for node_key, node in self._nx_graph.nodes.items(): - self._node_id_to_key_dict[node["id"]] = node_key + self._node_id_to_key_dict[node["id"]] = node_key # type:ignore def find_matching_subgraphs(self, patterns: GraphPattern, strict: bool = True) -> List[List[NNCFNode]]: """ diff --git a/nncf/common/graph/graph_matching.py b/nncf/common/graph/graph_matching.py index a4af34edf43..d2bc0886ffe 100644 --- a/nncf/common/graph/graph_matching.py +++ b/nncf/common/graph/graph_matching.py @@ -10,15 +10,15 @@ # limitations under the License. from typing import Dict, List -import networkx as nx -import networkx.algorithms.isomorphism as ism +import networkx as nx # type:ignore +import networkx.algorithms.isomorphism as ism # type:ignore from nncf.common.graph.patterns import GraphPattern ATTRS_TO_SKIP = [GraphPattern.LABEL_ATTR, GraphPattern.PATTERN_NODE_TO_EXCLUDE] -def _are_nodes_matched(node_1, node_2) -> bool: +def _are_nodes_matched(node_1, node_2) -> bool: # type:ignore for attr in node_2: if attr in ATTRS_TO_SKIP: continue diff --git a/nncf/common/graph/layer_attributes.py b/nncf/common/graph/layer_attributes.py index 06da97313b8..bf45dae94f6 100644 --- a/nncf/common/graph/layer_attributes.py +++ b/nncf/common/graph/layer_attributes.py @@ -43,7 +43,7 @@ def __init__(self, axis: int, num_inputs: Optional[int] = None): class MultipleOutputLayerAttributes(BaseLayerAttributes): - def __init__(self, chunks: Union[int, List], axis: int): + def __init__(self, chunks: Union[int, List[Any]], axis: int): """ :param chunks: Number of chunks (outputs). diff --git a/nncf/common/graph/model_transformer.py b/nncf/common/graph/model_transformer.py index 18874405efd..c63ba078cd3 100644 --- a/nncf/common/graph/model_transformer.py +++ b/nncf/common/graph/model_transformer.py @@ -29,7 +29,7 @@ def __init__(self, model: TModel): """ self._model = model - def transform(self, transformation_layout: TransformationLayout) -> TModel: + def transform(self, transformation_layout: TransformationLayout) -> TModel: # type:ignore """ Applies transformations to the model. diff --git a/nncf/common/graph/operator_metatypes.py b/nncf/common/graph/operator_metatypes.py index e65f659471f..27fc60213a4 100644 --- a/nncf/common/graph/operator_metatypes.py +++ b/nncf/common/graph/operator_metatypes.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Set, Type +from typing import Callable, Dict, List, Optional, Set, Type import nncf from nncf.common.graph.definitions import NNCFGraphNodeType @@ -77,9 +77,9 @@ def __init__(self, name: str): :param name: The registry name. """ super().__init__(name) - self._op_name_to_op_meta_dict = {} + self._op_name_to_op_meta_dict: Dict[str, Type[OperatorMetatype]] = {} - def register(self, name: Optional[str] = None): + def register(self, name: Optional[str] = None) -> Callable[..., Type[OperatorMetatype]]: """ Decorator for registering operator metatypes. @@ -89,7 +89,7 @@ def register(self, name: Optional[str] = None): name_ = name super_register = super()._register - def wrap(obj: Type[OperatorMetatype]): + def wrap(obj: Type[OperatorMetatype]) -> Type[OperatorMetatype]: """ Inner function for registering operator metatypes. diff --git a/nncf/common/graph/patterns/manager.py b/nncf/common/graph/patterns/manager.py index 0139ae12e50..824a3b4a4c5 100644 --- a/nncf/common/graph/patterns/manager.py +++ b/nncf/common/graph/patterns/manager.py @@ -8,7 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Optional, Union +from typing import Callable, Dict, Optional, Union, cast from nncf.common.graph.patterns.patterns import GraphPattern from nncf.common.graph.patterns.patterns import HWFusedPatternNames @@ -34,18 +34,22 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> Dict[HWFusedPatternNam :param backend: BackendType instance. :return: Dictionary with the HWFusedPatternNames instance as keys and creator function as a value. """ + registry: Dict[HWFusedPatternNames, Callable[[], GraphPattern]] = {} if backend == BackendType.ONNX: from nncf.onnx.hardware.fused_patterns import ONNX_HW_FUSED_PATTERNS - return ONNX_HW_FUSED_PATTERNS.registry_dict + registry = ONNX_HW_FUSED_PATTERNS.registry_dict + return registry if backend == BackendType.OPENVINO: from nncf.openvino.hardware.fused_patterns import OPENVINO_HW_FUSED_PATTERNS - return OPENVINO_HW_FUSED_PATTERNS.registry_dict + registry = OPENVINO_HW_FUSED_PATTERNS.registry_dict + return registry if backend == BackendType.TORCH: from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS - return PT_HW_FUSED_PATTERNS.registry_dict + registry = PT_HW_FUSED_PATTERNS.registry_dict + return registry raise ValueError(f"Hardware-fused patterns not implemented for {backend} backend.") @staticmethod @@ -58,23 +62,29 @@ def _get_backend_ignored_patterns_map( :param backend: BackendType instance. :return: Dictionary with the HWFusedPatternNames instance as keys and creator function as a value. """ + registry: Dict[IgnoredPatternNames, Callable[[], GraphPattern]] = {} if backend == BackendType.ONNX: from nncf.onnx.quantization.ignored_patterns import ONNX_IGNORED_PATTERNS - return ONNX_IGNORED_PATTERNS.registry_dict + registry = ONNX_IGNORED_PATTERNS.registry_dict + return registry if backend == BackendType.OPENVINO: from nncf.openvino.quantization.ignored_patterns import OPENVINO_IGNORED_PATTERNS - return OPENVINO_IGNORED_PATTERNS.registry_dict + registry = OPENVINO_IGNORED_PATTERNS.registry_dict + return registry if backend == BackendType.TORCH: from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS - return PT_IGNORED_PATTERNS.registry_dict + registry = PT_IGNORED_PATTERNS.registry_dict + return registry raise ValueError(f"Ignored patterns not implemented for {backend} backend.") @staticmethod def _filter_patterns( - patterns_to_filter: Dict[PatternNames, Callable[[], GraphPattern]], device: TargetDevice, model_type: ModelType + patterns_to_filter: Dict[PatternNames, Callable[[], GraphPattern]], + device: TargetDevice, + model_type: Optional[ModelType] = None, ) -> Dict[PatternNames, Callable[[], GraphPattern]]: """ Returns all patterns from patterns_to_filter that are satisfied device and model_type parameters. @@ -98,7 +108,7 @@ def _filter_patterns( def _get_full_pattern_graph( backend_patterns_map: Dict[PatternNames, Callable[[], GraphPattern]], device: TargetDevice, - model_type: ModelType, + model_type: Optional[ModelType] = None, ) -> GraphPattern: """ Filters patterns and returns GraphPattern with registered filtered patterns. @@ -127,7 +137,9 @@ def get_full_hw_pattern_graph( :param model_type: ModelType instance. :return: Completed GraphPattern based on the backend, device & model_type. """ - backend_patterns_map = PatternsManager._get_backend_hw_patterns_map(backend) + backend_patterns_map = cast( + Dict[PatternNames, Callable[[], GraphPattern]], PatternsManager._get_backend_hw_patterns_map(backend) + ) return PatternsManager._get_full_pattern_graph(backend_patterns_map, device, model_type) @staticmethod @@ -143,5 +155,7 @@ def get_full_ignored_pattern_graph( :param model_type: ModelType instance. :return: Completed GraphPattern with registered value based on the backend, device & model_type. """ - backend_patterns_map = PatternsManager._get_backend_ignored_patterns_map(backend) + backend_patterns_map = cast( + Dict[PatternNames, Callable[[], GraphPattern]], PatternsManager._get_backend_ignored_patterns_map(backend) + ) return PatternsManager._get_full_pattern_graph(backend_patterns_map, device, model_type) diff --git a/nncf/common/graph/patterns/patterns.py b/nncf/common/graph/patterns/patterns.py index 98da197dda7..89f7f2cd749 100644 --- a/nncf/common/graph/patterns/patterns.py +++ b/nncf/common/graph/patterns/patterns.py @@ -11,12 +11,13 @@ import copy import itertools as it import os +import pathlib from dataclasses import dataclass from enum import Enum -from typing import Dict, Hashable, List, Optional, Tuple +from typing import Any, Callable, Dict, Hashable, List, Optional, Tuple, cast -import networkx as nx -import networkx.algorithms.isomorphism as ism +import networkx as nx # type: ignore +import networkx.algorithms.isomorphism as ism # type: ignore import nncf from nncf.common.utils.dot_file_rw import write_dot_graph @@ -32,8 +33,8 @@ class Patterns: during the quantization algorithm. """ - def __init__(self): - self._patterns_dict = {} + def __init__(self) -> None: + self._patterns_dict: Dict[str, GraphPattern] = {} self._full_pattern_graph = GraphPattern() def register(self, pattern: "GraphPattern", name: str, match: bool = True) -> None: @@ -83,7 +84,7 @@ class GraphPattern: NON_PATTERN_NODE_TYPE = "NON_PATTERN_NODE" PATTERN_NODE_TO_EXCLUDE = "PATTERN_NODE_TO_EXCLUDE" - def __init__(self): + def __init__(self) -> None: self._graph = nx.DiGraph() self._node_counter = 0 @@ -130,8 +131,9 @@ def __or__(self, other: "GraphPattern") -> "GraphPattern": new_pattern._unite_with_copy_of_graph(other.graph) return new_pattern - def __eq__(self, other: "GraphPattern") -> bool: - return ism.is_isomorphic(self._graph, other.graph) + def __eq__(self, other: object) -> bool: + is_isomorphic: Callable[[Any, Any], bool] = ism.is_isomorphic + return isinstance(other, GraphPattern) and is_isomorphic(self._graph, other.graph) @property def graph(self) -> nx.DiGraph: @@ -232,27 +234,27 @@ def join_patterns(self, other: "GraphPattern", edges: Optional[List[Tuple[Hashab remapped_edges.append(new_edge) self._graph.add_edges_from(remapped_edges) - def add_node(self, **attrs) -> int: + def add_node(self, **attrs: Dict[str, Any]) -> int: if GraphPattern.METATYPE_ATTR in attrs and not isinstance(attrs[GraphPattern.METATYPE_ATTR], list): - attrs[GraphPattern.METATYPE_ATTR] = [attrs[GraphPattern.METATYPE_ATTR]] + attrs[GraphPattern.METATYPE_ATTR] = cast(Any, [attrs[GraphPattern.METATYPE_ATTR]]) self._graph.add_node(self._node_counter, **attrs) self._node_counter += 1 return self._node_counter - 1 - def add_edge(self, u_name, v_name) -> None: + def add_edge(self, u_name: str, v_name: str) -> None: self._graph.add_edge(u_name, v_name) - def add_edges_from(self, ebunch_to_add, **attr) -> None: + def add_edges_from(self, ebunch_to_add: List[Any], **attr: Dict[str, Any]) -> None: self._graph.add_edges_from(ebunch_to_add, **attr) def get_weakly_connected_subgraphs(self) -> List[nx.DiGraph]: return [self._graph.subgraph(c) for c in nx.weakly_connected_components(self._graph)] def dump_graph(self, path: str) -> None: - write_dot_graph(self._graph, path) + write_dot_graph(self._graph, pathlib.Path(path)) -def merge_two_types_of_operations(first_op: Dict, second_op: Dict, label: str) -> Dict: +def merge_two_types_of_operations(first_op: Dict[str, Any], second_op: Dict[str, Any], label: str) -> Dict[str, Any]: if GraphPattern.METATYPE_ATTR in first_op and GraphPattern.METATYPE_ATTR in second_op: res = {GraphPattern.METATYPE_ATTR: first_op[GraphPattern.METATYPE_ATTR]} res[GraphPattern.METATYPE_ATTR].extend(second_op[GraphPattern.METATYPE_ATTR]) @@ -277,7 +279,7 @@ class PatternDesc: name: str devices: Optional[List[TargetDevice]] = None - model_types: Optional[List[TargetDevice]] = None + model_types: Optional[List[ModelType]] = None class HWFusedPatternNames(Enum): diff --git a/nncf/common/graph/transformations/commands.py b/nncf/common/graph/transformations/commands.py index c03b9987e7c..7128f2793c0 100644 --- a/nncf/common/graph/transformations/commands.py +++ b/nncf/common/graph/transformations/commands.py @@ -163,7 +163,7 @@ def get_state(self) -> Dict[str, Any]: """ return {self._state_names.TARGET_TYPE: self._target_type.get_state()} - def is_weight_target_point(self): + def is_weight_target_point(self) -> bool: return self._target_type == TargetType.OPERATION_WITH_WEIGHTS @classmethod diff --git a/nncf/common/graph/transformations/layout.py b/nncf/common/graph/transformations/layout.py index 2698315075b..c6e9334af6a 100644 --- a/nncf/common/graph/transformations/layout.py +++ b/nncf/common/graph/transformations/layout.py @@ -23,11 +23,11 @@ class TransformationLayout: addresses these issues. """ - def __init__(self): + def __init__(self) -> None: """ Initialize Transformation Layout. """ - self._transformations = [] + self._transformations: List[TransformationCommand] = [] @property def transformations(self) -> List[TransformationCommand]: diff --git a/nncf/common/graph/utils.py b/nncf/common/graph/utils.py index 60a89d0615d..b21c3b013da 100644 --- a/nncf/common/graph/utils.py +++ b/nncf/common/graph/utils.py @@ -10,7 +10,7 @@ # limitations under the License. from functools import partial -from typing import List, Set, Tuple, Union +from typing import List, Set, Tuple, Type, Union from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode @@ -89,7 +89,9 @@ def get_split_axis(input_shapes: List[List[int]], output_shapes: List[List[int]] def get_number_of_quantized_ops( - graph: NNCFGraph, quantizer_metatypes: List[OperatorMetatype], quantizable_metatypes: List[OperatorMetatype] + graph: NNCFGraph, + quantizer_metatypes: List[Type[OperatorMetatype]], + quantizable_metatypes: List[Type[OperatorMetatype]], ) -> int: """ Returns the number of quantized operations in the graph. diff --git a/nncf/common/utils/registry.py b/nncf/common/utils/registry.py index b5633bd9ffa..e165a3b60f6 100644 --- a/nncf/common/utils/registry.py +++ b/nncf/common/utils/registry.py @@ -25,7 +25,7 @@ def registry_dict(self): def values(self): return self._registry_dict.values() - def _register(self, obj, name): + def _register(self, obj, name: str): if name in self._registry_dict: raise KeyError("{} is already registered in {}".format(name, self._name)) self._registry_dict[name] = obj