Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix overfit_batch sampler replacement logic #10486

Merged
merged 4 commits into from
Nov 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Squeeze the early stopping monitor to remove empty tensor dimensions ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461))


- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486))


-


-


Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,7 @@ def _reset_eval_dataloader(
for loader_i in range(len(dataloaders)):
loader = dataloaders[loader_i]

if hasattr(loader, "sampler") and isinstance(loader.sampler, RandomSampler):

if hasattr(loader, "sampler") and not isinstance(loader.sampler, SequentialSampler):
# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0 and mode.evaluating:
rank_zero_warn(
Expand Down Expand Up @@ -591,16 +590,17 @@ def _add_sampler_metadata_collate(dataloader: DataLoader) -> None:

@staticmethod
def _resolve_overfit_batches(dataloader: Collection[DataLoader]) -> Collection[DataLoader]:
has_random_sampler = False
all_have_sequential_sampler = True

def resolve_had_random_sampler(dataloader: DataLoader):
nonlocal has_random_sampler
if not has_random_sampler:
has_random_sampler = isinstance(dataloader.sampler, RandomSampler)
def resolve_has_no_sequential_sampler(dataloader: DataLoader):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
nonlocal all_have_sequential_sampler
all_have_sequential_sampler = all_have_sequential_sampler & isinstance(
dataloader.sampler, SequentialSampler
)

apply_to_collection(dataloader, DataLoader, resolve_had_random_sampler)
apply_to_collection(dataloader, DataLoader, resolve_has_no_sequential_sampler)

if has_random_sampler:
if not all_have_sequential_sampler:
rank_zero_warn(
"You requested to overfit but enabled training dataloader shuffling."
" We are turning off the training dataloader shuffling for you."
Expand Down
63 changes: 53 additions & 10 deletions tests/trainer/flags/test_overfit_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# limitations under the License.
import pytest
import torch
from torch.utils.data.sampler import Sampler, SequentialSampler

from pytorch_lightning import Trainer
from tests.helpers.boring_model import BoringModel, RandomDataset


def test_overfit_multiple_val_loaders(tmpdir):
"""Tests that only training_step can be used."""
"""Tests that overfit batches works with multiple val dataloaders."""
val_dl_count = 2
overfit_batches = 3

class TestModel(BoringModel):
def validation_step(self, batch, batch_idx, dataloader_idx):
Expand All @@ -31,25 +34,65 @@ def validation_epoch_end(self, outputs) -> None:
pass

def val_dataloader(self):
dl1 = torch.utils.data.DataLoader(RandomDataset(32, 64))
dl2 = torch.utils.data.DataLoader(RandomDataset(32, 64))
return [dl1, dl2]
dls = [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(val_dl_count)]
return dls

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir, max_epochs=2, overfit_batches=1, log_every_n_steps=1, enable_model_summary=False
default_root_dir=tmpdir,
max_epochs=2,
overfit_batches=overfit_batches,
log_every_n_steps=1,
enable_model_summary=False,
)

trainer.fit(model)
assert trainer.num_training_batches == overfit_batches
assert len(trainer.num_val_batches) == val_dl_count
assert all(nbatches == overfit_batches for nbatches in trainer.num_val_batches)


@pytest.mark.parametrize("overfit", [1, 2, 0.1, 0.25, 1.0])
def test_overfit_basic(tmpdir, overfit):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"""Tests that only training_step can be used."""
@pytest.mark.parametrize("overfit_batches", [1, 2, 0.1, 0.25, 1.0])
def test_overfit_basic(tmpdir, overfit_batches):
"""Tests that only training_step can be used when overfitting."""

model = BoringModel()
model.validation_step = None
total_train_samples = len(BoringModel().train_dataloader())

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=overfit, enable_model_summary=False)

trainer = Trainer(
default_root_dir=tmpdir, max_epochs=1, overfit_batches=overfit_batches, enable_model_summary=False
)
trainer.fit(model)

assert trainer.num_val_batches == []
assert trainer.num_training_batches == int(
overfit_batches * (1 if isinstance(overfit_batches, int) else total_train_samples)
)


def test_overfit_batches_raises_warning_in_case_of_sequential_sampler(tmpdir):
class NonSequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source

def __iter__(self):
return iter(range(len(self.data_source)))

def __len__(self):
return len(self.data_source)

class TestModel(BoringModel):
def train_dataloader(self):
dataset = RandomDataset(32, 64)
sampler = NonSequentialSampler(dataset)
return torch.utils.data.DataLoader(dataset, sampler=sampler)

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2)

with pytest.warns(UserWarning, match="requested to overfit but enabled training dataloader shuffling"):
trainer.fit(model)

assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)