diff --git a/src/pytorch_lightning/accelerators/__init__.py b/src/pytorch_lightning/accelerators/__init__.py index b4521d931c734..1bba4a42879bc 100644 --- a/src/pytorch_lightning/accelerators/__init__.py +++ b/src/pytorch_lightning/accelerators/__init__.py @@ -13,6 +13,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401 from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.cuda import CUDAAccelerator # noqa: F401 +from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.mps import MPSAccelerator # noqa: F401 diff --git a/src/pytorch_lightning/accelerators/gpu.py b/src/pytorch_lightning/accelerators/gpu.py new file mode 100644 index 0000000000000..a7d054b946393 --- /dev/null +++ b/src/pytorch_lightning/accelerators/gpu.py @@ -0,0 +1,31 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.accelerators.cuda import CUDAAccelerator +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation + + +class GPUAccelerator(CUDAAccelerator): + """Accelerator for NVIDIA GPU devices. + + .. deprecated:: 1.9 + + Please use the ``CUDAAccelerator`` instead. + """ + + def __init__(self) -> None: + rank_zero_deprecation( + "The `GPUAccelerator` has been renamed to `CUDAAccelerator` and will be removed in v1.9." + " Please use the `CUDAAccelerator` instead!" + ) + super().__init__() diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py index 66bbf80d4e3ea..9c7d02d499ab4 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py @@ -18,6 +18,7 @@ import pytorch_lightning.loggers.base as logger_base from pytorch_lightning import Trainer +from pytorch_lightning.accelerators.gpu import GPUAccelerator from pytorch_lightning.core.module import LightningModule from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.profiler.advanced import AdvancedProfiler @@ -195,3 +196,13 @@ def test_pytorch_profiler_schedule_wrapper_deprecation_warning(): def test_pytorch_profiler_register_record_function_deprecation_warning(): with pytest.deprecated_call(match="RegisterRecordFunction` is deprecated in v1.7 and will be removed in in v1.9."): _ = RegisterRecordFunction(None) + + +def test_gpu_accelerator_deprecation_warning(): + with pytest.deprecated_call( + match=( + "The `GPUAccelerator` has been renamed to `CUDAAccelerator` and will be removed in v1.9." + + " Please use the `CUDAAccelerator` instead!" + ) + ): + GPUAccelerator()