Skip to content

Commit

Permalink
Update persistent_workers recommendation when using spawn launcher (#…
Browse files Browse the repository at this point in the history
…18649)

Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
awaelchli and carmocca authored Sep 29, 2023
1 parent 3cd463e commit 996e768
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 78 deletions.
22 changes: 4 additions & 18 deletions docs/source-pytorch/advanced/speed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ The question of how many workers to specify in ``num_workers`` is tricky. Here's
1. ``num_workers=0`` means ONLY the main process will load batches (that can be a bottleneck).
2. ``num_workers=1`` means ONLY one worker (just not the main process) will load data, but it will still be slow.
3. The performance of high ``num_workers`` depends on the batch size and your machine.
4. A general place to start is to set ``num_workers`` equal to the number of CPU cores on that machine. You can get the number of CPU cores in python using ``os.cpu_count()``, but note that depending on your batch size, you may overflow RAM memory.
4. A general place to start is to set ``num_workers`` equal to the number of CPU cores on that machine. You can get the number of CPU cores in Python using ``os.cpu_count()``, but note that depending on your batch size, you may overflow CPU RAM.

.. warning:: Increasing ``num_workers`` will ALSO increase your CPU memory consumption.

The best thing to do is to increase the ``num_workers`` slowly and stop once there is no more improvement in your training speed.

For debugging purposes or for dataloaders that load very small datasets, it is desirable to set ``num_workers=0``. However, this will always log a warning for every dataloader with ``num_workers <= min(2, os.cpu_count())``. In such cases, you can specifically filter this warning by using:
For debugging purposes or for dataloaders that load very small datasets, it is desirable to set ``num_workers=0``. However, this will log a warning that you're not using enough workers. In such cases, you can specifically filter this warning by using:

.. code-block:: python
Expand All @@ -101,26 +101,12 @@ For debugging purposes or for dataloaders that load very small datasets, it is d
warnings.filterwarnings("ignore", category=PossibleUserWarning)
Spawn
^^^^^

When using ``strategy="ddp_spawn"`` or training on TPUs, the way multiple GPUs/TPU cores are used is by calling :obj:`torch.multiprocessing`
``.spawn()`` under the hood. The problem is that PyTorch has issues with ``num_workers>0`` when using ``.spawn()``. For this reason, we recommend you
use ``strategy="ddp"`` so you can increase the ``num_workers``, however since DDP doesn't work in an interactive environment like IPython/Jupyter notebooks
your script has to be callable like so:

.. code-block:: bash
python my_program.py
However, using ``strategy="ddp_spawn"`` enables to reduce memory usage with In-Memory Dataset and shared memory tensors. For more info, check out
:ref:`Sharing Datasets Across Process Boundaries <ddp_spawn_shared_memory>` section.
Persistent Workers
^^^^^^^^^^^^^^^^^^

When using ``strategy="ddp_spawn"`` and ``num_workers>0``, consider setting ``persistent_workers=True`` inside your DataLoader since it can result in data-loading bottlenecks and slowdowns.
This is a limitation of Python ``.spawn()`` and PyTorch.
If you use a large number of ``num_workers`` in your dataloaders or your epochs are very fast, you may notice a slowdown at the beginning of every epoch due to the time it takes for the dataloader to spawn its worker processes.
In this case, setting ``persistent_workers=True`` in your dataloader will significantly speed up the worker startup time across epochs.


TPU Training
Expand Down
33 changes: 12 additions & 21 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dataclasses import dataclass, field
from typing import Any, Iterable, Optional, Tuple, Union

import torch.multiprocessing as mp
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

Expand All @@ -28,7 +29,6 @@
)
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.combined_loader import CombinedLoader
Expand Down Expand Up @@ -420,28 +420,21 @@ def _check_dataloader_iterable(
)


def _worker_check(trainer: "pl.Trainer", using_spawn: bool, dataloader: object, name: str) -> None:
def _worker_check(trainer: "pl.Trainer", dataloader: object, name: str) -> None:
if not isinstance(dataloader, DataLoader):
return

upper_bound = suggested_max_num_workers(trainer.num_devices)
start_method = (
dataloader.multiprocessing_context.get_start_method()
if dataloader.multiprocessing_context is not None
else mp.get_start_method()
)

# ddp_spawn + num_workers > 0 don't mix! tell the user
if dataloader.num_workers > 0 and using_spawn:
if not dataloader.persistent_workers:
rank_zero_warn(
"num_workers>0, persistent_workers=False, and strategy=ddp_spawn"
" may result in data loading bottlenecks."
" Consider setting persistent_workers=True"
" (this is a limitation of Python .spawn() and PyTorch)"
)

