Skip to content

Commit

Permalink
mypy checks for common graph (#2600)
Browse files Browse the repository at this point in the history
### Changes

This PR closes issue #2495 by addressing various mypy checks and
enhancing type safety in the codebase.

- Resolved specific mypy errors related to type inconsistencies.
- Utilized `# type:ignore` for cases requiring significant refactoring
due to untyped packages like `networkx`.
- Added the directory `nncf/common/graph` to `.mypy.ini` to include
additional files for type checking.

### Related Tickets

N/A

### Tests

Pytests were run to ensure that the changes did not modify the common
logic and that the codebase remained functional.
  • Loading branch information
DaniAffCH authored Mar 26, 2024
1 parent f2f3bb7 commit b28c3fe
Show file tree
Hide file tree
Showing 12 changed files with 110 additions and 86 deletions.
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[mypy]
files = nncf/common/sparsity
files = nncf/common/sparsity, nncf/common/graph
follow_imports = silent
strict = True

Expand Down
94 changes: 50 additions & 44 deletions nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
from collections import defaultdict
from copy import deepcopy
from typing import Any, Callable, Dict, Generator, KeysView, List, Optional, Tuple, Type, ValuesView
from typing import Any, Callable, Dict, Generator, KeysView, List, Optional, Tuple, Type, ValuesView, cast

import networkx as nx
import networkx.algorithms.isomorphism as iso
import networkx as nx # type:ignore
import networkx.algorithms.isomorphism as iso # type:ignore
from networkx.classes.reportviews import OutEdgeView # type:ignore

import nncf
from nncf.common.graph.graph_matching import find_subgraphs_matching_pattern
Expand Down Expand Up @@ -46,7 +48,7 @@ class NNCFNode:
IS_INTEGER_INPUT_NODE_ATTR = "is_integer_input"
IS_SHARED_ATTR = "is_shared"

def __init__(self, attributes: Dict[str, Any]):
def __init__(self, attributes: Dict[str, Any]) -> None:
self._attributes = attributes

@property
Expand All @@ -55,23 +57,23 @@ def attributes(self) -> Dict[str, Any]:

@property
def node_id(self) -> int:
return self._attributes[NNCFNode.ID_NODE_ATTR]
return cast(int, self._attributes[NNCFNode.ID_NODE_ATTR])

@property
def node_key(self) -> str:
return self._attributes[NNCFNode.KEY_NODE_ATTR]
return cast(str, self._attributes[NNCFNode.KEY_NODE_ATTR])

@property
def node_name(self) -> NNCFNodeName:
return self._attributes[NNCFNode.NODE_NAME_ATTR]
return cast(NNCFNodeName, self._attributes[NNCFNode.NODE_NAME_ATTR])

@property
def metatype(self) -> Type[OperatorMetatype]:
return self._attributes[NNCFNode.METATYPE_ATTR]
return cast(Type[OperatorMetatype], self._attributes[NNCFNode.METATYPE_ATTR])

@property
def node_type(self) -> str:
return self._attributes[NNCFNode.NODE_TYPE_ATTR]
return cast(str, self._attributes[NNCFNode.NODE_TYPE_ATTR])

@property
def layer_name(self) -> Optional[LayerName]:
Expand All @@ -91,27 +93,27 @@ def layer_attributes(self, value: BaseLayerAttributes) -> None:

@property
def ignored_algorithms(self) -> List[str]:
return self._attributes[NNCFNode.IGNORED_ALGOS_ATTR]
return cast(List[str], self._attributes[NNCFNode.IGNORED_ALGOS_ATTR])

def is_in_iteration_scope(self) -> bool:
return self._attributes[NNCFNode.IS_IN_ITERATION_SCOPE_NODE_ATTR]
return cast(bool, self._attributes[NNCFNode.IS_IN_ITERATION_SCOPE_NODE_ATTR])

def is_integer_input(self) -> bool:
return self._attributes[NNCFNode.IS_INTEGER_INPUT_NODE_ATTR]
return cast(bool, self._attributes[NNCFNode.IS_INTEGER_INPUT_NODE_ATTR])

def is_shared(self) -> bool:
return self._attributes[NNCFNode.IS_SHARED_ATTR]
return cast(bool, self._attributes[NNCFNode.IS_SHARED_ATTR])

def __repr__(self):
def __repr__(self) -> str:
return str(self)

def __str__(self):
def __str__(self) -> str:
return " ".join([str(self.node_id), self.node_name, self.node_type])

def __hash__(self):
def __hash__(self) -> int:
return hash(str(self))

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
return isinstance(other, NNCFNode) and self.attributes == other.attributes


Expand All @@ -131,7 +133,7 @@ def __init__(
tensor_shape: List[int],
dtype: Dtype,
parallel_input_port_ids: List[int],
):
) -> None:
"""
:param from_node: An NNCFNode that sources the directed edge.
:param to_node: An NNCFNode that sinks the directed edge.
Expand All @@ -144,14 +146,14 @@ def __init__(
self.to_node = to_node
self.input_port_id = input_port_id
self.output_port_id = output_port_id
self.tensor_shape: Tuple[int] = tuple(tensor_shape)
self.tensor_shape: Tuple[int, ...] = tuple(tensor_shape)
self.dtype = dtype
self.parallel_input_port_ids = parallel_input_port_ids

def __str__(self):
def __str__(self) -> str:
return f"{self.from_node}:{self.output_port_id} -> {self.tensor_shape} -> {self.to_node}:{self.input_port_id}"

def __hash__(self):
def __hash__(self) -> int:
return hash(
(
self.from_node,
Expand All @@ -164,7 +166,7 @@ def __hash__(self):
)
)

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
return isinstance(other, NNCFGraphEdge) and self.__dict__ == other.__dict__


Expand All @@ -190,9 +192,9 @@ class NNCFGraph:
DTYPE_EDGE_ATTR = "dtype"
PARALLEL_INPUT_PORT_IDS_ATTR = "parallel_input_ports"

def __init__(self):
def __init__(self) -> None:
self._nx_graph = nx.DiGraph()
self._node_id_to_key_dict = {}
self._node_id_to_key_dict: Dict[int, str] = {}
self._nodes: Dict[str, NNCFNode] = {}
self._input_nncf_nodes: Dict[int, NNCFNode] = {}
self._output_nncf_nodes: Dict[int, NNCFNode] = {}
Expand Down Expand Up @@ -288,18 +290,20 @@ def get_all_simple_paths(
end_node = self.get_node_by_name(end_node_name)
start_node_key = self.get_node_key_by_id(start_node.node_id)
end_node_key = self.get_node_key_by_id(end_node.node_id)
return nx.all_simple_paths(self._nx_graph, start_node_key, end_node_key)
return cast(
Generator[List[NNCFNodeName], None, None], nx.all_simple_paths(self._nx_graph, start_node_key, end_node_key)
)

@staticmethod
def _get_edge_boundaries(
match: List[str], graph: nx.DiGraph
) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]:
) -> Tuple[List[Tuple[str, str, Dict[str, Any]]], List[Tuple[str, str, Dict[str, Any]]]]:
out_edge_boundary = list(nx.edge_boundary(graph, match, data=True))
complement = list(filter(lambda x: x not in match, graph.nodes.keys()))
in_edge_boundary = list(nx.edge_boundary(graph, complement, data=True))
return sorted(in_edge_boundary), sorted(out_edge_boundary) # must be sorted for determinism

def get_node_key_by_id(self, node_id: id) -> str:
def get_node_key_by_id(self, node_id: int) -> str:
"""
Returns node key (node_name) by provided id.
Expand Down Expand Up @@ -369,7 +373,7 @@ def _get_edges(self, from_node: NNCFNode, to_node: NNCFNode) -> List[NNCFGraphEd
to_node=edge.to_node,
input_port_id=input_port_id,
output_port_id=edge.output_port_id,
tensor_shape=edge.tensor_shape,
tensor_shape=list(edge.tensor_shape),
dtype=edge.dtype,
parallel_input_port_ids=[],
)
Expand All @@ -381,7 +385,7 @@ def traverse_graph(
curr_node: NNCFNode,
traverse_function: Callable[[NNCFNode, List[Any]], Tuple[bool, List[Any]]],
traverse_forward: bool = True,
):
) -> List[Any]:
"""
Traverses graph up or down starting form `curr_node` node.
Expand All @@ -390,7 +394,7 @@ def traverse_graph(
:param traverse_forward: Flag specifying direction of traversal.
:return:
"""
output = []
output: List[Any] = []
return self._traverse_graph_recursive_helper(curr_node, traverse_function, output, traverse_forward)

