Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename GPUAccelerator to CUDAAccelerator #13636

Merged
merged 2 commits into from
Jul 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ accelerators

Accelerator
CPUAccelerator
GPUAccelerator
CUDAAccelerator
HPUAccelerator
IPUAccelerator
TPUAccelerator
justusschock marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
8 changes: 4 additions & 4 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ Example::

.. code-block:: python

# This is part of the built-in `GPUAccelerator`
class GPUAccelerator(Accelerator):
# This is part of the built-in `CUDAAccelerator`
class CUDAAccelerator(Accelerator):
"""Accelerator for GPU devices."""

@staticmethod
Expand Down Expand Up @@ -603,8 +603,8 @@ based on the accelerator type (``"cpu", "gpu", "tpu", "ipu", "auto"``).

.. code-block:: python

# This is part of the built-in `GPUAccelerator`
class GPUAccelerator(Accelerator):
# This is part of the built-in `CUDAAccelerator`
class CUDAAccelerator(Accelerator):
"""Accelerator for GPU devices."""

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/extensions/accelerator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Accelerator API

Accelerator
CPUAccelerator
GPUAccelerator
CUDAAccelerator
HPUAccelerator
IPUAccelerator
MPSAccelerator
Expand Down
1 change: 1 addition & 0 deletions src/pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.
from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.cuda import CUDAAccelerator # noqa: F401
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401
Expand Down
167 changes: 167 additions & 0 deletions src/pytorch_lightning/accelerators/cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import subprocess
from typing import Any, Dict, List, Optional, Union

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _DEVICE

_log = logging.getLogger(__name__)


class CUDAAccelerator(Accelerator):
"""Accelerator for NVIDIA CUDA devices."""

def setup_environment(self, root_device: torch.device) -> None:
"""
Raises:
MisconfigurationException:
If the selected device is not GPU.
"""
super().setup_environment(root_device)
if root_device.type != "cuda":
raise MisconfigurationException(f"Device should be GPU, got {root_device} instead")
torch.cuda.set_device(root_device)

def setup(self, trainer: "pl.Trainer") -> None:
# TODO refactor input from trainer to local_rank @four4fish
self.set_nvidia_flags(trainer.local_rank)
# clear cache before training
torch.cuda.empty_cache()

@staticmethod
def set_nvidia_flags(local_rank: int) -> None:
# set the correct cuda visible devices (using pci order)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
all_gpu_ids = ",".join(str(x) for x in range(torch.cuda.device_count()))
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")

def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
"""Gets stats for the given GPU device.

Args:
device: GPU device for which to get stats

Returns:
A dictionary mapping the metrics to their values.

Raises:
FileNotFoundError:
If nvidia-smi installation not found
"""
return torch.cuda.memory_stats(device)

@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
"""Accelerator device parsing logic."""
return device_parser.parse_gpu_ids(devices, include_cuda=True)

@staticmethod
def get_parallel_devices(devices: List[int]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
return [torch.device("cuda", i) for i in devices]

@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return torch.cuda.device_count()

@staticmethod
def is_available() -> bool:
return torch.cuda.device_count() > 0

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"cuda",
cls,
description=f"{cls.__class__.__name__}",
)
# temporarily enable "gpu" to point to the CUDA Accelerator
justusschock marked this conversation as resolved.
Show resolved Hide resolved
accelerator_registry.register(
"gpu",
cls,
description=f"{cls.__class__.__name__}",
)

def teardown(self) -> None:
# clean up memory
torch.cuda.empty_cache()


def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.

Args:
device: GPU device for which to get stats

Returns:
A dictionary mapping the metrics to their values.

Raises:
FileNotFoundError:
If nvidia-smi installation not found
"""
nvidia_smi_path = shutil.which("nvidia-smi")
if nvidia_smi_path is None:
raise FileNotFoundError("nvidia-smi: command not found")

gpu_stat_metrics = [
("utilization.gpu", "%"),
("memory.used", "MB"),
("memory.free", "MB"),
("utilization.memory", "%"),
("fan.speed", "%"),
("temperature.gpu", "°C"),
("temperature.memory", "°C"),
]
gpu_stat_keys = [k for k, _ in gpu_stat_metrics]
gpu_query = ",".join(gpu_stat_keys)

index = torch._utils._get_device_index(device)
gpu_id = _get_gpu_id(index)
result = subprocess.run(
[nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"],
encoding="utf-8",
capture_output=True,
check=True,
)

def _to_float(x: str) -> float:
try:
return float(x)
except ValueError:
return 0.0

s = result.stdout.strip()
stats = [_to_float(x) for x in s.split(", ")]
gpu_stats = {f"{x} ({unit})": stat for (x, unit), stat in zip(gpu_stat_metrics, stats)}
return gpu_stats


def _get_gpu_id(device_id: int) -> str:
"""Get the unmasked real GPU IDs."""
# All devices if `CUDA_VISIBLE_DEVICES` unset
default = ",".join(str(i) for i in range(torch.cuda.device_count()))
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",")
return cuda_visible_devices[device_id].strip()
154 changes: 12 additions & 142 deletions src/pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,151 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import subprocess
from typing import Any, Dict, List, Optional, Union
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _DEVICE
class GPUAccelerator(CUDAAccelerator):
"""Accelerator for NVIDIA GPU devices.

