Skip to content

Commit

Permalink
some more
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Dec 30, 2024
1 parent b4d6a00 commit f0023c3
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 51 deletions.
25 changes: 13 additions & 12 deletions nncf/common/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,20 @@
# limitations under the License.

import functools
import inspect
import warnings
from typing import Callable, Type, TypeVar
from types import FunctionType
from typing import Any, Callable, TypeVar, cast

from packaging import version

ClassOrFn = TypeVar("ClassOrFn")


def warning_deprecated(msg: str) -> None:
# Note: must use FutureWarning in order not to get suppressed by default
warnings.warn(msg, FutureWarning, stacklevel=2)


ClassOrFn = TypeVar("ClassOrFn", Callable, Type)


class deprecated:
"""
A decorator for marking function calls or class instantiations as deprecated. A call to the marked function or an
Expand All @@ -41,15 +40,17 @@ def __init__(self, msg: str = None, start_version: str = None, end_version: str
self.end_version = version.parse(end_version) if end_version is not None else None

def __call__(self, fn_or_class: ClassOrFn) -> ClassOrFn:
name = fn_or_class.__module__ + "." + fn_or_class.__name__
if inspect.isclass(fn_or_class):
fn_or_class.__init__ = self._get_wrapper(fn_or_class.__init__, name)
return fn_or_class
return self._get_wrapper(fn_or_class, name)
if isinstance(fn_or_class, type):
name = f"{fn_or_class.__module__}.{fn_or_class.__name__}"
fn_or_class.__init__ = self._get_wrapper(fn_or_class.__init__, name) # type: ignore[misc]
return cast(ClassOrFn, fn_or_class)
if isinstance(fn_or_class, FunctionType):
return cast(ClassOrFn, self._get_wrapper(fn_or_class, f"{fn_or_class.__module__}.{fn_or_class.__name__}"))
raise TypeError("Unsupported type for @deprecated decorator")

def _get_wrapper(self, fn_to_wrap: Callable, name: str) -> Callable:
def _get_wrapper(self, fn_to_wrap: Callable[..., Any], name: str) -> Callable[..., Any]:
@functools.wraps(fn_to_wrap)
def wrapped(*args, **kwargs):
def wrapped(*args: Any, **kwargs: Any) -> Any:
msg = f"Usage of {name} is deprecated "
if self.start_version is not None:
msg += f"starting from NNCF v{str(self.start_version)} "
Expand Down
12 changes: 6 additions & 6 deletions nncf/common/hook_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

import weakref
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union


class HookHandle:
Expand All @@ -24,7 +24,7 @@ def __init__(self, hooks_registry: Dict[Any, Any], hook_id: str):
:param hooks_registry: A dictionary of hooks, indexed by hook `id`.
:param hook_id: Hook id to use as key in the dictionary of hooks.
"""
self.hooks_registry_ref = weakref.ref(hooks_registry)
self.hooks_registry_ref: Optional[weakref.ReferenceType[Dict[Any, Any]]] = weakref.ref(hooks_registry)
self._hook_id = hook_id

@property
Expand Down Expand Up @@ -58,8 +58,8 @@ def add_op_to_registry(hooks_registry: Dict[Any, Any], op: Any) -> HookHandle:
"""
if hooks_registry:
hook_id = max(map(int, hooks_registry))
hook_id = str(hook_id + 1)
hook_id_str = str(hook_id + 1)
else:
hook_id = "0"
hooks_registry[hook_id] = op
return HookHandle(hooks_registry, hook_id)
hook_id_str = "0"
hooks_registry[hook_id_str] = op
return HookHandle(hooks_registry, hook_id_str)
10 changes: 5 additions & 5 deletions nncf/common/insertion_point_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from enum import Enum
from typing import Dict, List, Set

import networkx as nx
import networkx as nx # type: ignore

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNodeName
Expand All @@ -36,19 +36,19 @@ def __init__(self, target_node_name: str, input_port_id: int):
self.target_node_name = target_node_name
self.input_port_id = input_port_id

def __str__(self):
def __str__(self) -> str:
return str(self.input_port_id) + " " + self.target_node_name


class PostHookInsertionPoint:
def __init__(self, target_node_name: str):
self.target_node_name = target_node_name

def __str__(self):
def __str__(self) -> str:
return self.target_node_name


class InsertionPointGraph(nx.DiGraph):
class InsertionPointGraph(nx.DiGraph): # type: ignore
"""
This graph is built from the NNCFGraph representation of the model control flow graph and adds ephemeral
"insertion point nodes" into the NNCF model graph representation corresponding to operator pre- and
Expand Down Expand Up @@ -304,7 +304,7 @@ def get_merged_node_from_single_node_key(self, node_key: str) -> str:
if data[InsertionPointGraph.IS_MERGED_NODE_ATTR]:
for nncf_node in data[InsertionPointGraph.MERGED_NNCF_NODE_LIST_NODE_ATTR]:
if node_key == nncf_node.node_key:
return node
return node # type: ignore
return node_key

