Skip to content

Commit

Permalink
Merge pull request #79 from bkmgit/importError
Browse files Browse the repository at this point in the history
Use importlib to check for installed packages
  • Loading branch information
rogerwwww authored Nov 13, 2023
2 parents dbfdad9 + 00a3c8f commit 546bfd7
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions pygmtools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .multi_graph_solvers import cao, mgm_floyd, gamgm
from .neural_solvers import pca_gm, ipca_gm, cie, ngm, genn_astar
import pygmtools.utils as utils
import importlib.util
BACKEND = 'numpy'
__version__ = '0.4.2a2'
__author__ = 'ThinkLab at SJTU'
Expand All @@ -38,30 +39,34 @@ def env_report():
from pygmtools import __version__
print("pygmtools", __version__)

try:
found_torch = importlib.util.find_spec("torch")
if found_torch is not None:
import torch
print("Torch", torch.__version__)
except ImportError:
else:
print("Torch not installed")

try:
found_paddle = importlib.util.find_spec("paddle")
if found_paddle is not None:
import paddle
print("Paddle", paddle.__version__)
except ImportError:
else:
print("Paddle not installed")

try:
found_jittor = importlib.util.find_spec("jittor")
if found_jittor is not None:
import jittor
print("Jittor", jittor.__version__)
except ImportError:
else:
print("Jittor not installed")

try:
found_pynvml = importlib.util.find_spec("pynvml")
if found_pynvml is not None:
import pynvml
pynvml.nvmlInit()
print("NVIDIA Driver Version:", pynvml.nvmlSystemGetDriverVersion())
for i in range(pynvml.nvmlDeviceGetCount()):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
print("GPU", i, ":", pynvml.nvmlDeviceGetName(handle))
except ImportError:
else:
print('No GPU found. If you are using GPU, make sure to install pynvml: pip install pynvml')

0 comments on commit 546bfd7

Please sign in to comment.