From 7337db92143b7884d616dbd060d7bff1a3e81ddd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 16 Aug 2020 17:19:57 +0200 Subject: [PATCH] ddp fix for trainer.test() + add basic ddp tests (#2997) * add ddp script variations * add ddp test * rename * shell * test * test * try call * try without subprocess * test * display the error * list all variations * try string * try copy env * debug * pythonpath * path * update test * change * simple ddp test * replace * remove random port * random port * str * clean up * check run spawn * clean up * docs * docs * update test * docs * changelog * changelog --- CHANGELOG.md | 2 + docs/source/multi_gpu.rst | 12 ++--- pytorch_lightning/accelerators/ddp_backend.py | 26 ++++++----- .../accelerators/ddp_spawn_backend.py | 5 ++- pytorch_lightning/core/lightning.py | 4 -- pytorch_lightning/core/saving.py | 1 + .../trainer/distrib_data_parallel.py | 21 --------- pytorch_lightning/trainer/trainer.py | 3 -- pytorch_lightning/utilities/distributed.py | 15 +++++++ tests/base/model_valid_epoch_ends.py | 1 - .../models/data/ddp/train_test_variations.py | 44 +++++++++++++++++++ tests/models/test_gpu.py | 34 ++++++++++++++ 12 files changed, 122 insertions(+), 46 deletions(-) create mode 100644 tests/models/data/ddp/train_test_variations.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d404a13842913..5aa98650d2949d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -144,6 +144,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed adding val step argument to metrics ([#2986](https://github.com/PyTorchLightning/pytorch-lightning/pull/2986)) +- Fixed an issue that caused `Trainer.test()` to stall in ddp mode ([#2997](https://github.com/PyTorchLightning/pytorch-lightning/pull/2997)) + ## [0.8.5] - 2020-07-09 ### Added diff --git a/docs/source/multi_gpu.rst b/docs/source/multi_gpu.rst index 706a977290f352..57b3e0813f54f9 100644 --- a/docs/source/multi_gpu.rst +++ b/docs/source/multi_gpu.rst @@ -286,17 +286,19 @@ variables: MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=1 LOCAL_RANK=0 python my_file.py --gpus 3 --etc MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=2 LOCAL_RANK=0 python my_file.py --gpus 3 --etc -If your code does not support this (ie: jupyter notebook, colab, or a nested script without a root package), -use `dp` or `ddp_spawn`. We use DDP this way because `ddp_spawn` has a few limitations (due to Python and PyTorch): 1. Since `.spawn()` trains the model in subprocesses, the model on the main process does not get updated. - 2. Dataloader(num_workers=N), where N is large, bottlenecks training with DDP... ie: it will be VERY slow or won't work at all. This is a PyTorch limitation. - 3. Forces everything to be picklable. -However, if you don't mind these limitations, you can use `ddp_spawn`. +There are cases in which it is not possible to use DDP. Examples are: + +- Jupyter Notebook, Google COLAB, Kaggle, etc. +- You have a nested script without a root package +- Your script needs to invoke `.fit` or `.test` multiple times + +In these situations you should use `dp` or `ddp_spawn` instead. Distributed Data Parallel 2 ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index ca4ade2d11f6bd..6866a66543e4a0 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -24,7 +24,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port try: from hydra.utils import to_absolute_path, get_original_cwd @@ -45,6 +45,7 @@ class DDPBackend(object): def __init__(self, trainer): self.trainer = trainer self.task_idx = None + self._has_spawned_children = False def slurm_setup(self): self.task_idx = int(os.environ['SLURM_LOCALID']) @@ -56,19 +57,17 @@ def train(self, model): self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) def spawn_ddp_children(self, model): - port = os.environ['MASTER_PORT'] + assert self.trainer.global_rank == 0 + self._check_can_spawn_children() + self._has_spawned_children = True - master_address = '127.0.0.1' if 'MASTER_ADDR' not in os.environ else os.environ['MASTER_ADDR'] - os.environ['MASTER_PORT'] = f'{port}' - os.environ['MASTER_ADDR'] = f'{master_address}' + os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR', '127.0.0.1') + os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port())) # allow the user to pass the node rank node_rank = '0' - if 'NODE_RANK' in os.environ: - node_rank = os.environ['NODE_RANK'] - if 'GROUP_RANK' in os.environ: - node_rank = os.environ['GROUP_RANK'] - + node_rank = os.environ.get('NODE_RANK', node_rank) + node_rank = os.environ.get('GROUP_RANK', node_rank) os.environ['NODE_RANK'] = node_rank os.environ['LOCAL_RANK'] = '0' @@ -235,3 +234,10 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 if self.trainer.global_rank == 0 and self.trainer.distributed_backend not in ['ddp_spawn', 'ddp_cpu']: return results + + def _check_can_spawn_children(self): + if self._has_spawned_children: + raise RuntimeError( + "You tried to run `.fit` or `.test` multiple times in the same script." + " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." + ) diff --git a/pytorch_lightning/accelerators/ddp_spawn_backend.py b/pytorch_lightning/accelerators/ddp_spawn_backend.py index 6fce4cb7a530ec..ee4954057964df 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_spawn_backend.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License +import os import torch import torch.multiprocessing as mp from pytorch_lightning import _logger as log from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port try: from apex import amp @@ -32,7 +33,7 @@ def __init__(self, trainer): self.mp_queue = None def setup(self): - self.trainer.set_random_port() + os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port())) # pass in a state q smp = mp.get_context('spawn') diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 66d067a5146b62..ba00e702f26c82 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -893,10 +893,6 @@ def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managi log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}") torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) - """ - configure_sync_batchnorm - ^^^^^^^^^^^^^^^^^^^^^^^^ - """ def configure_sync_batchnorm(self, model: 'LightningModule') -> 'LightningModule': """ Add global batchnorm for a model spread across multiple GPUs and nodes. diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 16a9467c71554f..1835251b2de54e 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -353,6 +353,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: with open(config_yaml, 'w', newline='') as fp: yaml.dump(hparams, fp) + def convert(val: str) -> Union[int, float, bool, str]: try: return ast.literal_eval(val) diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 5555c26172b112..438e4be3594413 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -132,7 +132,6 @@ def train_fx(trial_hparams, cluster_manager, _): from abc import ABC, abstractmethod from typing import Union, List, Optional, Tuple -import numpy as np import torch from pytorch_lightning import _logger as log @@ -163,10 +162,6 @@ def train_fx(trial_hparams, cluster_manager, _): else: XLA_AVAILABLE = True -PID = os.getpid() -RNG1 = np.random.RandomState(PID) -RANDOM_PORTS = RNG1.randint(10000, 19999, 1000) - class TrainerDDPMixin(ABC): @@ -389,22 +384,6 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids): # don't make this debug... this is good UX rank_zero_info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]') - def set_random_port(self, force=False): - """ - When running DDP NOT managed by SLURM, the ports might collide - """ - # pick a random port first - assert self.num_nodes == 1, 'random port can only be called from single node training' - global RANDOM_PORTS - default_port = RANDOM_PORTS[-1] - RANDOM_PORTS = RANDOM_PORTS[:-1] - - # when not forced, use the user port - if not force: - default_port = os.environ.get('MASTER_PORT', default_port) - - os.environ['MASTER_PORT'] = str(default_port) - def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): if self.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']: return diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8c2eb4bfe4c324..98d93ad6357fa4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1039,7 +1039,6 @@ def fit( # ddp elif self.distributed_backend == 'ddp': - self.set_random_port() self.accelerator_backend = DDPBackend(self) results = self.accelerator_backend.spawn_ddp_children(model) @@ -1377,7 +1376,6 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): # run tests self.tested_ckpt_path = ckpt_path - self.set_random_port(force=True) self.testing = True os.environ['PL_TESTING_MODE'] = '1' self.model = model @@ -1400,7 +1398,6 @@ def __test_given_model(self, model, test_dataloaders): # run test # sets up testing so we short circuit to eval - self.set_random_port(force=True) self.testing = True self.model = model results = self.fit(model) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index cd0621496fe42e..5375bfd0fa6ced 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -34,3 +34,18 @@ def _debug(*args, **kwargs): rank_zero_debug = rank_zero_only(_debug) rank_zero_info = rank_zero_only(_info) rank_zero_warn = rank_zero_only(_warn) + + +def find_free_network_port() -> int: + """ + Finds a free port on localhost. + It is useful in single-node training when we don't want to connect to a real master node but + have to set the `MASTER_PORT` environment variable. + """ + import socket + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + s.close() + return port diff --git a/tests/base/model_valid_epoch_ends.py b/tests/base/model_valid_epoch_ends.py index 89742244096248..943227e738caec 100644 --- a/tests/base/model_valid_epoch_ends.py +++ b/tests/base/model_valid_epoch_ends.py @@ -21,7 +21,6 @@ def _mean(res, key): # recursive mean for multilevel dicts return torch.stack([x[key] if isinstance(x, dict) else _mean(x, key) for x in res]).mean() - print('in validation epoch end') val_loss_mean = _mean(outputs, 'val_loss') val_acc_mean = _mean(outputs, 'val_acc') diff --git a/tests/models/data/ddp/train_test_variations.py b/tests/models/data/ddp/train_test_variations.py new file mode 100644 index 00000000000000..f37bd27e8a005c --- /dev/null +++ b/tests/models/data/ddp/train_test_variations.py @@ -0,0 +1,44 @@ +""" +Runs either `.fit()` or `.test()` on a single node across multiple gpus. +""" +from argparse import ArgumentParser + +from pytorch_lightning import Trainer, seed_everything +from tests.base import EvalModelTemplate + + +def variation_fit(trainer, model): + trainer.fit(model) + + +def variation_test(trainer, model): + trainer.test(model) + + +def get_variations(): + variations = [ + "variation_fit", + "variation_test", + ] + return variations + + +def main(): + seed_everything(1234) + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parser) + parser.add_argument('--variation', default=variation_fit.__name__) + parser.set_defaults(gpus=2) + parser.set_defaults(distributed_backend="ddp") + args = parser.parse_args() + + model = EvalModelTemplate() + trainer = Trainer.from_argparse_args(args) + + # run the chosen variation + run_variation = globals()[args.variation] + run_variation(trainer, model) + + +if __name__ == '__main__': + main() diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 509de4c07563a2..b6a2efbb8621b5 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -1,10 +1,15 @@ +import os +import subprocess +import sys from collections import namedtuple +from pathlib import Path from unittest.mock import patch import pytest import torch from torchtext.data import Batch, Dataset, Example, Field, LabelField +import pytorch_lightning import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer @@ -12,6 +17,7 @@ from pytorch_lightning.trainer.distrib_parts import _parse_gpu_ids, determine_root_gpu_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate +from tests.models.data.ddp import train_test_variations PRETEND_N_OF_GPUS = 16 @@ -94,6 +100,34 @@ def test_multi_gpu_model_dp(tmpdir): memory.get_memory_profile('min_max') +@pytest.mark.parametrize('cli_args', [ + pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'), +]) +@pytest.mark.parametrize('variation', train_test_variations.get_variations()) +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_multi_gpu_model_ddp(tmpdir, cli_args, variation): + """ Runs a basic training and test run with distributed_backend=ddp. """ + file = Path(train_test_variations.__file__).absolute() + cli_args = cli_args.split(' ') if cli_args else [] + cli_args += ['--default_root_dir', str(tmpdir)] + cli_args += ['--variation', variation] + command = [sys.executable, str(file)] + cli_args + + # need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment + env = os.environ.copy() + env['PYTHONPATH'] = f'{pytorch_lightning.__file__}:' + env.get('PYTHONPATH', '') + + # for running in ddp mode, we need to lauch it's own process or pytest will get stuck + p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) + + std, err = p.communicate(timeout=60) + std = std.decode('utf-8').strip() + err = err.decode('utf-8').strip() + assert std, f"{variation} produced no output" + if p.returncode > 0: + pytest.fail(err) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_multi_gpu_model_ddp_spawn(tmpdir): tutils.set_random_master_port()