Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix batchsampler does not work correctly #20327

Merged
merged 6 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lightning/pytorch/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ def _is_dataloader_shuffled(dataloader: object) -> bool:
if not hasattr(dataloader, "sampler"):
# shuffling is enabled via a sampler. No sampler, no shuffling
return False
sampler = dataloader.sampler
batch_sampler = dataloader.batch_sampler
sampler = batch_sampler.sampler if batch_sampler is not None else dataloader.sampler
if isinstance(sampler, SequentialSampler):
return False
return isinstance(sampler, RandomSampler)
28 changes: 27 additions & 1 deletion tests/tests_pytorch/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.data import (
_get_dataloader_init_args_and_kwargs,
_is_dataloader_shuffled,
_update_dataloader,
extract_batch_size,
has_len_all_ranks,
Expand All @@ -20,7 +21,7 @@
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning_utilities.test.warning import no_warning_call
from torch import Tensor
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler


def test_extract_batch_size():
Expand Down Expand Up @@ -304,6 +305,31 @@ def __init__(self, extra_arg):
_ = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING)


def test_batch_sampler_shuffle_setting():
"""Test whether the `shuffle` state is correctly set in the `BatchSampler`."""

random_sampler = RandomSampler(range(10))
seq_sampler = SequentialSampler(range(10))
shuffled_dataloader = DataLoader(
range(10), batch_sampler=BatchSampler(random_sampler, batch_size=2, drop_last=False)
)
sequential_dataloader = DataLoader(
range(10), batch_sampler=BatchSampler(seq_sampler, batch_size=2, drop_last=False)
)

# if batch_size is 1, the pytorch init a default SequentialSampler and set BatchSampler to None
single_dataloader = DataLoader(range(10), batch_sampler=BatchSampler(seq_sampler, batch_size=1, drop_last=False))
assert _is_dataloader_shuffled(shuffled_dataloader)
assert not _is_dataloader_shuffled(sequential_dataloader)
assert not _is_dataloader_shuffled(single_dataloader)

# if batch_size is 1, and no batch_sampler is set, the pytorch will set BatchSampler to None
single_dataloader = DataLoader(range(10), batch_size=1)
shuffled_single_dataloader = DataLoader(range(10), batch_size=1, shuffle=True)
assert not _is_dataloader_shuffled(single_dataloader)
assert _is_dataloader_shuffled(shuffled_single_dataloader)


@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):
"""Test that DataLoader kwargs are not replaced when using Iterable Dataset."""
Expand Down
Loading