diff --git a/docs/api/source/conf.py b/docs/api/source/conf.py index 19ec00be007..55a40d04220 100644 --- a/docs/api/source/conf.py +++ b/docs/api/source/conf.py @@ -75,7 +75,8 @@ def collect_api_entities() -> APIInfo: except Exception as e: skipped_modules[modname] = str(e) - from nncf.common.utils.api_marker import api + from nncf.common.utils.api_marker import API_MARKER_ATTR + from nncf.common.utils.api_marker import CANONICAL_ALIAS_ATTR canonical_imports_seen = set() @@ -86,7 +87,7 @@ def collect_api_entities() -> APIInfo: if ( objects_module == modname and (inspect.isclass(obj) or inspect.isfunction(obj)) - and hasattr(obj, api.API_MARKER_ATTR) + and hasattr(obj, API_MARKER_ATTR) ): marked_object_name = obj._nncf_api_marker # Check the actual name of the originally marked object @@ -95,8 +96,8 @@ def collect_api_entities() -> APIInfo: if marked_object_name != obj.__name__: continue fqn = f"{modname}.{obj_name}" - if hasattr(obj, api.CANONICAL_ALIAS_ATTR): - canonical_import_name = getattr(obj, api.CANONICAL_ALIAS_ATTR) + if hasattr(obj, CANONICAL_ALIAS_ATTR): + canonical_import_name = getattr(obj, CANONICAL_ALIAS_ATTR) if canonical_import_name in canonical_imports_seen: assert False, f"Duplicate canonical_alias detected: {canonical_import_name}" retval.fqn_vs_canonical_name[fqn] = canonical_import_name diff --git a/nncf/common/utils/api_marker.py b/nncf/common/utils/api_marker.py index 1b6b346231c..a56707b6703 100644 --- a/nncf/common/utils/api_marker.py +++ b/nncf/common/utils/api_marker.py @@ -8,24 +8,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Callable, TypeVar, Union -class api: - API_MARKER_ATTR = "_nncf_api_marker" - CANONICAL_ALIAS_ATTR = "_nncf_canonical_alias" +TObj = TypeVar("TObj", bound=Union[Callable[..., Any], type]) - def __init__(self, canonical_alias: str = None): - self._canonical_alias = canonical_alias +API_MARKER_ATTR = "_nncf_api_marker" +CANONICAL_ALIAS_ATTR = "_nncf_canonical_alias" - def __call__(self, obj: Any) -> Any: - # The value of the marker will be useful in determining - # whether we are handling a base class or a derived one. - setattr(obj, api.API_MARKER_ATTR, obj.__name__) - if self._canonical_alias is not None: - setattr(obj, api.CANONICAL_ALIAS_ATTR, self._canonical_alias) - return obj +def api(canonical_alias: str = None) -> Callable[[TObj], TObj]: + """ + Decorator function used to mark a object as an API. + + Example: + @api(canonical_alias="alias") + class Class: + pass + + @api(canonical_alias="alias") + def function(): + pass + + :param canonical_alias: The canonical alias for the API class. + """ + + def decorator(obj: TObj) -> TObj: + setattr(obj, API_MARKER_ATTR, obj.__name__) + if canonical_alias is not None: + setattr(obj, CANONICAL_ALIAS_ATTR, canonical_alias) + return obj -def is_api(obj: Any) -> bool: - return hasattr(obj, api.API_MARKER_ATTR) + return decorator diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 5d20a0d7ba6..d33ec9aa849 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -12,7 +12,7 @@ import functools import inspect from contextlib import contextmanager -from typing import Callable, List, Union +from typing import Callable, List, Optional, Union import torch import torch.utils.cpp_extension @@ -251,13 +251,13 @@ def get_torch_compile_wrapper(): """ @functools.wraps(_ORIG_TORCH_COMPILE) - def wrapper(model, *args, **kwargs): + def wrapper(model: Optional[Callable] = None, **kwargs): from nncf.torch.nncf_network import NNCFNetwork if isinstance(model, NNCFNetwork): raise TypeError("At the moment torch.compile() is not supported for models optimized by NNCF.") with disable_patching(): - return _ORIG_TORCH_COMPILE(model, *args, **kwargs) + return _ORIG_TORCH_COMPILE(model, **kwargs) return wrapper