Skip to content

Commit

Permalink
Rename GPUAccelerator to CUDAAccelerator
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Jul 19, 2022
1 parent 1d59b3f commit 7080ef7
Show file tree
Hide file tree
Showing 21 changed files with 90 additions and 84 deletions.
2 changes: 1 addition & 1 deletion docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ accelerators

Accelerator
CPUAccelerator
GPUAccelerator
CUDAAccelerator
HPUAccelerator
IPUAccelerator
TPUAccelerator
Expand Down
8 changes: 4 additions & 4 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ Example::

.. code-block:: python
# This is part of the built-in `GPUAccelerator`
class GPUAccelerator(Accelerator):
# This is part of the built-in `CUDAAccelerator`
class CUDAAccelerator(Accelerator):
"""Accelerator for GPU devices."""
@staticmethod
Expand Down Expand Up @@ -603,8 +603,8 @@ based on the accelerator type (``"cpu", "gpu", "tpu", "ipu", "auto"``).

.. code-block:: python
# This is part of the built-in `GPUAccelerator`
class GPUAccelerator(Accelerator):
# This is part of the built-in `CUDAAccelerator`
class CUDAAccelerator(Accelerator):
"""Accelerator for GPU devices."""
@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/extensions/accelerator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Accelerator API

Accelerator
CPUAccelerator
GPUAccelerator
CUDAAccelerator
HPUAccelerator
IPUAccelerator
MPSAccelerator
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.cuda import CUDAAccelerator # noqa: F401
from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.mps import MPSAccelerator # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
_log = logging.getLogger(__name__)


class GPUAccelerator(Accelerator):
"""Accelerator for GPU devices."""
class CUDAAccelerator(Accelerator):
"""Accelerator for NVIDIA CUDA devices."""