_log = logging.getLogger(__name__)
.. deprecated:: 1.9


class GPUAccelerator(Accelerator):
"""Accelerator for GPU devices."""

def setup_environment(self, root_device: torch.device) -> None:
"""
Raises:
MisconfigurationException:
If the selected device is not GPU.
"""
super().setup_environment(root_device)
if root_device.type != "cuda":
raise MisconfigurationException(f"Device should be GPU, got {root_device} instead")
torch.cuda.set_device(root_device)

def setup(self, trainer: "pl.Trainer") -> None:
# TODO refactor input from trainer to local_rank @four4fish
self.set_nvidia_flags(trainer.local_rank)
# clear cache before training
torch.cuda.empty_cache()

@staticmethod
def set_nvidia_flags(local_rank: int) -> None:
# set the correct cuda visible devices (using pci order)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
all_gpu_ids = ",".join(str(x) for x in range(torch.cuda.device_count()))
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")

def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
"""Gets stats for the given GPU device.

Args:
device: GPU device for which to get stats

Returns:
A dictionary mapping the metrics to their values.

Raises:
FileNotFoundError:
If nvidia-smi installation not found
"""
return torch.cuda.memory_stats(device)

@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
"""Accelerator device parsing logic."""
return device_parser.parse_gpu_ids(devices, include_cuda=True)

@staticmethod
def get_parallel_devices(devices: List[int]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
return [torch.device("cuda", i) for i in devices]

@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return torch.cuda.device_count()

@staticmethod
def is_available() -> bool:
return torch.cuda.device_count() > 0

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"gpu",
cls,
description=f"{cls.__class__.__name__}",
)

def teardown(self) -> None:
# clean up memory
torch.cuda.empty_cache()


def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.

Args:
device: GPU device for which to get stats

Returns:
A dictionary mapping the metrics to their values.

Raises:
FileNotFoundError:
If nvidia-smi installation not found
Please use the ``CUDAAccelerator`` instead.
"""
nvidia_smi_path = shutil.which("nvidia-smi")
if nvidia_smi_path is None:
raise FileNotFoundError("nvidia-smi: command not found")

gpu_stat_metrics = [
("utilization.gpu", "%"),
("memory.used", "MB"),
("memory.free", "MB"),
("utilization.memory", "%"),
("fan.speed", "%"),
("temperature.gpu", "°C"),
("temperature.memory", "°C"),
]
gpu_stat_keys = [k for k, _ in gpu_stat_metrics]
gpu_query = ",".join(gpu_stat_keys)

index = torch._utils._get_device_index(device)
gpu_id = _get_gpu_id(index)
result = subprocess.run(
[nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"],
encoding="utf-8",
capture_output=True,
check=True,
)

def _to_float(x: str) -> float:
try:
return float(x)
except ValueError:
return 0.0

s = result.stdout.strip()
stats = [_to_float(x) for x in s.split(", ")]
gpu_stats = {f"{x} ({unit})": stat for (x, unit), stat in zip(gpu_stat_metrics, stats)}
return gpu_stats


def _get_gpu_id(device_id: int) -> str:
"""Get the unmasked real GPU IDs."""
# All devices if `CUDA_VISIBLE_DEVICES` unset
default = ",".join(str(i) for i in range(torch.cuda.device_count()))
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",")
return cuda_visible_devices[device_id].strip()
def __init__(self) -> None:
rank_zero_deprecation(
"The `GPUAccelerator` has been renamed to `CUDAAccelerator` and will be removed in v1.9."
" Please use the `CUDAAccelerator` instead!"
)
super().__init__()
4 changes: 2 additions & 2 deletions src/pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.utils.data.dataloader import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.accelerators import CUDAAccelerator
from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
Expand Down Expand Up @@ -411,7 +411,7 @@ def _select_data_fetcher_type(trainer: "pl.Trainer") -> Type[AbstractDataFetcher
)
return DataLoaderIterDataFetcher
elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
if not isinstance(trainer.accelerator, GPUAccelerator):
if not isinstance(trainer.accelerator, CUDAAccelerator):
raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
return InterBatchParallelDataFetcher
return DataFetcher
Loading