def get_ip_graph_with_merged_hw_optimized_operations(
Expand Down
3 changes: 2 additions & 1 deletion nncf/common/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Generator


@contextmanager
def noninteractive_plotting():
def noninteractive_plotting() -> Generator[None, None, None]:
from matplotlib import pyplot as plt

backend = plt.get_backend()
Expand Down
4 changes: 2 additions & 2 deletions nncf/common/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __call__(self, epoch: int, step: Optional[int] = None, steps_per_epoch: Opti
else:
value = self.initial_value + (self.target_value - self.initial_value) * np.power(progress, self.power)

return value
return float(value)


class MultiStepSchedule:
Expand Down Expand Up @@ -141,7 +141,7 @@ def __call__(self, epoch: int) -> float:
if self.target_epoch == 0:
return self.target_value

value = self.initial_value * np.power(self.decay_rate, epoch / self.target_epoch)
value = self.initial_value * float(np.power(self.decay_rate, epoch / self.target_epoch))
return max(value, self.target_value)


Expand Down
4 changes: 3 additions & 1 deletion nncf/common/scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def should_consider_scope(
)


def get_not_matched_scopes(scope: Union[List[str], str, IgnoredScope], nodes: List[NNCFNode]) -> List[str]:
def get_not_matched_scopes(scope: Optional[Union[List[str], str, IgnoredScope]], nodes: List[NNCFNode]) -> List[str]:
"""
Return list of scope that do not match node list.
Expand All @@ -82,6 +82,8 @@ def get_not_matched_scopes(scope: Union[List[str], str, IgnoredScope], nodes: Li
:return : List of not matched scopes.
"""
if scope is None:
return []

if isinstance(scope, str):
patterns = [scope]
Expand Down
22 changes: 11 additions & 11 deletions nncf/common/stateful_classes_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ class StatefulClassesRegistry:

REQUIRED_METHOD_NAME = "from_state"

def __init__(self):
self._name_vs_class_map: Dict[str, object] = {}
self._class_vs_name_map: Dict[object, str] = {}
def __init__(self) -> None:
self._name_vs_class_map: Dict[str, type] = {}
self._class_vs_name_map: Dict[type, str] = {}

def register(self, name: str = None) -> Callable:
def register(self, name: str = None) -> Callable[[type], type]:
"""
Decorator to map class with some name - specified in the argument or name of the class.
:param name: The registration name. By default, it's name of the class.
:return: The inner function for registration.
"""

def decorator(cls):
def decorator(cls: type) -> type:
registered_name = name if name is not None else cls.__name__

if registered_name in self._name_vs_class_map:
Expand All @@ -59,7 +59,7 @@ def decorator(cls):

return decorator

def get_registered_class(self, registered_name: str) -> object:
def get_registered_class(self, registered_name: str) -> type:
"""
Provides a class that was registered with the given name.
Expand All @@ -70,7 +70,7 @@ def get_registered_class(self, registered_name: str) -> object:
return self._name_vs_class_map[registered_name]
raise KeyError("No registered stateful classes with {} name".format(registered_name))

def get_registered_name(self, stateful_cls: object) -> str:
def get_registered_name(self, stateful_cls: type) -> str:
"""
Provides a name that was used to register the given stateful class.
Expand All @@ -88,23 +88,23 @@ class CommonStatefulClassesRegistry:
"""

@staticmethod
def register(name: str = None) -> Callable:
def register(name: str = None) -> Callable[[type], type]:
"""
Decorator to map class with some name - specified in the argument or name of the class.
:param name: The registration name. By default, it's name of the class.
:return: The inner function for registration.
"""

def decorator(cls):
def decorator(cls: type) -> type:
PT_STATEFUL_CLASSES.register(name)(cls)
TF_STATEFUL_CLASSES.register(name)(cls)
return cls

return decorator

@staticmethod
def get_registered_class(registered_name: str) -> object:
def get_registered_class(registered_name: str) -> type:
"""
Provides a class that was registered with the given name.
Expand All @@ -114,7 +114,7 @@ def get_registered_class(registered_name: str) -> object:
return PT_STATEFUL_CLASSES.get_registered_class(registered_name)

@staticmethod
def get_registered_name(stateful_cls: object) -> str:
def get_registered_name(stateful_cls: type) -> str:
"""
Provides a name that was used to register the given stateful class.
Expand Down
4 changes: 2 additions & 2 deletions nncf/common/strip.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def strip(model: TModel, do_copy: bool = True) -> TModel:
"""
model_backend = get_backend(model)
if model_backend == BackendType.TORCH:
from nncf.torch import strip as strip_pt
from nncf.torch.strip import strip as strip_pt

return strip_pt(model, do_copy)
return strip_pt(model, do_copy) # type: ignore

raise nncf.UnsupportedBackendError(f"Method `strip` does not support for {model_backend.value} backend.")
16 changes: 9 additions & 7 deletions nncf/common/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, TypeVar
from typing import List, TypeVar

import nncf

Expand All @@ -23,24 +23,26 @@ class NNCFTensor:
An interface of framework specific tensors for common NNCF algorithms.
"""

def __init__(self, tensor: Optional[TensorType]):
def __init__(self, tensor: TensorType):
self._tensor = tensor

def __eq__(self, other: "NNCFTensor") -> bool:
return self._tensor == other.tensor
def __eq__(self, other: object) -> bool:
if not isinstance(other, NNCFTensor):
raise nncf.InternalError("Attempt to compare NNCFTensor with a non-NNCFTensor object")
return bool(self._tensor == other.tensor)

@property
def tensor(self) -> TensorType:
def tensor(self) -> TensorType: # type: ignore
return self._tensor

@property
def shape(self) -> List[int]:
if self._tensor is None:
raise nncf.InternalError("Attempt to get shape of empty NNCFTensor")
return self._tensor.shape
return self._tensor.shape # type: ignore

@property
def device(self) -> DeviceType:
def device(self) -> DeviceType: # type: ignore
raise NotImplementedError

def is_empty(self) -> bool:
Expand Down
20 changes: 16 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,36 @@ known_third_party = "datasets"

[tool.mypy]
follow_imports = "silent"
# disable_error_code = ["import-untyped"]
plugins = "numpy.typing.mypy_plugin"
strict = true
# should be removed later
# mypy recommends the following tool as an autofix:
# https://github.com/hauntsaninja/no_implicit_optional
implicit_optional = true
files = [
"nncf/api",
"nncf/common/sparsity",
"nncf/common/deprecation.py",
"nncf/common/engine.py",
"nncf/common/hook_handle.py",
"nncf/common/insertion_point_graph.py",
"nncf/common/plotting.py",
"nncf/common/schedulers.py",
"nncf/common/scopes.py",
"nncf/common/stateful_classes_registry.py",
"nncf/common/strip.py",
"nncf/common/tensor.py",
"nncf/common/accuracy_aware_training",
"nncf/common/graph",
"nncf/common/accuracy_aware_training/",
"nncf/common/utils/",
"nncf/common/sparsity",
"nncf/common/tensor_statistics",
"nncf/common/utils",
"nncf/experimental/torch2",
"nncf/quantization/passes.py",
"nncf/quantization/advanced_parameters.py",
"nncf/quantization/range_estimator.py",
"nncf/quantization/telemetry_extractors.py",
"nncf/telemetry/",
"nncf/telemetry",
]

[tool.ruff]
Expand Down

0 comments on commit f0023c3

Please sign in to comment.