Skip to content

Commit

Permalink
fixed tests failing
Browse files Browse the repository at this point in the history
  • Loading branch information
parthraut committed Dec 12, 2024
1 parent 636c45d commit 3ab379a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions zeus/utils/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def all_reduce(
If running in a distributed setting, the objects are reduced across all replicas.
If running in a non-distributed setting, the operation is just done on the single object.
"""
if torch_is_available():
if torch_is_available(ensure_cuda=False):
torch = MODULE_CACHE["torch"]

# wrap object in a tensor if it is not already
Expand Down Expand Up @@ -146,7 +146,7 @@ def all_reduce(

def is_distributed() -> bool:
"""Check if the current execution is distributed across multiple devices."""
if torch_is_available():
if torch_is_available(ensure_cuda=False):
torch = MODULE_CACHE["torch"]
return torch.distributed.is_available() and torch.distributed.is_initialized()
if jax_is_available():
Expand Down

0 comments on commit 3ab379a

Please sign in to comment.