From 00a3c8fb38f04e875454f0ec32e46d2434a497a1 Mon Sep 17 00:00:00 2001 From: Benson Muite Date: Mon, 11 Sep 2023 16:10:46 +0300 Subject: [PATCH] Use importlib to check for installed packages --- pygmtools/__init__.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/pygmtools/__init__.py b/pygmtools/__init__.py index 0453380..6b1e025 100644 --- a/pygmtools/__init__.py +++ b/pygmtools/__init__.py @@ -14,6 +14,7 @@ from .multi_graph_solvers import cao, mgm_floyd, gamgm from .neural_solvers import pca_gm, ipca_gm, cie, ngm, genn_astar import pygmtools.utils as utils +import importlib.util BACKEND = 'numpy' __version__ = '0.4.0' __author__ = 'ThinkLab at SJTU' @@ -38,30 +39,34 @@ def env_report(): from pygmtools import __version__ print("pygmtools", __version__) - try: + found_torch = importlib.util.find_spec("torch") + if found_torch is not None: import torch print("Torch", torch.__version__) - except ImportError: + else: print("Torch not installed") - try: + found_paddle = importlib.util.find_spec("paddle") + if found_paddle is not None: import paddle print("Paddle", paddle.__version__) - except ImportError: + else: print("Paddle not installed") - try: + found_jittor = importlib.util.find_spec("jittor") + if found_jittor is not None: import jittor print("Jittor", jittor.__version__) - except ImportError: + else: print("Jittor not installed") - try: + found_pynvml = importlib.util.find_spec("pynvml") + if found_pynvml is not None: import pynvml pynvml.nvmlInit() print("NVIDIA Driver Version:", pynvml.nvmlSystemGetDriverVersion()) for i in range(pynvml.nvmlDeviceGetCount()): handle = pynvml.nvmlDeviceGetHandleByIndex(i) print("GPU", i, ":", pynvml.nvmlDeviceGetName(handle)) - except ImportError: + else: print('No GPU found. If you are using GPU, make sure to install pynvml: pip install pynvml')