Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make NNCF common utils code pass mypy checks #2780

Merged
merged 31 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d2478b4
Commit
anzr299 Apr 17, 2024
d243701
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Apr 17, 2024
80f91e7
Test to check Union
anzr299 Apr 17, 2024
cf1bc87
Changed all '|' to use Union[]
anzr299 Apr 17, 2024
fcd80bc
minor change
anzr299 Apr 17, 2024
468f262
Pre-commit checks pass
anzr299 Apr 17, 2024
15e0f05
black and isort reformatted
anzr299 Apr 17, 2024
940a603
test commit
anzr299 Apr 17, 2024
b2c2945
checked pre commit again
anzr299 Apr 17, 2024
c97e402
changing tuple to typing.Tuple
anzr299 Apr 18, 2024
6254b7f
Ignore matplotlib import error
anzr299 Apr 18, 2024
f74b857
change tuple to typing.Tuple
anzr299 Apr 18, 2024
74fd9de
Remove some casts, Add support for optional PIL Image Import
anzr299 Apr 18, 2024
2d8d42e
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Apr 19, 2024
e19b153
Add keyword arguments support in `_determine_compression_rate_step_va…
anzr299 Apr 19, 2024
6bc4c7e
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Apr 19, 2024
2d04ee0
Change list() to use cast(List[]) when converting compressed training…
anzr299 Apr 19, 2024
27a0175
Used cast(List[]) instead of list()
anzr299 Apr 19, 2024
ea4e4f6
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Apr 22, 2024
c2eb185
Fix issues with compression_rate_target
anzr299 Apr 23, 2024
8fef98e
Pytorch tests pass
anzr299 Apr 24, 2024
43ec062
Fix mypy errors regarding initialization
anzr299 Apr 24, 2024
724ce29
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Apr 27, 2024
fb5df42
Made some test passes.
anzr299 Apr 28, 2024
9a0bbc2
Merge branch 'openvinotoolkit:develop' into mypy-utils
anzr299 Jul 2, 2024
b8fd244
made changes
anzr299 Jul 2, 2024
923a2af
Final Commit mypy passes
anzr299 Jul 2, 2024
f4f4ccd
run pre-commit locally
anzr299 Jul 2, 2024
7cbbaee
Use Typing.tuple instead of tuple
anzr299 Jul 2, 2024
46fe918
Change timer test to accomodate for additional line
anzr299 Jul 3, 2024
ec2e8ed
Use typing.Any instead of object
anzr299 Jul 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[mypy]
files = nncf/common/sparsity, nncf/common/graph, nncf/common/accuracy_aware_training/
files = nncf/common/sparsity, nncf/common/graph, nncf/common/accuracy_aware_training/, nncf/common/utils/
follow_imports = silent
strict = True

Expand Down
2 changes: 1 addition & 1 deletion nncf/common/accuracy_aware_training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def remove_registry_prefix(algo_name: str) -> str:
)

return {
remove_registry_prefix(algo_name): controller_cls
remove_registry_prefix(algo_name): cast(CompressionAlgorithmController, controller_cls)
for algo_name, controller_cls in ADAPTIVE_COMPRESSION_CONTROLLERS.registry_dict.items()
}

Expand Down
2 changes: 1 addition & 1 deletion nncf/common/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from packaging import version


def warning_deprecated(msg):
def warning_deprecated(msg: str) -> None:
# Note: must use FutureWarning in order not to get suppressed by default
warnings.warn(msg, FutureWarning, stacklevel=2)

Expand Down
16 changes: 10 additions & 6 deletions nncf/common/graph/patterns/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,19 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> Dict[HWFusedPatternNam
if backend == BackendType.ONNX:
from nncf.onnx.hardware.fused_patterns import ONNX_HW_FUSED_PATTERNS

registry = ONNX_HW_FUSED_PATTERNS.registry_dict
registry = cast(Dict[HWFusedPatternNames, Callable[[], GraphPattern]], ONNX_HW_FUSED_PATTERNS.registry_dict)
return registry
if backend == BackendType.OPENVINO:
from nncf.openvino.hardware.fused_patterns import OPENVINO_HW_FUSED_PATTERNS

registry = OPENVINO_HW_FUSED_PATTERNS.registry_dict
registry = cast(
Dict[HWFusedPatternNames, Callable[[], GraphPattern]], OPENVINO_HW_FUSED_PATTERNS.registry_dict
)
return registry
if backend == BackendType.TORCH:
from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS

registry = PT_HW_FUSED_PATTERNS.registry_dict
registry = cast(Dict[HWFusedPatternNames, Callable[[], GraphPattern]], PT_HW_FUSED_PATTERNS.registry_dict)
return registry
raise ValueError(f"Hardware-fused patterns not implemented for {backend} backend.")

Expand All @@ -66,17 +68,19 @@ def _get_backend_ignored_patterns_map(
if backend == BackendType.ONNX:
from nncf.onnx.quantization.ignored_patterns import ONNX_IGNORED_PATTERNS

registry = ONNX_IGNORED_PATTERNS.registry_dict
registry = cast(Dict[IgnoredPatternNames, Callable[[], GraphPattern]], ONNX_IGNORED_PATTERNS.registry_dict)
return registry
if backend == BackendType.OPENVINO:
from nncf.openvino.quantization.ignored_patterns import OPENVINO_IGNORED_PATTERNS

registry = OPENVINO_IGNORED_PATTERNS.registry_dict
registry = cast(
Dict[IgnoredPatternNames, Callable[[], GraphPattern]], OPENVINO_IGNORED_PATTERNS.registry_dict
)
return registry
if backend == BackendType.TORCH:
from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS

registry = PT_IGNORED_PATTERNS.registry_dict
registry = cast(Dict[IgnoredPatternNames, Callable[[], GraphPattern]], PT_IGNORED_PATTERNS.registry_dict)
return registry
raise ValueError(f"Ignored patterns not implemented for {backend} backend.")

Expand Down
6 changes: 3 additions & 3 deletions nncf/common/utils/api_marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ class api:
def __init__(self, canonical_alias: str = None):
self._canonical_alias = canonical_alias

def __call__(self, obj):
def __call__(self, obj: object) -> object:
# The value of the marker will be useful in determining
# whether we are handling a base class or a derived one.
setattr(obj, api.API_MARKER_ATTR, obj.__name__)
setattr(obj, api.API_MARKER_ATTR, obj.__name__) # type: ignore
if self._canonical_alias is not None:
setattr(obj, api.CANONICAL_ALIAS_ATTR, self._canonical_alias)
return obj


def is_api(obj) -> bool:
def is_api(obj: object) -> bool:
return hasattr(obj, api.API_MARKER_ATTR)
10 changes: 5 additions & 5 deletions nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def is_torch_model(model: TModel) -> bool:
:param model: A target model.
:return: True if the model is an instance of torch.nn.Module, otherwise False.
"""
import torch
import torch # type: ignore

return isinstance(model, torch.nn.Module)

Expand All @@ -68,7 +68,7 @@ def is_tensorflow_model(model: TModel) -> bool:
:param model: A target model.
:return: True if the model is an instance of tensorflow.Module, otherwise False.
"""
import tensorflow
import tensorflow # type: ignore

return isinstance(model, tensorflow.Module)

Expand All @@ -80,7 +80,7 @@ def is_onnx_model(model: TModel) -> bool:
:param model: A target model.
:return: True if the model is an instance of onnx.ModelProto, otherwise False.
"""
import onnx
import onnx # type: ignore

return isinstance(model, onnx.ModelProto)

Expand All @@ -92,7 +92,7 @@ def is_openvino_model(model: TModel) -> bool:
:param model: A target model.
:return: True if the model is an instance of openvino.runtime.Model, otherwise False.
"""
import openvino.runtime as ov
import openvino.runtime as ov # type: ignore

return isinstance(model, ov.Model)

Expand Down Expand Up @@ -147,7 +147,7 @@ def copy_model(model: TModel) -> TModel:
model_backend = get_backend(model)
if model_backend == BackendType.OPENVINO:
# TODO(l-bat): Remove after fixing ticket: 100919
return model.clone()
return model.clone() # type: ignore
if model_backend == BackendType.TENSORFLOW:
# deepcopy and tensorflow.keras.models.clone_model does not work correctly on 2.8.4 version
from nncf.tensorflow.graph.model_transformer import TFModelTransformer
Expand Down
7 changes: 4 additions & 3 deletions nncf/common/utils/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,24 @@

import logging
from contextlib import contextmanager
from typing import Generator

from nncf.common.logging import nncf_logger

DEBUG_LOG_DIR = "./nncf_debug"


def is_debug():
def is_debug() -> bool:
return nncf_logger.getEffectiveLevel() == logging.DEBUG


def set_debug_log_dir(dir_: str):
def set_debug_log_dir(dir_: str) -> None:
global DEBUG_LOG_DIR
DEBUG_LOG_DIR = dir_


@contextmanager
def nncf_debug():
def nncf_debug() -> Generator[None, None, None]:
from nncf.common.logging.logger import set_log_level

set_log_level(logging.DEBUG)
Expand Down
8 changes: 4 additions & 4 deletions nncf/common/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
# limitations under the License.

from importlib import import_module
from typing import Callable, List
from typing import Any, Callable, Dict, List

from nncf.common.logging import nncf_logger

IMPORTED_DEPENDENCIES = {}
IMPORTED_DEPENDENCIES: Dict[str, bool] = {}


def skip_if_dependency_unavailable(dependencies: List[str]) -> Callable:
def skip_if_dependency_unavailable(dependencies: List[str]) -> Callable[[Callable[..., None]], Callable[..., None]]:
"""
Decorator factory to skip a noreturn function if dependencies are not met.

Expand All @@ -26,7 +26,7 @@ def skip_if_dependency_unavailable(dependencies: List[str]) -> Callable:
"""

def wrap(func: Callable[..., None]) -> Callable[..., None]:
def wrapped_f(*args, **kwargs):
def wrapped_f(*args: Any, **kwargs: Any): # type: ignore
for libname in dependencies:
if libname in IMPORTED_DEPENDENCIES:
if IMPORTED_DEPENDENCIES[libname]:
Expand Down
14 changes: 7 additions & 7 deletions nncf/common/utils/dot_file_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from collections import defaultdict
from typing import Dict

import networkx as nx
import networkx as nx # type: ignore


def write_dot_graph(G: nx.DiGraph, path: pathlib.Path):
def write_dot_graph(G: nx.DiGraph, path: pathlib.Path) -> None:
# NOTE: writing dot files with colons even in labels or other node/edge/graph attributes leads to an
# error. See https://github.com/networkx/networkx/issues/5962. If `relabel` is True in this function,
# then the colons (:) will be replaced with (^) symbols.
Expand Down Expand Up @@ -47,29 +47,29 @@ def read_dot_graph(path: pathlib.Path) -> nx.DiGraph:
REPLACEMENT_CHAR = "^"


def _maybe_escape_colons_in_attrs(data: Dict):
def _maybe_escape_colons_in_attrs(data: Dict[str, str]) -> None:
for attr_name in data:
attr_val = str(data[attr_name])
if RESERVED_CHAR in attr_val and not (attr_val[0] == '"' or attr_val[-1] == '"'):
data[attr_name] = '"' + data[attr_name] + '"' # escaped colons are allowed


def _unescape_colons_in_attrs_with_colons(data: Dict):
def _unescape_colons_in_attrs_with_colons(data: Dict[str, str]) -> None:
for attr_name in data:
attr_val = data[attr_name]
if RESERVED_CHAR in attr_val and (attr_val[0] == '"' and attr_val[-1] == '"'):
data[attr_name] = data[attr_name][1:-1]


def _remove_cosmetic_labels(graph: nx.DiGraph):
def _remove_cosmetic_labels(graph: nx.DiGraph) -> None:
for node_name, node_data in graph.nodes(data=True):
if "label" in node_data:
label = node_data["label"]
if node_name == label or '"' + node_name + '"' == label:
del node_data["label"]


def _add_cosmetic_labels(graph: nx.DiGraph, relabeled_node_mapping: Dict[str, str]):
def _add_cosmetic_labels(graph: nx.DiGraph, relabeled_node_mapping: Dict[str, str]) -> None:
for original_name, dot_name in relabeled_node_mapping.items():
node_data = graph.nodes[dot_name]
if "label" not in node_data:
Expand Down Expand Up @@ -98,7 +98,7 @@ def relabel_graph_for_dot_visualization(nx_graph: nx.Graph, from_reference: bool
__CHARACTER_REPLACE_FROM = REPLACEMENT_CHAR
__CHARACTER_REPLACE_TO = RESERVED_CHAR

hits = defaultdict(lambda: 0)
hits: Dict[str, int] = defaultdict(lambda: 0)
mapping = {}
for original_name in nx_graph.nodes():
if not isinstance(original_name, str):
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def configure_accuracy_aware_paths(log_dir: Union[str, pathlib.Path]) -> Union[s
return acc_aware_log_dir


def product_dict(d: Dict[Hashable, List]) -> Dict:
def product_dict(d: Dict[Hashable, List[str]]) -> Iterable[Dict[Hashable, str]]:
"""
Generates dicts which enumerate the options for keys given in the input dict;
options are represented by list values in the input dict.
Expand Down
11 changes: 6 additions & 5 deletions nncf/common/utils/os.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@
import sys
from contextlib import contextmanager
from pathlib import Path
from typing import IO, Any, BinaryIO, Iterator, TextIO, Union

import psutil

import nncf


def fail_if_symlink(file: Path):
def fail_if_symlink(file: Path) -> None:
if file.is_symlink():
raise nncf.ValidationError("File {} is a symbolic link, aborting.".format(str(file)))


@contextmanager
def safe_open(file: Path, *args, **kwargs):
def safe_open(file: Path, *args, **kwargs) -> Iterator[Union[TextIO, BinaryIO, IO[Any]]]: # type: ignore
"""
Safe function to open file and return a stream.

Expand All @@ -38,11 +39,11 @@ def safe_open(file: Path, *args, **kwargs):
yield f


def is_windows():
def is_windows() -> bool:
return "win32" in sys.platform


def is_linux():
def is_linux() -> bool:
return "linux" in sys.platform


Expand All @@ -61,7 +62,7 @@ def get_available_cpu_count(logical: bool = True) -> int:
return 1


def get_available_memory_amount() -> int:
def get_available_memory_amount() -> float:
"""
:return: Available memory amount (bytes)
"""
Expand Down
Loading
Loading