def _traverse_graph_recursive_helper(
Expand All @@ -399,7 +403,7 @@ def _traverse_graph_recursive_helper(
traverse_function: Callable[[NNCFNode, List[Any]], Tuple[bool, List[Any]]],
output: List[Any],
traverse_forward: bool,
):
) -> List[Any]:
is_finished, output = traverse_function(curr_node, output)
get_nodes_fn = self.get_next_nodes if traverse_forward else self.get_previous_nodes
if not is_finished:
Expand Down Expand Up @@ -450,7 +454,7 @@ def add_nncf_node(
if node_id_override is not None:
node_id = node_id_override
else:
node_ids = self.get_all_node_ids()
node_ids = list(self.get_all_node_ids())
if node_ids:
node_id = max(self.get_all_node_ids()) + 1
else:
Expand Down Expand Up @@ -508,7 +512,7 @@ def add_edge_between_nncf_nodes(
output_port_id: int,
dtype: Dtype,
parallel_input_port_ids: Optional[List[int]] = None,
):
) -> None:
"""
Adds a directed edge between two `NNCFNode`s that are already present in the graph.
The edge represents an activation tensor, produced or consumed by an operation (which is represented by a node)
Expand Down Expand Up @@ -559,12 +563,12 @@ def topological_sort(self) -> List[NNCFNode]:
)
]

def dump_graph(self, path: str):
write_dot_graph(self.get_graph_for_structure_analysis(), path)
def dump_graph(self, path: str) -> None:
write_dot_graph(self.get_graph_for_structure_analysis(), pathlib.Path(path))

def visualize_graph(self, path: str):
def visualize_graph(self, path: str) -> None:
out_graph = self._get_graph_for_visualization()
write_dot_graph(out_graph, path)
write_dot_graph(out_graph, pathlib.Path(path))

def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph:
"""
Expand Down Expand Up @@ -633,7 +637,7 @@ def _get_graph_for_visualization(self) -> nx.DiGraph:
mapping = {k: v["label"] for k, v in out_graph.nodes.items()}
out_graph = nx.relabel_nodes(out_graph, mapping)
for node in out_graph.nodes.values():
node.pop("label")
node.pop("label") # type: ignore

