From de3e340f45c18092914203e81f9ee0eba05a50b3 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 16 Dec 2021 18:05:43 +0000 Subject: [PATCH 1/3] Fix val_split --- flash/core/data/data_module.py | 2 +- tests/core/data/test_data_module.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index f273d18383..7e97c2305d 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -127,7 +127,7 @@ def __init__( "A `val_dataset` was provided with `val_split`. Please, choose one or the other." ) - if self._train_input is not None and (val_split is not None and self._val_input is None): + if self._train_input and (val_split is not None and not self._val_input): self._train_input, self._val_input = self._split_train_val(self._train_input, val_split) if self._train_input: diff --git a/tests/core/data/test_data_module.py b/tests/core/data/test_data_module.py index a3a494143d..8cef9c25a4 100644 --- a/tests/core/data/test_data_module.py +++ b/tests/core/data/test_data_module.py @@ -439,3 +439,15 @@ def test_dataloaders_with_sampler(mock_dataloader): for dl in [datamodule.val_dataloader(), datamodule.test_dataloader()]: kwargs = mock_dataloader.call_args[1] assert "sampler" not in kwargs + + +def test_val_split(): + datamodule = DataModule( + Input(RunningStage.TRAINING, [1] * 100), + batch_size=2, + num_workers=0, + val_split=0.2, + ) + + assert len(datamodule.train_dataset) == 80 + assert len(datamodule.val_dataset) == 20 From 5232f5c18cf25d91b68963079815087a6ac8a2a9 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 16 Dec 2021 18:09:10 +0000 Subject: [PATCH 2/3] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71679f0c45..5cb78ee80b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where prediction would sometimes give the wrong number of outputs ([#1077](https://github.com/PyTorchLightning/lightning-flash/pull/1077)) +- Fixed a bug where passing the `val_split` to the `DataModule` would not have the desired effect ([#1079](https://github.com/PyTorchLightning/lightning-flash/pull/1079)) + ### Removed ## [0.6.0] - 2021-13-12 From 255d507c0e930d8fd19257ef40fde956fcd2c940 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 16 Dec 2021 18:21:07 +0000 Subject: [PATCH 3/3] Fixes --- flash/core/data/data_module.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 7e97c2305d..51299cf115 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -110,6 +110,14 @@ def __init__( self._test_input = test_input self._predict_input = predict_input + if self._train_input and self._val_input and isinstance(val_split, float) and val_split > 0: + raise MisconfigurationException( + "A `val_dataset` was provided with `val_split`. Please, choose one or the other." + ) + + if self._train_input and (val_split is not None and not self._val_input): + self._train_input, self._val_input = self._split_train_val(self._train_input, val_split) + self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher() self._train_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._train_input) @@ -122,14 +130,6 @@ def __init__( self._test_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._test_input) self._predict_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._predict_input) - if self._train_input and self._val_input and isinstance(val_split, float) and val_split > 0: - raise MisconfigurationException( - "A `val_dataset` was provided with `val_split`. Please, choose one or the other." - ) - - if self._train_input and (val_split is not None and not self._val_input): - self._train_input, self._val_input = self._split_train_val(self._train_input, val_split) - if self._train_input: self.train_dataloader = self._train_dataloader