Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS Accelerator #13123

Merged
merged 101 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from 88 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
b1aea3d
update accelerator and device parsing
justusschock May 21, 2022
17523f3
update runif checks
justusschock May 21, 2022
1ed3895
update runif usages
justusschock May 21, 2022
020e036
more occurences of min_gpus
justusschock May 21, 2022
13da4d0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2022
aaa5967
fix mps accelerator
justusschock May 21, 2022
06ff778
fix imports and device parser
justusschock May 21, 2022
7a2dd98
trainer integration
justusschock May 21, 2022
f27a7cf
docs
justusschock May 21, 2022
4460caf
fix runif
justusschock May 21, 2022
4787724
update mps tests
justusschock May 21, 2022
46eaa7c
update accelerator connector to reflect mps changes
justusschock May 21, 2022
1fe3a94
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2022
2a08c61
add import
justusschock May 22, 2022
1458acc
fix accelerator registry test
justusschock May 22, 2022
b34c447
fix some gpu tests
justusschock May 22, 2022
3f17012
fix gpu intput normalization
justusschock May 22, 2022
b89efac
adjust runif for mps to also allow requiring that it is not available
justusschock May 23, 2022
1fa485e
separate auto choice tests for mps and gpu
justusschock May 23, 2022
934b46a
pep8
justusschock May 23, 2022
70003d6
remove unnecessary block in docs
justusschock May 23, 2022
b5ba668
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2022
6b00f0f
update docs
justusschock May 23, 2022
e406b97
docs
Borda May 23, 2022
c41d059
mypy
justusschock May 23, 2022
5a95095
change wording from mps to apple silicon
justusschock May 23, 2022
56d5181
update mps tests
justusschock May 23, 2022
6f263d6
update device parser
justusschock May 23, 2022
0f832e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2022
ccfac24
Update pytorch_lightning/accelerators/mps.py
justusschock May 23, 2022
064f3af
Update pytorch_lightning/utilities/device_parser.py
justusschock May 23, 2022
b93d1e4
update callback tests
justusschock May 24, 2022
a1c71e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2022
ea340c9
update core
justusschock May 24, 2022
b88ec78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2022
2e995b5
upate lite tests
justusschock May 24, 2022
e643dee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2022
4ffc08e
update loops
justusschock May 24, 2022
deeb60b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2022
936fc0f
models
justusschock May 24, 2022
477377d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2022
98ada0b
strategies
justusschock May 24, 2022
c48351c
trainer
justusschock May 24, 2022
394799e
utilities
justusschock May 24, 2022
16eb921
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2022
3a93dc2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2022
f7e2005
fix pre-commit
justusschock May 24, 2022
6516826
fix typo
justusschock May 24, 2022
11d7a2a
fix accelerator devices to always contain index
justusschock May 24, 2022
6f89c3e
add mps accelerator type for lite
justusschock May 24, 2022
bc6da48
add accelerator type mps
justusschock May 24, 2022
3df39d2
fix callback tests
justusschock May 24, 2022
12a22c4
fix core tests
justusschock May 24, 2022
c594c83
fix lite tests
justusschock May 24, 2022
f705cc9
fix loop tests
justusschock May 24, 2022
0d9f758
fix model tests
justusschock May 24, 2022
18c9e23
fix strategy tests
justusschock May 24, 2022
0dac2d8
fix logging tests
justusschock May 24, 2022
62190c3
fix trainer tests
justusschock May 24, 2022
0afeb46
fix utility tests
justusschock May 24, 2022
638d137
Merge branch 'master' into mps_accelerator
justusschock May 24, 2022
5e6c7ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2022
99ee709
make all mps device tests to accept string and lazily construct devic…
justusschock May 24, 2022
b37056e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2022
dc151c5
stupid bug
justusschock May 24, 2022
62936c3
handle different torch versions
justusschock May 24, 2022
49ab319
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2022
c76fd40
time to get some sleep
justusschock May 24, 2022
b791e47
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2022
d64cbf9
remove foreach
justusschock May 24, 2022
ee0000b
make cpu tests pass (hopefully)
justusschock May 24, 2022
8e16631
remove leftover breakpoint
justusschock May 24, 2022
0028737
remove lefvtover from debugging
justusschock May 25, 2022
31a60b3
move _MPS_AVAILABLE to accelerator file
justusschock May 25, 2022
3e56f1f
update printing logic trainer
justusschock May 25, 2022
e8e42ad
add warning
justusschock May 25, 2022
3f77e57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 25, 2022
85031f8
pep8
justusschock May 25, 2022
ef6fa81
Add changelog entry
justusschock May 25, 2022
2ae9b8c
Merge branch 'master' into mps_accelerator
justusschock May 25, 2022
eec7ddb
Update pytorch_lightning/trainer/trainer.py
justusschock May 25, 2022
d21ed47
Update tests/helpers/runif.py
justusschock May 25, 2022
578e4eb
adress comments from @carmocca
justusschock May 25, 2022
6f75a8b
mypy
justusschock May 25, 2022
45aa2e0
Update pytorch_lightning/utilities/device_parser.py
justusschock May 25, 2022
717a53c
Update tests/plugins/test_amp_plugins.py
justusschock May 25, 2022
cf6e7b0
add link
justusschock May 25, 2022
e44265f
Update pytorch_lightning/accelerators/mps.py
justusschock May 25, 2022
e74af6f
Update pytorch_lightning/accelerators/mps.py
justusschock May 31, 2022
3b252d7
Update tests/accelerators/test_mps.py
justusschock May 31, 2022
beaf933
comments from @awaelchi
justusschock Jun 7, 2022
5c364dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2022
7af2f87
fix wrong indent
justusschock Jun 7, 2022
858708d
resolve merge conflicts
justusschock Jun 23, 2022
25d148b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 23, 2022
d267d9c
fix imports
justusschock Jun 23, 2022
7e9c9f4
fix test imports
justusschock Jun 23, 2022
e072d92
fix imports
justusschock Jun 23, 2022
5b5a58e
fix pep8
justusschock Jun 23, 2022
c783a6a
Apply suggestions from code review
justusschock Jun 24, 2022
92e00b2
Update docs/source-pytorch/accelerators/mps_basic.rst
justusschock Jun 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `teardown()` method to `Accelerator` ([#11935](https://github.com/PyTorchLightning/pytorch-lightning/pull/11935))


- Added Apple Silicon Support via `MPSAccelerator` ([#13123](https://github.com/PyTorchLightning/pytorch-lightning/pull/13123))


-


### Changed

- Enable validation during overfitting ([#12527](https://github.com/PyTorchLightning/pytorch-lightning/pull/12527))
Expand Down
32 changes: 32 additions & 0 deletions docs/source/accelerators/mps.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
.. _mps:

Accelerator: Apple Silicon training
===================================

.. raw:: html

<div class="display-card-container">
<div class="row">

.. Add callout items below this line

.. displayitem::
:header: Prepare your code (Optional)
:description: Prepare your code to run on any hardware
:col_css: col-md-4
:button_link: accelerator_prepare.html
:height: 150
:tag: basic

.. displayitem::
:header: Basic
:description: Learn the basics of Apple silicon gpu training.
:col_css: col-md-4
:button_link: mps_basic.html
:height: 150
:tag: basic

.. raw:: html

</div>
</div>
48 changes: 48 additions & 0 deletions docs/source/accelerators/mps_basic.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
:orphan:
justusschock marked this conversation as resolved.
Show resolved Hide resolved

.. _mps_basic:
justusschock marked this conversation as resolved.
Show resolved Hide resolved

MPS training (basic)
justusschock marked this conversation as resolved.
Show resolved Hide resolved
====================
**Audience:** Users looking to train on their Apple silicon GPUs.

.. warning::

The MPS accelerator as well as the PyTorch backend are still very experimental.
So far not all operations are supported, but more ops are coming every day due to development from the PyTorch Team.
You can use ``PYTORCH_ENABLE_MPS_FALLBACK=1 python your_script.py`` to fall back to cpu for unsupported operations.


----

What is Apple silicon?
----------------------
Apple silicon chips are a unified system on a chip (SoC) developed by Apple based on the ARM design.
Among other things, they feature a CPU-cores, GPU-cores, a neural engine and shared memory between all of those.

----

So it's a CPU?
--------------
Among other things yes, it includes CPU-cores. However, when running on the ``CPUAccelerator``, not the full potential of hardware acceleration the M-Socs are capable of, is used because they also feature a GPU and a neural engine.

To use them, Lightning supports the ``MPSAccelerator``.

----

Run on Apple silicon gpus
-------------------------
Enable the following Trainer arguments to run on Apple silicon gpus (MPS devices).

.. code::

trainer = Trainer(accelerator="mps", devices=1)

.. note::
The ``MPSAccelerator`` only supports 1 device at a time. Currently there are no machines with multiple MPS-capable GPUs.

----

What does MPS stand for?
------------------------
MPS is short for `Metal Performance Shaders <https://developer.apple.com/metal/>`_ which is the technology used in the back for gpu communication and computing.
4 changes: 3 additions & 1 deletion docs/source/extensions/accelerator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
Accelerator
###########

The Accelerator connects a Lightning Trainer to arbitrary hardware (CPUs, GPUs, TPUs, IPUs, ...).
The Accelerator connects a Lightning Trainer to arbitrary hardware (CPUs, GPUs, TPUs, IPUs, MPS, ...).
Currently there are accelerators for:

- CPU
- :doc:`GPU <../accelerators/gpu>`
- :doc:`TPU <../accelerators/tpu>`
- :doc:`IPU <../accelerators/ipu>`
- :doc:`HPU <../accelerators/hpu>`
- :doc:`MPS <../accelerators/mps>`

The Accelerator is part of the Strategy which manages communication across multiple devices (distributed communication).
Whenever the Trainer, the loops or any other component in Lightning needs to talk to hardware, it calls into the Strategy and the Strategy calls into the Accelerator.
Expand Down Expand Up @@ -127,4 +128,5 @@ Accelerator API
GPUAccelerator
HPUAccelerator
IPUAccelerator
MPSAccelerator
TPUAccelerator
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ Current Lightning Users
Train on single or multiple HPUs <accelerators/hpu>
Train on single or multiple IPUs <accelerators/ipu>
Train on single or multiple TPUs <accelerators/tpu>
Train on MPS <accelerators/mps>
Use a pretrained model <advanced/pretrained>
model/own_your_loop

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.mps import MPSAccelerator # noqa: F401
from pytorch_lightning.accelerators.registry import AcceleratorRegistry, call_register_accelerators # noqa: F401
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
"""Accelerator device parsing logic."""
return device_parser.parse_gpu_ids(devices)
return device_parser.parse_gpu_ids(devices, include_cuda=True)

@staticmethod
def get_parallel_devices(devices: List[int]) -> List[torch.device]:
Expand Down
98 changes: 98 additions & 0 deletions pytorch_lightning/accelerators/mps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Union

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE, _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.types import _DEVICE

if _TORCH_GREATER_EQUAL_1_12:
_MPS_AVAILABLE = torch.backends.mps.is_available()
else:
_MPS_AVAILABLE = False
justusschock marked this conversation as resolved.
Show resolved Hide resolved


class MPSAccelerator(Accelerator):
"""Accelerator for Metal Apple Silicon GPU devices."""

def setup_environment(self, root_device: torch.device) -> None:
"""
Raises:
MisconfigurationException:
If the selected device is not MPS.
"""
super().setup_environment(root_device)
if root_device.type != "mps":
raise MisconfigurationException(f"Device should be MPS, got {root_device} instead.")

def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
"""Get M1 (cpu + gpu) stats from ``psutil`` package."""
return get_device_stats()

@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
"""Accelerator device parsing logic."""
parsed_devices = device_parser.parse_gpu_ids(devices, include_mps=True)
return parsed_devices

@staticmethod
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
justusschock marked this conversation as resolved.
Show resolved Hide resolved
"""Gets parallel devices for the Accelerator."""
parsed_devices = MPSAccelerator.parse_devices(devices)
assert parsed_devices is not None

return [torch.device("mps", i) for i in range(len(parsed_devices))]

@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return 1

@staticmethod
def is_available() -> bool:
"""MPS is only available for certain torch builds starting at torch>=1.12."""
return _MPS_AVAILABLE
carmocca marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"mps",
cls,
description=cls.__class__.__name__,
)


# device metrics
_VM_PERCENT = "M1_vm_percent"
_PERCENT = "M1_percent"
_SWAP_PERCENT = "M1_swap_percent"


def get_device_stats() -> Dict[str, float]:
if not _PSUTIL_AVAILABLE:
raise ModuleNotFoundError(
"Fetching M1 device stats requires `psutil` to be installed."
" Install it by running `pip install -U psutil`."
)
import psutil

return {
_VM_PERCENT: psutil.virtual_memory().percent,
_PERCENT: psutil.cpu_percent(),
_SWAP_PERCENT: psutil.swap_memory().percent,
}
1 change: 1 addition & 0 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ def _supported_device_types() -> Sequence[_AcceleratorType]:
_AcceleratorType.CPU,
_AcceleratorType.GPU,
_AcceleratorType.TPU,
_AcceleratorType.MPS,
)

@staticmethod
Expand Down
19 changes: 12 additions & 7 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.accelerators.hpu import HPUAccelerator
from pytorch_lightning.accelerators.ipu import IPUAccelerator
from pytorch_lightning.accelerators.mps import MPSAccelerator
from pytorch_lightning.accelerators.registry import AcceleratorRegistry
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.plugins import (
Expand Down Expand Up @@ -173,7 +174,7 @@ def __init__(
self._precision_flag: Optional[Union[int, str]] = None
self._precision_plugin_flag: Optional[PrecisionPlugin] = None
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
self._parallel_devices: List[Union[int, torch.device]] = []
self._parallel_devices: List[Union[int, torch.device, str]] = []
self._layer_sync: Optional[LayerSync] = NativeSyncBatchNorm() if sync_batchnorm else None
self.checkpoint_io: Optional[CheckpointIO] = None
self._amp_type_flag: Optional[LightningEnum] = None
Expand Down Expand Up @@ -402,7 +403,7 @@ def _check_device_config_and_set_final_flags(
if self._devices_flag == "auto" and self._accelerator_flag is None:
raise MisconfigurationException(
f"You passed `devices={devices}` but haven't specified"
" `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu'|'hpu)` for the devices mapping."
" `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu'|'hpu'|'mps')` for the devices mapping."
)

def _map_deprecated_devices_specific_info_to_accelerator_and_device_flag(
Expand Down Expand Up @@ -479,6 +480,8 @@ def _choose_accelerator(self) -> str:
return "ipu"
if _HPU_AVAILABLE:
return "hpu"
if MPSAccelerator.is_available():
justusschock marked this conversation as resolved.
Show resolved Hide resolved
return "mps"
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
return "gpu"
return "cpu"
Expand Down Expand Up @@ -566,11 +569,13 @@ def _choose_strategy(self) -> Union[Strategy, str]:
if self._num_nodes_flag > 1:
return DDPStrategy.strategy_name
if len(self._parallel_devices) <= 1:
device = (
device_parser.determine_root_gpu_device(self._parallel_devices) # type: ignore
if self._accelerator_flag == "gpu"
else "cpu"
)
# TODO: Change this once gpu accelerator was renamed to cuda accelerator
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(self._accelerator_flag, (GPUAccelerator, MPSAccelerator)) or (
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("gpu", "mps")
):
device = device_parser.determine_root_gpu_device(self._parallel_devices)
else:
device = "cpu"
# TODO: lazy initialized device, then here could be self._strategy_flag = "single_device"
return SingleDeviceStrategy(device=device) # type: ignore
if len(self._parallel_devices) > 1:
Expand Down
34 changes: 29 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.accelerators import Accelerator, GPUAccelerator, HPUAccelerator, IPUAccelerator, TPUAccelerator
from pytorch_lightning.accelerators import (
Accelerator,
GPUAccelerator,
HPUAccelerator,
IPUAccelerator,
MPSAccelerator,
TPUAccelerator,
)
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.datamodule import LightningDataModule
Expand Down Expand Up @@ -189,7 +196,7 @@ def __init__(

Args:

accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "auto")
accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps, "auto")
as well as custom accelerator instances.

.. deprecated:: v1.5
Expand Down Expand Up @@ -1745,9 +1752,19 @@ def __setup_profiler(self) -> None:
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)

def _log_device_info(self) -> None:
rank_zero_info(
f"GPU available: {torch.cuda.is_available()}, used: {isinstance(self.accelerator, GPUAccelerator)}"
)

if GPUAccelerator.is_available():
gpu_available = True
gpu_type = " (cuda)"
elif MPSAccelerator.is_available():
gpu_available = True
gpu_type = " (mps)"
else:
gpu_available = False
gpu_type = ""

gpu_used = isinstance(self.accelerator, (GPUAccelerator, MPSAccelerator))
rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}")
akihironitta marked this conversation as resolved.
Show resolved Hide resolved

num_tpu_cores = self.num_devices if isinstance(self.accelerator, TPUAccelerator) else 0
rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores")
Expand All @@ -1758,6 +1775,7 @@ def _log_device_info(self) -> None:
num_hpus = self.num_devices if isinstance(self.accelerator, HPUAccelerator) else 0
rank_zero_info(f"HPU available: {_HPU_AVAILABLE}, using: {num_hpus} HPUs")

# TODO: Integrate MPS Accelerator here, once gpu maps to both
if torch.cuda.is_available() and not isinstance(self.accelerator, GPUAccelerator):
rank_zero_warn(
"GPU available but not used. Set `accelerator` and `devices` using"
Expand All @@ -1783,6 +1801,12 @@ def _log_device_info(self) -> None:
f" `Trainer(accelerator='hpu', devices={HPUAccelerator.auto_device_count()})`."
)

if MPSAccelerator.is_available() and not isinstance(self.accelerator, MPSAccelerator):
rank_zero_warn(
"MPS available but not used. Set `accelerator` and `devices` using"
f" `Trainer(accelerator='mps', devices={MPSAccelerator.auto_device_count()})`."
)

"""
Data loading methods
"""
Expand Down
Loading