return out_graph

Expand All @@ -647,14 +651,16 @@ def get_node_by_name(self, name: NNCFNodeName) -> NNCFNode:
node_key = f"{node_ids[0]} {name}"
return self._nodes[node_key]

def __eq__(self, other: "NNCFGraph"):
def __eq__(self, other: object) -> bool:
nm = iso.categorical_node_match(
[NNCFNode.ID_NODE_ATTR, NNCFNode.KEY_NODE_ATTR, NNCFNode.LAYER_ATTRIBUTES], [None, None, None]
)
em = iso.categorical_edge_match(
[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR, NNCFGraph.INPUT_PORT_ID_EDGE_ATTR], [None, None]
)
return nx.is_isomorphic(self._nx_graph, other._nx_graph, node_match=nm, edge_match=em)
return isinstance(other, NNCFGraph) and bool(
nx.is_isomorphic(self._nx_graph, other._nx_graph, node_match=nm, edge_match=em)
)

def get_nx_graph_copy(self) -> nx.DiGraph:
return deepcopy(self._nx_graph)
Expand Down Expand Up @@ -697,13 +703,13 @@ def get_nncf_graph_pattern_io(self, match: List[str]) -> NNCFGraphPatternIO:

return NNCFGraphPatternIO(input_nncf_edges, output_nncf_edges)

