diff --git a/zeus/utils/framework.py b/zeus/utils/framework.py index 1e5cd749..b321e0e0 100644 --- a/zeus/utils/framework.py +++ b/zeus/utils/framework.py @@ -142,7 +142,7 @@ def all_reduce( if jax.process_count() == 1: return object - raise NotImplementedError("JAX all-reduce not yet implemented") + raise NotImplementedError("JAX distributed all-reduce not yet implemented") raise RuntimeError("No framework is available.") @@ -155,4 +155,3 @@ def is_distributed() -> bool: if jax_is_available(): jax = MODULE_CACHE["jax"] return jax.process_count() > 1 - return False