Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix val_split #1079

Merged
merged 3 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where prediction would sometimes give the wrong number of outputs ([#1077](https://github.com/PyTorchLightning/lightning-flash/pull/1077))

- Fixed a bug where passing the `val_split` to the `DataModule` would not have the desired effect ([#1079](https://github.com/PyTorchLightning/lightning-flash/pull/1079))

### Removed

## [0.6.0] - 2021-13-12
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
"A `val_dataset` was provided with `val_split`. Please, choose one or the other."
)

if self._train_input is not None and (val_split is not None and self._val_input is None):
if self._train_input and (val_split is not None and not self._val_input):
self._train_input, self._val_input = self._split_train_val(self._train_input, val_split)

if self._train_input:
Expand Down
12 changes: 12 additions & 0 deletions tests/core/data/test_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,15 @@ def test_dataloaders_with_sampler(mock_dataloader):
for dl in [datamodule.val_dataloader(), datamodule.test_dataloader()]:
kwargs = mock_dataloader.call_args[1]
assert "sampler" not in kwargs


def test_val_split():
datamodule = DataModule(
Input(RunningStage.TRAINING, [1] * 100),
batch_size=2,
num_workers=0,
val_split=0.2,
)

assert len(datamodule.train_dataset) == 80
assert len(datamodule.val_dataset) == 20