From f0023c320f79598973589088f5a75cb3ee27c2f2 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Mon, 30 Dec 2024 05:24:58 +0200 Subject: [PATCH] some more --- nncf/common/deprecation.py | 25 ++++++++++++------------ nncf/common/hook_handle.py | 12 ++++++------ nncf/common/insertion_point_graph.py | 10 +++++----- nncf/common/plotting.py | 3 ++- nncf/common/schedulers.py | 4 ++-- nncf/common/scopes.py | 4 +++- nncf/common/stateful_classes_registry.py | 22 ++++++++++----------- nncf/common/strip.py | 4 ++-- nncf/common/tensor.py | 16 ++++++++------- pyproject.toml | 20 +++++++++++++++---- 10 files changed, 69 insertions(+), 51 deletions(-) diff --git a/nncf/common/deprecation.py b/nncf/common/deprecation.py index 240a22a2eda..c0041feb3b2 100644 --- a/nncf/common/deprecation.py +++ b/nncf/common/deprecation.py @@ -10,21 +10,20 @@ # limitations under the License. import functools -import inspect import warnings -from typing import Callable, Type, TypeVar +from types import FunctionType +from typing import Any, Callable, TypeVar, cast from packaging import version +ClassOrFn = TypeVar("ClassOrFn") + def warning_deprecated(msg: str) -> None: # Note: must use FutureWarning in order not to get suppressed by default warnings.warn(msg, FutureWarning, stacklevel=2) -ClassOrFn = TypeVar("ClassOrFn", Callable, Type) - - class deprecated: """ A decorator for marking function calls or class instantiations as deprecated. A call to the marked function or an @@ -41,15 +40,17 @@ def __init__(self, msg: str = None, start_version: str = None, end_version: str self.end_version = version.parse(end_version) if end_version is not None else None def __call__(self, fn_or_class: ClassOrFn) -> ClassOrFn: - name = fn_or_class.__module__ + "." + fn_or_class.__name__ - if inspect.isclass(fn_or_class): - fn_or_class.__init__ = self._get_wrapper(fn_or_class.__init__, name) - return fn_or_class - return self._get_wrapper(fn_or_class, name) + if isinstance(fn_or_class, type): + name = f"{fn_or_class.__module__}.{fn_or_class.__name__}" + fn_or_class.__init__ = self._get_wrapper(fn_or_class.__init__, name) # type: ignore[misc] + return cast(ClassOrFn, fn_or_class) + if isinstance(fn_or_class, FunctionType): + return cast(ClassOrFn, self._get_wrapper(fn_or_class, f"{fn_or_class.__module__}.{fn_or_class.__name__}")) + raise TypeError("Unsupported type for @deprecated decorator") - def _get_wrapper(self, fn_to_wrap: Callable, name: str) -> Callable: + def _get_wrapper(self, fn_to_wrap: Callable[..., Any], name: str) -> Callable[..., Any]: @functools.wraps(fn_to_wrap) - def wrapped(*args, **kwargs): + def wrapped(*args: Any, **kwargs: Any) -> Any: msg = f"Usage of {name} is deprecated " if self.start_version is not None: msg += f"starting from NNCF v{str(self.start_version)} " diff --git a/nncf/common/hook_handle.py b/nncf/common/hook_handle.py index 0dde089455a..4bbe4841dfb 100644 --- a/nncf/common/hook_handle.py +++ b/nncf/common/hook_handle.py @@ -10,7 +10,7 @@ # limitations under the License. import weakref -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union class HookHandle: @@ -24,7 +24,7 @@ def __init__(self, hooks_registry: Dict[Any, Any], hook_id: str): :param hooks_registry: A dictionary of hooks, indexed by hook `id`. :param hook_id: Hook id to use as key in the dictionary of hooks. """ - self.hooks_registry_ref = weakref.ref(hooks_registry) + self.hooks_registry_ref: Optional[weakref.ReferenceType[Dict[Any, Any]]] = weakref.ref(hooks_registry) self._hook_id = hook_id @property @@ -58,8 +58,8 @@ def add_op_to_registry(hooks_registry: Dict[Any, Any], op: Any) -> HookHandle: """ if hooks_registry: hook_id = max(map(int, hooks_registry)) - hook_id = str(hook_id + 1) + hook_id_str = str(hook_id + 1) else: - hook_id = "0" - hooks_registry[hook_id] = op - return HookHandle(hooks_registry, hook_id) + hook_id_str = "0" + hooks_registry[hook_id_str] = op + return HookHandle(hooks_registry, hook_id_str) diff --git a/nncf/common/insertion_point_graph.py b/nncf/common/insertion_point_graph.py index 7bb1368dadf..d607a3c67e1 100644 --- a/nncf/common/insertion_point_graph.py +++ b/nncf/common/insertion_point_graph.py @@ -14,7 +14,7 @@ from enum import Enum from typing import Dict, List, Set -import networkx as nx +import networkx as nx # type: ignore from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNodeName @@ -36,7 +36,7 @@ def __init__(self, target_node_name: str, input_port_id: int): self.target_node_name = target_node_name self.input_port_id = input_port_id - def __str__(self): + def __str__(self) -> str: return str(self.input_port_id) + " " + self.target_node_name @@ -44,11 +44,11 @@ class PostHookInsertionPoint: def __init__(self, target_node_name: str): self.target_node_name = target_node_name - def __str__(self): + def __str__(self) -> str: return self.target_node_name -class InsertionPointGraph(nx.DiGraph): +class InsertionPointGraph(nx.DiGraph): # type: ignore """ This graph is built from the NNCFGraph representation of the model control flow graph and adds ephemeral "insertion point nodes" into the NNCF model graph representation corresponding to operator pre- and @@ -304,7 +304,7 @@ def get_merged_node_from_single_node_key(self, node_key: str) -> str: if data[InsertionPointGraph.IS_MERGED_NODE_ATTR]: for nncf_node in data[InsertionPointGraph.MERGED_NNCF_NODE_LIST_NODE_ATTR]: if node_key == nncf_node.node_key: - return node + return node # type: ignore return node_key def get_ip_graph_with_merged_hw_optimized_operations( diff --git a/nncf/common/plotting.py b/nncf/common/plotting.py index b6fb2d76268..bc990ca9890 100644 --- a/nncf/common/plotting.py +++ b/nncf/common/plotting.py @@ -9,10 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager +from typing import Generator @contextmanager -def noninteractive_plotting(): +def noninteractive_plotting() -> Generator[None, None, None]: from matplotlib import pyplot as plt backend = plt.get_backend() diff --git a/nncf/common/schedulers.py b/nncf/common/schedulers.py index fe064d4a1fd..36f838787c0 100644 --- a/nncf/common/schedulers.py +++ b/nncf/common/schedulers.py @@ -69,7 +69,7 @@ def __call__(self, epoch: int, step: Optional[int] = None, steps_per_epoch: Opti else: value = self.initial_value + (self.target_value - self.initial_value) * np.power(progress, self.power) - return value + return float(value) class MultiStepSchedule: @@ -141,7 +141,7 @@ def __call__(self, epoch: int) -> float: if self.target_epoch == 0: return self.target_value - value = self.initial_value * np.power(self.decay_rate, epoch / self.target_epoch) + value = self.initial_value * float(np.power(self.decay_rate, epoch / self.target_epoch)) return max(value, self.target_value) diff --git a/nncf/common/scopes.py b/nncf/common/scopes.py index 218a058653b..d884705ebb5 100644 --- a/nncf/common/scopes.py +++ b/nncf/common/scopes.py @@ -73,7 +73,7 @@ def should_consider_scope( ) -def get_not_matched_scopes(scope: Union[List[str], str, IgnoredScope], nodes: List[NNCFNode]) -> List[str]: +def get_not_matched_scopes(scope: Optional[Union[List[str], str, IgnoredScope]], nodes: List[NNCFNode]) -> List[str]: """ Return list of scope that do not match node list. @@ -82,6 +82,8 @@ def get_not_matched_scopes(scope: Union[List[str], str, IgnoredScope], nodes: Li :return : List of not matched scopes. """ + if scope is None: + return [] if isinstance(scope, str): patterns = [scope] diff --git a/nncf/common/stateful_classes_registry.py b/nncf/common/stateful_classes_registry.py index f608254d488..3cbf811c80e 100644 --- a/nncf/common/stateful_classes_registry.py +++ b/nncf/common/stateful_classes_registry.py @@ -20,11 +20,11 @@ class StatefulClassesRegistry: REQUIRED_METHOD_NAME = "from_state" - def __init__(self): - self._name_vs_class_map: Dict[str, object] = {} - self._class_vs_name_map: Dict[object, str] = {} + def __init__(self) -> None: + self._name_vs_class_map: Dict[str, type] = {} + self._class_vs_name_map: Dict[type, str] = {} - def register(self, name: str = None) -> Callable: + def register(self, name: str = None) -> Callable[[type], type]: """ Decorator to map class with some name - specified in the argument or name of the class. @@ -32,7 +32,7 @@ def register(self, name: str = None) -> Callable: :return: The inner function for registration. """ - def decorator(cls): + def decorator(cls: type) -> type: registered_name = name if name is not None else cls.__name__ if registered_name in self._name_vs_class_map: @@ -59,7 +59,7 @@ def decorator(cls): return decorator - def get_registered_class(self, registered_name: str) -> object: + def get_registered_class(self, registered_name: str) -> type: """ Provides a class that was registered with the given name. @@ -70,7 +70,7 @@ def get_registered_class(self, registered_name: str) -> object: return self._name_vs_class_map[registered_name] raise KeyError("No registered stateful classes with {} name".format(registered_name)) - def get_registered_name(self, stateful_cls: object) -> str: + def get_registered_name(self, stateful_cls: type) -> str: """ Provides a name that was used to register the given stateful class. @@ -88,7 +88,7 @@ class CommonStatefulClassesRegistry: """ @staticmethod - def register(name: str = None) -> Callable: + def register(name: str = None) -> Callable[[type], type]: """ Decorator to map class with some name - specified in the argument or name of the class. @@ -96,7 +96,7 @@ def register(name: str = None) -> Callable: :return: The inner function for registration. """ - def decorator(cls): + def decorator(cls: type) -> type: PT_STATEFUL_CLASSES.register(name)(cls) TF_STATEFUL_CLASSES.register(name)(cls) return cls @@ -104,7 +104,7 @@ def decorator(cls): return decorator @staticmethod - def get_registered_class(registered_name: str) -> object: + def get_registered_class(registered_name: str) -> type: """ Provides a class that was registered with the given name. @@ -114,7 +114,7 @@ def get_registered_class(registered_name: str) -> object: return PT_STATEFUL_CLASSES.get_registered_class(registered_name) @staticmethod - def get_registered_name(stateful_cls: object) -> str: + def get_registered_name(stateful_cls: type) -> str: """ Provides a name that was used to register the given stateful class. diff --git a/nncf/common/strip.py b/nncf/common/strip.py index c2cd592ff46..545ac628f4b 100644 --- a/nncf/common/strip.py +++ b/nncf/common/strip.py @@ -33,8 +33,8 @@ def strip(model: TModel, do_copy: bool = True) -> TModel: """ model_backend = get_backend(model) if model_backend == BackendType.TORCH: - from nncf.torch import strip as strip_pt + from nncf.torch.strip import strip as strip_pt - return strip_pt(model, do_copy) + return strip_pt(model, do_copy) # type: ignore raise nncf.UnsupportedBackendError(f"Method `strip` does not support for {model_backend.value} backend.") diff --git a/nncf/common/tensor.py b/nncf/common/tensor.py index af6fddd46c4..df2386ce353 100644 --- a/nncf/common/tensor.py +++ b/nncf/common/tensor.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, TypeVar +from typing import List, TypeVar import nncf @@ -23,24 +23,26 @@ class NNCFTensor: An interface of framework specific tensors for common NNCF algorithms. """ - def __init__(self, tensor: Optional[TensorType]): + def __init__(self, tensor: TensorType): self._tensor = tensor - def __eq__(self, other: "NNCFTensor") -> bool: - return self._tensor == other.tensor + def __eq__(self, other: object) -> bool: + if not isinstance(other, NNCFTensor): + raise nncf.InternalError("Attempt to compare NNCFTensor with a non-NNCFTensor object") + return bool(self._tensor == other.tensor) @property - def tensor(self) -> TensorType: + def tensor(self) -> TensorType: # type: ignore return self._tensor @property def shape(self) -> List[int]: if self._tensor is None: raise nncf.InternalError("Attempt to get shape of empty NNCFTensor") - return self._tensor.shape + return self._tensor.shape # type: ignore @property - def device(self) -> DeviceType: + def device(self) -> DeviceType: # type: ignore raise NotImplementedError def is_empty(self) -> bool: diff --git a/pyproject.toml b/pyproject.toml index 883b0b50ee5..f9601087a1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,8 @@ known_third_party = "datasets" [tool.mypy] follow_imports = "silent" +# disable_error_code = ["import-untyped"] +plugins = "numpy.typing.mypy_plugin" strict = true # should be removed later # mypy recommends the following tool as an autofix: @@ -95,17 +97,27 @@ strict = true implicit_optional = true files = [ "nncf/api", - "nncf/common/sparsity", + "nncf/common/deprecation.py", + "nncf/common/engine.py", + "nncf/common/hook_handle.py", + "nncf/common/insertion_point_graph.py", + "nncf/common/plotting.py", + "nncf/common/schedulers.py", + "nncf/common/scopes.py", + "nncf/common/stateful_classes_registry.py", + "nncf/common/strip.py", + "nncf/common/tensor.py", + "nncf/common/accuracy_aware_training", "nncf/common/graph", - "nncf/common/accuracy_aware_training/", - "nncf/common/utils/", + "nncf/common/sparsity", "nncf/common/tensor_statistics", + "nncf/common/utils", "nncf/experimental/torch2", "nncf/quantization/passes.py", "nncf/quantization/advanced_parameters.py", "nncf/quantization/range_estimator.py", "nncf/quantization/telemetry_extractors.py", - "nncf/telemetry/", + "nncf/telemetry", ] [tool.ruff]