elif dataloader.num_workers == 0 and using_spawn:
if not dataloader.persistent_workers:
rank_zero_warn(
"strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks."
" Consider setting num_workers>0 and persistent_workers=True"
)
if dataloader.num_workers > 0 and start_method == "spawn" and not dataloader.persistent_workers:
rank_zero_warn(
f"Consider setting `persistent_workers=True` in '{name}' to speed up the dataloader worker initialization."
)
elif dataloader.num_workers <= 2 < upper_bound or dataloader.num_workers < 2 <= upper_bound:
# if changed, update the `filterwarnings` snippet in 'speed.html#num-workers'
rank_zero_warn(
Expand Down Expand Up @@ -499,13 +492,11 @@ def _process_dataloader(
dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=is_shuffled, mode=stage)

# let the strategy inject its logic
strategy = trainer.strategy
dataloader = strategy.process_dataloader(dataloader)
dataloader = trainer.strategy.process_dataloader(dataloader)

# check the workers
_worker_check(
trainer=trainer,
using_spawn=isinstance(strategy, DDPStrategy) and strategy._start_method == "spawn",
dataloader=dataloader,
name=f"{stage.dataloader_prefix}_dataloader",
)
Expand Down
69 changes: 30 additions & 39 deletions tests/tests_pytorch/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import redirect_stderr
from io import StringIO
from re import escape
from typing import Sized
from unittest import mock
Expand Down Expand Up @@ -44,7 +42,7 @@

@RunIf(skip_windows=True)
@pytest.mark.parametrize("mode", [1, 2])
def test_replace_distributed_sampler(tmpdir, mode):
def test_replace_distributed_sampler(tmp_path, mode):
class IndexedRandomDataset(RandomDataset):
def __getitem__(self, index):
return self.data[index]
Expand Down Expand Up @@ -100,7 +98,7 @@ def test_dataloader(self):
model = TestModel(2, mode)

trainer = Trainer(
default_root_dir=tmpdir,
default_root_dir=tmp_path,
limit_test_batches=2,
accelerator="cpu",
devices=1,
Expand All @@ -110,43 +108,35 @@ def test_dataloader(self):


class TestSpawnBoringModel(BoringModel):
def __init__(self, num_workers):
def __init__(self, warning_expected=False):
super().__init__()
self.num_workers = num_workers

def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), num_workers=self.num_workers)
self.warning_expected = warning_expected

def on_fit_start(self):
self._resout = StringIO()
self.ctx = redirect_stderr(self._resout)
self.ctx.__enter__()
ctx = pytest.warns if self.warning_expected else no_warning_call
self.ctx = ctx(UserWarning, match="Consider setting `persistent_workers=True`")
if self.global_rank == 0:
self.ctx.__enter__()

def on_train_end(self):
def _get_warning_msg():
dl = self.trainer.train_dataloader
if hasattr(dl, "persistent_workers"):
if self.num_workers == 0:
warn_str = "Consider setting num_workers>0 and persistent_workers=True"
else:
warn_str = "Consider setting persistent_workers=True"
else:
warn_str = "Consider setting strategy=ddp"

return warn_str

if self.trainer.is_global_zero:
if self.global_rank == 0:
self.ctx.__exit__(None, None, None)
msg = self._resout.getvalue()
warn_str = _get_warning_msg()
assert warn_str in msg


@RunIf(skip_windows=True)
@pytest.mark.parametrize("num_workers", [0, 1])
def test_dataloader_warnings(tmpdir, num_workers):
trainer = Trainer(default_root_dir=tmpdir, accelerator="cpu", devices=2, strategy="ddp_spawn", fast_dev_run=4)
trainer.fit(TestSpawnBoringModel(num_workers))
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path):
"""Test that when the multiprocessing start-method is 'spawn', we recommend setting `persistent_workers=True`."""
trainer = Trainer(
default_root_dir=tmp_path,
accelerator="cpu",
devices=1,
strategy="ddp_spawn",
max_steps=1,
barebones=True,
)
model = TestSpawnBoringModel(warning_expected=(num_workers > 0))
dataloader = DataLoader(RandomDataset(32, 64), num_workers=num_workers)
trainer.fit(model, dataloader)


@pytest.mark.parametrize(
Expand All @@ -166,21 +156,22 @@ def test_dataloader_warnings(tmpdir, num_workers):
],
)
@mock.patch("lightning.fabric.utilities.data.os.cpu_count")
def test_worker_check(cpu_count_mock, num_devices, num_workers, cpu_count, expected_warning, monkeypatch):
@mock.patch("lightning.pytorch.trainer.connectors.data_connector.mp.get_start_method", return_value="not_spawn")
def test_worker_check(_, cpu_count_mock, num_devices, num_workers, cpu_count, expected_warning, monkeypatch):
monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False)
trainer = Mock(spec=Trainer)
dataloader = Mock(spec=DataLoader)
dataloader = Mock(spec=DataLoader, persistent_workers=False)
trainer.num_devices = num_devices
dataloader.num_workers = num_workers
cpu_count_mock.return_value = cpu_count

if expected_warning:
ctx = pytest.warns(UserWarning, match="Consider increasing the value of the `num_workers` argument`")
else:
ctx = no_warning_call(UserWarning)
ctx = no_warning_call()

with ctx:
_worker_check(trainer, using_spawn=False, dataloader=dataloader, name="train_dataloader")
_worker_check(trainer, dataloader=dataloader, name="train_dataloader")


def test_update_dataloader_raises():
Expand Down Expand Up @@ -628,10 +619,10 @@ def test_error_raised_with_insufficient_float_limit_train_dataloader():
("predict", "dataloaders"),
],
)
def test_attach_data_input_validation_with_none_dataloader(trainer_fn_name, dataloader_name, tmpdir):
def test_attach_data_input_validation_with_none_dataloader(trainer_fn_name, dataloader_name, tmp_path):
"""Test that passing `Trainer.method(x_dataloader=None)` with no module-method implementations available raises an
error."""
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True)
model = BoringModel()
datamodule = BoringDataModule()
trainer_fn = getattr(trainer, trainer_fn_name)
Expand Down

0 comments on commit 996e768

Please sign in to comment.