Skip to content

Commit

Permalink
Remove TPU Availability check from parse devices (#12326)
Browse files Browse the repository at this point in the history
* Remove TPU Availability check from parse devices
* Update tests
  • Loading branch information
kaushikb11 authored Mar 29, 2022
1 parent 4fe0076 commit 041da41
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 12 deletions.
6 changes: 1 addition & 5 deletions pytorch_lightning/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _DEVICE

Expand Down Expand Up @@ -122,7 +121,7 @@ def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional
Raises:
MisconfigurationException:
If TPU cores aren't 1 or 8 cores, or no TPU devices are found
If TPU cores aren't 1, 8 or [<1-8>]
"""
_check_data_type(tpu_cores)

Expand All @@ -132,9 +131,6 @@ def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional
if not _tpu_cores_valid(tpu_cores):
raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]")

if tpu_cores is not None and not _TPU_AVAILABLE:
raise MisconfigurationException("No TPU devices were found.")

return tpu_cores


Expand Down
6 changes: 2 additions & 4 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,7 @@ def test_ipython_compatible_dp_strategy_gpu(_, monkeypatch):


@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
def test_ipython_compatible_strategy_tpu(mock_devices, mock_tpu_acc_avail, monkeypatch):
def test_ipython_compatible_strategy_tpu(mock_tpu_acc_avail, monkeypatch):
monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True)
trainer = Trainer(accelerator="tpu")
assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible
Expand Down Expand Up @@ -894,8 +893,7 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock


@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
def test_unsupported_tpu_choice(mock_devices, mock_tpu_acc_avail):
def test_unsupported_tpu_choice(mock_tpu_acc_avail):

with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"):
Trainer(accelerator="tpu", precision=64)
Expand Down
5 changes: 2 additions & 3 deletions tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,9 +1139,8 @@ def test_trainer_gpus(monkeypatch, trainer_kwargs):


def test_trainer_tpu_cores(monkeypatch):
monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "is_available", lambda: True)
monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "parse_devices", lambda: 8)
trainer = Trainer(accelerator="TPU", devices=8)
monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "is_available", lambda _: True)
trainer = Trainer(accelerator="tpu", devices=8)
with pytest.deprecated_call(
match="`Trainer.tpu_cores` is deprecated in v1.6 and will be removed in v1.8. "
"Please use `Trainer.num_devices` instead."
Expand Down

0 comments on commit 041da41

Please sign in to comment.