Skip to content

Commit

Permalink
Generalize TORCH_CUDA_ARCH_LIST warning suppression for NNCF extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
devesh-2002 committed Dec 25, 2024
1 parent 289bd27 commit 7c60117
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
11 changes: 0 additions & 11 deletions examples/post_training_quantization/torch/mobilenet_v2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import subprocess
from functools import partial
Expand All @@ -35,14 +34,6 @@
DATASET_CLASSES = 10


def set_cuda_arch_list():
if "TORCH_CUDA_ARCH_LIST" not in os.environ:
arch_list = torch.cuda.get_arch_list()
formatted_arch_list = [str(int(arch.split("_")[1]) / 10.0) for arch in arch_list]
arch_string = ";".join(formatted_arch_list)
os.environ["TORCH_CUDA_ARCH_LIST"] = arch_string


def download_dataset() -> Path:
downloader = FastDownload(base=DATASET_PATH.resolve(), archive="downloaded", data="extracted")
return downloader.get(DATASET_URL)
Expand Down Expand Up @@ -103,8 +94,6 @@ def get_model_size(ir_path: Path, m_type: str = "Mb") -> float:

###############################################################################
# Create a PyTorch model and dataset
set_cuda_arch_list()

dataset_path = download_dataset()

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Expand Down
19 changes: 16 additions & 3 deletions nncf/torch/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import enum
import os
import textwrap
import warnings
from abc import ABC
from abc import abstractmethod
from multiprocessing.context import TimeoutError as MPTimeoutError
Expand Down Expand Up @@ -47,6 +48,14 @@ def get_build_directory_for_extension(name: str) -> Path:
return build_dir


def set_cuda_arch_list():
if "TORCH_CUDA_ARCH_LIST" not in os.environ:
arch_list = torch.cuda.get_arch_list()
formatted_arch_list = [str(int(arch.split("_")[1]) / 10.0) for arch in arch_list]
arch_string = ";".join(formatted_arch_list)
os.environ["TORCH_CUDA_ARCH_LIST"] = arch_string


class ExtensionLoader(ABC):
@classmethod
@abstractmethod
Expand Down Expand Up @@ -98,9 +107,13 @@ def get(self, fn_name: str) -> Callable:

with extension_is_loading_info_log(self._loader.name()):
try:
pool = ThreadPool(processes=1)
async_result = pool.apply_async(self._loader.load)
self._loaded_namespace = async_result.get(timeout=timeout)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="TORCH_CUDA_ARCH_LIST is not set")
set_cuda_arch_list()

pool = ThreadPool(processes=1)
async_result = pool.apply_async(self._loader.load)
self._loaded_namespace = async_result.get(timeout=timeout)
except MPTimeoutError as error:
msg = textwrap.dedent(
f"""\
Expand Down

0 comments on commit 7c60117

Please sign in to comment.