Skip to content

Commit

Permalink
Fix logging message
Browse files Browse the repository at this point in the history
  • Loading branch information
jaywonchung committed Sep 9, 2024
1 parent 697d327 commit eb00bdb
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions zeus/utils/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ def torch_is_available(ensure_available: bool = False, ensure_cuda: bool = True)
try:
import torch

if ensure_cuda and not torch.cuda.is_available():
cuda_available = torch.cuda.is_available()
if ensure_cuda and not cuda_available:
raise RuntimeError("PyTorch is available but does not have CUDA support.")
MODULE_CACHE["torch"] = torch
logger.info("PyTorch with CUDA support is available.")
logger.info(
"PyTorch %s CUDA support is available.",
"with" if cuda_available else "without",
)
return True
except ImportError as e:
logger.info("PyTorch is not available.")
Expand All @@ -36,10 +40,13 @@ def jax_is_available(ensure_available: bool = False, ensure_cuda: bool = True):
try:
import jax # type: ignore

if ensure_cuda and not jax.devices("gpu"):
cuda_available = jax.devices("gpu")
if ensure_cuda and not cuda_available:
raise RuntimeError("JAX is available but does not have CUDA support.")
MODULE_CACHE["jax"] = jax
logger.info("JAX with CUDA support is available.")
logger.info(
"JAX %s CUDA support is available.", "with" if cuda_available else "without"
)
return True
except ImportError as e:
logger.info("JAX is not available")
Expand Down

0 comments on commit eb00bdb

Please sign in to comment.