Skip to content

Commit

Permalink
Fix Trainer with a parallel model (#9578)
Browse files Browse the repository at this point in the history
* Fix Trainer with a parallel model

* More clean up
  • Loading branch information
sgugger authored Jan 14, 2021
1 parent 126fd28 commit 5e1bea4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
23 changes: 11 additions & 12 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional

from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from .trainer_utils import EvaluationStrategy, SchedulerType
Expand Down Expand Up @@ -426,7 +426,6 @@ def __post_init__(self):

if is_torch_available() and self.device.type != "cuda" and self.fp16:
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
self._n_gpu = torch.cuda.device_count()

def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
Expand Down Expand Up @@ -467,14 +466,14 @@ def eval_batch_size(self) -> int:

@cached_property
@torch_required
def _setup_devices(self) -> Tuple["torch.device", int]:
def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices")
if self.no_cuda:
device = torch.device("cpu")
n_gpu = 0
self._n_gpu = 0
elif is_torch_tpu_available():
device = xm.xla_device()
n_gpu = 0
self._n_gpu = 0
elif self.local_rank == -1:
# if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
Expand All @@ -485,9 +484,7 @@ def _setup_devices(self) -> Tuple["torch.device", int]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
# the default value.
if self._n_gpu == -1:
self._n_gpu = torch.cuda.device_count()
n_gpu = self._n_gpu
self._n_gpu = torch.cuda.device_count()
else:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
Expand All @@ -507,20 +504,20 @@ def _setup_devices(self) -> Tuple["torch.device", int]:
else:
torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", self.local_rank)
n_gpu = 1
self._n_gpu = 1

if device.type == "cuda":
torch.cuda.set_device(device)

return device, n_gpu
return device

@property
@torch_required
def device(self) -> "torch.device":
"""
The device used by this process.
"""
return self._setup_devices[0]
return self._setup_devices

@property
@torch_required
Expand All @@ -532,7 +529,9 @@ def n_gpu(self):
This will only be greater than one when you have multiple GPUs available but are not using distributed
training. For distributed training, it will always be 1.
"""
return self._setup_devices[1]
# Make sure `self._n_gpu` is properly setup.
_ = self._setup_devices
return self._n_gpu

@property
@torch_required
Expand Down
4 changes: 3 additions & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,11 @@ def test_data_is_not_parallelized_when_model_is_parallel(self):
# Make the Trainer believe it's a parallelized model
model.is_parallelizable = True
model.model_parallel = True
trainer = Trainer(model=model, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
args = TrainingArguments("./regression", per_device_train_batch_size=16, per_device_eval_batch_size=16)
trainer = Trainer(model, args, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
# Check the Trainer was fooled
self.assertTrue(trainer.is_model_parallel)
self.assertEqual(trainer.args.n_gpu, 1)

# The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
self.assertEqual(trainer.get_train_dataloader().batch_size, 16)
Expand Down

0 comments on commit 5e1bea4

Please sign in to comment.