Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable non blocking to device with MPS #14368

Merged
merged 19 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
92ea1a0
disable non-blocking for mps due to race condition bug
j0rd1smit Aug 23, 2022
437a366
fixed typo
j0rd1smit Aug 23, 2022
401f97a
fixed: unknown mps device for non arm systems
j0rd1smit Aug 23, 2022
a112f95
Removed unrobust test case
j0rd1smit Aug 23, 2022
1e63b36
moved _MPS_DEVICES such that we used in apply_func
j0rd1smit Aug 23, 2022
e56026d
Merge branch 'master' into bugfix/13285_disable_non_blocking_to_devic…
j0rd1smit Aug 23, 2022
02220a7
Resolve circular dependencies
carmocca Aug 23, 2022
fa5ea1b
Comment rewording
carmocca Aug 23, 2022
a81ee6f
changed torchElasticEnvironment to a global import
j0rd1smit Aug 24, 2022
8762196
simplified if statement to blocking device type
j0rd1smit Aug 24, 2022
c81b8c2
Added change to CHANGELOG
j0rd1smit Aug 24, 2022
9ff968a
Update src/pytorch_lightning/utilities/apply_func.py
justusschock Aug 24, 2022
3add8c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2022
b89148d
fixed mypy not detecting casting of device
j0rd1smit Aug 24, 2022
a89eaeb
Moved check into if statement to mainain original behavior
j0rd1smit Aug 25, 2022
39af0b9
Merge branch 'master' into bugfix/13285_disable_non_blocking_to_devic…
justusschock Aug 25, 2022
523b3a8
Merge branch 'master' into bugfix/13285_disable_non_blocking_to_devic…
Borda Aug 26, 2022
f364736
Merge branch 'master' into bugfix/13285_disable_non_blocking_to_devic…
carmocca Aug 26, 2022
fa87e28
Merge branch 'master' into bugfix/13285_disable_non_blocking_to_devic…
carmocca Aug 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import dataclasses
import operator
import platform
from abc import ABC
from collections import defaultdict, OrderedDict
from collections.abc import Mapping, Sequence
Expand All @@ -27,7 +28,7 @@
from torch import Tensor

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY
from pytorch_lightning.utilities.imports import _compare_version, _TORCH_GREATER_EQUAL_1_12, _TORCHTEXT_LEGACY
from pytorch_lightning.utilities.warnings import rank_zero_deprecation

if _TORCHTEXT_LEGACY:
Expand All @@ -40,6 +41,10 @@


_CPU_DEVICES = ("cpu", torch.device("cpu"))
if _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
_MPS_DEVICES = ("mps", torch.device("mps"))
justusschock marked this conversation as resolved.
Show resolved Hide resolved
else:
_MPS_DEVICES = ("mps",)


def to_dtype_tensor(
Expand Down Expand Up @@ -343,7 +348,9 @@ def batch_to(data: Any) -> Any:

kwargs = {}
# Don't issue non-blocking transfers to CPU
if isinstance(data, Tensor) and device not in _CPU_DEVICES:
# Don't issue non-blocking transfers to MPS due to race condition bug:
# https://github.com/pytorch/pytorch/issues/83015
if isinstance(data, Tensor) and device not in _CPU_DEVICES and device not in _MPS_DEVICES:
justusschock marked this conversation as resolved.
Show resolved Hide resolved
kwargs["non_blocking"] = True
data_output = data.to(device, **kwargs)
if data_output is not None:
Expand Down
11 changes: 11 additions & 0 deletions tests/tests_pytorch/accelerators/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,14 @@ def to(self, *args, **kwargs):

assert batch.text.type() == "torch.mps.LongTensor"
assert batch.label.type() == "torch.mps.LongTensor"


@RunIf(mps=True)
def test_data_is_not_changed_after_move_to_mps_device():
trainer = Trainer(accelerator="mps", devices=1)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
x = torch.zeros([10, 10])
device = torch.device("mps")

for _ in range(1000):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
x_mps = trainer.strategy.batch_to_device(x.clone(), device)
torch.testing.assert_close(x_mps, x)