diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 9371212423..ef5e118cc3 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -460,9 +460,9 @@ def from_data_source( val_data: Any = None, test_data: Any = None, predict_data: Any = None, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, + train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, @@ -522,6 +522,7 @@ def from_data_source( }, ) """ + preprocess = preprocess or cls.preprocess_cls( train_transform, val_transform, diff --git a/flash/core/data/process.py b/flash/core/data/process.py index c2ad49c390..5ebb4d15b0 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -14,7 +14,7 @@ import inspect import os from abc import ABC, abstractclassmethod, abstractmethod -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union import torch from pytorch_lightning.trainer.states import RunningStage @@ -28,6 +28,7 @@ from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataKeys, DefaultDataSources from flash.core.data.properties import Properties from flash.core.data.states import CollateFn +from flash.core.data.transforms import ApplyToKeys from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext @@ -177,10 +178,10 @@ def pre_tensor_transform(self, sample: PIL.Image) -> PIL.Image: def __init__( self, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, + train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + predict_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, data_sources: Optional[Dict[str, "DataSource"]] = None, deserializer: Optional["Deserializer"] = None, default_data_source: Optional[str] = None, @@ -252,6 +253,11 @@ def _check_transforms( if transform is None: return transform + if isinstance(transform, list): + transform = {"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, torch.nn.Sequential(*transform))} + elif callable(transform): + transform = {"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, transform)} + if not isinstance(transform, Dict): raise MisconfigurationException( "Transform should be a dict. " f"Here are the available keys for your transforms: {_PREPROCESS_FUNCS}." @@ -439,10 +445,10 @@ def data_source_of_name(self, data_source_name: str) -> DataSource: class DefaultPreprocess(Preprocess): def __init__( self, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, + train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + predict_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, data_sources: Optional[Dict[str, "DataSource"]] = None, default_data_source: Optional[str] = None, ): diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index a9f1185dd4..f5f03fcd6a 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -120,9 +120,9 @@ def from_data_frame( predict_data_frame: Optional[pd.DataFrame] = None, predict_images_root: Optional[str] = None, predict_resolver: Optional[Callable[[str, str], str]] = None, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, + train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, @@ -217,9 +217,9 @@ def from_csv( predict_file: Optional[str] = None, predict_images_root: Optional[str] = None, predict_resolver: Optional[Callable[[str, str], str]] = None, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, + train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, diff --git a/tests/core/data/test_process.py b/tests/core/data/test_process.py index 509bbce3f8..61ab591591 100644 --- a/tests/core/data/test_process.py +++ b/tests/core/data/test_process.py @@ -88,7 +88,6 @@ class Serializer2State(ProcessState): def test_saving_with_serializers(tmpdir): - checkpoint_file = os.path.join(tmpdir, "tmp.ckpt") class CustomModel(Task): @@ -122,7 +121,6 @@ def __init__(self): def test_data_source_of_name(): - preprocess = CustomPreprocess() assert preprocess.data_source_of_name("test")() == "test" @@ -135,7 +133,6 @@ def test_data_source_of_name(): def test_available_data_sources(): - preprocess = CustomPreprocess() assert DefaultDataSources.TENSORS in preprocess.available_data_sources() @@ -147,3 +144,9 @@ def test_available_data_sources(): assert DefaultDataSources.TENSORS in data_module.available_data_sources() assert "test" in data_module.available_data_sources() assert len(data_module.available_data_sources()) == 3 + + +def test_check_transforms(): + transform = torch.nn.Identity() + DefaultPreprocess(train_transform=transform) + DefaultPreprocess(train_transform=[transform])