From cced33542db5896afc0e8b3e3c9263dba44deb4f Mon Sep 17 00:00:00 2001 From: Jordi Smit Date: Fri, 26 Aug 2022 19:17:20 +0200 Subject: [PATCH] Disable non blocking to device with MPS (#14368) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * disable non-blocking for mps due to race condition bug * fixed typo * fixed: unknown mps device for non arm systems * Removed unrobust test case * moved _MPS_DEVICES such that we used in apply_func * Resolve circular dependencies * Comment rewording * changed torchElasticEnvironment to a global import * simplified if statement to blocking device type * Added change to CHANGELOG * Update src/pytorch_lightning/utilities/apply_func.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed mypy not detecting casting of device * Moved check into if statement to mainain original behavior Co-authored-by: Carlos MocholĂ­ Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec --- src/pytorch_lightning/CHANGELOG.md | 3 +++ src/pytorch_lightning/accelerators/cpu.py | 6 +++--- src/pytorch_lightning/accelerators/hpu.py | 5 +++-- src/pytorch_lightning/accelerators/ipu.py | 2 +- .../plugins/environments/xla_environment.py | 2 +- src/pytorch_lightning/utilities/apply_func.py | 8 ++++++-- src/pytorch_lightning/utilities/device_parser.py | 1 + 7 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 4b5e3a893429e..4a746ac91b911 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -104,6 +104,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed wrong num padding for `RichProgressBar` ([#14296](https://github.com/Lightning-AI/lightning/pull/14296)) +- Fixed incorrect values after transferring data to a MPS device ([#13285](https://github.com/Lightning-AI/lightning/issues/13285)) + + - Fixed an issue to avoid the impact of sanity check on `reload_dataloaders_every_n_epochs` for validation ([#13964](https://github.com/Lightning-AI/lightning/pull/13964)) diff --git a/src/pytorch_lightning/accelerators/cpu.py b/src/pytorch_lightning/accelerators/cpu.py index fea8ee70d17df..d0981e7269305 100644 --- a/src/pytorch_lightning/accelerators/cpu.py +++ b/src/pytorch_lightning/accelerators/cpu.py @@ -16,7 +16,7 @@ import torch from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities import device_parser +from pytorch_lightning.utilities.device_parser import parse_cpu_cores from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE from pytorch_lightning.utilities.types import _DEVICE @@ -42,13 +42,13 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: @staticmethod def parse_devices(devices: Union[int, str, List[int]]) -> int: """Accelerator device parsing logic.""" - devices = device_parser.parse_cpu_cores(devices) + devices = parse_cpu_cores(devices) return devices @staticmethod def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" - devices = device_parser.parse_cpu_cores(devices) + devices = parse_cpu_cores(devices) return [torch.device("cpu")] * devices @staticmethod diff --git a/src/pytorch_lightning/accelerators/hpu.py b/src/pytorch_lightning/accelerators/hpu.py index 8fc242fa55f20..c85e81756c2a9 100644 --- a/src/pytorch_lightning/accelerators/hpu.py +++ b/src/pytorch_lightning/accelerators/hpu.py @@ -17,8 +17,9 @@ import torch from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities import _HPU_AVAILABLE, device_parser +from pytorch_lightning.utilities.device_parser import parse_hpus from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _HPU_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_debug if _HPU_AVAILABLE: @@ -61,7 +62,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: @staticmethod def parse_devices(devices: Union[int, str, List[int]]) -> Optional[int]: """Accelerator device parsing logic.""" - return device_parser.parse_hpus(devices) + return parse_hpus(devices) @staticmethod def get_parallel_devices(devices: int) -> List[torch.device]: diff --git a/src/pytorch_lightning/accelerators/ipu.py b/src/pytorch_lightning/accelerators/ipu.py index b5110e58028a5..b09fd33c29227 100644 --- a/src/pytorch_lightning/accelerators/ipu.py +++ b/src/pytorch_lightning/accelerators/ipu.py @@ -16,7 +16,7 @@ import torch from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities import _IPU_AVAILABLE +from pytorch_lightning.utilities.imports import _IPU_AVAILABLE class IPUAccelerator(Accelerator): diff --git a/src/pytorch_lightning/plugins/environments/xla_environment.py b/src/pytorch_lightning/plugins/environments/xla_environment.py index a78ebeb36a6a4..4072f6f8715f5 100644 --- a/src/pytorch_lightning/plugins/environments/xla_environment.py +++ b/src/pytorch_lightning/plugins/environments/xla_environment.py @@ -15,7 +15,7 @@ import os from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.imports import _TPU_AVAILABLE if _TPU_AVAILABLE: import torch_xla.core.xla_env_vars as xenv diff --git a/src/pytorch_lightning/utilities/apply_func.py b/src/pytorch_lightning/utilities/apply_func.py index 15e9962e40bff..fe8b0b836456f 100644 --- a/src/pytorch_lightning/utilities/apply_func.py +++ b/src/pytorch_lightning/utilities/apply_func.py @@ -38,7 +38,7 @@ Batch = type(None) -_CPU_DEVICES = ("cpu", torch.device("cpu")) +_BLOCKING_DEVICE_TYPES = ("cpu", "mps") def to_dtype_tensor( @@ -322,6 +322,9 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any: - :class:`torch.device` """ + if isinstance(device, str): + device = torch.device(device) + def batch_to(data: Any) -> Any: # try to move torchtext data first if _TORCHTEXT_LEGACY and isinstance(data, Batch): @@ -342,7 +345,8 @@ def batch_to(data: Any) -> Any: kwargs = {} # Don't issue non-blocking transfers to CPU - if isinstance(data, Tensor) and device not in _CPU_DEVICES: + # Same with MPS due to a race condition bug: https://github.com/pytorch/pytorch/issues/83015 + if isinstance(data, Tensor) and isinstance(device, torch.device) and device.type not in _BLOCKING_DEVICE_TYPES: kwargs["non_blocking"] = True data_output = data.to(device, **kwargs) if data_output is not None: diff --git a/src/pytorch_lightning/utilities/device_parser.py b/src/pytorch_lightning/utilities/device_parser.py index 1b9f43137943f..32f370b5b246e 100644 --- a/src/pytorch_lightning/utilities/device_parser.py +++ b/src/pytorch_lightning/utilities/device_parser.py @@ -93,6 +93,7 @@ def parse_gpu_ids( gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps) if not gpus: raise MisconfigurationException("GPUs requested but none are available.") + if ( TorchElasticEnvironment.detect() and len(gpus) != 1