Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
uniartisan committed Dec 20, 2024
1 parent 933036d commit 8cde331
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 5 deletions.
5 changes: 4 additions & 1 deletion src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def __init__(
self._backward_sync_control = _DDPBackwardSyncControl()
self._ddp_kwargs = kwargs

self.device_type = self.root_device.type
if isinstance(self.accelerator, Accelerator):
self.device_type = self.accelerator.get_device_type()
else:
self.device_type = "cuda"
self.torch_lib = getattr(torch, self.device_type)

@property
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,10 @@ def __init__(

self._deepspeed_engine: Optional[DeepSpeedEngine] = None

self.device_type = self.root_device.type
if isinstance(self.accelerator, Accelerator):
self.device_type = self.accelerator.get_device_type()
else:
self.device_type = "cuda"
self.torch_lib = getattr(torch, self.device_type)

@property
Expand Down
6 changes: 4 additions & 2 deletions src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ def __init__(
self._timeout: Optional[timedelta] = timeout
self._start_method = start_method

self.device_type = self.root_device.type
self.torch_lib = getattr(torch, self.device_type)
try:
self.device_type = self.accelerator.get_device_type()
except Exception:
self.device_type = "cuda"

@property
def is_distributed(self) -> bool: # pragma: no-cover
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,10 @@ def __init__(
self.hysteresis = hysteresis
self.min_loss_scale = min_loss_scale

self.device_type = self.root_device.type
try:
self.device_type = self.accelerator.get_device_type()
except Exception:
self.device_type = "cuda"
self.torch_lib = getattr(torch, self.device_type)

@override
Expand Down

0 comments on commit 8cde331

Please sign in to comment.