diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index 55454e7cac0a2..530c40e49ed3e 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -33,6 +33,7 @@ local tputests = base.BaseTest { echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}" coverage run --source=pytorch_lightning -m pytest -v --capture=no \ + tests/plugins/test_tpu_spawn.py \ tests/profiler/test_xla_profiler.py \ pytorch_lightning/utilities/xla_device.py \ tests/accelerators/test_tpu.py \ diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py index 5f4abf560d6a6..ba5dc0e9d5f0d 100644 --- a/tests/plugins/test_tpu_spawn.py +++ b/tests/plugins/test_tpu_spawn.py @@ -86,7 +86,7 @@ def test_error_process_iterable_dataloader(_): class BoringModelTPU(BoringModel): def on_train_start(self) -> None: - assert self.device == torch.device("xla") + assert self.device == torch.device("xla", index=1) assert os.environ.get("PT_XLA_DEBUG") == "1" @@ -94,12 +94,12 @@ def on_train_start(self) -> None: @pl_multi_process_test def test_model_tpu_one_core(): """Tests if device/debug flag is set correctely when training and after teardown for TPUSpawnPlugin.""" - trainer = Trainer(tpu_cores=1, fast_dev_run=True, plugin=TPUSpawnPlugin(debug=True)) + trainer = Trainer(tpu_cores=1, fast_dev_run=True, strategy=TPUSpawnPlugin(debug=True)) # assert training type plugin attributes for device setting assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin) assert not trainer.training_type_plugin.on_gpu assert trainer.training_type_plugin.on_tpu - assert trainer.training_type_plugin.root_device == torch.device("xla") + assert trainer.training_type_plugin.root_device == torch.device("xla", index=1) model = BoringModelTPU() trainer.fit(model) assert "PT_XLA_DEBUG" not in os.environ