diff --git a/nncf/common/initialization/batchnorm_adaptation.py b/nncf/common/initialization/batchnorm_adaptation.py index 8b8e2ce8d8b..f3febc0ce7c 100644 --- a/nncf/common/initialization/batchnorm_adaptation.py +++ b/nncf/common/initialization/batchnorm_adaptation.py @@ -12,7 +12,7 @@ import math from abc import ABC from abc import abstractmethod -from typing import Optional +from typing import Optional, Type from nncf.api.compression import TModel from nncf.common.initialization.dataloader import NNCFDataLoader @@ -83,6 +83,7 @@ def run(self, model: TModel) -> None: :param model: A model for which the algorithm will be applied. """ backend = get_backend(model) + impl_cls: Type[BatchnormAdaptationAlgorithmImpl] if backend is BackendType.TORCH: from nncf.torch.batchnorm_adaptation import PTBatchnormAdaptationAlgorithmImpl diff --git a/nncf/common/initialization/dataloader.py b/nncf/common/initialization/dataloader.py index 07fa4b98ab9..5934121a338 100644 --- a/nncf/common/initialization/dataloader.py +++ b/nncf/common/initialization/dataloader.py @@ -11,6 +11,7 @@ """Interface for user-defined data usage during the compression algorithm initialization process.""" from abc import ABC from abc import abstractmethod +from typing import Any from nncf.common.utils.api_marker import api @@ -31,7 +32,7 @@ def batch_size(self) -> int: """ @abstractmethod - def __iter__(self): + def __iter__(self) -> Any: """ Creates an iterator for the elements of a custom data source. The returned iterator implements the Python Iterator protocol. diff --git a/nncf/torch/extensions/__init__.py b/nncf/torch/extensions/__init__.py index 59edea1ca24..2dcedab4da0 100644 --- a/nncf/torch/extensions/__init__.py +++ b/nncf/torch/extensions/__init__.py @@ -24,7 +24,6 @@ import nncf from nncf.common.logging import nncf_logger -from nncf.common.logging.logger import extension_is_loading_info_log from nncf.common.utils.api_marker import api from nncf.common.utils.registry import Registry @@ -95,27 +94,27 @@ def get(self, fn_name: str) -> Callable: if self._loaded_namespace is None: timeout = int(os.environ.get(EXTENSION_LOAD_TIMEOUT_ENV_VAR, DEFAULT_EXTENSION_LOAD_TIMEOUT)) timeout = timeout if timeout > 0 else None - - with extension_is_loading_info_log(self._loader.name()): - try: - pool = ThreadPool(processes=1) - async_result = pool.apply_async(self._loader.load) - self._loaded_namespace = async_result.get(timeout=timeout) - except MPTimeoutError as error: - msg = textwrap.dedent( - f"""\ - The extension load function failed to execute within {timeout} seconds. - This may be due to leftover lock files from the PyTorch C++ extension build process. - If this is the case, running the following command should help: - rm -rf {self._loader.get_build_dir()} - For a machine with poor performance, you may try increasing the time limit by setting the environment variable: - {EXTENSION_LOAD_TIMEOUT_ENV_VAR}=180 - Or disable timeout by set: - {EXTENSION_LOAD_TIMEOUT_ENV_VAR}=0 - For more information, see FAQ entry at: https://github.com/openvinotoolkit/nncf/blob/develop/docs/FAQ.md#importing-anything-from-nncftorch-hangs - """ # noqa: E501 - ) - raise ExtensionLoaderTimeoutException(msg) from error + nncf_logger.info(f"Compiling and loading torch extension: {self._loader.name()}...") + try: + pool = ThreadPool(processes=1) + async_result = pool.apply_async(self._loader.load) + self._loaded_namespace = async_result.get(timeout=timeout) + except MPTimeoutError as error: + msg = textwrap.dedent( + f"""\ + The extension load function failed to execute within {timeout} seconds. + This may be due to leftover lock files from the PyTorch C++ extension build process. + If this is the case, running the following command should help: + rm -rf {self._loader.get_build_dir()} + For a machine with poor performance, you may try increasing the time limit by setting the environment variable: + {EXTENSION_LOAD_TIMEOUT_ENV_VAR}=180 + Or disable timeout by set: + {EXTENSION_LOAD_TIMEOUT_ENV_VAR}=0 + For more information, see FAQ entry at: https://github.com/openvinotoolkit/nncf/blob/develop/docs/FAQ.md#importing-anything-from-nncftorch-hangs + """ # noqa: E501 + ) + raise ExtensionLoaderTimeoutException(msg) from error + nncf_logger.info(f"Finished loading torch extension: {self._loader.name()}") return getattr(self._loaded_namespace, fn_name) diff --git a/pyproject.toml b/pyproject.toml index f9601087a1c..2bbfe807db4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,8 +107,10 @@ files = [ "nncf/common/stateful_classes_registry.py", "nncf/common/strip.py", "nncf/common/tensor.py", + "nncf/common/collector.py", "nncf/common/accuracy_aware_training", "nncf/common/graph", + "nncf/common/initialization", "nncf/common/sparsity", "nncf/common/tensor_statistics", "nncf/common/utils",