def get_nx_edge(self, node_u: NNCFNode, node_v: NNCFNode):
def get_nx_edge(self, node_u: NNCFNode, node_v: NNCFNode) -> OutEdgeView:
nx_node_u = self._nx_graph.nodes[self._node_id_to_key_dict[node_u.node_id]]
nx_node_v = self._nx_graph.nodes[self._node_id_to_key_dict[node_v.node_id]]
return self._nx_graph.edges[nx_node_u["key"], nx_node_v["key"]]

def get_nodes_count(self):
return self._nx_graph.number_of_nodes()
def get_nodes_count(self) -> int:
return int(self._nx_graph.number_of_nodes())

def get_edge(self, from_node: NNCFNode, to_node: NNCFNode) -> NNCFGraphEdge:
"""
Expand Down Expand Up @@ -741,7 +747,7 @@ def remove_nodes_from(self, nodes: List[NNCFNode]) -> None:

self._node_id_to_key_dict = {}
for node_key, node in self._nx_graph.nodes.items():
self._node_id_to_key_dict[node["id"]] = node_key
self._node_id_to_key_dict[node["id"]] = node_key # type:ignore

def find_matching_subgraphs(self, patterns: GraphPattern, strict: bool = True) -> List[List[NNCFNode]]:
"""
Expand Down
6 changes: 3 additions & 3 deletions nncf/common/graph/graph_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
# limitations under the License.
from typing import Dict, List

import networkx as nx
import networkx.algorithms.isomorphism as ism
import networkx as nx # type:ignore
import networkx.algorithms.isomorphism as ism # type:ignore

from nncf.common.graph.patterns import GraphPattern

ATTRS_TO_SKIP = [GraphPattern.LABEL_ATTR, GraphPattern.PATTERN_NODE_TO_EXCLUDE]


def _are_nodes_matched(node_1, node_2) -> bool:
def _are_nodes_matched(node_1, node_2) -> bool: # type:ignore
for attr in node_2:
if attr in ATTRS_TO_SKIP:
continue
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/graph/layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, axis: int, num_inputs: Optional[int] = None):


class MultipleOutputLayerAttributes(BaseLayerAttributes):
def __init__(self, chunks: Union[int, List], axis: int):
def __init__(self, chunks: Union[int, List[Any]], axis: int):
"""
:param chunks: Number of chunks (outputs).
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, model: TModel):
"""
self._model = model

def transform(self, transformation_layout: TransformationLayout) -> TModel:
def transform(self, transformation_layout: TransformationLayout) -> TModel: # type:ignore
"""
Applies transformations to the model.
Expand Down
8 changes: 4 additions & 4 deletions nncf/common/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Set, Type
from typing import Callable, Dict, List, Optional, Set, Type

import nncf
from nncf.common.graph.definitions import NNCFGraphNodeType
Expand Down Expand Up @@ -77,9 +77,9 @@ def __init__(self, name: str):
:param name: The registry name.
"""
super().__init__(name)
self._op_name_to_op_meta_dict = {}
self._op_name_to_op_meta_dict: Dict[str, Type[OperatorMetatype]] = {}

def register(self, name: Optional[str] = None):
def register(self, name: Optional[str] = None) -> Callable[..., Type[OperatorMetatype]]:
"""
Decorator for registering operator metatypes.
Expand All @@ -89,7 +89,7 @@ def register(self, name: Optional[str] = None):
name_ = name
super_register = super()._register

def wrap(obj: Type[OperatorMetatype]):
def wrap(obj: Type[OperatorMetatype]) -> Type[OperatorMetatype]:
"""
Inner function for registering operator metatypes.
Expand Down
Loading

0 comments on commit b28c3fe

Please sign in to comment.