Skip to content

Commit

Permalink
Fix all is_torch_tpu_available issues (huggingface#17936)
Browse files Browse the repository at this point in the history
* Fix all is_torch_tpu_available
  • Loading branch information
muellerzr authored and viclzhu committed Jul 18, 2022
1 parent 409c77d commit 91d925a
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 19 deletions.
2 changes: 1 addition & 1 deletion examples/pytorch/question-answering/trainer_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformers.trainer_utils import PredictionOutput


if is_torch_tpu_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/question-answering/trainer_seq2seq_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from transformers.trainer_utils import PredictionOutput


if is_torch_tpu_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

logger = logging.getLogger(__name__)

if is_torch_tpu_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/benchmark/benchmark_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
if is_torch_available():
import torch

if is_torch_tpu_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm


Expand Down
2 changes: 1 addition & 1 deletion src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def require_torch_tpu(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
"""
return unittest.skipUnless(is_torch_tpu_available(), "test requires PyTorch TPU")(test_case)
return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case)


if is_torch_available():
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
if is_datasets_available():
import datasets

if is_torch_tpu_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout))

if is_torch_tpu_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm

# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def is_main_process(local_rank):
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
`local_rank`.
"""
if is_torch_tpu_available():
if is_torch_tpu_available(check_device=True):
import torch_xla.core.xla_model as xm

return xm.get_ordinal() == 0
Expand All @@ -318,7 +318,7 @@ def total_processes_number(local_rank):
"""
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
"""
if is_torch_tpu_available():
if is_torch_tpu_available(check_device=True):
import torch_xla.core.xla_model as xm

return xm.xrt_world_size()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
import torch
import torch.distributed as dist

if is_torch_tpu_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm


Expand Down
21 changes: 12 additions & 9 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,19 +396,22 @@ def is_ftfy_available():
return _ftfy_available


def is_torch_tpu_available():
def is_torch_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
if not _torch_available:
return False
if importlib.util.find_spec("torch_xla") is None:
return False
import torch_xla.core.xla_model as xm
if importlib.util.find_spec("torch_xla") is not None:
if check_device:
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
try:
import torch_xla.core.xla_model as xm

# We need to check if `xla_device` can be found, will raise a RuntimeError if not
try:
xm.xla_device()
_ = xm.xla_device()
return True
except RuntimeError:
return False
return True
except RuntimeError:
return False
return False


def is_torchdynamo_available():
Expand Down

0 comments on commit 91d925a

Please sign in to comment.