Skip to content

Commit

Permalink
Fix tpu spawn plugin test (#11131)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored and lexierule committed Dec 21, 2021
1 parent a542f82 commit 8b2c39b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
1 change: 1 addition & 0 deletions dockers/tpu-tests/tpu_test_cases.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ local tputests = base.BaseTest {
echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS
export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}"
coverage run --source=pytorch_lightning -m pytest -v --capture=no \
tests/plugins/test_tpu_spawn.py \
tests/profiler/test_xla_profiler.py \
pytorch_lightning/utilities/xla_device.py \
tests/accelerators/test_tpu.py \
Expand Down
6 changes: 3 additions & 3 deletions tests/plugins/test_tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,20 @@ def test_error_process_iterable_dataloader(_):

class BoringModelTPU(BoringModel):
def on_train_start(self) -> None:
assert self.device == torch.device("xla")
assert self.device == torch.device("xla", index=1)
assert os.environ.get("PT_XLA_DEBUG") == "1"


@RunIf(tpu=True)
@pl_multi_process_test
def test_model_tpu_one_core():
"""Tests if device/debug flag is set correctely when training and after teardown for TPUSpawnPlugin."""
trainer = Trainer(tpu_cores=1, fast_dev_run=True, plugin=TPUSpawnPlugin(debug=True))
trainer = Trainer(tpu_cores=1, fast_dev_run=True, strategy=TPUSpawnPlugin(debug=True))
# assert training type plugin attributes for device setting
assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)
assert not trainer.training_type_plugin.on_gpu
assert trainer.training_type_plugin.on_tpu
assert trainer.training_type_plugin.root_device == torch.device("xla")
assert trainer.training_type_plugin.root_device == torch.device("xla", index=1)
model = BoringModelTPU()
trainer.fit(model)
assert "PT_XLA_DEBUG" not in os.environ
Expand Down

0 comments on commit 8b2c39b

Please sign in to comment.