diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 7e97c2305d..51299cf115 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -110,6 +110,14 @@ def __init__( self._test_input = test_input self._predict_input = predict_input + if self._train_input and self._val_input and isinstance(val_split, float) and val_split > 0: + raise MisconfigurationException( + "A `val_dataset` was provided with `val_split`. Please, choose one or the other." + ) + + if self._train_input and (val_split is not None and not self._val_input): + self._train_input, self._val_input = self._split_train_val(self._train_input, val_split) + self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher() self._train_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._train_input) @@ -122,14 +130,6 @@ def __init__( self._test_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._test_input) self._predict_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._predict_input) - if self._train_input and self._val_input and isinstance(val_split, float) and val_split > 0: - raise MisconfigurationException( - "A `val_dataset` was provided with `val_split`. Please, choose one or the other." - ) - - if self._train_input and (val_split is not None and not self._val_input): - self._train_input, self._val_input = self._split_train_val(self._train_input, val_split) - if self._train_input: self.train_dataloader = self._train_dataloader