Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Dec 30, 2024
1 parent f0023c3 commit 81d1199
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 24 deletions.
3 changes: 2 additions & 1 deletion nncf/common/initialization/batchnorm_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import math
from abc import ABC
from abc import abstractmethod
from typing import Optional
from typing import Optional, Type

from nncf.api.compression import TModel
from nncf.common.initialization.dataloader import NNCFDataLoader
Expand Down Expand Up @@ -83,6 +83,7 @@ def run(self, model: TModel) -> None:
:param model: A model for which the algorithm will be applied.
"""
backend = get_backend(model)
impl_cls: Type[BatchnormAdaptationAlgorithmImpl]
if backend is BackendType.TORCH:
from nncf.torch.batchnorm_adaptation import PTBatchnormAdaptationAlgorithmImpl

Expand Down
3 changes: 2 additions & 1 deletion nncf/common/initialization/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""Interface for user-defined data usage during the compression algorithm initialization process."""
from abc import ABC
from abc import abstractmethod
from typing import Any

from nncf.common.utils.api_marker import api

Expand All @@ -31,7 +32,7 @@ def batch_size(self) -> int:
"""

@abstractmethod
def __iter__(self):
def __iter__(self) -> Any:
"""
Creates an iterator for the elements of a custom data source.
The returned iterator implements the Python Iterator protocol.
Expand Down
43 changes: 21 additions & 22 deletions nncf/torch/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import nncf
from nncf.common.logging import nncf_logger
from nncf.common.logging.logger import extension_is_loading_info_log
from nncf.common.utils.api_marker import api
from nncf.common.utils.registry import Registry

Expand Down Expand Up @@ -95,27 +94,27 @@ def get(self, fn_name: str) -> Callable:
if self._loaded_namespace is None:
timeout = int(os.environ.get(EXTENSION_LOAD_TIMEOUT_ENV_VAR, DEFAULT_EXTENSION_LOAD_TIMEOUT))
timeout = timeout if timeout > 0 else None

with extension_is_loading_info_log(self._loader.name()):
try:
pool = ThreadPool(processes=1)
async_result = pool.apply_async(self._loader.load)
self._loaded_namespace = async_result.get(timeout=timeout)
except MPTimeoutError as error:
msg = textwrap.dedent(
f"""\
The extension load function failed to execute within {timeout} seconds.
This may be due to leftover lock files from the PyTorch C++ extension build process.
If this is the case, running the following command should help:
rm -rf {self._loader.get_build_dir()}
For a machine with poor performance, you may try increasing the time limit by setting the environment variable:
{EXTENSION_LOAD_TIMEOUT_ENV_VAR}=180
Or disable timeout by set:
{EXTENSION_LOAD_TIMEOUT_ENV_VAR}=0
For more information, see FAQ entry at: https://github.com/openvinotoolkit/nncf/blob/develop/docs/FAQ.md#importing-anything-from-nncftorch-hangs
""" # noqa: E501
)
raise ExtensionLoaderTimeoutException(msg) from error
nncf_logger.info(f"Compiling and loading torch extension: {self._loader.name()}...")
try:
pool = ThreadPool(processes=1)
async_result = pool.apply_async(self._loader.load)
self._loaded_namespace = async_result.get(timeout=timeout)
except MPTimeoutError as error:
msg = textwrap.dedent(
f"""\
The extension load function failed to execute within {timeout} seconds.
This may be due to leftover lock files from the PyTorch C++ extension build process.
If this is the case, running the following command should help:
rm -rf {self._loader.get_build_dir()}
For a machine with poor performance, you may try increasing the time limit by setting the environment variable:
{EXTENSION_LOAD_TIMEOUT_ENV_VAR}=180
Or disable timeout by set:
{EXTENSION_LOAD_TIMEOUT_ENV_VAR}=0
For more information, see FAQ entry at: https://github.com/openvinotoolkit/nncf/blob/develop/docs/FAQ.md#importing-anything-from-nncftorch-hangs
""" # noqa: E501
)
raise ExtensionLoaderTimeoutException(msg) from error
nncf_logger.info(f"Finished loading torch extension: {self._loader.name()}")

return getattr(self._loaded_namespace, fn_name)

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,10 @@ files = [
"nncf/common/stateful_classes_registry.py",
"nncf/common/strip.py",
"nncf/common/tensor.py",
"nncf/common/collector.py",
"nncf/common/accuracy_aware_training",
"nncf/common/graph",
"nncf/common/initialization",
"nncf/common/sparsity",
"nncf/common/tensor_statistics",
"nncf/common/utils",
Expand Down

0 comments on commit 81d1199

Please sign in to comment.