Skip to content

Commit

Permalink
Disable non blocking to device with MPS (#14368)
Browse files Browse the repository at this point in the history
* disable non-blocking for mps due to race condition bug

* fixed typo

* fixed: unknown mps device for non arm systems

* Removed unrobust test case

* moved _MPS_DEVICES such that we used in apply_func

* Resolve circular dependencies

* Comment rewording

* changed torchElasticEnvironment to a global import

* simplified if statement to blocking device type

* Added change to CHANGELOG

* Update src/pytorch_lightning/utilities/apply_func.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed mypy not detecting casting of device

* Moved check into if statement to mainain original behavior

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Justus Schock <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
5 people authored Aug 26, 2022
1 parent d4bcafa commit cced335
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 9 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed wrong num padding for `RichProgressBar` ([#14296](https://github.com/Lightning-AI/lightning/pull/14296))


- Fixed incorrect values after transferring data to a MPS device ([#13285](https://github.com/Lightning-AI/lightning/issues/13285))


- Fixed an issue to avoid the impact of sanity check on `reload_dataloaders_every_n_epochs` for validation ([#13964](https://github.com/Lightning-AI/lightning/pull/13964))


Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.device_parser import parse_cpu_cores
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE
from pytorch_lightning.utilities.types import _DEVICE
Expand All @@ -42,13 +42,13 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> int:
"""Accelerator device parsing logic."""
devices = device_parser.parse_cpu_cores(devices)
devices = parse_cpu_cores(devices)
return devices

@staticmethod
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = device_parser.parse_cpu_cores(devices)
devices = parse_cpu_cores(devices)
return [torch.device("cpu")] * devices

@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions src/pytorch_lightning/accelerators/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities import _HPU_AVAILABLE, device_parser
from pytorch_lightning.utilities.device_parser import parse_hpus
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _HPU_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_debug

if _HPU_AVAILABLE:
Expand Down Expand Up @@ -61,7 +62,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[int]:
"""Accelerator device parsing logic."""
return device_parser.parse_hpus(devices)
return parse_hpus(devices)

@staticmethod
def get_parallel_devices(devices: int) -> List[torch.device]:
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/accelerators/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities import _IPU_AVAILABLE
from pytorch_lightning.utilities.imports import _IPU_AVAILABLE


class IPUAccelerator(Accelerator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE

if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv
Expand Down
8 changes: 6 additions & 2 deletions src/pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
Batch = type(None)


_CPU_DEVICES = ("cpu", torch.device("cpu"))
_BLOCKING_DEVICE_TYPES = ("cpu", "mps")


def to_dtype_tensor(
Expand Down Expand Up @@ -322,6 +322,9 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
- :class:`torch.device`
"""

if isinstance(device, str):
device = torch.device(device)

def batch_to(data: Any) -> Any:
# try to move torchtext data first
if _TORCHTEXT_LEGACY and isinstance(data, Batch):
Expand All @@ -342,7 +345,8 @@ def batch_to(data: Any) -> Any:

kwargs = {}
# Don't issue non-blocking transfers to CPU
if isinstance(data, Tensor) and device not in _CPU_DEVICES:
# Same with MPS due to a race condition bug: https://github.com/pytorch/pytorch/issues/83015
if isinstance(data, Tensor) and isinstance(device, torch.device) and device.type not in _BLOCKING_DEVICE_TYPES:
kwargs["non_blocking"] = True
data_output = data.to(device, **kwargs)
if data_output is not None:
Expand Down
1 change: 1 addition & 0 deletions src/pytorch_lightning/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def parse_gpu_ids(
gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps)
if not gpus:
raise MisconfigurationException("GPUs requested but none are available.")

if (
TorchElasticEnvironment.detect()
and len(gpus) != 1
Expand Down

0 comments on commit cced335

Please sign in to comment.