From 4d00c34a636515ee25fc18e9ba35b469f5e90a89 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 16 Dec 2021 19:25:21 +0000 Subject: [PATCH] Fix val_split (#1079) --- CHANGELOG.md | 2 ++ flash/core/data/data_module.py | 16 ++++++++-------- tests/core/data/test_data_module.py | 12 ++++++++++++ 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71679f0c45..5cb78ee80b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where prediction would sometimes give the wrong number of outputs ([#1077](https://github.com/PyTorchLightning/lightning-flash/pull/1077)) +- Fixed a bug where passing the `val_split` to the `DataModule` would not have the desired effect ([#1079](https://github.com/PyTorchLightning/lightning-flash/pull/1079)) + ### Removed ## [0.6.0] - 2021-13-12 diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index f273d18383..51299cf115 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -110,6 +110,14 @@ def __init__( self._test_input = test_input self._predict_input = predict_input + if self._train_input and self._val_input and isinstance(val_split, float) and val_split > 0: + raise MisconfigurationException( + "A `val_dataset` was provided with `val_split`. Please, choose one or the other." + ) + + if self._train_input and (val_split is not None and not self._val_input): + self._train_input, self._val_input = self._split_train_val(self._train_input, val_split) + self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher() self._train_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._train_input) @@ -122,14 +130,6 @@ def __init__( self._test_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._test_input) self._predict_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._predict_input) - if self._train_input and self._val_input and isinstance(val_split, float) and val_split > 0: - raise MisconfigurationException( - "A `val_dataset` was provided with `val_split`. Please, choose one or the other." - ) - - if self._train_input is not None and (val_split is not None and self._val_input is None): - self._train_input, self._val_input = self._split_train_val(self._train_input, val_split) - if self._train_input: self.train_dataloader = self._train_dataloader diff --git a/tests/core/data/test_data_module.py b/tests/core/data/test_data_module.py index a3a494143d..8cef9c25a4 100644 --- a/tests/core/data/test_data_module.py +++ b/tests/core/data/test_data_module.py @@ -439,3 +439,15 @@ def test_dataloaders_with_sampler(mock_dataloader): for dl in [datamodule.val_dataloader(), datamodule.test_dataloader()]: kwargs = mock_dataloader.call_args[1] assert "sampler" not in kwargs + + +def test_val_split(): + datamodule = DataModule( + Input(RunningStage.TRAINING, [1] * 100), + batch_size=2, + num_workers=0, + val_split=0.2, + ) + + assert len(datamodule.train_dataset) == 80 + assert len(datamodule.val_dataset) == 20