Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix annotation for tracked_function #3163

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions nncf/telemetry/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.
import functools
import inspect
from typing import Callable, List, Union
from typing import Any, Callable, List, Optional, TypeVar, Union

from nncf.telemetry.events import MODEL_BASED_CATEGORY
from nncf.telemetry.events import get_current_category
Expand All @@ -21,6 +21,8 @@
from nncf.telemetry.extractors import VerbatimTelemetryExtractor
from nncf.telemetry.wrapper import telemetry

TFunction = TypeVar("TFunction", bound=Callable[..., Any])


class tracked_function:
"""
Expand All @@ -29,7 +31,7 @@ class tracked_function:
function execution. The category of the session and events will be determined by parameters to the decorator.
"""

def __init__(self, category: str = None, extractors: List[Union[str, TelemetryExtractor]] = None):
def __init__(self, category: str = None, extractors: Optional[List[Union[str, TelemetryExtractor]]] = None) -> None:
"""
:param category: A category to be attributed to the events. If set to None, no events will be sent.
:param extractors: Add argument names in this list as string values to send an event with an "action" equal to
Expand All @@ -44,11 +46,11 @@ def __init__(self, category: str = None, extractors: List[Union[str, TelemetryEx
else:
self._collectors = []

def __call__(self, fn: Callable) -> Callable:
def __call__(self, fn: TFunction) -> TFunction:
fn_signature = inspect.signature(fn)

@functools.wraps(fn)
def wrapped(*args, **kwargs):
def wrapped(*args: Any, **kwargs: Any) -> Any:
bound_args = fn_signature.bind(*args, **kwargs)
bound_args.apply_defaults()

Expand All @@ -59,7 +61,7 @@ def wrapped(*args, **kwargs):
events: List[CollectedEvent] = []
for collector in self._collectors:
argname = collector.argname
argvalue = bound_args.arguments[argname] if argname is not None else None
argvalue = bound_args.arguments[argname] if argname else None
event = collector.extract(argvalue)
events.append(event)

Expand All @@ -82,4 +84,4 @@ def wrapped(*args, **kwargs):
telemetry.end_session(self._category)
return retval

return wrapped
return wrapped # type: ignore[return-value]
8 changes: 4 additions & 4 deletions nncf/telemetry/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from contextlib import contextmanager
from typing import Optional, TypeVar
from typing import Generator, Optional, TypeVar

from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
Expand All @@ -25,12 +25,12 @@
# Dynamic categories
MODEL_BASED_CATEGORY = "model_based"

CURRENT_CATEGORY = None
CURRENT_CATEGORY: Optional[str] = None

TModel = TypeVar("TModel")


def _set_current_category(category: str):
def _set_current_category(category: Optional[str]) -> None:
global CURRENT_CATEGORY
CURRENT_CATEGORY = category

Expand All @@ -56,7 +56,7 @@ def get_model_based_category(model: TModel) -> str:


@contextmanager
def telemetry_category(category: str) -> str:
def telemetry_category(category: Optional[str]) -> Generator[Optional[str], None, None]:
previous_category = get_current_category()
_set_current_category(category)
yield category
Expand Down
8 changes: 4 additions & 4 deletions nncf/telemetry/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from enum import Enum
from typing import Any, Optional, Union

SerializableData = Union[str, Enum]
SerializableData = Union[str, Enum, bool]


@dataclass
Expand All @@ -26,16 +26,16 @@ class CollectedEvent:
"""

name: str
data: SerializableData = None # GA limitations
int_data: int = None
data: Optional[SerializableData] = None # GA limitations
int_data: Optional[int] = None


class TelemetryExtractor(ABC):
"""
Interface for custom telemetry extractors, to be used with the `nncf.telemetry.tracked_function` decorator.
"""

def __init__(self, argname: Optional[str] = None):
def __init__(self, argname: str = ""):
self._argname = argname

@property
Expand Down
28 changes: 14 additions & 14 deletions nncf/telemetry/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import sys
from abc import ABC
from abc import abstractmethod
from typing import Callable, Optional
from typing import Any, Callable, Optional
from unittest.mock import MagicMock

from nncf.common.logging import nncf_logger
Expand All @@ -29,7 +29,7 @@ class ITelemetry(ABC):
# https://support.google.com/analytics/answer/1033068

@abstractmethod
def start_session(self, category: str, **kwargs):
def start_session(self, category: str, **kwargs: Any) -> None:
"""
Sends a message about starting of a new session.

Expand All @@ -45,9 +45,9 @@ def send_event(
event_action: str,
event_label: str,
event_value: Optional[int] = None,
force_send=False,
**kwargs,
):
force_send: bool = False,
**kwargs: Any,
) -> None:
"""
Send single event.

Expand All @@ -61,7 +61,7 @@ def send_event(
"""

@abstractmethod
def end_session(self, category: str, **kwargs):
def end_session(self, category: str, **kwargs: Any) -> None:
"""
Sends a message about ending of the current session.

Expand All @@ -78,7 +78,7 @@ def skip_if_raised(func: Callable[..., None]) -> Callable[..., None]:
"""

@functools.wraps(func)
def wrapped(*args, **kwargs):
def wrapped(*args: Any, **kwargs: Any) -> None:
try:
func(*args, **kwargs)

Expand All @@ -91,7 +91,7 @@ def wrapped(*args, **kwargs):
class NNCFTelemetry(ITelemetry):
MEASUREMENT_ID = "G-W5E9RNLD4H"

def __init__(self):
def __init__(self) -> None:
self._app_name = "nncf"
self._app_version = __version__
try:
Expand All @@ -108,7 +108,7 @@ def __init__(self):
nncf_logger.debug(f"Failed to instantiate telemetry object: exception {e}")

@skip_if_raised
def start_session(self, category: str, **kwargs):
def start_session(self, category: str, **kwargs: Any) -> None:
self._impl.start_session(category, **kwargs)

@skip_if_raised
Expand All @@ -118,9 +118,9 @@ def send_event(
event_action: str,
event_label: str,
event_value: Optional[int] = None,
force_send=False,
**kwargs,
):
force_send: bool = False,
**kwargs: Any,
) -> None:
if event_value is None:
event_value = 1
self._impl.send_event(
Expand All @@ -135,12 +135,12 @@ def send_event(
)

@skip_if_raised
def end_session(self, category: str, **kwargs):
def end_session(self, category: str, **kwargs: Any) -> None:
self._impl.end_session(category, **kwargs)


try:
from openvino_telemetry import Telemetry
from openvino_telemetry import Telemetry # type: ignore

telemetry = NNCFTelemetry()
except ImportError:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ files = [
"nncf/common/utils/",
"nncf/common/tensor_statistics",
"nncf/experimental/torch2",
"nncf/telemetry/",
]

[tool.ruff]
Expand Down