diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 889a816ba47..a8fa33e3e13 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -309,7 +309,11 @@ def __init__(self, cpu: bool = False, **kwargs): else: self.device = self.default_device else: - self.distributed_type = DistributedType.NO + self.distributed_type = ( + DistributedType.NO + if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "false" + else DistributedType.DEEPSPEED + ) self.num_processes = 1 self.process_index = self.local_process_index = 0