diff --git a/src/pytorch_lightning/accelerators/mps.py b/src/pytorch_lightning/accelerators/mps.py index 3a7178f0623c2..20a2e609fa54b 100644 --- a/src/pytorch_lightning/accelerators/mps.py +++ b/src/pytorch_lightning/accelerators/mps.py @@ -24,7 +24,7 @@ # For using the `MPSAccelerator`, user's machine should have `torch>=1.12`, Metal programming framework and # the ARM-based Apple Silicon processors. -_MPS_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.platform() == "arm" +_MPS_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() == "arm" class MPSAccelerator(Accelerator):