Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix val_split (#1079)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Dec 16, 2021
1 parent f37e50d commit 4d00c34
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where prediction would sometimes give the wrong number of outputs ([#1077](https://github.com/PyTorchLightning/lightning-flash/pull/1077))

- Fixed a bug where passing the `val_split` to the `DataModule` would not have the desired effect ([#1079](https://github.com/PyTorchLightning/lightning-flash/pull/1079))

### Removed

## [0.6.0] - 2021-13-12
Expand Down
16 changes: 8 additions & 8 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ def __init__(
self._test_input = test_input
self._predict_input = predict_input

if self._train_input and self._val_input and isinstance(val_split, float) and val_split > 0:
raise MisconfigurationException(
"A `val_dataset` was provided with `val_split`. Please, choose one or the other."
)

if self._train_input and (val_split is not None and not self._val_input):
self._train_input, self._val_input = self._split_train_val(self._train_input, val_split)

self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher()

self._train_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._train_input)
Expand All @@ -122,14 +130,6 @@ def __init__(
self._test_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._test_input)
self._predict_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._predict_input)

if self._train_input and self._val_input and isinstance(val_split, float) and val_split > 0:
raise MisconfigurationException(
"A `val_dataset` was provided with `val_split`. Please, choose one or the other."
)

if self._train_input is not None and (val_split is not None and self._val_input is None):
self._train_input, self._val_input = self._split_train_val(self._train_input, val_split)

if self._train_input:
self.train_dataloader = self._train_dataloader

Expand Down
12 changes: 12 additions & 0 deletions tests/core/data/test_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,15 @@ def test_dataloaders_with_sampler(mock_dataloader):
for dl in [datamodule.val_dataloader(), datamodule.test_dataloader()]:
kwargs = mock_dataloader.call_args[1]
assert "sampler" not in kwargs


def test_val_split():
datamodule = DataModule(
Input(RunningStage.TRAINING, [1] * 100),
batch_size=2,
num_workers=0,
val_split=0.2,
)

assert len(datamodule.train_dataset) == 80
assert len(datamodule.val_dataset) == 20

0 comments on commit 4d00c34

Please sign in to comment.