From 036ecf2708c3025d72834f390b1163c8bdadb2e8 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 7 Feb 2022 09:36:58 -0800 Subject: [PATCH 1/2] Update tpu.py --- pytorch_lightning/accelerators/tpu.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 34c37dcd95e7f..a325aa9b17c91 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -16,7 +16,8 @@ import torch from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities import _XLA_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _XLA_AVAILABLE if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm @@ -25,6 +26,15 @@ class TPUAccelerator(Accelerator): """Accelerator for TPU devices.""" + def setup_environment(self, root_device: torch.device) -> None: + """ + Raises: + MisconfigurationException: + If the TPU device is not available. + """ + if not _XLA_AVAILABLE: + raise MisconfigurationException("The TPU Accelerator requires torch_xla and a TPU device to run.") + def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """Gets stats for the given TPU device. From 88e65e81d751a4c08f23dd60013e5e0c0b383956 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 7 Feb 2022 09:42:27 -0800 Subject: [PATCH 2/2] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ pytorch_lightning/accelerators/tpu.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 38d1c7c3f2833..daf6445425461 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -98,6 +98,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `_Stateful` support for `LightningDataModule` ([#11637](https://github.com/PyTorchLightning/pytorch-lightning/pull/11637)) +- Added checks to `TPUAccelerator.setup_environment` to assert device availability ([#11799](https://github.com/PyTorchLightning/pytorch-lightning/pull/11799)) + + ### Changed - Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332)) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index a325aa9b17c91..3d552ddf6d5c0 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -26,7 +26,7 @@ class TPUAccelerator(Accelerator): """Accelerator for TPU devices.""" - def setup_environment(self, root_device: torch.device) -> None: + def __init__(self) -> None: """ Raises: MisconfigurationException: