Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Adding optional sampler input to DataModule #390

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added support for `torch.jit` to tasks where possible and documented task JIT compatibility ([#389](https://github.com/PyTorchLightning/lightning-flash/pull/389))
- Added option to provide a `Sampler` to the `DataModule` to use when creating a `DataLoader` ([#390](https://github.com/PyTorchLightning/lightning-flash/pull/390))

### Changed

Expand Down
34 changes: 33 additions & 1 deletion flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import IterableDataset, Subset
from torch.utils.data.sampler import Sampler

from flash.core.data.auto_dataset import BaseAutoDataset, IterableAutoDataset
from flash.core.data.base_viz import BaseVisualization
Expand Down Expand Up @@ -58,6 +59,8 @@ class DataModule(pl.LightningDataModule):
num_workers: The number of workers to use for parallelized loading.
Defaults to None which equals the number of available CPU threads,
or 0 for Windows or Darwin platform.
sampler: A sampler following the :class:`~torch.utils.data.sampler.Sampler` type.
Will be passed to the DataLoader for the training dataset. Defaults to None.
"""

preprocess_cls = DefaultPreprocess
Expand All @@ -76,6 +79,7 @@ def __init__(
val_split: Optional[float] = None,
batch_size: int = 1,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
) -> None:

super().__init__()
Expand Down Expand Up @@ -118,6 +122,7 @@ def __init__(
else:
num_workers = os.cpu_count()
self.num_workers = num_workers
self.sampler = sampler

self.set_running_stages()

Expand Down Expand Up @@ -259,11 +264,14 @@ def _resolve_collate_fn(self, dataset: Dataset, running_stage: RunningStage) ->

def _train_dataloader(self) -> DataLoader:
train_ds: Dataset = self._train_ds() if isinstance(self._train_ds, Callable) else self._train_ds
shuffle = not isinstance(train_ds, (IterableDataset, IterableAutoDataset))
shuffle: bool = False
if self.sampler is None:
shuffle = not isinstance(train_ds, (IterableDataset, IterableAutoDataset))
return DataLoader(
train_ds,
batch_size=self.batch_size,
shuffle=shuffle,
sampler=self.sampler,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
Expand Down Expand Up @@ -372,6 +380,7 @@ def from_data_source(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given inputs to
Expand Down Expand Up @@ -407,6 +416,7 @@ def from_data_source(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -451,6 +461,7 @@ def from_data_source(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
)

@classmethod
Expand All @@ -469,6 +480,7 @@ def from_folders(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the
Expand Down Expand Up @@ -497,6 +509,7 @@ def from_folders(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -527,6 +540,7 @@ def from_folders(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)

Expand All @@ -549,6 +563,7 @@ def from_files(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given sequences of files using
Expand Down Expand Up @@ -580,6 +595,7 @@ def from_files(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -611,6 +627,7 @@ def from_files(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)

Expand All @@ -633,6 +650,7 @@ def from_tensors(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given tensors using the
Expand Down Expand Up @@ -664,6 +682,7 @@ def from_tensors(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -695,6 +714,7 @@ def from_tensors(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)

Expand All @@ -717,6 +737,7 @@ def from_numpy(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given numpy array using the
Expand Down Expand Up @@ -748,6 +769,7 @@ def from_numpy(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -779,6 +801,7 @@ def from_numpy(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)

Expand All @@ -800,6 +823,7 @@ def from_json(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given JSON files using the
Expand Down Expand Up @@ -830,6 +854,7 @@ def from_json(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -862,6 +887,7 @@ def from_json(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)

Expand All @@ -883,6 +909,7 @@ def from_csv(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given CSV files using the
Expand Down Expand Up @@ -913,6 +940,7 @@ def from_csv(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -945,6 +973,7 @@ def from_csv(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)

Expand All @@ -964,6 +993,7 @@ def from_datasets(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given datasets using the
Expand Down Expand Up @@ -992,6 +1022,7 @@ def from_datasets(
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -1022,5 +1053,6 @@ def from_datasets(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)
32 changes: 32 additions & 0 deletions tests/core/data/test_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

from flash import DataModule


@mock.patch("flash.core.data.data_module.DataLoader")
def test_dataloaders_with_sampler(mock_dataloader):
train_ds = val_ds = test_ds = 'dataset'
mock_sampler = 'sampler'
dm = DataModule(train_ds, val_ds, test_ds, num_workers=0, sampler=mock_sampler)
assert dm.sampler is mock_sampler
dl = dm.train_dataloader()
kwargs = mock_dataloader.call_args[1]
assert 'sampler' in kwargs
assert kwargs['sampler'] is mock_sampler
for dl in [dm.val_dataloader(), dm.test_dataloader()]:
kwargs = mock_dataloader.call_args[1]
assert 'sampler' not in kwargs