diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e22b81cf3fa1..8804dba4d9c8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -423,6 +423,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the lr-scheduler state not being dumped to checkpoint when using the deepspeed strategy ([#11307](https://github.com/PyTorchLightning/pytorch-lightning/pull/11307)) +- Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507)) + + ## [1.5.8] - 2022-01-05 ### Fixed diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 9ff8473cb744f..b365ab99f3cf3 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -228,7 +228,11 @@ def _get_dataloader_init_kwargs( # kwargs to re-construct the dataloader dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} - dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode)) + if isinstance(dl_kwargs["dataset"], IterableDataset): + dl_kwargs["batch_sampler"] = None + dl_kwargs["sampler"] = None + else: + dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode)) required_args = { p.name @@ -263,10 +267,6 @@ def _get_dataloader_init_kwargs( f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) - if isinstance(dl_kwargs["dataset"], IterableDataset): - dl_kwargs["batch_sampler"] = None - dl_kwargs["sampler"] = None - if _FaultTolerantMode.detect_current_mode().is_automatic: dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs) diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index 629d141505004..0d874e81d8c67 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -3,7 +3,9 @@ from torch.utils.data.dataloader import DataLoader from pytorch_lightning import Trainer +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.data import ( + _get_dataloader_init_kwargs, _replace_dataloader_init_method, _update_dataloader, extract_batch_size, @@ -172,3 +174,16 @@ def __init__(self, attribute1, attribute2, *args, **kwargs): dataloader = DataLoaderSubclass2("attribute1", "attribute2", dataset=range(4), batch_size=2) assert dataloader.attribute1 == "attribute1" assert dataloader.attribute2 == "attribute2" + + +@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING]) +def test_dataloader_kwargs_replacement_with_iterable_dataset(mode): + """Test that DataLoader kwargs are not replaced when using Iterable Dataset.""" + dataset = RandomIterableDataset(7, 100) + dataloader = DataLoader(dataset, batch_size=32) + dl_kwargs = _get_dataloader_init_kwargs(dataloader, dataloader.sampler, mode=mode) + assert dl_kwargs["sampler"] is None + assert dl_kwargs["batch_sampler"] is None + assert dl_kwargs["batch_size"] is dataloader.batch_size + assert dl_kwargs["dataset"] is dataloader.dataset + assert dl_kwargs["collate_fn"] is dataloader.collate_fn