Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix val_split #1079

Merged
merged 3 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where prediction would sometimes give the wrong number of outputs ([#1077](https://github.com/PyTorchLightning/lightning-flash/pull/1077))

- Fixed a bug where passing the `val_split` to the `DataModule` would not have the desired effect ([#1079](https://github.com/PyTorchLightning/lightning-flash/pull/1079))

### Removed

## [0.6.0] - 2021-13-12
Expand Down
16 changes: 8 additions & 8 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ def __init__(
self._test_input = test_input
self._predict_input = predict_input

if self._train_input and self._val_input and isinstance(val_split, float) and val_split > 0:
raise MisconfigurationException(
"A `val_dataset` was provided with `val_split`. Please, choose one or the other."
)

if self._train_input and (val_split is not None and not self._val_input):
self._train_input, self._val_input = self._split_train_val(self._train_input, val_split)

self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher()

self._train_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._train_input)
Expand All @@ -122,14 +130,6 @@ def __init__(
self._test_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._test_input)
self._predict_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._predict_input)

if self._train_input and self._val_input and isinstance(val_split, float) and val_split > 0:
raise MisconfigurationException(
"A `val_dataset` was provided with `val_split`. Please, choose one or the other."
)

if self._train_input is not None and (val_split is not None and self._val_input is None):
self._train_input, self._val_input = self._split_train_val(self._train_input, val_split)

if self._train_input:
self.train_dataloader = self._train_dataloader

Expand Down
12 changes: 12 additions & 0 deletions tests/core/data/test_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,15 @@ def test_dataloaders_with_sampler(mock_dataloader):
for dl in [datamodule.val_dataloader(), datamodule.test_dataloader()]:
kwargs = mock_dataloader.call_args[1]
assert "sampler" not in kwargs


def test_val_split():
datamodule = DataModule(
Input(RunningStage.TRAINING, [1] * 100),
batch_size=2,
num_workers=0,
val_split=0.2,
)

assert len(datamodule.train_dataset) == 80
assert len(datamodule.val_dataset) == 20