def setup_environment(self, root_device: torch.device) -> None:
"""
Expand Down Expand Up @@ -92,6 +92,12 @@ def is_available() -> bool:

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"cuda",
cls,
description=f"{cls.__class__.__name__}",
)
# temporarily enable "gpu" to point to the CUDA Accelerator
accelerator_registry.register(
"gpu",
cls,
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.utils.data.dataloader import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.accelerators import CUDAAccelerator
from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
Expand Down Expand Up @@ -411,7 +411,7 @@ def _select_data_fetcher_type(trainer: "pl.Trainer") -> Type[AbstractDataFetcher
)
return DataLoaderIterDataFetcher
elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
if not isinstance(trainer.accelerator, GPUAccelerator):
if not isinstance(trainer.accelerator, CUDAAccelerator):
raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
return InterBatchParallelDataFetcher
return DataFetcher
4 changes: 2 additions & 2 deletions src/pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Optional, Type

import pytorch_lightning as pl
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.accelerators import CUDAAccelerator
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE
Expand Down Expand Up @@ -340,7 +340,7 @@ def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
)
return DataLoaderIterDataFetcher
elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
if not isinstance(trainer.accelerator, GPUAccelerator):
if not isinstance(trainer.accelerator, CUDAAccelerator):
raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
return InterBatchParallelDataFetcher
return DataFetcher
4 changes: 2 additions & 2 deletions src/pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
Expand Down Expand Up @@ -452,7 +452,7 @@ def init_deepspeed(self):
if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
raise MisconfigurationException("DeepSpeed does not support clipping gradients by value.")

if not isinstance(self.accelerator, GPUAccelerator):
if not isinstance(self.accelerator, CUDAAccelerator):
raise MisconfigurationException(
f"DeepSpeed strategy is only supported on GPU but `{self.accelerator.__class__.__name__}` is used."
)
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/strategies/hivemind.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def num_peers(self) -> int:
@property
def root_device(self) -> torch.device:
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.accelerators.cuda import CUDAAccelerator

if isinstance(self.accelerator, GPUAccelerator):
if isinstance(self.accelerator, CUDAAccelerator):
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif isinstance(self.accelerator, CPUAccelerator):
return torch.device("cpu")
Expand Down
20 changes: 10 additions & 10 deletions src/pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
from pytorch_lightning.accelerators.hpu import HPUAccelerator
from pytorch_lightning.accelerators.ipu import IPUAccelerator
from pytorch_lightning.accelerators.mps import MPSAccelerator
Expand Down Expand Up @@ -370,12 +370,12 @@ def _check_config_and_set_final_flags(
)
self._accelerator_flag = "cpu"
if self._strategy_flag.parallel_devices[0].type == "cuda":
if self._accelerator_flag and self._accelerator_flag not in ("auto", "gpu"):
if self._accelerator_flag and self._accelerator_flag not in ("auto", "cuda", "gpu"):
raise MisconfigurationException(
f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class,"
f" but accelerator set to {self._accelerator_flag}, please choose one device type"
)
self._accelerator_flag = "gpu"
self._accelerator_flag = "cuda"
self._parallel_devices = self._strategy_flag.parallel_devices

amp_type = amp_type if isinstance(amp_type, str) else None
Expand Down Expand Up @@ -475,7 +475,7 @@ def _map_deprecated_devices_specific_info_to_accelerator_and_device_flag(
if tpu_cores:
self._accelerator_flag = "tpu"
if gpus:
self._accelerator_flag = "gpu"
self._accelerator_flag = "cuda"
if num_processes:
self._accelerator_flag = "cpu"

Expand All @@ -497,7 +497,7 @@ def _choose_accelerator(self) -> str:
if MPSAccelerator.is_available():
return "mps"
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
return "gpu"
return "cuda"
return "cpu"

def _set_parallel_devices_and_init_accelerator(self) -> None:
Expand Down Expand Up @@ -534,7 +534,7 @@ def _set_devices_flag_if_auto_passed(self) -> None:
self._devices_flag = self.accelerator.auto_device_count()

def _set_devices_flag_if_auto_select_gpus_passed(self) -> None:
if self._auto_select_gpus and isinstance(self._gpus, int) and isinstance(self.accelerator, GPUAccelerator):
if self._auto_select_gpus and isinstance(self._gpus, int) and isinstance(self.accelerator, CUDAAccelerator):
self._devices_flag = pick_multiple_gpus(self._gpus)
log.info(f"Auto select gpus: {self._devices_flag}")

Expand Down Expand Up @@ -579,8 +579,8 @@ def _choose_strategy(self) -> Union[Strategy, str]:
return DDPStrategy.strategy_name
if len(self._parallel_devices) <= 1:
# TODO: Change this once gpu accelerator was renamed to cuda accelerator
if isinstance(self._accelerator_flag, (GPUAccelerator, MPSAccelerator)) or (
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("gpu", "mps")
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
):
device = device_parser.determine_root_gpu_device(self._parallel_devices)
else:
Expand Down Expand Up @@ -609,7 +609,7 @@ def _check_strategy_and_fallback(self) -> None:
if (
strategy_flag in DDPFullyShardedNativeStrategy.get_registered_strategies()
or isinstance(self._strategy_flag, DDPFullyShardedNativeStrategy)
) and self._accelerator_flag != "gpu":
) and self._accelerator_flag not in ("cuda", "gpu"):
raise MisconfigurationException(
f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, "
"but GPU accelerator is not used."
Expand All @@ -632,7 +632,7 @@ def _handle_horovod(self) -> None:
)

hvd.init()
if isinstance(self.accelerator, GPUAccelerator):
if isinstance(self.accelerator, CUDAAccelerator):
# Horovod assigns one local GPU per process
self._parallel_devices = [torch.device(f"cuda:{i}") for i in range(hvd.local_size())]
else:
Expand Down
16 changes: 8 additions & 8 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import pytorch_lightning as pl
from pytorch_lightning.accelerators import (
Accelerator,
GPUAccelerator,
CUDAAccelerator,
HPUAccelerator,
IPUAccelerator,
MPSAccelerator,
Expand Down Expand Up @@ -1735,7 +1735,7 @@ def __setup_profiler(self) -> None:

def _log_device_info(self) -> None:

if GPUAccelerator.is_available():
if CUDAAccelerator.is_available():
gpu_available = True
gpu_type = " (cuda)"
elif MPSAccelerator.is_available():
Expand All @@ -1745,7 +1745,7 @@ def _log_device_info(self) -> None:
gpu_available = False
gpu_type = ""

gpu_used = isinstance(self.accelerator, (GPUAccelerator, MPSAccelerator))
gpu_used = isinstance(self.accelerator, (CUDAAccelerator, MPSAccelerator))
rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}")

num_tpu_cores = self.num_devices if isinstance(self.accelerator, TPUAccelerator) else 0
Expand All @@ -1758,10 +1758,10 @@ def _log_device_info(self) -> None:
rank_zero_info(f"HPU available: {_HPU_AVAILABLE}, using: {num_hpus} HPUs")

# TODO: Integrate MPS Accelerator here, once gpu maps to both
if torch.cuda.is_available() and not isinstance(self.accelerator, GPUAccelerator):
if torch.cuda.is_available() and not isinstance(self.accelerator, CUDAAccelerator):
rank_zero_warn(
"GPU available but not used. Set `accelerator` and `devices` using"
f" `Trainer(accelerator='gpu', devices={GPUAccelerator.auto_device_count()})`.",
f" `Trainer(accelerator='gpu', devices={CUDAAccelerator.auto_device_count()})`.",
category=PossibleUserWarning,
)

Expand Down Expand Up @@ -2069,7 +2069,7 @@ def root_gpu(self) -> Optional[int]:
"`Trainer.root_gpu` is deprecated in v1.6 and will be removed in v1.8. "
"Please use `Trainer.strategy.root_device.index` instead."
)
return self.strategy.root_device.index if isinstance(self.accelerator, GPUAccelerator) else None
return self.strategy.root_device.index if isinstance(self.accelerator, CUDAAccelerator) else None

@property
def tpu_cores(self) -> int:
Expand All @@ -2093,7 +2093,7 @@ def num_gpus(self) -> int:
"`Trainer.num_gpus` was deprecated in v1.6 and will be removed in v1.8."
" Please use `Trainer.num_devices` instead."
)
return self.num_devices if isinstance(self.accelerator, GPUAccelerator) else 0
return self.num_devices if isinstance(self.accelerator, CUDAAccelerator) else 0

@property
def devices(self) -> int:
Expand All @@ -2109,7 +2109,7 @@ def data_parallel_device_ids(self) -> Optional[List[int]]:
"`Trainer.data_parallel_device_ids` was deprecated in v1.6 and will be removed in v1.8."
" Please use `Trainer.device_ids` instead."
)
return self.device_ids if isinstance(self.accelerator, GPUAccelerator) else None
return self.device_ids if isinstance(self.accelerator, CUDAAccelerator) else None

@property
def lightning_module(self) -> "pl.LightningModule":
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/utilities/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_gpu_memory_map() -> Dict[str, float]:
r"""
.. deprecated:: v1.5
This function was deprecated in v1.5 in favor of
`pytorch_lightning.accelerators.gpu._get_nvidia_gpu_stats` and will be removed in v1.7.
`pytorch_lightning.accelerators.cuda._get_nvidia_gpu_stats` and will be removed in v1.7.
Get the current gpu usage.
Expand Down
Loading

0 comments on commit 7080ef7

Please sign in to comment.