diff --git a/examples/post_training_quantization/torch/mobilenet_v2/main.py b/examples/post_training_quantization/torch/mobilenet_v2/main.py index ad984143eec..898ac3fb34d 100644 --- a/examples/post_training_quantization/torch/mobilenet_v2/main.py +++ b/examples/post_training_quantization/torch/mobilenet_v2/main.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import re import subprocess from functools import partial @@ -35,14 +34,6 @@ DATASET_CLASSES = 10 -def set_cuda_arch_list(): - if "TORCH_CUDA_ARCH_LIST" not in os.environ: - arch_list = torch.cuda.get_arch_list() - formatted_arch_list = [str(int(arch.split("_")[1]) / 10.0) for arch in arch_list] - arch_string = ";".join(formatted_arch_list) - os.environ["TORCH_CUDA_ARCH_LIST"] = arch_string - - def download_dataset() -> Path: downloader = FastDownload(base=DATASET_PATH.resolve(), archive="downloaded", data="extracted") return downloader.get(DATASET_URL) @@ -103,8 +94,6 @@ def get_model_size(ir_path: Path, m_type: str = "Mb") -> float: ############################################################################### # Create a PyTorch model and dataset -set_cuda_arch_list() - dataset_path = download_dataset() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) diff --git a/nncf/torch/extensions/__init__.py b/nncf/torch/extensions/__init__.py index 59edea1ca24..81cd90af8a3 100644 --- a/nncf/torch/extensions/__init__.py +++ b/nncf/torch/extensions/__init__.py @@ -12,6 +12,7 @@ import enum import os import textwrap +import warnings from abc import ABC from abc import abstractmethod from multiprocessing.context import TimeoutError as MPTimeoutError @@ -47,6 +48,14 @@ def get_build_directory_for_extension(name: str) -> Path: return build_dir +def set_cuda_arch_list(): + if "TORCH_CUDA_ARCH_LIST" not in os.environ: + arch_list = torch.cuda.get_arch_list() + formatted_arch_list = [str(int(arch.split("_")[1]) / 10.0) for arch in arch_list] + arch_string = ";".join(formatted_arch_list) + os.environ["TORCH_CUDA_ARCH_LIST"] = arch_string + + class ExtensionLoader(ABC): @classmethod @abstractmethod @@ -98,9 +107,13 @@ def get(self, fn_name: str) -> Callable: with extension_is_loading_info_log(self._loader.name()): try: - pool = ThreadPool(processes=1) - async_result = pool.apply_async(self._loader.load) - self._loaded_namespace = async_result.get(timeout=timeout) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="TORCH_CUDA_ARCH_LIST is not set") + set_cuda_arch_list() + + pool = ThreadPool(processes=1) + async_result = pool.apply_async(self._loader.load) + self._loaded_namespace = async_result.get(timeout=timeout) except MPTimeoutError as error: msg = textwrap.dedent( f"""\