Skip to content

Commit

Permalink
Make NNCF common utils code pass mypy checks (#2780)
Browse files Browse the repository at this point in the history
### Changes

Made all the mypy checks pass for nncf/common/utils

### Related tickets

Closes Issue #2491
  • Loading branch information
anzr299 authored Jul 5, 2024
1 parent 2552f1b commit 28d99a0
Show file tree
Hide file tree
Showing 17 changed files with 91 additions and 69 deletions.
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[mypy]
files = nncf/common/sparsity, nncf/common/graph, nncf/common/accuracy_aware_training/
files = nncf/common/sparsity, nncf/common/graph, nncf/common/accuracy_aware_training/, nncf/common/utils/
follow_imports = silent
strict = True

Expand Down
2 changes: 1 addition & 1 deletion nncf/common/accuracy_aware_training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def remove_registry_prefix(algo_name: str) -> str:
)

return {
remove_registry_prefix(algo_name): controller_cls
remove_registry_prefix(algo_name): cast(CompressionAlgorithmController, controller_cls)
for algo_name, controller_cls in ADAPTIVE_COMPRESSION_CONTROLLERS.registry_dict.items()
}

Expand Down
2 changes: 1 addition & 1 deletion nncf/common/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from packaging import version


def warning_deprecated(msg):
def warning_deprecated(msg: str) -> None:
# Note: must use FutureWarning in order not to get suppressed by default
warnings.warn(msg, FutureWarning, stacklevel=2)

