From 3ab379a8f68def4510679a9af637d8a7caeaedea Mon Sep 17 00:00:00 2001 From: Parth Raut Date: Wed, 11 Dec 2024 23:59:02 -0500 Subject: [PATCH] fixed tests failing --- zeus/utils/framework.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zeus/utils/framework.py b/zeus/utils/framework.py index 0b7e605e..3459eda2 100644 --- a/zeus/utils/framework.py +++ b/zeus/utils/framework.py @@ -112,7 +112,7 @@ def all_reduce( If running in a distributed setting, the objects are reduced across all replicas. If running in a non-distributed setting, the operation is just done on the single object. """ - if torch_is_available(): + if torch_is_available(ensure_cuda=False): torch = MODULE_CACHE["torch"] # wrap object in a tensor if it is not already @@ -146,7 +146,7 @@ def all_reduce( def is_distributed() -> bool: """Check if the current execution is distributed across multiple devices.""" - if torch_is_available(): + if torch_is_available(ensure_cuda=False): torch = MODULE_CACHE["torch"] return torch.distributed.is_available() and torch.distributed.is_initialized() if jax_is_available():