diff --git a/docs/source-pytorch/advanced/speed.rst b/docs/source-pytorch/advanced/speed.rst index 97503f7f5b482..6c5302b538848 100644 --- a/docs/source-pytorch/advanced/speed.rst +++ b/docs/source-pytorch/advanced/speed.rst @@ -82,13 +82,13 @@ The question of how many workers to specify in ``num_workers`` is tricky. Here's 1. ``num_workers=0`` means ONLY the main process will load batches (that can be a bottleneck). 2. ``num_workers=1`` means ONLY one worker (just not the main process) will load data, but it will still be slow. 3. The performance of high ``num_workers`` depends on the batch size and your machine. -4. A general place to start is to set ``num_workers`` equal to the number of CPU cores on that machine. You can get the number of CPU cores in python using ``os.cpu_count()``, but note that depending on your batch size, you may overflow RAM memory. +4. A general place to start is to set ``num_workers`` equal to the number of CPU cores on that machine. You can get the number of CPU cores in Python using ``os.cpu_count()``, but note that depending on your batch size, you may overflow CPU RAM. .. warning:: Increasing ``num_workers`` will ALSO increase your CPU memory consumption. The best thing to do is to increase the ``num_workers`` slowly and stop once there is no more improvement in your training speed. -For debugging purposes or for dataloaders that load very small datasets, it is desirable to set ``num_workers=0``. However, this will always log a warning for every dataloader with ``num_workers <= min(2, os.cpu_count())``. In such cases, you can specifically filter this warning by using: +For debugging purposes or for dataloaders that load very small datasets, it is desirable to set ``num_workers=0``. However, this will log a warning that you're not using enough workers. In such cases, you can specifically filter this warning by using: .. code-block:: python @@ -101,26 +101,12 @@ For debugging purposes or for dataloaders that load very small datasets, it is d warnings.filterwarnings("ignore", category=PossibleUserWarning) -Spawn -^^^^^ - -When using ``strategy="ddp_spawn"`` or training on TPUs, the way multiple GPUs/TPU cores are used is by calling :obj:`torch.multiprocessing` -``.spawn()`` under the hood. The problem is that PyTorch has issues with ``num_workers>0`` when using ``.spawn()``. For this reason, we recommend you -use ``strategy="ddp"`` so you can increase the ``num_workers``, however since DDP doesn't work in an interactive environment like IPython/Jupyter notebooks -your script has to be callable like so: - -.. code-block:: bash - - python my_program.py - -However, using ``strategy="ddp_spawn"`` enables to reduce memory usage with In-Memory Dataset and shared memory tensors. For more info, check out -:ref:`Sharing Datasets Across Process Boundaries ` section. Persistent Workers ^^^^^^^^^^^^^^^^^^ -When using ``strategy="ddp_spawn"`` and ``num_workers>0``, consider setting ``persistent_workers=True`` inside your DataLoader since it can result in data-loading bottlenecks and slowdowns. -This is a limitation of Python ``.spawn()`` and PyTorch. +If you use a large number of ``num_workers`` in your dataloaders or your epochs are very fast, you may notice a slowdown at the beginning of every epoch due to the time it takes for the dataloader to spawn its worker processes. +In this case, setting ``persistent_workers=True`` in your dataloader will significantly speed up the worker startup time across epochs. TPU Training diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index a4723cad0b99e..3c9371f4e2a4d 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -15,6 +15,7 @@ from dataclasses import dataclass, field from typing import Any, Iterable, Optional, Tuple, Union +import torch.multiprocessing as mp from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -28,7 +29,6 @@ ) from lightning.fabric.utilities.distributed import DistributedSamplerWrapper from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper -from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.trainer import call from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.combined_loader import CombinedLoader @@ -420,28 +420,21 @@ def _check_dataloader_iterable( ) -def _worker_check(trainer: "pl.Trainer", using_spawn: bool, dataloader: object, name: str) -> None: +def _worker_check(trainer: "pl.Trainer", dataloader: object, name: str) -> None: if not isinstance(dataloader, DataLoader): return upper_bound = suggested_max_num_workers(trainer.num_devices) + start_method = ( + dataloader.multiprocessing_context.get_start_method() + if dataloader.multiprocessing_context is not None + else mp.get_start_method() + ) - # ddp_spawn + num_workers > 0 don't mix! tell the user - if dataloader.num_workers > 0 and using_spawn: - if not dataloader.persistent_workers: - rank_zero_warn( - "num_workers>0, persistent_workers=False, and strategy=ddp_spawn" - " may result in data loading bottlenecks." - " Consider setting persistent_workers=True" - " (this is a limitation of Python .spawn() and PyTorch)" - ) - - elif dataloader.num_workers == 0 and using_spawn: - if not dataloader.persistent_workers: - rank_zero_warn( - "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks." - " Consider setting num_workers>0 and persistent_workers=True" - ) + if dataloader.num_workers > 0 and start_method == "spawn" and not dataloader.persistent_workers: + rank_zero_warn( + f"Consider setting `persistent_workers=True` in '{name}' to speed up the dataloader worker initialization." + ) elif dataloader.num_workers <= 2 < upper_bound or dataloader.num_workers < 2 <= upper_bound: # if changed, update the `filterwarnings` snippet in 'speed.html#num-workers' rank_zero_warn( @@ -499,13 +492,11 @@ def _process_dataloader( dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=is_shuffled, mode=stage) # let the strategy inject its logic - strategy = trainer.strategy - dataloader = strategy.process_dataloader(dataloader) + dataloader = trainer.strategy.process_dataloader(dataloader) # check the workers _worker_check( trainer=trainer, - using_spawn=isinstance(strategy, DDPStrategy) and strategy._start_method == "spawn", dataloader=dataloader, name=f"{stage.dataloader_prefix}_dataloader", ) diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 9402420ff3a03..b8025f5dcba66 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import redirect_stderr -from io import StringIO from re import escape from typing import Sized from unittest import mock @@ -44,7 +42,7 @@ @RunIf(skip_windows=True) @pytest.mark.parametrize("mode", [1, 2]) -def test_replace_distributed_sampler(tmpdir, mode): +def test_replace_distributed_sampler(tmp_path, mode): class IndexedRandomDataset(RandomDataset): def __getitem__(self, index): return self.data[index] @@ -100,7 +98,7 @@ def test_dataloader(self): model = TestModel(2, mode) trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=tmp_path, limit_test_batches=2, accelerator="cpu", devices=1, @@ -110,43 +108,35 @@ def test_dataloader(self): class TestSpawnBoringModel(BoringModel): - def __init__(self, num_workers): + def __init__(self, warning_expected=False): super().__init__() - self.num_workers = num_workers - - def train_dataloader(self): - return DataLoader(RandomDataset(32, 64), num_workers=self.num_workers) + self.warning_expected = warning_expected def on_fit_start(self): - self._resout = StringIO() - self.ctx = redirect_stderr(self._resout) - self.ctx.__enter__() + ctx = pytest.warns if self.warning_expected else no_warning_call + self.ctx = ctx(UserWarning, match="Consider setting `persistent_workers=True`") + if self.global_rank == 0: + self.ctx.__enter__() def on_train_end(self): - def _get_warning_msg(): - dl = self.trainer.train_dataloader - if hasattr(dl, "persistent_workers"): - if self.num_workers == 0: - warn_str = "Consider setting num_workers>0 and persistent_workers=True" - else: - warn_str = "Consider setting persistent_workers=True" - else: - warn_str = "Consider setting strategy=ddp" - - return warn_str - - if self.trainer.is_global_zero: + if self.global_rank == 0: self.ctx.__exit__(None, None, None) - msg = self._resout.getvalue() - warn_str = _get_warning_msg() - assert warn_str in msg -@RunIf(skip_windows=True) -@pytest.mark.parametrize("num_workers", [0, 1]) -def test_dataloader_warnings(tmpdir, num_workers): - trainer = Trainer(default_root_dir=tmpdir, accelerator="cpu", devices=2, strategy="ddp_spawn", fast_dev_run=4) - trainer.fit(TestSpawnBoringModel(num_workers)) +@pytest.mark.parametrize("num_workers", [0, 1, 2]) +def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path): + """Test that when the multiprocessing start-method is 'spawn', we recommend setting `persistent_workers=True`.""" + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cpu", + devices=1, + strategy="ddp_spawn", + max_steps=1, + barebones=True, + ) + model = TestSpawnBoringModel(warning_expected=(num_workers > 0)) + dataloader = DataLoader(RandomDataset(32, 64), num_workers=num_workers) + trainer.fit(model, dataloader) @pytest.mark.parametrize( @@ -166,10 +156,11 @@ def test_dataloader_warnings(tmpdir, num_workers): ], ) @mock.patch("lightning.fabric.utilities.data.os.cpu_count") -def test_worker_check(cpu_count_mock, num_devices, num_workers, cpu_count, expected_warning, monkeypatch): +@mock.patch("lightning.pytorch.trainer.connectors.data_connector.mp.get_start_method", return_value="not_spawn") +def test_worker_check(_, cpu_count_mock, num_devices, num_workers, cpu_count, expected_warning, monkeypatch): monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False) trainer = Mock(spec=Trainer) - dataloader = Mock(spec=DataLoader) + dataloader = Mock(spec=DataLoader, persistent_workers=False) trainer.num_devices = num_devices dataloader.num_workers = num_workers cpu_count_mock.return_value = cpu_count @@ -177,10 +168,10 @@ def test_worker_check(cpu_count_mock, num_devices, num_workers, cpu_count, expec if expected_warning: ctx = pytest.warns(UserWarning, match="Consider increasing the value of the `num_workers` argument`") else: - ctx = no_warning_call(UserWarning) + ctx = no_warning_call() with ctx: - _worker_check(trainer, using_spawn=False, dataloader=dataloader, name="train_dataloader") + _worker_check(trainer, dataloader=dataloader, name="train_dataloader") def test_update_dataloader_raises(): @@ -628,10 +619,10 @@ def test_error_raised_with_insufficient_float_limit_train_dataloader(): ("predict", "dataloaders"), ], ) -def test_attach_data_input_validation_with_none_dataloader(trainer_fn_name, dataloader_name, tmpdir): +def test_attach_data_input_validation_with_none_dataloader(trainer_fn_name, dataloader_name, tmp_path): """Test that passing `Trainer.method(x_dataloader=None)` with no module-method implementations available raises an error.""" - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True) model = BoringModel() datamodule = BoringDataModule() trainer_fn = getattr(trainer, trainer_fn_name)