Expand Down
16 changes: 10 additions & 6 deletions nncf/common/graph/patterns/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,19 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> Dict[HWFusedPatternNam
if backend == BackendType.ONNX:
from nncf.onnx.hardware.fused_patterns import ONNX_HW_FUSED_PATTERNS

registry = ONNX_HW_FUSED_PATTERNS.registry_dict
registry = cast(Dict[HWFusedPatternNames, Callable[[], GraphPattern]], ONNX_HW_FUSED_PATTERNS.registry_dict)
return registry
if backend == BackendType.OPENVINO:
from nncf.openvino.hardware.fused_patterns import OPENVINO_HW_FUSED_PATTERNS

registry = OPENVINO_HW_FUSED_PATTERNS.registry_dict
registry = cast(
Dict[HWFusedPatternNames, Callable[[], GraphPattern]], OPENVINO_HW_FUSED_PATTERNS.registry_dict
)
return registry
if backend == BackendType.TORCH:
from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS

registry = PT_HW_FUSED_PATTERNS.registry_dict
registry = cast(Dict[HWFusedPatternNames, Callable[[], GraphPattern]], PT_HW_FUSED_PATTERNS.registry_dict)
return registry
raise ValueError(f"Hardware-fused patterns not implemented for {backend} backend.")

Expand All @@ -66,17 +68,19 @@ def _get_backend_ignored_patterns_map(
if backend == BackendType.ONNX:
from nncf.onnx.quantization.ignored_patterns import ONNX_IGNORED_PATTERNS

registry = ONNX_IGNORED_PATTERNS.registry_dict
registry = cast(Dict[IgnoredPatternNames, Callable[[], GraphPattern]], ONNX_IGNORED_PATTERNS.registry_dict)
return registry
if backend == BackendType.OPENVINO:
from nncf.openvino.quantization.ignored_patterns import OPENVINO_IGNORED_PATTERNS

registry = OPENVINO_IGNORED_PATTERNS.registry_dict
registry = cast(
Dict[IgnoredPatternNames, Callable[[], GraphPattern]], OPENVINO_IGNORED_PATTERNS.registry_dict
)
return registry
if backend == BackendType.TORCH:
from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS

registry = PT_IGNORED_PATTERNS.registry_dict
registry = cast(Dict[IgnoredPatternNames, Callable[[], GraphPattern]], PT_IGNORED_PATTERNS.registry_dict)
return registry
raise ValueError(f"Ignored patterns not implemented for {backend} backend.")

Expand Down
5 changes: 3 additions & 2 deletions nncf/common/utils/api_marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any


class api:
Expand All @@ -17,7 +18,7 @@ class api:
def __init__(self, canonical_alias: str = None):
self._canonical_alias = canonical_alias

def __call__(self, obj):
def __call__(self, obj: Any) -> Any:
# The value of the marker will be useful in determining
# whether we are handling a base class or a derived one.
setattr(obj, api.API_MARKER_ATTR, obj.__name__)
Expand All @@ -26,5 +27,5 @@ def __call__(self, obj):
return obj


def is_api(obj) -> bool:
def is_api(obj: Any) -> bool:
return hasattr(obj, api.API_MARKER_ATTR)
10 changes: 5 additions & 5 deletions nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def is_torch_model(model: TModel) -> bool:
:param model: A target model.
:return: True if the model is an instance of torch.nn.Module, otherwise False.
"""
import torch
import torch # type: ignore

return isinstance(model, torch.nn.Module)

Expand All @@ -68,7 +68,7 @@ def is_tensorflow_model(model: TModel) -> bool:
:param model: A target model.
:return: True if the model is an instance of tensorflow.Module, otherwise False.
"""
import tensorflow
import tensorflow # type: ignore

return isinstance(model, tensorflow.Module)

Expand All @@ -80,7 +80,7 @@ def is_onnx_model(model: TModel) -> bool:
:param model: A target model.
:return: True if the model is an instance of onnx.ModelProto, otherwise False.
"""
import onnx
import onnx # type: ignore

return isinstance(model, onnx.ModelProto)

Expand All @@ -92,7 +92,7 @@ def is_openvino_model(model: TModel) -> bool:
:param model: A target model.
:return: True if the model is an instance of openvino.runtime.Model, otherwise False.
"""
import openvino.runtime as ov
import openvino.runtime as ov # type: ignore

return isinstance(model, ov.Model)

Expand Down Expand Up @@ -147,7 +147,7 @@ def copy_model(model: TModel) -> TModel:
model_backend = get_backend(model)
if model_backend == BackendType.OPENVINO:
# TODO(l-bat): Remove after fixing ticket: 100919
return model.clone()
return model.clone() # type: ignore
if model_backend == BackendType.TENSORFLOW:
# deepcopy and tensorflow.keras.models.clone_model does not work correctly on 2.8.4 version
from nncf.tensorflow.graph.model_transformer import TFModelTransformer
Expand Down
7 changes: 4 additions & 3 deletions nncf/common/utils/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,24 @@

import logging
from contextlib import contextmanager
from typing import Generator

from nncf.common.logging import nncf_logger

DEBUG_LOG_DIR = "./nncf_debug"


def is_debug():
def is_debug() -> bool:
return nncf_logger.getEffectiveLevel() == logging.DEBUG


def set_debug_log_dir(dir_: str):
def set_debug_log_dir(dir_: str) -> None:
global DEBUG_LOG_DIR
DEBUG_LOG_DIR = dir_


@contextmanager
def nncf_debug():
def nncf_debug() -> Generator[None, None, None]:
from nncf.common.logging.logger import set_log_level

set_log_level(logging.DEBUG)
Expand Down
8 changes: 4 additions & 4 deletions nncf/common/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
# limitations under the License.

from importlib import import_module
from typing import Callable, List
from typing import Any, Callable, Dict, List

from nncf.common.logging import nncf_logger

IMPORTED_DEPENDENCIES = {}
IMPORTED_DEPENDENCIES: Dict[str, bool] = {}


def skip_if_dependency_unavailable(dependencies: List[str]) -> Callable:
def skip_if_dependency_unavailable(dependencies: List[str]) -> Callable[[Callable[..., None]], Callable[..., None]]:
"""
Decorator factory to skip a noreturn function if dependencies are not met.
Expand All @@ -26,7 +26,7 @@ def skip_if_dependency_unavailable(dependencies: List[str]) -> Callable:
"""

def wrap(func: Callable[..., None]) -> Callable[..., None]:
def wrapped_f(*args, **kwargs):
def wrapped_f(*args: Any, **kwargs: Any): # type: ignore
for libname in dependencies:
if libname in IMPORTED_DEPENDENCIES:
if IMPORTED_DEPENDENCIES[libname]:
Expand Down
14 changes: 7 additions & 7 deletions nncf/common/utils/dot_file_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from collections import defaultdict
from typing import Dict

import networkx as nx
import networkx as nx # type: ignore


def write_dot_graph(G: nx.DiGraph, path: pathlib.Path):
def write_dot_graph(G: nx.DiGraph, path: pathlib.Path) -> None:
# NOTE: writing dot files with colons even in labels or other node/edge/graph attributes leads to an
# error. See https://github.com/networkx/networkx/issues/5962. If `relabel` is True in this function,
# then the colons (:) will be replaced with (^) symbols.
Expand Down Expand Up @@ -47,29 +47,29 @@ def read_dot_graph(path: pathlib.Path) -> nx.DiGraph:
REPLACEMENT_CHAR = "^"


def _maybe_escape_colons_in_attrs(data: Dict):
def _maybe_escape_colons_in_attrs(data: Dict[str, str]) -> None:
for attr_name in data:
attr_val = str(data[attr_name])
if RESERVED_CHAR in attr_val and not (attr_val[0] == '"' or attr_val[-1] == '"'):
data[attr_name] = '"' + data[attr_name] + '"' # escaped colons are allowed


def _unescape_colons_in_attrs_with_colons(data: Dict):
def _unescape_colons_in_attrs_with_colons(data: Dict[str, str]) -> None:
for attr_name in data:
attr_val = data[attr_name]
if RESERVED_CHAR in attr_val and (attr_val[0] == '"' and attr_val[-1] == '"'):
data[attr_name] = data[attr_name][1:-1]


def _remove_cosmetic_labels(graph: nx.DiGraph):
def _remove_cosmetic_labels(graph: nx.DiGraph) -> None:
for node_name, node_data in graph.nodes(data=True):
if "label" in node_data:
label = node_data["label"]
if node_name == label or '"' + node_name + '"' == label:
del node_data["label"]


def _add_cosmetic_labels(graph: nx.DiGraph, relabeled_node_mapping: Dict[str, str]):
def _add_cosmetic_labels(graph: nx.DiGraph, relabeled_node_mapping: Dict[str, str]) -> None:
for original_name, dot_name in relabeled_node_mapping.items():
node_data = graph.nodes[dot_name]
if "label" not in node_data:
Expand Down Expand Up @@ -98,7 +98,7 @@ def relabel_graph_for_dot_visualization(nx_graph: nx.Graph, from_reference: bool
__CHARACTER_REPLACE_FROM = REPLACEMENT_CHAR
__CHARACTER_REPLACE_TO = RESERVED_CHAR

hits = defaultdict(lambda: 0)
hits: Dict[str, int] = defaultdict(lambda: 0)
mapping = {}
for original_name in nx_graph.nodes():
if not isinstance(original_name, str):
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def configure_accuracy_aware_paths(log_dir: Union[str, pathlib.Path]) -> Union[s
return acc_aware_log_dir


def product_dict(d: Dict[Hashable, List]) -> Dict:
def product_dict(d: Dict[Hashable, List[str]]) -> Iterable[Dict[Hashable, str]]:
"""
Generates dicts which enumerate the options for keys given in the input dict;
options are represented by list values in the input dict.
Expand Down
11 changes: 6 additions & 5 deletions nncf/common/utils/os.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@
import sys
from contextlib import contextmanager
from pathlib import Path
from typing import IO, Any, BinaryIO, Iterator, TextIO, Union

import psutil

import nncf


def fail_if_symlink(file: Path):
def fail_if_symlink(file: Path) -> None:
if file.is_symlink():
raise nncf.ValidationError("File {} is a symbolic link, aborting.".format(str(file)))


@contextmanager
def safe_open(file: Path, *args, **kwargs):
def safe_open(file: Path, *args, **kwargs) -> Iterator[Union[TextIO, BinaryIO, IO[Any]]]: # type: ignore
"""
Safe function to open file and return a stream.
Expand All @@ -38,11 +39,11 @@ def safe_open(file: Path, *args, **kwargs):
yield f


def is_windows():
def is_windows() -> bool:
return "win32" in sys.platform


def is_linux():
def is_linux() -> bool:
return "linux" in sys.platform


Expand All @@ -61,7 +62,7 @@ def get_available_cpu_count(logical: bool = True) -> int:
return 1


def get_available_memory_amount() -> int:
def get_available_memory_amount() -> float:
"""
:return: Available memory amount (bytes)
"""
Expand Down
Loading

0 comments on commit 28d99a0

Please sign in to comment.