From 6da53fe99b220edacf69ea1701ee082ce76ef184 Mon Sep 17 00:00:00 2001 From: karthikrangasai <39360170+karthikrangasai@users.noreply.github.com> Date: Sat, 26 Mar 2022 01:28:25 +0530 Subject: [PATCH] Refactor InputTransform and DataModule (#1233) Co-authored-by: Kushashwa Ravi Shrimali Co-authored-by: Ethan Harris --- README.md | 2 +- docs/source/integrations/icevision.rst | 4 +- .../source/reference/image_classification.rst | 2 +- docs/source/reference/object_detection.rst | 2 +- flash/audio/classification/data.py | 152 ++++------- flash/audio/speech_recognition/data.py | 94 +++---- flash/core/adapter.py | 23 +- flash/core/data/data_module.py | 249 +++++++++++------- flash/core/data/io/input.py | 64 +---- flash/core/data/io/input_transform.py | 219 ++++++++------- flash/core/data/io/output_transform.py | 3 +- flash/core/integrations/icevision/adapter.py | 43 ++- .../core/integrations/icevision/backbones.py | 26 +- flash/core/integrations/icevision/wrappers.py | 43 +++ flash/core/model.py | 47 +++- flash/core/serve/flash_components.py | 9 +- flash/graph/classification/data.py | 25 +- flash/image/classification/adapters.py | 51 ++-- flash/image/classification/data.py | 237 ++++++----------- flash/image/detection/backbones.py | 26 +- flash/image/detection/data.py | 110 +++----- flash/image/face_detection/data.py | 29 +- .../image/instance_segmentation/backbones.py | 10 +- flash/image/instance_segmentation/data.py | 62 ++--- flash/image/keypoint_detection/backbones.py | 8 +- flash/image/keypoint_detection/data.py | 50 ++-- flash/image/segmentation/data.py | 120 +++------ flash/image/style_transfer/data.py | 76 ++---- flash/pointcloud/detection/data.py | 45 ++-- flash/pointcloud/detection/model.py | 7 + flash/pointcloud/segmentation/data.py | 51 ++-- flash/pointcloud/segmentation/model.py | 7 + flash/tabular/classification/data.py | 46 ++-- flash/tabular/forecasting/data.py | 23 +- flash/tabular/regression/data.py | 46 ++-- flash/template/classification/data.py | 92 ++----- flash/text/classification/data.py | 169 +++++------- flash/text/question_answering/data.py | 92 +++---- flash/text/seq2seq/summarization/data.py | 95 +++---- flash/text/seq2seq/translation/data.py | 95 +++---- flash/video/classification/data.py | 139 +++------- flash/video/classification/input_transform.py | 13 +- .../flash_components/custom_data_loading.py | 165 ++++-------- .../image_classification_imagenette_mini.py | 4 +- tests/audio/classification/test_data.py | 7 +- tests/core/data/test_callback.py | 16 +- tests/core/data/test_callbacks.py | 9 +- tests/core/data/test_data_module.py | 125 ++++----- tests/core/data/test_data_pipeline.py | 3 +- tests/core/data/test_input_transform.py | 99 ++++--- tests/core/test_model.py | 3 +- tests/graph/classification/test_data.py | 5 +- tests/graph/classification/test_model.py | 18 +- tests/graph/embedding/test_model.py | 14 +- tests/image/classification/test_data.py | 7 +- tests/image/detection/test_model.py | 56 +++- tests/image/embedding/utils.py | 5 +- 57 files changed, 1386 insertions(+), 1856 deletions(-) create mode 100644 flash/core/integrations/icevision/wrappers.py diff --git a/README.md b/README.md index 506c4b0f48..a3d024c289 100644 --- a/README.md +++ b/README.md @@ -259,7 +259,7 @@ class MixUpInputTransform(InputTransform): datamodule = ImageClassificationData.from_folders( train_folder="data/train", - train_transform=MixUpInputTransform, + transform=MixUpInputTransform, batch_size=2, ) diff --git a/docs/source/integrations/icevision.rst b/docs/source/integrations/icevision.rst index bfb71356b2..28fabc0f18 100644 --- a/docs/source/integrations/icevision.rst +++ b/docs/source/integrations/icevision.rst @@ -34,11 +34,11 @@ Here's an example: from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter from flash.image import ObjectDetectionData - train_transform = { + transform = { "per_sample_transform": IceVisionTransformAdapter([A.HorizontalFlip(), A.Normalize()]), } datamodule = ObjectDetectionData.from_coco( ..., - train_transform=train_transform, + transform=transform, ) diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index 78f103cb90..59c7617acd 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -142,7 +142,7 @@ Here's an example: datamodule = ImageClassificationData.from_folders( train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", - train_transform=ImageClassificationInputTransform, + transform=ImageClassificationInputTransform, transform_kwargs=dict(image_size=(128, 128)), batch_size=1, ) diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index 0b3c14107e..9786d79e73 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -112,6 +112,6 @@ creating a subclass of :class:`~flash.core.data.io.input_transform.InputTransfor train_folder="data/coco128/images/train2017/", train_ann_file="data/coco128/annotations/instances_train2017.json", val_split=0.1, - train_transform=BrightnessContrastTransform, + transform=BrightnessContrastTransform, batch_size=4, ) diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py index e520410b14..1ae4ca868d 100644 --- a/flash/audio/classification/data.py +++ b/flash/audio/classification/data.py @@ -32,7 +32,6 @@ from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _AUDIO_TESTING from flash.core.utilities.stages import RunningStage from flash.image.classification.data import MatplotlibVisualization @@ -47,7 +46,6 @@ class AudioClassificationData(DataModule): classmethods for loading data for audio classification.""" input_transform_cls = AudioClassificationInputTransform - input_transforms_registry = FlashRegistry("input_transforms") @classmethod def from_files( @@ -59,13 +57,10 @@ def from_files( test_files: Optional[Sequence[str]] = None, test_targets: Optional[Sequence[Any]] = None, predict_files: Optional[Sequence[str]] = None, - train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationFilesInput, + transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, transform_kwargs: Optional[Dict] = None, + target_formatter: Optional[TargetFormatter] = None, **data_module_kwargs: Any, ) -> "AudioClassificationData": """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from lists of files and @@ -86,14 +81,10 @@ def from_files( test_files: The list of spectrogram image files to use when testing. test_targets: The list of targets to use when testing. predict_files: The list of spectrogram image files to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -147,18 +138,18 @@ def from_files( ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_files, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_files, val_targets, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_files, test_targets, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_files, val_targets, **ds_kw), + input_cls(RunningStage.TESTING, test_files, test_targets, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -169,13 +160,10 @@ def from_folders( val_folder: Optional[str] = None, test_folder: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationFolderInput, + transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, transform_kwargs: Optional[Dict] = None, + target_formatter: Optional[TargetFormatter] = None, **data_module_kwargs: Any, ) -> "AudioClassificationData": """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from folders containing @@ -215,14 +203,10 @@ def from_folders( val_folder: The folder containing spectrogram images to use when validating. test_folder: The folder containing spectrogram images to use when testing. predict_folder: The folder containing spectrogram images to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -279,18 +263,18 @@ def from_folders( ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_folder, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_folder, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_folder, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_folder, **ds_kw), + input_cls(RunningStage.TESTING, test_folder, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_folder, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -304,13 +288,10 @@ def from_numpy( test_data: Optional[Collection[np.ndarray]] = None, test_targets: Optional[Sequence[Any]] = None, predict_data: Optional[Collection[np.ndarray]] = None, - train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationNumpyInput, + transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, transform_kwargs: Optional[Dict] = None, + target_formatter: Optional[TargetFormatter] = None, **data_module_kwargs: Any, ) -> "AudioClassificationData": """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from numpy arrays (or lists @@ -329,14 +310,10 @@ def from_numpy( test_data: The numpy array or list of arrays to use when testing. test_targets: The list of targets to use when testing. predict_data: The numpy array or list of arrays to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -374,18 +351,18 @@ def from_numpy( ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data, val_targets, **ds_kw), + input_cls(RunningStage.TESTING, test_data, test_targets, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -399,13 +376,10 @@ def from_tensors( test_data: Optional[Collection[torch.Tensor]] = None, test_targets: Optional[Sequence[Any]] = None, predict_data: Optional[Collection[torch.Tensor]] = None, - train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationTensorInput, + transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, transform_kwargs: Optional[Dict] = None, + target_formatter: Optional[TargetFormatter] = None, **data_module_kwargs: Any, ) -> "AudioClassificationData": """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from torch tensors (or lists @@ -424,14 +398,10 @@ def from_tensors( test_data: The torch tensor or list of tensors to use when testing. test_targets: The list of targets to use when testing. predict_data: The torch tensor or list of tensors to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -469,18 +439,18 @@ def from_tensors( ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data, val_targets, **ds_kw), + input_cls(RunningStage.TESTING, test_data, test_targets, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -501,13 +471,10 @@ def from_data_frame( predict_data_frame: Optional[pd.DataFrame] = None, predict_images_root: Optional[str] = None, predict_resolver: Optional[Callable[[str, str], str]] = None, - train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationDataFrameInput, + transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, transform_kwargs: Optional[Dict] = None, + target_formatter: Optional[TargetFormatter] = None, **data_module_kwargs: Any, ) -> "AudioClassificationData": """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from pandas DataFrame objects @@ -540,14 +507,10 @@ def from_data_frame( predict_images_root: The root directory containing predict spectrogram images. predict_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a spectrogram image file path. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -623,8 +586,6 @@ def from_data_frame( ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) train_data = (train_data_frame, input_field, target_fields, train_images_root, train_resolver) @@ -632,14 +593,16 @@ def from_data_frame( test_data = (test_data_frame, input_field, target_fields, test_images_root, test_resolver) predict_data = (predict_data_frame, input_field, None, predict_images_root, predict_resolver) - train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, *train_data, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, *val_data, **ds_kw), + input_cls(RunningStage.TESTING, *test_data, **ds_kw), + input_cls(RunningStage.PREDICTING, *predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -660,13 +623,10 @@ def from_csv( predict_file: Optional[str] = None, predict_images_root: Optional[str] = None, predict_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, - train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, - target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationCSVInput, + transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, transform_kwargs: Optional[Dict] = None, + target_formatter: Optional[TargetFormatter] = None, **data_module_kwargs: Any, ) -> "AudioClassificationData": """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from CSV files containing @@ -699,14 +659,10 @@ def from_csv( predict_images_root: The root directory containing predict spectrogram images. predict_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a spectrogram image file path. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -792,8 +748,6 @@ def from_csv( ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) train_data = (train_file, input_field, target_fields, train_images_root, train_resolver) @@ -801,14 +755,16 @@ def from_csv( test_data = (test_file, input_field, target_fields, test_images_root, test_resolver) predict_data = (predict_file, input_field, None, predict_images_root, predict_resolver) - train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, *train_data, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, *val_data, **ds_kw), + input_cls(RunningStage.TESTING, *test_data, **ds_kw), + input_cls(RunningStage.PREDICTING, *predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py index 0b44c4ac26..e7d94df21d 100644 --- a/flash/audio/speech_recognition/data.py +++ b/flash/audio/speech_recognition/data.py @@ -25,7 +25,6 @@ from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform -from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _AUDIO_TESTING from flash.core.utilities.stages import RunningStage @@ -40,7 +39,6 @@ class SpeechRecognitionData(DataModule): input_transform_cls = InputTransform output_transform_cls = SpeechRecognitionOutputTransform - input_transforms_registry = FlashRegistry("input_transforms") @classmethod def from_files( @@ -53,11 +51,8 @@ def from_files( test_targets: Optional[Sequence[str]] = None, predict_files: Optional[Sequence[str]] = None, sampling_rate: int = 16000, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = SpeechRecognitionPathsInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SpeechRecognitionData": @@ -77,12 +72,8 @@ def from_files( test_targets: The list of targets (ground truth speech transcripts) to use when testing. predict_files: The list of audio files to use when predicting. sampling_rate: Sampling rate to use when loading the audio files. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -127,16 +118,16 @@ def from_files( """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, sampling_rate=sampling_rate, ) return cls( - input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_files, val_targets, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_files, test_targets, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_files, train_targets, **ds_kw), + input_cls(RunningStage.VALIDATING, val_files, val_targets, **ds_kw), + input_cls(RunningStage.TESTING, test_files, test_targets, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -150,11 +141,8 @@ def from_csv( test_file: Optional[str] = None, predict_file: Optional[str] = None, sampling_rate: int = 16000, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = SpeechRecognitionCSVInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SpeechRecognitionData": @@ -175,12 +163,8 @@ def from_csv( test_file: The CSV file to use when testing. predict_file: The CSV file to use when predicting. sampling_rate: Sampling rate to use when loading the audio files. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -255,17 +239,17 @@ def from_csv( """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, input_key=input_field, sampling_rate=sampling_rate, ) return cls( - input_cls(RunningStage.TRAINING, train_file, transform=train_transform, target_key=target_field, **ds_kw), - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, target_key=target_field, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, target_key=target_field, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_file, target_key=target_field, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, target_key=target_field, **ds_kw), + input_cls(RunningStage.TESTING, test_file, target_key=target_field, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_file, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -280,11 +264,8 @@ def from_json( predict_file: Optional[str] = None, sampling_rate: int = 16000, field: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = SpeechRecognitionJSONInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SpeechRecognitionData": @@ -306,12 +287,8 @@ def from_json( predict_file: The JSON file to use when predicting. sampling_rate: Sampling rate to use when loading the audio files. field: The field that holds the data in the JSON file. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -384,18 +361,18 @@ def from_json( """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, input_key=input_field, sampling_rate=sampling_rate, field=field, ) return cls( - input_cls(RunningStage.TRAINING, train_file, transform=train_transform, target_key=target_field, **ds_kw), - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, target_key=target_field, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, target_key=target_field, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_file, target_key=target_field, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, target_key=target_field, **ds_kw), + input_cls(RunningStage.TESTING, test_file, target_key=target_field, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_file, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -406,12 +383,9 @@ def from_datasets( val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, predict_dataset: Optional[Dataset] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, sampling_rate: int = 16000, input_cls: Type[Input] = SpeechRecognitionDatasetInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SpeechRecognitionData": @@ -433,12 +407,8 @@ def from_datasets( test_dataset: The Dataset to use when testing. predict_dataset: The Dataset to use when predicting. sampling_rate: Sampling rate to use when loading the audio files. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -540,15 +510,15 @@ def from_datasets( """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, sampling_rate=sampling_rate, ) return cls( - input_cls(RunningStage.TRAINING, train_dataset, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_dataset, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_dataset, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_dataset, **ds_kw), + input_cls(RunningStage.VALIDATING, val_dataset, **ds_kw), + input_cls(RunningStage.TESTING, test_dataset, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_dataset, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/core/adapter.py b/flash/core/adapter.py index 1e3ca81e92..433fdb74cb 100644 --- a/flash/core/adapter.py +++ b/flash/core/adapter.py @@ -20,6 +20,7 @@ import flash from flash.core.data.io.input import InputBase +from flash.core.data.io.input_transform import InputTransform from flash.core.model import DatasetProcessor, ModuleWrapperBase, Task from flash.core.utilities.types import INPUT_TRANSFORM_TYPE @@ -121,6 +122,7 @@ def process_train_dataset( self, dataset: InputBase, trainer: "flash.Trainer", + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -131,8 +133,9 @@ def process_train_dataset( persistent_workers: bool = False, ) -> DataLoader: return self.adapter.process_train_dataset( - dataset, - trainer, + dataset=dataset, + trainer=trainer, + input_transform=input_transform, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, @@ -147,6 +150,7 @@ def process_val_dataset( self, dataset: InputBase, trainer: "flash.Trainer", + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -157,8 +161,9 @@ def process_val_dataset( persistent_workers: bool = False, ) -> DataLoader: return self.adapter.process_val_dataset( - dataset, - trainer, + dataset=dataset, + trainer=trainer, + input_transform=input_transform, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, @@ -173,6 +178,7 @@ def process_test_dataset( self, dataset: InputBase, trainer: "flash.Trainer", + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -183,8 +189,9 @@ def process_test_dataset( persistent_workers: bool = False, ) -> DataLoader: return self.adapter.process_test_dataset( - dataset, - trainer, + dataset=dataset, + trainer=trainer, + input_transform=input_transform, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, @@ -198,6 +205,7 @@ def process_test_dataset( def process_predict_dataset( self, dataset: InputBase, + input_transform: InputTransform, batch_size: int = 1, num_workers: int = 0, pin_memory: bool = False, @@ -208,7 +216,8 @@ def process_predict_dataset( persistent_workers: bool = False, ) -> DataLoader: return self.adapter.process_predict_dataset( - dataset, + dataset=dataset, + input_transform=input_transform, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index c7ce317471..ebc95df9d0 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union import numpy as np import pytorch_lightning as pl import torch +from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader, Dataset -from torch.utils.data._utils.collate import default_collate from torch.utils.data.dataset import IterableDataset from torch.utils.data.sampler import Sampler @@ -28,16 +29,17 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.io.input import DataKeys, Input, InputBase, IterableInput from flash.core.data.io.input_transform import ( - _create_collate_input_transform_processors, - _InputTransformProcessorV2, - create_transform, + _InputTransformProcessor, + create_device_input_transform_processor, + create_or_configure_input_transform, + create_worker_input_transform_processor, InputTransform, ) -from flash.core.data.io.output_transform import OutputTransform from flash.core.data.splits import SplitDataset from flash.core.data.utils import _STAGES_PREFIX from flash.core.registry import FlashRegistry from flash.core.utilities.stages import RunningStage +from flash.core.utilities.types import INPUT_TRANSFORM_TYPE class DatasetInput(Input): @@ -63,6 +65,8 @@ class DataModule(pl.LightningDataModule): data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to attach to the :class:`~flash.core.data.io.input_transform.InputTransform`. If ``None``, the output from :meth:`~flash.core.data.data_module.DataModule.configure_data_fetcher` will be used. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. val_split: An optional float which gives the relative amount of the training dataset to use for the validation dataset. batch_size: The batch size to be used by the DataLoader. @@ -103,7 +107,7 @@ class DataModule(pl.LightningDataModule): """ input_transform_cls = InputTransform - input_transforms_registry: Optional[FlashRegistry] = None + input_transforms_registry = FlashRegistry("input_transforms") def __init__( self, @@ -112,6 +116,8 @@ def __init__( test_input: Optional[Input] = None, predict_input: Optional[Input] = None, data_fetcher: Optional[BaseDataFetcher] = None, + transform: INPUT_TRANSFORM_TYPE = InputTransform, + transform_kwargs: Optional[Dict] = None, val_split: Optional[float] = None, batch_size: Optional[int] = None, num_workers: int = 0, @@ -126,8 +132,11 @@ def __init__( if flash._IS_TESTING and torch.cuda.is_available(): batch_size = 16 - self._input_transform: Optional[OutputTransform] = None - self._viz: Optional[BaseVisualization] = None + self.input_transform = DataModule.configure_input_transform( + transform=transform, transform_kwargs=transform_kwargs + ) + + self.viz: Optional[BaseVisualization] = None self._train_input = train_input self._val_input = val_input @@ -144,17 +153,17 @@ def __init__( self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher() - self._train_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._train_input) - self._val_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._val_input) - self._test_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._test_input) - self._predict_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._predict_input) + self._train_dataloader_collate_fn = self._resolve_dataloader_collate_fn(RunningStage.TRAINING) + self._val_dataloader_collate_fn = self._resolve_dataloader_collate_fn(RunningStage.VALIDATING) + self._test_dataloader_collate_fn = self._resolve_dataloader_collate_fn(RunningStage.TESTING) + self._predict_dataloader_collate_fn = self._resolve_dataloader_collate_fn(RunningStage.PREDICTING) self._on_after_batch_transfer_fns = { - RunningStage.TRAINING: self._resolve_on_after_batch_transfer_fn(self._train_input), - RunningStage.VALIDATING: self._resolve_on_after_batch_transfer_fn(self._val_input), - RunningStage.SANITY_CHECKING: self._resolve_on_after_batch_transfer_fn(self._val_input), - RunningStage.TESTING: self._resolve_on_after_batch_transfer_fn(self._test_input), - RunningStage.PREDICTING: self._resolve_on_after_batch_transfer_fn(self._predict_input), + RunningStage.TRAINING: self._resolve_on_after_batch_transfer_fn(RunningStage.TRAINING), + RunningStage.VALIDATING: self._resolve_on_after_batch_transfer_fn(RunningStage.VALIDATING), + RunningStage.SANITY_CHECKING: self._resolve_on_after_batch_transfer_fn(RunningStage.VALIDATING), + RunningStage.TESTING: self._resolve_on_after_batch_transfer_fn(RunningStage.TESTING), + RunningStage.PREDICTING: self._resolve_on_after_batch_transfer_fn(RunningStage.PREDICTING), } self._model_on_after_batch_transfer_fns = None @@ -200,33 +209,28 @@ def predict_dataset(self) -> Optional[Input]: """This property returns the prediction dataset.""" return self._predict_input - def _resolve_dataloader_collate_fn(self, ds: Optional[Input]) -> Optional[Callable]: - if not ds: - return None - if isinstance(ds.transform, InputTransform): - return ds._create_dataloader_collate_fn([self.data_fetcher]) - return default_collate + ##################################### + # METHODS PERTAINING TO DATALOADERS # + ##################################### - def _resolve_on_after_batch_transfer_fn(self, ds: Optional[Input]) -> Optional[Callable]: - if not ds: - return None - if isinstance(ds.transform, InputTransform): - return ds._create_on_after_batch_transfer_fn([self.data_fetcher]) + def _resolve_dataloader_collate_fn(self, stage: RunningStage) -> _InputTransformProcessor: + return create_worker_input_transform_processor(stage, self.input_transform, [self.data_fetcher]) + + def _resolve_on_after_batch_transfer_fn(self, stage: RunningStage) -> _InputTransformProcessor: + return create_device_input_transform_processor(stage, self.input_transform, [self.data_fetcher]) def _train_dataloader(self) -> DataLoader: train_ds: Input = self._train_input - collate_fn = self._train_dataloader_collate_fn + transform_processor = self._train_dataloader_collate_fn if isinstance(getattr(self, "trainer", None), pl.Trainer): input_transform = getattr(self.trainer.lightning_module, "input_transform", None) if input_transform is not None: - input_transform = create_transform(input_transform, RunningStage.TRAINING) - collate_fn = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[0] - - transform_processor = None - if isinstance(collate_fn, _InputTransformProcessorV2): - transform_processor = collate_fn - collate_fn = transform_processor.collate_fn + input_transform = create_or_configure_input_transform(input_transform, RunningStage.TRAINING) + transform_processor = create_worker_input_transform_processor( + RunningStage.TRAINING, input_transform, [self.data_fetcher] + ) + self.input_transform = input_transform shuffle: bool = False if isinstance(train_ds, IterableDataset): @@ -242,16 +246,24 @@ def _train_dataloader(self) -> DataLoader: else: sampler = self.sampler + # `transform_processor` is an _InputTransformProcessor object + # Use the `transform_processor` object directly as the collate_fn for the DataLoader. + + # `self.input_transform` is an InputTransform object + # Inject the `self.collate_fn` returned by the model into the `transforms` dict of the `input_transform` object + # through the process_train_dataset method of the model. + if isinstance(getattr(self, "trainer", None), pl.Trainer): dataloader = self.trainer.lightning_module.process_train_dataset( train_ds, trainer=self.trainer, + input_transform=self.input_transform, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, shuffle=shuffle, drop_last=drop_last, - collate_fn=collate_fn, + collate_fn=transform_processor, sampler=sampler, persistent_workers=self.persistent_workers, ) @@ -264,40 +276,42 @@ def _train_dataloader(self) -> DataLoader: num_workers=self.num_workers, pin_memory=self.pin_memory, drop_last=drop_last, - collate_fn=collate_fn, + collate_fn=transform_processor, persistent_workers=self.persistent_workers, ) - if transform_processor is not None: - transform_processor.collate_fn = dataloader.collate_fn - dataloader.collate_fn = transform_processor - self._model_on_after_batch_transfer_fns = None return dataloader def _val_dataloader(self) -> DataLoader: val_ds: Input = self._val_input - collate_fn = self._val_dataloader_collate_fn + transform_processor = self._val_dataloader_collate_fn if isinstance(getattr(self, "trainer", None), pl.Trainer): input_transform = getattr(self.trainer.lightning_module, "input_transform", None) if input_transform is not None: - input_transform = create_transform(input_transform, RunningStage.VALIDATING) - collate_fn = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[0] + input_transform = create_or_configure_input_transform(input_transform, RunningStage.VALIDATING) + transform_processor = create_worker_input_transform_processor( + RunningStage.VALIDATING, input_transform, [self.data_fetcher] + ) + self.input_transform = input_transform + + # `transform_processor` is an _InputTransformProcessor object + # Use the `transform_processor` object directly as the collate_fn for the DataLoader. - transform_processor = None - if isinstance(collate_fn, _InputTransformProcessorV2): - transform_processor = collate_fn - collate_fn = transform_processor.collate_fn + # `self.input_transform` is an InputTransform object + # Inject the `self.collate_fn` returned by the model into the `transforms` dict of the `input_transform` object + # through the process_train_dataset method of the model. if isinstance(getattr(self, "trainer", None), pl.Trainer): dataloader = self.trainer.lightning_module.process_val_dataset( val_ds, trainer=self.trainer, + input_transform=self.input_transform, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, - collate_fn=collate_fn, + collate_fn=transform_processor, persistent_workers=self.persistent_workers, ) else: @@ -306,40 +320,42 @@ def _val_dataloader(self) -> DataLoader: batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, - collate_fn=collate_fn, + collate_fn=transform_processor, persistent_workers=self.persistent_workers, ) - if transform_processor is not None: - transform_processor.collate_fn = dataloader.collate_fn - dataloader.collate_fn = transform_processor - self._model_on_after_batch_transfer_fns = None return dataloader def _test_dataloader(self) -> DataLoader: test_ds: Input = self._test_input - collate_fn = self._test_dataloader_collate_fn + transform_processor = self._test_dataloader_collate_fn if isinstance(getattr(self, "trainer", None), pl.Trainer): input_transform = getattr(self.trainer.lightning_module, "input_transform", None) if input_transform is not None: - input_transform = create_transform(input_transform, RunningStage.TESTING) - collate_fn = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[0] + input_transform = create_or_configure_input_transform(input_transform, RunningStage.TESTING) + transform_processor = create_worker_input_transform_processor( + RunningStage.TESTING, input_transform, [self.data_fetcher] + ) + self.input_transform = input_transform + + # `transform_processor` is an _InputTransformProcessor object + # Use the `transform_processor` object directly as the collate_fn for the DataLoader. - transform_processor = None - if isinstance(collate_fn, _InputTransformProcessorV2): - transform_processor = collate_fn - collate_fn = transform_processor.collate_fn + # `self.input_transform` is an InputTransform object + # Inject the `self.collate_fn` returned by the model into the `transforms` dict of the `input_transform` object + # through the process_train_dataset method of the model. if isinstance(getattr(self, "trainer", None), pl.Trainer): dataloader = self.trainer.lightning_module.process_test_dataset( test_ds, trainer=self.trainer, + input_transform=self.input_transform, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, - collate_fn=collate_fn, + collate_fn=transform_processor, persistent_workers=self.persistent_workers, ) else: @@ -348,44 +364,46 @@ def _test_dataloader(self) -> DataLoader: batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, - collate_fn=collate_fn, + collate_fn=transform_processor, persistent_workers=self.persistent_workers, ) - if transform_processor is not None: - transform_processor.collate_fn = dataloader.collate_fn - dataloader.collate_fn = transform_processor - self._model_on_after_batch_transfer_fns = None return dataloader def _predict_dataloader(self) -> DataLoader: predict_ds: Input = self._predict_input - collate_fn = self._predict_dataloader_collate_fn + transform_processor = self._predict_dataloader_collate_fn if isinstance(getattr(self, "trainer", None), pl.Trainer): input_transform = getattr(self.trainer.lightning_module, "input_transform", None) if input_transform is not None: - input_transform = create_transform(input_transform, RunningStage.PREDICTING) - collate_fn = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[0] - - transform_processor = None - if isinstance(collate_fn, _InputTransformProcessorV2): - transform_processor = collate_fn - collate_fn = transform_processor.collate_fn + input_transform = create_or_configure_input_transform(input_transform, RunningStage.PREDICTING) + transform_processor = create_worker_input_transform_processor( + RunningStage.PREDICTING, input_transform, [self.data_fetcher] + ) + self.input_transform = input_transform if isinstance(predict_ds, IterableDataset): batch_size = self.batch_size else: batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1) + # `transform_processor` is an _InputTransformProcessor object + # Use the `transform_processor` object directly as the collate_fn for the DataLoader. + + # `self.input_transform` is an InputTransform object + # Inject the `self.collate_fn` returned by the model into the `transforms` dict of the `input_transform` object + # through the process_train_dataset method of the model. + if isinstance(getattr(self, "trainer", None), pl.Trainer): dataloader = self.trainer.lightning_module.process_predict_dataset( predict_ds, + input_transform=self.input_transform, batch_size=batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, - collate_fn=collate_fn, + collate_fn=transform_processor, persistent_workers=self.persistent_workers, ) else: @@ -394,17 +412,17 @@ def _predict_dataloader(self) -> DataLoader: batch_size=batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, - collate_fn=collate_fn, + collate_fn=transform_processor, persistent_workers=self.persistent_workers, ) - if transform_processor is not None: - transform_processor.collate_fn = dataloader.collate_fn - dataloader.collate_fn = transform_processor - self._model_on_after_batch_transfer_fns = None return dataloader + ############################################################ + # METHODS RELATED TO on_after_batch_transfer FUNCTIONALITY # + ############################################################ + def _load_model_on_after_batch_transfer_fns(self) -> None: self._model_on_after_batch_transfer_fns = {} @@ -419,10 +437,14 @@ def _load_model_on_after_batch_transfer_fns(self) -> None: if isinstance(getattr(self, "trainer", None), pl.Trainer): input_transform = getattr(self.trainer.lightning_module, "input_transform", None) if input_transform is not None: - input_transform = create_transform( + input_transform = create_or_configure_input_transform( input_transform, stage if stage != RunningStage.SANITY_CHECKING else RunningStage.VALIDATING ) - transform = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[1] + transform = create_device_input_transform_processor( + stage if stage != RunningStage.SANITY_CHECKING else RunningStage.VALIDATING, + input_transform, + [self.data_fetcher], + ) self._model_on_after_batch_transfer_fns[stage] = transform def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: @@ -442,13 +464,9 @@ def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: batch = transform(batch) return batch - @property - def viz(self) -> BaseVisualization: - return self._viz or DataModule.configure_data_fetcher() - - @viz.setter - def viz(self, viz: BaseVisualization) -> None: - self._viz = viz + ################################### + # METHODS RELATED TO DATA FETCHER # + ################################### @staticmethod def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: @@ -467,6 +485,55 @@ def data_fetcher(self) -> BaseDataFetcher: def data_fetcher(self, data_fetcher: BaseDataFetcher) -> None: self._data_fetcher = data_fetcher + ###################################### + # METHODS RELATED TO INPUT TRANSFORM # + ###################################### + + @property + def input_transform(self) -> InputTransform: + """This property returns the data fetcher.""" + return self._input_transform + + @input_transform.setter + def input_transform(self, input_transform: InputTransform) -> None: + self._input_transform = input_transform + + @staticmethod + def configure_input_transform( + transform: INPUT_TRANSFORM_TYPE, transform_kwargs: Optional[Dict] = None + ) -> InputTransform: + """This function is used to configure a :class:`~flash.core.data.io.input_transform.InputTransform`. + + Override with your custom one. + """ + return create_or_configure_input_transform( + transform=transform, + transform_kwargs=transform_kwargs, + input_transforms_registry=DataModule.input_transforms_registry, + ) + + @classmethod + def register_input_transform( + cls, enum: Union[LightningEnum, str], fn: Union[Type["flash.InputTransform"], partial] + ) -> None: + if cls.input_transforms_registry is None: + raise MisconfigurationException( + "The class attribute `input_transforms_registry` should be set as a class attribute. " + ) + cls.input_transforms_registry(fn=fn, name=enum) + + #################################### + # METHODS RELATED TO VISUALIZATION # + #################################### + + @property + def viz(self) -> BaseVisualization: + return self._viz or DataModule.configure_data_fetcher() + + @viz.setter + def viz(self, viz: BaseVisualization) -> None: + self._viz = viz + def _reset_iterator(self, stage: str) -> Iterable[Any]: iter_name = f"_{stage}_iter" # num_workers has to be set to 0 to work properly diff --git a/flash/core/data/io/input.py b/flash/core/data/io/input.py index 3fb6468855..506f937884 100644 --- a/flash/core/data/io/input.py +++ b/flash/core/data/io/input.py @@ -15,20 +15,15 @@ import os import sys from copy import deepcopy -from functools import partial -from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, cast, Dict, Iterable, List, Sequence, Tuple, Union from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Dataset -import flash -from flash.core.data.callback import FlashCallback from flash.core.data.properties import Properties from flash.core.data.utils import _STAGES_PREFIX -from flash.core.registry import FlashRegistry from flash.core.utilities.stages import RunningStage -from flash.core.utilities.types import INPUT_TRANSFORM_TYPE if sys.version_info < (3, 7): from typing import GenericMeta @@ -166,45 +161,14 @@ class InputBase(Properties, metaclass=_InputMeta): **kwargs: Any additional keyword arguments to pass to the ``load_data`` hook. """ - input_transforms_registry = FlashRegistry("input_transforms") - - def __init__( - self, - running_stage: RunningStage, - *args: Any, - transform: INPUT_TRANSFORM_TYPE = None, - transform_kwargs: Optional[Dict] = None, - input_transforms_registry: Optional[FlashRegistry] = None, - **kwargs: Any, - ) -> None: - from flash.core.data.io.input_transform import create_transform - - self.transform = create_transform( - transform, - running_stage, - input_transforms_registry or self.input_transforms_registry, - transform_kwargs, - ) + def __init__(self, running_stage: RunningStage, *args: Any, **kwargs: Any) -> None: + super().__init__(running_stage=running_stage) self.data = None if len(args) >= 1 and args[0] is not None: self.data = getattr(self, f"{_STAGES_PREFIX[running_stage]}_load_data")(*args, **kwargs) - def _create_dataloader_collate_fn(self, callbacks: List[FlashCallback]) -> Optional[Callable]: - from flash.core.data.io.input_transform import _create_collate_input_transform_processors - - if not self.transform: - return - return _create_collate_input_transform_processors(self.transform, callbacks)[0] - - def _create_on_after_batch_transfer_fn(self, callbacks: List[FlashCallback]) -> Optional[Callable]: - from flash.core.data.io.input_transform import _create_collate_input_transform_processors - - if not self.transform: - return - return _create_collate_input_transform_processors(self.transform, callbacks)[1] - def _call_load_sample(self, sample: Any) -> Any: # Deepcopy the sample to avoid leaks with complex data structures return getattr(self, f"{_STAGES_PREFIX[self.running_stage]}_load_sample")(deepcopy(sample)) @@ -307,16 +271,6 @@ def __bool__(self): """ return self.data is not None - @classmethod - def register_input_transform( - cls, enum: Union[LightningEnum, str], fn: Union[Type["flash.InputTransform"], partial] - ) -> None: - if cls.input_transforms_registry is None: - raise MisconfigurationException( - "The class attribute `input_transforms_registry` should be set as a class attribute. " - ) - cls.input_transforms_registry(fn=fn, name=enum) - class Input(InputBase, Dataset): def __getitem__(self, index: int) -> Any: @@ -336,19 +290,11 @@ def __next__(self) -> Any: class ServeInput(Input): - def __init__( - self, - transform: INPUT_TRANSFORM_TYPE = None, - transform_kwargs: Optional[Dict] = None, - ) -> None: + def __init__(self) -> None: if hasattr(self, "serve_load_data"): raise MisconfigurationException("`serve_load_data` shouldn't be implemented.") - super().__init__( - RunningStage.SERVING, - transform=transform, - transform_kwargs=transform_kwargs, - ) + super().__init__(RunningStage.SERVING) def serve_load_sample(self, sample: Any) -> List[Any]: raise NotImplementedError diff --git a/flash/core/data/io/input_transform.py b/flash/core/data/io/input_transform.py index 62bbce2258..d40bec94c5 100644 --- a/flash/core/data/io/input_transform.py +++ b/flash/core/data/io/input_transform.py @@ -13,16 +13,16 @@ # limitations under the License. import inspect from dataclasses import dataclass -from functools import partial, wraps +from functools import partial from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.warnings import rank_zero_warn from torch.utils.data._utils.collate import default_collate from flash.core.data.callback import ControlFlow, FlashCallback from flash.core.data.io.input import DataKeys -from flash.core.data.properties import Properties from flash.core.data.transforms import ApplyToKeys from flash.core.data.utils import _INPUT_TRANSFORM_FUNCS, _STAGES_PREFIX from flash.core.registry import FlashRegistry @@ -45,15 +45,7 @@ class ApplyToKeyPrefix(LightningEnum): TARGET = "target" -def transform_context(func: Callable, current_fn: str) -> Callable: - @wraps(func) - def wrapper(self, *args, **kwargs) -> Any: - self.current_fn = current_fn - result = func(self, *args, **kwargs) - self.current_fn = None - return result - - return wrapper +INVALID_STAGES_FOR_INPUT_TRANSFORMS = [RunningStage.SANITY_CHECKING, RunningStage.TUNING] # Credit to Torchvision Team: @@ -82,32 +74,38 @@ def __repr__(self): @dataclass -class InputTransform(Properties): +class _InputTransformPerStage: + collate_in_worker_from_transform: Optional[bool] = None + transforms: Optional[Dict[str, Callable]] = None - running_stage: RunningStage +@dataclass +class InputTransform: def __post_init__(self): - # used to keep track of provided transforms - self._collate_in_worker_from_transform: Optional[bool] = None - self._transform = None - self._transform = self._check_transforms(self._resolve_transforms(self.running_stage), self.running_stage) - - # Hack - Properties.__init__(self, running_stage=self.running_stage) - @property - def current_transform(self) -> Callable: - if self._transform: - return self._get_transform(self._transform) - return self._identity + # used to keep track of provided transforms + self._transform: Dict[RunningStage, _InputTransformPerStage] = {} + + # For all the stages possible, set/load the transforms. + for stage in RunningStage: + if stage not in INVALID_STAGES_FOR_INPUT_TRANSFORMS: + self._populate_transforms_for_stage(stage) + + def current_transform(self, stage: RunningStage, current_fn: str) -> Callable: + if stage in [RunningStage.SANITY_CHECKING, RunningStage.TUNING]: + raise KeyError( + f"Transforms are only defined for stages:" + f"\t{[stage for stage in RunningStage if stage not in INVALID_STAGES_FOR_INPUT_TRANSFORMS]}" + f"But received {stage} instead." + ) - @property - def transforms(self) -> Dict[str, Optional[Dict[str, Callable]]]: - """The transforms currently being used by this - :class:`~flash.core.data.io.input_transform.InputTransform`.""" - return { - "transform": self._transform, - } + # Check is transforms are present and the key is from the Enum defined above. + if InputTransformPlacement.from_str(current_fn) is None: + raise KeyError( + f"{[fn for fn in InputTransformPlacement]} are the only allowed keys to retreive the transform." + f"But received {current_fn} instead." + ) + return self._transform[stage].transforms.get(current_fn, self._identity) ######################## # PER SAMPLE TRANSFORM # @@ -831,33 +829,29 @@ def collate(self) -> Callable: # HOOKS CALLED INTERNALLY WITHIN FLASH # ######################################## - @partial(transform_context, current_fn="per_sample_transform") - def _per_sample_transform(self, sample: Any) -> Any: - fn = self.current_transform + def _per_sample_transform(self, sample: Any, stage: RunningStage) -> Any: + fn = self.current_transform(stage=stage, current_fn="per_sample_transform") if isinstance(sample, list): return [fn(s) for s in sample] return fn(sample) - @partial(transform_context, current_fn="per_batch_transform") - def _per_batch_transform(self, batch: Any) -> Any: + def _per_batch_transform(self, batch: Any, stage: RunningStage) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). .. note:: This option is mutually exclusive with :meth:`per_sample_transform_on_device`, since if both are specified, uncollation has to be applied. """ - return self.current_transform(batch) + return self.current_transform(stage=stage, current_fn="per_batch_transform")(batch) - @partial(transform_context, current_fn="collate") - def _collate(self, samples: Sequence, metadata=None) -> Any: + def _collate(self, samples: Sequence, stage: RunningStage, metadata=None) -> Any: """Transform to convert a sequence of samples to a collated batch.""" - collate_fn = self.current_transform + collate_fn = self.current_transform(stage=stage, current_fn="collate") parameters = inspect.signature(collate_fn).parameters if len(parameters) > 1 and DataKeys.METADATA in parameters: return collate_fn(samples, metadata) return collate_fn(samples) - @partial(transform_context, current_fn="per_sample_transform_on_device") - def _per_sample_transform_on_device(self, sample: Any) -> Any: + def _per_sample_transform_on_device(self, sample: Any, stage: RunningStage) -> Any: """Transforms to apply to the data before the collation (per-sample basis). .. note:: This option is mutually exclusive with :meth:`per_batch_transform`, since if both are @@ -865,25 +859,41 @@ def _per_sample_transform_on_device(self, sample: Any) -> Any: workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). """ - fn = self.current_transform + fn = self.current_transform(stage=stage, current_fn="per_sample_transform_on_device") if isinstance(sample, list): return [fn(s) for s in sample] return fn(sample) - @partial(transform_context, current_fn="per_batch_transform_on_device") - def _per_batch_transform_on_device(self, batch: Any) -> Any: + def _per_batch_transform_on_device(self, batch: Any, stage: RunningStage) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). .. note:: This function won't be called within the dataloader workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). """ - return self.current_transform(batch) + return self.current_transform(stage=stage, current_fn="per_batch_transform_on_device")(batch) ############# # UTILITIES # ############# - def _resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]: + def inject_collate_fn(self, collate_fn: Callable): + # For all the stages possible, set collate function + for stage in RunningStage: + if stage not in [RunningStage.SANITY_CHECKING, RunningStage.TUNING]: + self._transform[stage].transforms[InputTransformPlacement.COLLATE.value] = collate_fn + + def _populate_transforms_for_stage(self, running_stage: RunningStage): + transform, collate_in_worker = self.__check_transforms( + transform=self.__resolve_transforms(running_stage), stage=running_stage + ) + if self._transform is None: + self._transform = {} + self._transform[running_stage] = _InputTransformPerStage( + collate_in_worker_from_transform=collate_in_worker, + transforms=transform, + ) + + def __resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]: from flash.core.data.data_pipeline import DataPipeline transforms_out = {} @@ -957,9 +967,9 @@ def _resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, return transforms_out - def _check_transforms( + def __check_transforms( self, transform: Optional[Dict[str, Callable]], stage: RunningStage - ) -> Optional[Dict[str, Callable]]: + ) -> Tuple[Optional[Dict[str, Callable]], Optional[bool]]: if transform is None: return transform @@ -986,18 +996,12 @@ def _check_transforms( elif is_per_sample_transform_on_device_in: collate_in_worker = False - self._collate_in_worker_from_transform = collate_in_worker - return transform + return transform, collate_in_worker @staticmethod def _identity(x: Any) -> Any: return x - def _get_transform(self, transform: Dict[str, Callable]) -> Callable: - if self.current_fn in transform: - return transform[self.current_fn] - return self._identity - def __str__(self) -> str: return f"{self.__class__.__name__}(" + f"running_stage={self.running_stage}, transform={self._transform})" @@ -1035,12 +1039,11 @@ def _sanitize_registry_transform( return enum, transform_kwargs -def create_transform( +def create_or_configure_input_transform( transform: INPUT_TRANSFORM_TYPE, - running_stage: RunningStage, input_transforms_registry: Optional[FlashRegistry] = None, transform_kwargs: Optional[Dict] = None, -) -> Optional["InputTransform"]: +) -> Optional[InputTransform]: if not transform_kwargs: transform_kwargs = {} @@ -1049,14 +1052,19 @@ def create_transform( return transform if inspect.isclass(transform) and issubclass(transform, InputTransform): - return transform(running_stage=running_stage, **transform_kwargs) + # Deprecation Warning + rank_zero_warn( + "Please pass an instantiated object of the `InputTransform` class. Passing the Class and keyword arguments" + " separartely will be deprecated in v0.9.0.", + FutureWarning, + ) + return transform(**transform_kwargs) if isinstance(transform, partial) and transform.func.__name__ == "LambdaInputTransform": - return transform(running_stage=running_stage, **transform_kwargs) + return transform(**transform_kwargs) if isinstance(transform, Callable): return LambdaInputTransform( - running_stage=running_stage, transform=transform, **transform_kwargs, ) @@ -1064,7 +1072,7 @@ def create_transform( if isinstance(transform, tuple) or isinstance(transform, (LightningEnum, str)): enum, transform_kwargs = _sanitize_registry_transform(transform, input_transforms_registry) transform_cls = input_transforms_registry.get(enum) - return transform_cls(running_stage, **transform_kwargs) + return transform_cls(**transform_kwargs) if not transform: return None @@ -1072,15 +1080,9 @@ def create_transform( raise MisconfigurationException(f"The format for the transform isn't correct. Found {transform}") -def _make_collates(input_transform: "InputTransform", on_device: bool, collate: Callable) -> Tuple[Callable, Callable]: - if on_device: - return input_transform._identity, collate - return collate, input_transform._identity - - -class _InputTransformProcessorV2: +class _InputTransformProcessor: """ - This class is used to encapsulate the following functions of a InputTransformInputTransform Object: + This class is used to encapsulate the following functions of an `InputTransform` Object: Inside a worker: per_sample_transform: Function to transform an individual sample collate: Function to merge sample into a batch @@ -1131,7 +1133,7 @@ def __call__(self, samples: Sequence[Any]) -> Any: else: list_samples = samples - transformed_samples = [self.per_sample_transform(sample) for sample in list_samples] + transformed_samples = [self.per_sample_transform(sample, self.stage) for sample in list_samples] for sample in transformed_samples: if self.on_device: @@ -1141,16 +1143,16 @@ def __call__(self, samples: Sequence[Any]) -> Any: extracted_samples, metadata = self._extract_metadata(transformed_samples) try: - collated_samples = self.collate_fn(extracted_samples, metadata) + collated_samples = self.collate_fn(extracted_samples, self.stage, metadata) except TypeError: - collated_samples = self.collate_fn(extracted_samples) + collated_samples = self.collate_fn(extracted_samples, self.stage) if metadata and isinstance(collated_samples, dict): collated_samples[DataKeys.METADATA] = metadata self.callback.on_collate(collated_samples, self.stage) else: collated_samples = samples - transformed_collated_samples = self.per_batch_transform(collated_samples) + transformed_collated_samples = self.per_batch_transform(collated_samples, self.stage) if self.on_device: self.callback.on_per_batch_transform_on_device(transformed_collated_samples, self.stage) else: @@ -1170,15 +1172,22 @@ def __str__(self) -> str: ) -def _create_collate_input_transform_processors( - input_transform: "InputTransform", callbacks: List[FlashCallback] -) -> Tuple[_InputTransformProcessorV2, _InputTransformProcessorV2]: - """This utility is used to create the 2 `_InputTransformProcessorV2` objects which contain the transforms used - as the DataLoader `collate_fn` and the DataModule `on_after_batch_transfer` hook.""" +def __make_collates(input_transform: InputTransform, on_device: bool, collate: Callable) -> Tuple[Callable, Callable]: + """Returns the appropriate collate functions based on whether the transforms happen in a DataLoader worker or + on the device (main process).""" + if on_device: + return input_transform._identity, collate + return collate, input_transform._identity + + +def __configure_worker_and_device_collate_fn( + running_stage: RunningStage, input_transform: InputTransform +) -> Tuple[Callable, Callable]: from flash.core.data.data_pipeline import DataPipeline - prefix: str = _STAGES_PREFIX[input_transform.running_stage] + prefix: str = _STAGES_PREFIX[running_stage] + transform_for_stage: _InputTransformPerStage = input_transform._transform[running_stage] per_batch_transform_overridden: bool = DataPipeline._is_overridden_recursive( "per_batch_transform", input_transform, InputTransform, prefix=prefix @@ -1189,41 +1198,63 @@ def _create_collate_input_transform_processors( ) is_per_overridden = per_batch_transform_overridden and per_sample_transform_on_device_overridden - if input_transform._collate_in_worker_from_transform is None and is_per_overridden: + if transform_for_stage.collate_in_worker_from_transform is None and is_per_overridden: raise MisconfigurationException( f"{input_transform.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` " - f"are mutually exclusive for stage {input_transform.running_stage}" + f"are mutually exclusive for stage {running_stage}" ) - if isinstance(input_transform._collate_in_worker_from_transform, bool): - worker_collate_fn, device_collate_fn = _make_collates( - input_transform, not input_transform._collate_in_worker_from_transform, input_transform._collate + if isinstance(transform_for_stage.collate_in_worker_from_transform, bool): + worker_collate_fn, device_collate_fn = __make_collates( + input_transform, not transform_for_stage.collate_in_worker_from_transform, input_transform._collate ) else: - worker_collate_fn, device_collate_fn = _make_collates( + worker_collate_fn, device_collate_fn = __make_collates( input_transform, per_sample_transform_on_device_overridden, input_transform._collate ) worker_collate_fn = ( - worker_collate_fn.collate_fn if isinstance(worker_collate_fn, _InputTransformProcessorV2) else worker_collate_fn + worker_collate_fn.collate_fn if isinstance(worker_collate_fn, _InputTransformProcessor) else worker_collate_fn ) - worker_input_transform_processor = _InputTransformProcessorV2( + return worker_collate_fn, device_collate_fn + + +def create_worker_input_transform_processor( + running_stage: RunningStage, input_transform: InputTransform, callbacks: List[FlashCallback] +) -> _InputTransformProcessor: + """This utility is used to create the 2 `_InputTransformProcessor` objects which contain the transforms used as + the DataLoader `collate_fn`.""" + worker_collate_fn, _ = __configure_worker_and_device_collate_fn( + running_stage=running_stage, input_transform=input_transform + ) + worker_input_transform_processor = _InputTransformProcessor( input_transform, worker_collate_fn, input_transform._per_sample_transform, input_transform._per_batch_transform, - input_transform.running_stage, + running_stage, callbacks=callbacks, ) - device_input_transform_processor = _InputTransformProcessorV2( + return worker_input_transform_processor + + +def create_device_input_transform_processor( + running_stage: RunningStage, input_transform: InputTransform, callbacks: List[FlashCallback] +) -> _InputTransformProcessor: + """This utility is used to create a `_InputTransformProcessor` object which contain the transforms used as the + DataModule `on_after_batch_transfer` hook.""" + _, device_collate_fn = __configure_worker_and_device_collate_fn( + running_stage=running_stage, input_transform=input_transform + ) + device_input_transform_processor = _InputTransformProcessor( input_transform, device_collate_fn, input_transform._per_sample_transform_on_device, input_transform._per_batch_transform_on_device, - input_transform.running_stage, + running_stage, apply_per_sample_transform=device_collate_fn != input_transform._identity, on_device=True, callbacks=callbacks, ) - return worker_input_transform_processor, device_input_transform_processor + return device_input_transform_processor diff --git a/flash/core/data/io/output_transform.py b/flash/core/data/io/output_transform.py index 76bc5edf2f..6e9e27dbe6 100644 --- a/flash/core/data/io/output_transform.py +++ b/flash/core/data/io/output_transform.py @@ -14,10 +14,9 @@ from typing import Any, Sequence from flash.core.data.batch import default_uncollate -from flash.core.data.properties import Properties -class OutputTransform(Properties): +class OutputTransform: """The :class:`~flash.core.data.io.output_transform.OutputTransform` encapsulates all the data processing logic that should run after the model.""" diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index bff52bc5e9..cdb6e12ae1 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -20,11 +20,13 @@ import flash from flash.core.adapter import Adapter from flash.core.data.io.input import DataKeys, InputBase +from flash.core.data.io.input_transform import InputTransform from flash.core.integrations.icevision.transforms import ( from_icevision_predictions, from_icevision_record, to_icevision_record, ) +from flash.core.integrations.icevision.wrappers import wrap_icevision_adapter from flash.core.model import Task from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.core.utilities.url_error import catch_url_error @@ -56,7 +58,7 @@ def __init__(self, model_type, model, icevision_adapter, backbone, predict_kwarg # Modules can't be pickled so just store the name self.model_type = model_type.__name__ self.model = model - self.icevision_adapter = icevision_adapter + self.icevision_adapter = wrap_icevision_adapter(icevision_adapter) self.backbone = backbone self.predict_kwargs = predict_kwargs @@ -100,6 +102,7 @@ def process_train_dataset( self, dataset: InputBase, trainer: "flash.Trainer", + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -119,13 +122,23 @@ def process_train_dataset( sampler=sampler, persistent_workers=persistent_workers, ) - data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn) + # Assign the InputTransform + if self.input_transform is None: + self.input_transform = input_transform + + # Inject the actual `collate function` into the InputTransform object. + self.input_transform.inject_collate_fn(functools.partial(self._wrap_collate_fn, data_loader.collate_fn)) + + # Replace the collate_fn with _InputTransformProcessor object so that the complete + # InputTransform sequence is called. + data_loader.collate_fn = collate_fn return data_loader def process_val_dataset( self, dataset: InputBase, trainer: "flash.Trainer", + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -145,13 +158,18 @@ def process_val_dataset( sampler=sampler, persistent_workers=persistent_workers, ) - data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn) + # Assign the InputTransform + if self.input_transform is None: + self.input_transform = input_transform + self.input_transform.inject_collate_fn(functools.partial(self._wrap_collate_fn, data_loader.collate_fn)) + data_loader.collate_fn = collate_fn return data_loader def process_test_dataset( self, dataset: InputBase, trainer: "flash.Trainer", + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -171,12 +189,17 @@ def process_test_dataset( sampler=sampler, persistent_workers=persistent_workers, ) - data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn) + # Assign the InputTransform + if self.input_transform is None: + self.input_transform = input_transform + self.input_transform.inject_collate_fn(functools.partial(self._wrap_collate_fn, data_loader.collate_fn)) + data_loader.collate_fn = collate_fn return data_loader def process_predict_dataset( self, dataset: InputBase, + input_transform: InputTransform, batch_size: int = 1, num_workers: int = 0, pin_memory: bool = False, @@ -196,7 +219,11 @@ def process_predict_dataset( sampler=sampler, persistent_workers=persistent_workers, ) - data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn) + # Assign the InputTransform + if self.input_transform is None: + self.input_transform = input_transform + self.input_transform.inject_collate_fn(functools.partial(self._wrap_collate_fn, data_loader.collate_fn)) + data_loader.collate_fn = collate_fn return data_loader def training_step(self, batch, batch_idx) -> Any: @@ -229,3 +256,9 @@ def validation_epoch_end(self, outputs) -> None: def test_epoch_end(self, outputs) -> None: return self.icevision_adapter.validation_epoch_end(outputs) + + def __setstate__(self, newstate): + super().__setstate__(newstate) + + # Re-wrap IceVision adapter + self.icevision_adapter = wrap_icevision_adapter(self.icevision_adapter) diff --git a/flash/core/integrations/icevision/backbones.py b/flash/core/integrations/icevision/backbones.py index 80d92f1bbe..3d038daa5a 100644 --- a/flash/core/integrations/icevision/backbones.py +++ b/flash/core/integrations/icevision/backbones.py @@ -22,21 +22,7 @@ from icevision.backbones import BackboneConfig -def _log_with_prog_bar_override(self, name, value, **kwargs): - if "prog_bar" not in kwargs: - kwargs["prog_bar"] = True - return self._original_log(name.split("/")[-1], value, **kwargs) - - -def icevision_model_adapter(model_type): - adapter = model_type.lightning.ModelAdapter - if not hasattr(adapter, "_original_log"): - adapter._original_log = adapter.log - adapter.log = _log_with_prog_bar_override - return adapter - - -def load_icevision(adapter, model_type, backbone, num_classes, **kwargs): +def load_icevision(model_type, backbone, num_classes, **kwargs): model = model_type.model(backbone=backbone, num_classes=num_classes, **kwargs) backbone = nn.Module() @@ -49,16 +35,16 @@ def load_icevision(adapter, model_type, backbone, num_classes, **kwargs): if hasattr(model, "backbone") and hasattr(model.backbone, "param_groups"): del model.backbone.param_groups - return model_type, model, adapter(model_type), backbone + return model_type, model, model_type.lightning.ModelAdapter, backbone -def load_icevision_ignore_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs): - return load_icevision(adapter, model_type, backbone, num_classes, **kwargs) +def load_icevision_ignore_image_size(model_type, backbone, num_classes, image_size=None, **kwargs): + return load_icevision(model_type, backbone, num_classes, **kwargs) -def load_icevision_with_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs): +def load_icevision_with_image_size(model_type, backbone, num_classes, image_size=None, **kwargs): kwargs["img_size"] = image_size - return load_icevision(adapter, model_type, backbone, num_classes, **kwargs) + return load_icevision(model_type, backbone, num_classes, **kwargs) def get_backbones(model_type): diff --git a/flash/core/integrations/icevision/wrappers.py b/flash/core/integrations/icevision/wrappers.py new file mode 100644 index 0000000000..76da568c1f --- /dev/null +++ b/flash/core/integrations/icevision/wrappers.py @@ -0,0 +1,43 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import torch + +from flash.core.utilities.imports import _ICEVISION_AVAILABLE + +if _ICEVISION_AVAILABLE: + from icevision.models.ross.efficientdet.lightning.model_adapter import ModelAdapter as EffDetModelAdapter + + +def _log_with_prog_bar_override(log, name, value, **kwargs): + if "prog_bar" not in kwargs: + kwargs["prog_bar"] = True + return log(name.split("/")[-1], value, **kwargs) + + +def _effdet_validation_step(validation_step, batch, batch_idx): + images = batch[0][0] + batch[0][1]["img_scale"] = torch.ones_like(images[:, 0, 0, 0]).unsqueeze(1) + batch[0][1]["img_size"] = (torch.ones_like(images[:, 0, 0, 0]) * images[0].shape[-1]).unsqueeze(1).repeat(1, 2) + return validation_step(batch, batch_idx) + + +def wrap_icevision_adapter(adapter): + if not isinstance(adapter.log, partial): + adapter.log = partial(_log_with_prog_bar_override, adapter.log) + + if isinstance(adapter, EffDetModelAdapter) and not isinstance(adapter.validation_step, partial): + adapter.validation_step = partial(_effdet_validation_step, adapter.validation_step) + return adapter diff --git a/flash/core/model.py b/flash/core/model.py index 597128c8b6..771e403b52 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -104,16 +104,17 @@ def collate_fn(self, collate_fn: Callable) -> None: @torch.jit.unused @property - def input_transform(self) -> Optional[INPUT_TRANSFORM_TYPE]: + def input_transform(self) -> Optional[InputTransform]: return self._input_transform @input_transform.setter - def input_transform(self, input_transform: INPUT_TRANSFORM_TYPE) -> None: + def input_transform(self, input_transform: InputTransform) -> None: self._input_transform = input_transform def _process_dataset( self, dataset: InputBase, + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -123,6 +124,16 @@ def _process_dataset( sampler: Optional[Sampler] = None, persistent_workers: bool = False, ) -> DataLoader: + + # Assign the InputTransform + if self.input_transform is None: + self.input_transform = input_transform + + # Now inject the `self.collate_fn` so that it doesn't override `InputTransform._collate` but is called through + # the `InputTransform._collate` method. + if self.collate_fn is not None: + self.input_transform.inject_collate_fn(self.collate_fn) + return DataLoader( dataset, batch_size=batch_size, @@ -131,7 +142,7 @@ def _process_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, - collate_fn=self.collate_fn if self.collate_fn is not None else collate_fn, + collate_fn=collate_fn, persistent_workers=persistent_workers, ) @@ -139,6 +150,7 @@ def process_train_dataset( self, dataset: InputBase, trainer: "flash.Trainer", + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -150,10 +162,11 @@ def process_train_dataset( ) -> DataLoader: return self._process_dataset( dataset, + input_transform=input_transform, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, - collate_fn=self.collate_fn if self.collate_fn is not None else collate_fn, + collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last, sampler=sampler, @@ -164,6 +177,7 @@ def process_val_dataset( self, dataset: InputBase, trainer: "flash.Trainer", + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -175,10 +189,11 @@ def process_val_dataset( ) -> DataLoader: return self._process_dataset( dataset, + input_transform=input_transform, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, - collate_fn=self.collate_fn if self.collate_fn is not None else collate_fn, + collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last, sampler=sampler, @@ -189,6 +204,7 @@ def process_test_dataset( self, dataset: InputBase, trainer: "flash.Trainer", + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -200,10 +216,11 @@ def process_test_dataset( ) -> DataLoader: return self._process_dataset( dataset, + input_transform=input_transform, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, - collate_fn=self.collate_fn if self.collate_fn is not None else collate_fn, + collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last, sampler=sampler, @@ -213,6 +230,7 @@ def process_test_dataset( def process_predict_dataset( self, dataset: InputBase, + input_transform: InputTransform, batch_size: int = 1, num_workers: int = 0, pin_memory: bool = False, @@ -224,10 +242,11 @@ def process_predict_dataset( ) -> DataLoader: return self._process_dataset( dataset, + input_transform=input_transform, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, - collate_fn=self.collate_fn if self.collate_fn is not None else collate_fn, + collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last, sampler=sampler, @@ -438,7 +457,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A batch = torch.stack(batch) return self(batch) - def modules_to_freeze(self) -> Optional[Union[nn.Module]]: + def modules_to_freeze(self) -> Optional[nn.Module]: """By default, we try to get the ``backbone`` attribute from the task and return it or ``None`` if not present. @@ -807,13 +826,15 @@ def configure_callbacks(self): return [BenchmarkConvergenceCI()] @requires("serve") - def run_serve_sanity_check(self, serve_input: ServeInput, output: Output): + def run_serve_sanity_check( + self, serve_input: ServeInput, transform: INPUT_TRANSFORM_TYPE, transform_kwargs: Optional[Dict], output: Output + ): from fastapi.testclient import TestClient from flash.core.serve.flash_components import build_flash_serve_model_component print("Running serve sanity check") - comp = build_flash_serve_model_component(self, serve_input, output) + comp = build_flash_serve_model_component(self, serve_input, output, transform, transform_kwargs) composition = Composition(predict=comp, TESTING=True, DEBUG=True) app = composition.serve(host="0.0.0.0", port=8000) @@ -850,16 +871,16 @@ def serve( if input_cls is None: raise NotImplementedError("The `input_cls` must be provided to enable serving.") - serve_input = input_cls(transform=transform, transform_kwargs=transform_kwargs) + serve_input = input_cls() output = output or Output() if isinstance(output, str): output = self.outputs.get(output).from_task(self) if sanity_check: - self.run_serve_sanity_check(serve_input, output) + self.run_serve_sanity_check(serve_input, transform, transform_kwargs, output) - comp = build_flash_serve_model_component(self, serve_input, output) + comp = build_flash_serve_model_component(self, serve_input, output, transform, transform_kwargs) composition = Composition(predict=comp, TESTING=flash._IS_TESTING) composition.serve(host=host, port=port) return composition diff --git a/flash/core/serve/flash_components.py b/flash/core/serve/flash_components.py index c7370f93d8..513c715d8a 100644 --- a/flash/core/serve/flash_components.py +++ b/flash/core/serve/flash_components.py @@ -53,9 +53,14 @@ def deserialize(self, data: str) -> Any: # pragma: no cover return None -def build_flash_serve_model_component(model, serve_input, output): +def build_flash_serve_model_component(model, serve_input, output, transform, transform_kwargs): # TODO: Resolve this hack - data_module = DataModule(predict_input=serve_input, batch_size=1) + data_module = DataModule( + predict_input=serve_input, + batch_size=1, + transform=transform, + transform_kwargs=transform_kwargs, + ) class MockTrainer(Trainer): def __init__(self): diff --git a/flash/graph/classification/data.py b/flash/graph/classification/data.py index e340535961..86936406ba 100644 --- a/flash/graph/classification/data.py +++ b/flash/graph/classification/data.py @@ -42,13 +42,10 @@ def from_datasets( val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, predict_dataset: Optional[Dataset] = None, - train_transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform, - target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = GraphClassificationDatasetInput, + transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform, transform_kwargs: Optional[Dict] = None, + target_formatter: Optional[TargetFormatter] = None, **data_module_kwargs, ) -> "GraphClassificationData": """Load the :class:`~flash.graph.classification.data.GraphClassificationData` from PyTorch Dataset objects. @@ -67,14 +64,10 @@ def from_datasets( val_dataset: The Dataset to use when validating. test_dataset: The Dataset to use when testing. predict_dataset: The Dataset to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. If ``None`` then no formatting will be applied to targets. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -175,18 +168,18 @@ def from_datasets( ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_dataset, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_dataset, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_dataset, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_dataset, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_dataset, **ds_kw), + input_cls(RunningStage.TESTING, test_dataset, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_dataset, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/image/classification/adapters.py b/flash/image/classification/adapters.py index 66064ccdb4..b0d4bb5e5c 100644 --- a/flash/image/classification/adapters.py +++ b/flash/image/classification/adapters.py @@ -29,6 +29,7 @@ import flash from flash.core.adapter import Adapter, AdapterTask from flash.core.data.io.input import DataKeys, InputBase +from flash.core.data.io.input_transform import InputTransform from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.compatibility import accelerator_connector @@ -325,6 +326,7 @@ def process_train_dataset( self, dataset: InputBase, trainer: "flash.Trainer", + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -350,10 +352,11 @@ def process_train_dataset( return super().process_train_dataset( dataset, trainer, - self._sanetize_batch_size(batch_size), - num_workers, - False, - collate_fn, + input_transform=input_transform, + batch_size=self._sanetize_batch_size(batch_size), + num_workers=num_workers, + pin_memory=False, + collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last, sampler=sampler, @@ -364,6 +367,7 @@ def process_val_dataset( self, dataset: InputBase, trainer: "flash.Trainer", + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -387,12 +391,13 @@ def process_val_dataset( shuffle = False sampler = None return super().process_train_dataset( - dataset, - trainer, - self._sanetize_batch_size(batch_size), - num_workers, - False, - collate_fn, + dataset=dataset, + trainer=trainer, + input_transform=input_transform, + batch_size=self._sanetize_batch_size(batch_size), + num_workers=num_workers, + pin_memory=False, + collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last, sampler=sampler, @@ -403,6 +408,7 @@ def process_test_dataset( self, dataset: InputBase, trainer: "flash.Trainer", + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -426,12 +432,13 @@ def process_test_dataset( shuffle = False sampler = None return super().process_train_dataset( - dataset, - trainer, - self._sanetize_batch_size(batch_size), - num_workers, - False, - collate_fn, + dataset=dataset, + trainer=trainer, + input_transform=input_transform, + batch_size=self._sanetize_batch_size(batch_size), + num_workers=num_workers, + pin_memory=False, + collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last, sampler=sampler, @@ -441,6 +448,7 @@ def process_test_dataset( def process_predict_dataset( self, dataset: InputBase, + input_transform: InputTransform, batch_size: int = 1, num_workers: int = 0, pin_memory: bool = False, @@ -457,11 +465,12 @@ def process_predict_dataset( ) return super().process_predict_dataset( - dataset, - batch_size, - num_workers, - pin_memory, - collate_fn, + dataset=dataset, + input_transform=input_transform, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last, sampler=sampler, diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 9ec5f81196..f87171a961 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -26,7 +26,6 @@ from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.paths import PATH_TYPE from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioImageClassificationInput -from flash.core.registry import FlashRegistry from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, _IMAGE_EXTRAS_TESTING, @@ -77,7 +76,6 @@ class ImageClassificationData(DataModule): """The ``ImageClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of classmethods for loading data for image classification.""" - input_transforms_registry = FlashRegistry("input_transforms") input_transform_cls = ImageClassificationInputTransform @classmethod @@ -90,12 +88,9 @@ def from_files( test_files: Optional[Sequence[str]] = None, test_targets: Optional[Sequence[Any]] = None, predict_files: Optional[Sequence[str]] = None, - train_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ImageClassificationFilesInput, + transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ImageClassificationData": @@ -117,14 +112,10 @@ def from_files( test_files: The list of image files to use when testing. test_targets: The list of targets to use when testing. predict_files: The list of image files to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -172,18 +163,18 @@ def from_files( """ ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_files, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_files, val_targets, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_files, test_targets, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_files, val_targets, **ds_kw), + input_cls(RunningStage.TESTING, test_files, test_targets, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -194,12 +185,9 @@ def from_folders( val_folder: Optional[str] = None, test_folder: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ImageClassificationFolderInput, + transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ImageClassificationData": @@ -240,14 +228,10 @@ def from_folders( val_folder: The folder containing images to use when validating. test_folder: The folder containing images to use when testing. predict_folder: The folder containing images to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -300,18 +284,18 @@ def from_folders( """ ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_folder, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_folder, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_folder, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_folder, **ds_kw), + input_cls(RunningStage.TESTING, test_folder, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_folder, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -325,12 +309,9 @@ def from_numpy( test_data: Optional[Collection[np.ndarray]] = None, test_targets: Optional[Sequence[Any]] = None, predict_data: Optional[Collection[np.ndarray]] = None, - train_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ImageClassificationNumpyInput, + transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ImageClassificationData": @@ -350,14 +331,10 @@ def from_numpy( test_data: The numpy array or list of arrays to use when testing. test_targets: The list of targets to use when testing. predict_data: The numpy array or list of arrays to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -393,18 +370,18 @@ def from_numpy( """ ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data, val_targets, **ds_kw), + input_cls(RunningStage.TESTING, test_data, test_targets, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -418,12 +395,9 @@ def from_tensors( test_data: Optional[Collection[torch.Tensor]] = None, test_targets: Optional[Sequence[Any]] = None, predict_data: Optional[Collection[torch.Tensor]] = None, - train_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ImageClassificationTensorInput, + transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ImageClassificationData": @@ -443,14 +417,10 @@ def from_tensors( test_data: The torch tensor or list of tensors to use when testing. test_targets: The list of targets to use when testing. predict_data: The torch tensor or list of tensors to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -486,18 +456,18 @@ def from_tensors( """ ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data, val_targets, **ds_kw), + input_cls(RunningStage.TESTING, test_data, test_targets, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -518,12 +488,9 @@ def from_data_frame( predict_data_frame: Optional[pd.DataFrame] = None, predict_images_root: Optional[str] = None, predict_resolver: Optional[Callable[[str, str], str]] = None, - train_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ImageClassificationDataFrameInput, + transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ImageClassificationData": @@ -557,14 +524,10 @@ def from_data_frame( predict_images_root: The root directory containing predict images. predict_resolver: Optionally provide a function which converts an entry from the ``input_field`` into an image file path. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -632,8 +595,6 @@ def from_data_frame( """ ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) train_data = (train_data_frame, input_field, target_fields, train_images_root, train_resolver) @@ -641,14 +602,16 @@ def from_data_frame( test_data = (test_data_frame, input_field, target_fields, test_images_root, test_resolver) predict_data = (predict_data_frame, input_field, None, predict_images_root, predict_resolver) - train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, *train_data, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, *val_data, **ds_kw), + input_cls(RunningStage.TESTING, *test_data, **ds_kw), + input_cls(RunningStage.PREDICTING, *predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -669,12 +632,9 @@ def from_csv( predict_file: Optional[str] = None, predict_images_root: Optional[str] = None, predict_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, - train_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ImageClassificationCSVInput, + transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ImageClassificationData": @@ -708,14 +668,10 @@ def from_csv( predict_images_root: The root directory containing predict images. predict_resolver: Optionally provide a function which converts an entry from the ``input_field`` into an image file path. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -797,8 +753,6 @@ def from_csv( """ ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) train_data = (train_file, input_field, target_fields, train_images_root, train_resolver) @@ -806,14 +760,16 @@ def from_csv( test_data = (test_file, input_field, target_fields, test_images_root, test_resolver) predict_data = (predict_file, input_field, None, predict_images_root, predict_resolver) - train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, *train_data, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, *val_data, **ds_kw), + input_cls(RunningStage.TESTING, *test_data, **ds_kw), + input_cls(RunningStage.PREDICTING, *predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -826,12 +782,9 @@ def from_fiftyone( test_dataset: Optional[SampleCollection] = None, predict_dataset: Optional[SampleCollection] = None, label_field: str = "ground_truth", - train_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ImageClassificationFiftyOneInput, + transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "ImageClassificationData": @@ -851,14 +804,10 @@ def from_fiftyone( test_dataset: The ``SampleCollection`` to use when testing. predict_dataset: The ``SampleCollection`` to use when predicting. label_field: The field in the ``SampleCollection`` objects containing the targets. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -923,20 +872,18 @@ def from_fiftyone( """ ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls( - RunningStage.TRAINING, train_dataset, transform=train_transform, label_field=label_field, **ds_kw - ) + train_input = input_cls(RunningStage.TRAINING, train_dataset, label_field=label_field, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_dataset, transform=val_transform, label_field=label_field, **ds_kw), - input_cls(RunningStage.TESTING, test_dataset, transform=test_transform, label_field=label_field, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_dataset, label_field=label_field, **ds_kw), + input_cls(RunningStage.TESTING, test_dataset, label_field=label_field, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_dataset, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -953,11 +900,8 @@ def from_labelstudio( val_data_folder: str = None, test_data_folder: str = None, predict_data_folder: str = None, - train_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, input_cls: Type[Input] = LabelStudioImageClassificationInput, + transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, transform_kwargs: Optional[Dict] = None, val_split: Optional[float] = None, multi_label: Optional[bool] = False, @@ -971,33 +915,20 @@ def from_labelstudio( Args: export_json: path to label studio export file - train_export_json: path to label studio export file for train set, - overrides export_json if specified + train_export_json: path to label studio export file for train set.(overrides export_json if specified) val_export_json: path to label studio export file for validation test_export_json: path to label studio export file for test predict_export_json: path to label studio export file for predict data_folder: path to label studio data folder - train_data_folder: path to label studio data folder for train data set, - overrides data_folder if specified + train_data_folder: path to label studio data folder for train data set.(overrides data_folder if specified) val_data_folder: path to label studio data folder for validation data test_data_folder: path to label studio data folder for test data predict_data_folder: path to label studio data folder for predict data - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the - :class:`~flash.core.data.data_module.DataModule`. - input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` - will be constructed and used. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. multi_label: Whether the labels are multi encoded - image_size: Size of the image. data_module_kwargs: Additional keyword arguments to use when constructing the datamodule. Returns: @@ -1027,19 +958,18 @@ def from_labelstudio( multi_label=multi_label, ) - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) + ds_kw = dict() - train_input = input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_data, **ds_kw) ds_kw["parameters"] = getattr(train_input, "parameters", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_data, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, val_data, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data, **ds_kw), + input_cls(RunningStage.TESTING, val_data, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -1050,11 +980,8 @@ def from_datasets( val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, predict_dataset: Optional[Dataset] = None, - train_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, input_cls: Type[Input] = DatasetInput, + transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "DataModule": @@ -1068,15 +995,8 @@ def from_datasets( val_dataset: Dataset used during validating. test_dataset: Dataset used during testing. predict_dataset: Dataset used during predicting. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. input_cls: Input class used to create the datasets. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Additional keyword arguments to be used when constructing the transform. data_module_kwargs: Additional keyword arguments to use when constructing the DataModule. @@ -1089,16 +1009,15 @@ def from_datasets( train_dataset=train_dataset, ) """ - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) + ds_kw = dict() return cls( - input_cls(RunningStage.TRAINING, train_dataset, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_dataset, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_dataset, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_dataset, **ds_kw), + input_cls(RunningStage.VALIDATING, val_dataset, **ds_kw), + input_cls(RunningStage.TESTING, test_dataset, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_dataset, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/image/detection/backbones.py b/flash/image/detection/backbones.py index adfa271a5a..54bc0ee4d5 100644 --- a/flash/image/detection/backbones.py +++ b/flash/image/detection/backbones.py @@ -14,13 +14,10 @@ from functools import partial from typing import Optional -import torch - from flash.core.adapter import Adapter from flash.core.integrations.icevision.adapter import IceVisionAdapter, SimpleCOCOMetric from flash.core.integrations.icevision.backbones import ( get_backbones, - icevision_model_adapter, load_icevision_ignore_image_size, load_icevision_with_image_size, ) @@ -66,7 +63,7 @@ def from_task( if _TORCHVISION_AVAILABLE: for model_type in [icevision_models.torchvision.retinanet, icevision_models.torchvision.faster_rcnn]: OBJECT_DETECTION_HEADS( - partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + partial(load_icevision_ignore_image_size, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), adapter=IceVisionObjectDetectionAdapter, @@ -76,7 +73,7 @@ def from_task( if _module_available("yolov5"): model_type = icevision_models.ultralytics.yolov5 OBJECT_DETECTION_HEADS( - partial(load_icevision_with_image_size, icevision_model_adapter, model_type), + partial(load_icevision_with_image_size, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), adapter=IceVisionObjectDetectionAdapter, @@ -91,7 +88,7 @@ def from_task( icevision_models.mmdet.sparse_rcnn, ]: OBJECT_DETECTION_HEADS( - partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + partial(load_icevision_ignore_image_size, model_type), f"mmdet_{model_type.__name__.split('.')[-1]}", backbones=get_backbones(model_type), adapter=IceVisionObjectDetectionAdapter, @@ -100,24 +97,9 @@ def from_task( if _module_available("effdet"): - def _icevision_effdet_validation_step(self, batch, batch_idx): - images = batch[0][0] - batch[0][1]["img_scale"] = torch.ones_like(images[:, 0, 0, 0]).unsqueeze(1) - batch[0][1]["img_size"] = ( - (torch.ones_like(images[:, 0, 0, 0]) * images[0].shape[-1]).unsqueeze(1).repeat(1, 2) - ) - return self._original_validation_step(batch, batch_idx) - - def _icevision_effdet_model_adapter(model_type): - adapter = icevision_model_adapter(model_type) - if not hasattr(adapter, "_original_validation_step"): - adapter._original_validation_step = adapter.validation_step - adapter.validation_step = _icevision_effdet_validation_step - return adapter - model_type = icevision_models.ross.efficientdet OBJECT_DETECTION_HEADS( - partial(load_icevision_with_image_size, _icevision_effdet_model_adapter, model_type), + partial(load_icevision_with_image_size, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), adapter=IceVisionObjectDetectionAdapter, diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 0376b1752c..32c089f1a9 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -63,12 +63,9 @@ def from_files( test_targets: Optional[Sequence[Sequence[Any]]] = None, test_bboxes: Optional[Sequence[Sequence[Dict[str, int]]]] = None, predict_files: Optional[Sequence[str]] = None, - train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ObjectDetectionFilesInput, + transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ObjectDetectionData": @@ -95,14 +92,10 @@ def from_files( test_targets: The list of lists of targets to use when testing. test_bboxes: The list of lists of bounding boxes to use when testing. predict_files: The list of image files to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -157,7 +150,6 @@ def from_files( ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, ) train_input = input_cls( @@ -165,7 +157,6 @@ def from_files( train_files, train_targets, train_bboxes, - transform=train_transform, **ds_kw, ) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -177,7 +168,6 @@ def from_files( val_files, val_targets, val_bboxes, - transform=val_transform, **ds_kw, ), input_cls( @@ -185,10 +175,11 @@ def from_files( test_files, test_targets, test_bboxes, - transform=test_transform, **ds_kw, ), - input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -205,17 +196,14 @@ def from_icedata( test_ann_file: Optional[str] = None, test_parser_kwargs: Optional[Dict[str, Any]] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, parser: Optional[Union[Callable, Type[Parser]]] = None, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "ObjectDetectionData": - ds_kw = dict(parser=parser, transform_kwargs=transform_kwargs) + ds_kw = dict(parser=parser) return cls( input_cls( @@ -223,7 +211,6 @@ def from_icedata( train_folder, train_ann_file, parser_kwargs=train_parser_kwargs, - transform=train_transform, **ds_kw, ), input_cls( @@ -231,7 +218,6 @@ def from_icedata( val_folder, val_ann_file, parser_kwargs=val_parser_kwargs, - transform=val_transform, **ds_kw, ), input_cls( @@ -239,10 +225,11 @@ def from_icedata( test_folder, test_ann_file, parser_kwargs=test_parser_kwargs, - transform=test_transform, **ds_kw, ), - input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_folder, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -256,10 +243,7 @@ def from_coco( test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -283,12 +267,8 @@ def from_coco( test_folder: The folder containing images to use when testing. test_ann_file: The COCO format annotation file to use when testing. predict_folder: The folder containing images to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -391,12 +371,9 @@ def from_coco( test_folder=test_folder, test_ann_file=test_ann_file, predict_folder=predict_folder, - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, parser=COCOBBoxParser, input_cls=input_cls, + transform=transform, transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -412,10 +389,7 @@ def from_voc( test_folder: Optional[str] = None, test_ann_folder: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -438,12 +412,8 @@ def from_voc( test_folder: The folder containing images to use when testing. test_ann_folder: The folder containing VOC format annotation files to use when testing. predict_folder: The folder containing images to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -557,12 +527,9 @@ def from_voc( test_folder=test_folder, test_ann_file=test_ann_folder, predict_folder=predict_folder, - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, parser=partial(VOCBBoxParser, class_map=ClassMap(list(sorted_alphanumeric(labels)))), input_cls=input_cls, + transform=transform, transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -579,10 +546,7 @@ def from_via( test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -606,12 +570,8 @@ def from_via( test_folder: The folder containing images to use when testing. test_ann_file: The VIA format annotation file to use when testing. predict_folder: The folder containing images to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -718,16 +678,13 @@ def from_via( test_folder=test_folder, test_ann_file=test_ann_file, predict_folder=predict_folder, - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, parser=partial( VIABBoxParser, class_map=ClassMap(list(sorted_alphanumeric(labels))), label_field=label_field, ), input_cls=input_cls, + transform=transform, transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -742,10 +699,7 @@ def from_fiftyone( predict_dataset: Optional[SampleCollection] = None, label_field: str = "ground_truth", iscrowd: str = "iscrowd", - train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, input_cls: Type[Input] = ObjectDetectionFiftyOneInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -764,12 +718,8 @@ def from_fiftyone( predict_dataset: The ``SampleCollection`` to use when predicting. label_field: The field in the ``SampleCollection`` objects containing the targets. iscrowd: The field in the ``SampleCollection`` objects containing the ``iscrowd`` annotation (if required). - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -841,13 +791,15 @@ def from_fiftyone( >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] """ - ds_kw = dict(transform_kwargs=transform_kwargs) + ds_kw = dict() return cls( - input_cls(RunningStage.TRAINING, train_dataset, label_field, iscrowd, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_dataset, label_field, iscrowd, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_dataset, label_field, iscrowd, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_dataset, label_field, iscrowd, **ds_kw), + input_cls(RunningStage.VALIDATING, val_dataset, label_field, iscrowd, **ds_kw), + input_cls(RunningStage.TESTING, test_dataset, label_field, iscrowd, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_dataset, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -872,8 +824,8 @@ def from_folders( The constructed data module. """ return cls( - predict_input=input_cls( - RunningStage.PREDICTING, predict_folder, transform=predict_transform, transform_kwargs=transform_kwargs - ), + predict_input=input_cls(RunningStage.PREDICTING, predict_folder), + transform=predict_transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/image/face_detection/data.py b/flash/image/face_detection/data.py index 422e2e3028..9e77526540 100644 --- a/flash/image/face_detection/data.py +++ b/flash/image/face_detection/data.py @@ -34,22 +34,21 @@ def from_datasets( val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, predict_dataset: Optional[Dataset] = None, - train_transform: INPUT_TRANSFORM_TYPE = FaceDetectionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = FaceDetectionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = FaceDetectionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = FaceDetectionInputTransform, input_cls: Type[Input] = FaceDetectionInput, + transform: INPUT_TRANSFORM_TYPE = FaceDetectionInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "FaceDetectionData": - ds_kw = dict(transform_kwargs=transform_kwargs) + ds_kw = dict() return cls( - input_cls(RunningStage.TRAINING, train_dataset, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_dataset, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_dataset, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_dataset, **ds_kw), + input_cls(RunningStage.VALIDATING, val_dataset, **ds_kw), + input_cls(RunningStage.TESTING, test_dataset, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_dataset, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -63,10 +62,10 @@ def from_files( **data_module_kwargs: Any, ) -> "FaceDetectionData": - ds_kw = dict(transform=predict_transform, transform_kwargs=transform_kwargs) - return cls( - predict_input=input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), + predict_input=input_cls(RunningStage.PREDICTING, predict_files), + transform=predict_transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -80,9 +79,9 @@ def from_folders( **data_module_kwargs: Any, ) -> "FaceDetectionData": - ds_kw = dict(transform=predict_transform, transform_kwargs=transform_kwargs) - return cls( - predict_input=input_cls(RunningStage.PREDICTING, predict_folder, **ds_kw), + predict_input=input_cls(RunningStage.PREDICTING, predict_folder), + transform=predict_transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/image/instance_segmentation/backbones.py b/flash/image/instance_segmentation/backbones.py index 9811d6fa78..3617c2a796 100644 --- a/flash/image/instance_segmentation/backbones.py +++ b/flash/image/instance_segmentation/backbones.py @@ -16,11 +16,7 @@ from flash.core.adapter import Adapter from flash.core.integrations.icevision.adapter import IceVisionAdapter, SimpleCOCOMetric -from flash.core.integrations.icevision.backbones import ( - get_backbones, - icevision_model_adapter, - load_icevision_ignore_image_size, -) +from flash.core.integrations.icevision.backbones import get_backbones, load_icevision_ignore_image_size from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE @@ -63,7 +59,7 @@ def from_task( if _TORCHVISION_AVAILABLE: model_type = icevision_models.torchvision.mask_rcnn INSTANCE_SEGMENTATION_HEADS( - partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + partial(load_icevision_ignore_image_size, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), adapter=IceVisionInstanceSegmentationAdapter, @@ -73,7 +69,7 @@ def from_task( if _module_available("mmdet"): model_type = icevision_models.mmdet.mask_rcnn INSTANCE_SEGMENTATION_HEADS( - partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + partial(load_icevision_ignore_image_size, model_type), f"mmdet_{model_type.__name__.split('.')[-1]}", backbones=get_backbones(model_type), adapter=IceVisionInstanceSegmentationAdapter, diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index 2d855ba107..b30d117436 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -62,17 +62,14 @@ def from_icedata( test_ann_file: Optional[str] = None, test_parser_kwargs: Optional[Dict[str, Any]] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, parser: Optional[Union[Callable, Type[Parser]]] = None, input_cls: Type[Input] = IceVisionInput, + transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "InstanceSegmentationData": - ds_kw = dict(parser=parser, transform_kwargs=transform_kwargs) + ds_kw = dict(parser=parser) return cls( input_cls( @@ -80,7 +77,6 @@ def from_icedata( train_folder, train_ann_file, parser_kwargs=train_parser_kwargs, - transform=train_transform, **ds_kw, ), input_cls( @@ -88,7 +84,6 @@ def from_icedata( val_folder, val_ann_file, parser_kwargs=val_parser_kwargs, - transform=val_transform, **ds_kw, ), input_cls( @@ -96,10 +91,11 @@ def from_icedata( test_folder, test_ann_file, parser_kwargs=test_parser_kwargs, - transform=test_transform, **ds_kw, ), - input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_folder, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -113,11 +109,8 @@ def from_coco( test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, input_cls: Type[Input] = IceVisionInput, + transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ): @@ -138,12 +131,8 @@ def from_coco( test_folder: The folder containing images to use when testing. test_ann_file: The COCO format annotation file to use when testing. predict_folder: The folder containing images to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -252,13 +241,10 @@ def from_coco( test_folder=test_folder, test_ann_file=test_ann_file, predict_folder=predict_folder, - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, - transform_kwargs=transform_kwargs, parser=COCOMaskParser, input_cls=input_cls, + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -276,11 +262,8 @@ def from_voc( test_target_folder: Optional[str] = None, test_ann_folder: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, input_cls: Type[Input] = IceVisionInput, + transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ): @@ -307,12 +290,8 @@ def from_voc( test_target_folder: The folder containing mask images to use when testing. test_ann_folder: The folder containing VOC format annotation files to use when testing. predict_folder: The folder containing images to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -444,13 +423,10 @@ def from_voc( test_ann_file=test_ann_folder, test_parser_kwargs={"masks_dir": test_target_folder}, predict_folder=predict_folder, - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, - transform_kwargs=transform_kwargs, parser=partial(VOCMaskParser, class_map=ClassMap(list(sorted_alphanumeric(labels)))), input_cls=input_cls, + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -477,10 +453,10 @@ def from_folders( Returns: The constructed data module. """ - ds_kw = dict(transform=predict_transform, transform_kwargs=transform_kwargs) - return cls( - predict_input=input_cls(RunningStage.PREDICTING, predict_folder, **ds_kw), + predict_input=input_cls(RunningStage.PREDICTING, predict_folder), + transform=predict_transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -507,9 +483,9 @@ def from_files( Returns: The constructed data module. """ - ds_kw = dict(transform=predict_transform, transform_kwargs=transform_kwargs) - return cls( - predict_input=input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), + predict_input=input_cls(RunningStage.PREDICTING, predict_files), + transform=predict_transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/image/keypoint_detection/backbones.py b/flash/image/keypoint_detection/backbones.py index 72334761f2..0df353f5dd 100644 --- a/flash/image/keypoint_detection/backbones.py +++ b/flash/image/keypoint_detection/backbones.py @@ -16,11 +16,7 @@ from flash.core.adapter import Adapter from flash.core.integrations.icevision.adapter import IceVisionAdapter -from flash.core.integrations.icevision.backbones import ( - get_backbones, - icevision_model_adapter, - load_icevision_ignore_image_size, -) +from flash.core.integrations.icevision.backbones import get_backbones, load_icevision_ignore_image_size from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE @@ -64,7 +60,7 @@ def from_task( if _TORCHVISION_AVAILABLE: model_type = icevision_models.torchvision.keypoint_rcnn KEYPOINT_DETECTION_HEADS( - partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + partial(load_icevision_ignore_image_size, model_type), model_type.__name__.split(".")[-1], backbones=get_backbones(model_type), adapter=IceVisionKeypointDetectionAdapter, diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index 870e8e839e..af0661419d 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -79,17 +79,14 @@ def from_icedata( test_ann_file: Optional[str] = None, test_parser_kwargs: Optional[Dict[str, Any]] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, parser: Optional[Union[Callable, Type[Parser]]] = None, input_cls: Type[Input] = IceVisionInput, + transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "KeypointDetectionData": - ds_kw = dict(parser=parser, transform_kwargs=transform_kwargs) + ds_kw = dict(parser=parser) return cls( input_cls( @@ -97,7 +94,6 @@ def from_icedata( train_folder, train_ann_file, parser_kwargs=train_parser_kwargs, - transform=train_transform, **ds_kw, ), input_cls( @@ -105,7 +101,6 @@ def from_icedata( val_folder, val_ann_file, parser_kwargs=val_parser_kwargs, - transform=val_transform, **ds_kw, ), input_cls( @@ -113,10 +108,11 @@ def from_icedata( test_folder, test_ann_file, parser_kwargs=test_parser_kwargs, - transform=test_transform, **ds_kw, ), - input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_folder, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -130,11 +126,8 @@ def from_coco( test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, input_cls: Type[Input] = IceVisionInput, + transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ): @@ -155,12 +148,8 @@ def from_coco( test_folder: The folder containing images to use when testing. test_ann_file: The COCO format annotation file to use when testing. predict_folder: The folder containing images to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -273,12 +262,9 @@ def from_coco( test_folder=test_folder, test_ann_file=test_ann_file, predict_folder=predict_folder, - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, parser=FlashCOCOKeyPointsParser, input_cls=input_cls, + transform=transform, transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -287,8 +273,8 @@ def from_coco( def from_folders( cls, predict_folder: Optional[str] = None, - predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, input_cls: Type[Input] = IceVisionInput, + predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "DataModule": @@ -298,18 +284,18 @@ def from_folders( Args: predict_folder: The folder containing the predict data. - predict_transform: The dictionary of transforms to use during predicting which maps input_cls: The :class:`~flash.core.data.io.input.Input` used to create the dataset. + predict_transform: The dictionary of transforms to use during predicting which maps transform_kwargs: Keyword arguments provided to the transform on instantiation. data_module_kwargs: The keywords arguments for creating the datamodule. Returns: The constructed data module. """ - ds_kw = dict(transform=predict_transform, transform_kwargs=transform_kwargs) - return cls( - predict_input=input_cls(RunningStage.PREDICTING, predict_folder, **ds_kw), + predict_input=input_cls(RunningStage.PREDICTING, predict_folder), + transform=predict_transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -317,8 +303,8 @@ def from_folders( def from_files( cls, predict_files: Optional[List[str]] = None, - predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, input_cls: Type[Input] = IceVisionInput, + predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "DataModule": @@ -328,17 +314,17 @@ def from_files( Args: predict_files: The list of files containing the predict data. - predict_transform: The dictionary of transforms to use during predicting which maps. input_cls: The :class:`~flash.core.data.io.input.Input` used to create the dataset. + predict_transform: The dictionary of transforms to use during predicting which maps. transform_kwargs: Keyword arguments provided to the transform on instantiation. data_module_kwargs: The keywords arguments for creating the datamodule. Returns: The constructed data module. """ - ds_kw = dict(transform=predict_transform, transform_kwargs=transform_kwargs) - return cls( - predict_input=input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), + predict_input=input_cls(RunningStage.PREDICTING, predict_files), + transform=predict_transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index e6ae4bf9a4..66d07d46a0 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -19,7 +19,6 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input -from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_EXTRAS_TESTING, _IMAGE_TESTING, lazy_import from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE @@ -58,7 +57,6 @@ class SemanticSegmentationData(DataModule): """The ``SemanticSegmentationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of classmethods for loading data for semantic segmentation.""" - input_transforms_registry = FlashRegistry("input_transforms") input_transform_cls = SemanticSegmentationInputTransform @property @@ -75,13 +73,10 @@ def from_files( test_files: Optional[Sequence[str]] = None, test_targets: Optional[Sequence[str]] = None, predict_files: Optional[Sequence[str]] = None, - train_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, input_cls: Type[Input] = SemanticSegmentationFilesInput, num_classes: Optional[int] = None, labels_map: Dict[int, Tuple[int, int, int]] = None, + transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SemanticSegmentationData": @@ -101,15 +96,11 @@ def from_files( test_files: The list of image files to use when testing. test_targets: The list of mask files to use when testing. predict_files: The list of image files to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. num_classes: The number of segmentation classes. labels_map: An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -159,17 +150,17 @@ def from_files( """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, num_classes=num_classes, labels_map=labels_map, ) return cls( - input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_files, val_targets, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_files, test_targets, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_files, train_targets, **ds_kw), + input_cls(RunningStage.VALIDATING, val_files, val_targets, **ds_kw), + input_cls(RunningStage.TESTING, test_files, test_targets, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -183,13 +174,10 @@ def from_folders( test_folder: Optional[str] = None, test_target_folder: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, input_cls: Type[Input] = SemanticSegmentationFolderInput, num_classes: Optional[int] = None, labels_map: Dict[int, Tuple[int, int, int]] = None, + transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SemanticSegmentationData": @@ -245,15 +233,11 @@ def from_folders( test_target_folder: The folder containing masks to use when testing (files should have the same name as the files in the ``train_folder``). predict_folder: The folder containing images to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. num_classes: The number of segmentation classes. labels_map: An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -307,17 +291,17 @@ def from_folders( """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, num_classes=num_classes, labels_map=labels_map, ) return cls( - input_cls(RunningStage.TRAINING, train_folder, train_target_folder, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_folder, val_target_folder, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_folder, test_target_folder, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_folder, train_target_folder, **ds_kw), + input_cls(RunningStage.VALIDATING, val_folder, val_target_folder, **ds_kw), + input_cls(RunningStage.TESTING, test_folder, test_target_folder, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_folder, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -331,13 +315,10 @@ def from_numpy( test_data: Optional[Collection[np.ndarray]] = None, test_targets: Optional[Collection[np.ndarray]] = None, predict_data: Optional[Collection[np.ndarray]] = None, - train_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, input_cls: Type[Input] = SemanticSegmentationNumpyInput, num_classes: Optional[int] = None, labels_map: Dict[int, Tuple[int, int, int]] = None, + transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SemanticSegmentationData": @@ -355,15 +336,11 @@ def from_numpy( test_data: The numpy array or list of arrays containing images to use when testing. test_targets: The numpy array or list of arrays containing masks to use when testing. predict_data: The numpy array or list of arrays to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. num_classes: The number of segmentation classes. labels_map: An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -402,17 +379,17 @@ def from_numpy( """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, num_classes=num_classes, labels_map=labels_map, ) return cls( - input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data, val_targets, **ds_kw), + input_cls(RunningStage.TESTING, test_data, test_targets, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -426,13 +403,10 @@ def from_tensors( test_data: Optional[Collection[torch.Tensor]] = None, test_targets: Optional[Collection[torch.Tensor]] = None, predict_data: Optional[Collection[torch.Tensor]] = None, - train_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, input_cls: Type[Input] = SemanticSegmentationTensorInput, num_classes: Optional[int] = None, labels_map: Dict[int, Tuple[int, int, int]] = None, + transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SemanticSegmentationData": @@ -450,15 +424,11 @@ def from_tensors( test_data: The torch tensor or list of tensors containing images to use when testing. test_targets: The torch tensor or list of tensors containing masks to use when testing. predict_data: The torch tensor or list of tensors to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. num_classes: The number of segmentation classes. labels_map: An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -497,17 +467,17 @@ def from_tensors( """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, num_classes=num_classes, labels_map=labels_map, ) return cls( - input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data, val_targets, **ds_kw), + input_cls(RunningStage.TESTING, test_data, test_targets, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -518,13 +488,10 @@ def from_fiftyone( val_dataset: Optional[SampleCollection] = None, test_dataset: Optional[SampleCollection] = None, predict_dataset: Optional[SampleCollection] = None, - train_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, input_cls: Type[Input] = SemanticSegmentationFiftyOneInput, num_classes: Optional[int] = None, labels_map: Dict[int, Tuple[int, int, int]] = None, + transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, transform_kwargs: Optional[Dict] = None, label_field: str = "ground_truth", **data_module_kwargs: Any, @@ -544,15 +511,11 @@ def from_fiftyone( test_dataset: The ``SampleCollection`` to use when testing. predict_dataset: The ``SampleCollection`` to use when predicting. label_field: The field in the ``SampleCollection`` objects containing the targets. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. num_classes: The number of segmentation classes. labels_map: An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -614,40 +577,31 @@ def from_fiftyone( >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] """ - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) - return cls( input_cls( RunningStage.TRAINING, train_dataset, - transform=train_transform, label_field=label_field, num_classes=num_classes, labels_map=labels_map, - **ds_kw, ), input_cls( RunningStage.VALIDATING, val_dataset, - transform=val_transform, label_field=label_field, num_classes=num_classes, labels_map=labels_map, - **ds_kw, ), input_cls( RunningStage.TESTING, test_dataset, - transform=test_transform, label_field=label_field, num_classes=num_classes, labels_map=labels_map, - **ds_kw, ), - input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_dataset), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index 9400f39e99..9ffc9b9082 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -41,9 +41,8 @@ def from_files( cls, train_files: Optional[Sequence[str]] = None, predict_files: Optional[Sequence[str]] = None, - train_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, input_cls: Type[Input] = ImageClassificationFilesInput, + transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any ) -> "StyleTransferData": @@ -57,10 +56,9 @@ def from_files( Args: train_files: The list of image files to use when training. predict_files: The list of image files to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -102,14 +100,11 @@ def from_files( >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] """ - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) - return cls( - input_cls(RunningStage.TRAINING, train_files, transform=train_transform, **ds_kw), - predict_input=input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_files), + predict_input=input_cls(RunningStage.PREDICTING, predict_files), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -118,9 +113,8 @@ def from_folders( cls, train_folder: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, input_cls: Type[Input] = ImageClassificationFolderInput, + transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any ) -> "StyleTransferData": @@ -144,10 +138,9 @@ def from_folders( Args: train_folder: The folder containing images to use when training. predict_folder: The folder containing images to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -192,14 +185,11 @@ def from_folders( >>> shutil.rmtree("predict_folder") """ - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) - return cls( - input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw), - predict_input=input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_folder), + predict_input=input_cls(RunningStage.PREDICTING, predict_folder), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -208,9 +198,8 @@ def from_numpy( cls, train_data: Optional[Collection[np.ndarray]] = None, predict_data: Optional[Collection[np.ndarray]] = None, - train_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, input_cls: Type[Input] = ImageNumpyInput, + transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any ) -> "StyleTransferData": @@ -223,10 +212,9 @@ def from_numpy( Args: train_data: The numpy array or list of arrays to use when training. predict_data: The numpy array or list of arrays to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -256,14 +244,11 @@ def from_numpy( Predicting... """ - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) - return cls( - input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw), - predict_input=input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_data), + predict_input=input_cls(RunningStage.PREDICTING, predict_data), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -272,9 +257,8 @@ def from_tensors( cls, train_data: Optional[Collection[torch.Tensor]] = None, predict_data: Optional[Collection[torch.Tensor]] = None, - train_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, input_cls: Type[Input] = ImageTensorInput, + transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any ) -> "StyleTransferData": @@ -287,10 +271,9 @@ def from_tensors( Args: train_data: The torch tensor or list of tensors to use when training. predict_data: The torch tensor or list of tensors to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -320,13 +303,10 @@ def from_tensors( Predicting... """ - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) - return cls( - input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw), - predict_input=input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_data), + predict_input=input_cls(RunningStage.PREDICTING, predict_data), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py index 7b83a7cc03..6a7306b691 100644 --- a/flash/pointcloud/detection/data.py +++ b/flash/pointcloud/detection/data.py @@ -40,15 +40,12 @@ def from_folders( val_folder: Optional[str] = None, test_folder: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, scans_folder_name: Optional[str] = "scans", labels_folder_name: Optional[str] = "labels", calibrations_folder_name: Optional[str] = "calibs", data_format: Optional[BaseDataFormat] = PointCloudObjectDetectionDataFormat.KITTI, input_cls: Type[Input] = PointCloudObjectDetectorFoldersInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudObjectDetectorData": @@ -58,15 +55,15 @@ def from_folders( labels_folder_name=labels_folder_name, calibrations_folder_name=calibrations_folder_name, data_format=data_format, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) return cls( - input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_folder, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_folder, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_folder, **ds_kw), + input_cls(RunningStage.VALIDATING, val_folder, **ds_kw), + input_cls(RunningStage.TESTING, test_folder, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_folder, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -74,12 +71,12 @@ def from_folders( def from_files( cls, predict_files: Optional[List[str]] = None, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, scans_folder_name: Optional[str] = "scans", labels_folder_name: Optional[str] = "labels", calibrations_folder_name: Optional[str] = "calibs", data_format: Optional[BaseDataFormat] = PointCloudObjectDetectionDataFormat.KITTI, input_cls: Type[Input] = PointCloudObjectDetectorFoldersInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudObjectDetectorData": @@ -89,12 +86,12 @@ def from_files( labels_folder_name=labels_folder_name, calibrations_folder_name=calibrations_folder_name, data_format=data_format, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) return cls( - predict_input=input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw), + predict_input=input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -105,24 +102,20 @@ def from_datasets( val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, predict_dataset: Optional[Dataset] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = PointCloudObjectDetectorDatasetInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudObjectDetectorData": - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) + ds_kw = dict() return cls( - input_cls(RunningStage.TRAINING, train_dataset, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_dataset, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_dataset, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_dataset, **ds_kw), + input_cls(RunningStage.VALIDATING, val_dataset, **ds_kw), + input_cls(RunningStage.TESTING, test_dataset, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_dataset, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index d400b20c49..a11962e959 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -19,6 +19,7 @@ from torch.utils.data import DataLoader, Sampler from flash.core.data.io.input import DataKeys, Input +from flash.core.data.io.input_transform import InputTransform from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.apply_func import get_callable_dict @@ -134,6 +135,7 @@ def forward(self, x) -> torch.Tensor: def _process_dataset( self, dataset: Input, + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -146,6 +148,11 @@ def _process_dataset( dataset.input_transform_fn = self.model.preprocess dataset.transform_fn = self.model.transform + if self.input_transform is None: + self.input_transform = input_transform + if self.collate_fn is not None: + self.input_transform.inject_collate_fn(self.collate_fn) + return DataLoader( dataset, batch_size=batch_size, diff --git a/flash/pointcloud/segmentation/data.py b/flash/pointcloud/segmentation/data.py index 3afb8aa000..fbe99c7403 100644 --- a/flash/pointcloud/segmentation/data.py +++ b/flash/pointcloud/segmentation/data.py @@ -34,25 +34,21 @@ def from_folders( val_folder: Optional[str] = None, test_folder: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = PointCloudSegmentationFoldersInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudSegmentationData": - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) + ds_kw = dict() return cls( - input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_folder, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_folder, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_folder, **ds_kw), + input_cls(RunningStage.VALIDATING, val_folder, **ds_kw), + input_cls(RunningStage.TESTING, test_folder, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_folder, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -60,19 +56,18 @@ def from_folders( def from_files( cls, predict_files: Optional[List[str]] = None, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = PointCloudSegmentationFoldersInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudSegmentationData": - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) + ds_kw = dict() return cls( - predict_input=input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw), + predict_input=input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -83,24 +78,20 @@ def from_datasets( val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, predict_dataset: Optional[Dataset] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = PointCloudSegmentationDatasetInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudSegmentationData": - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) + ds_kw = dict() return cls( - input_cls(RunningStage.TRAINING, train_dataset, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_dataset, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_dataset, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_dataset, **ds_kw), + input_cls(RunningStage.VALIDATING, val_dataset, **ds_kw), + input_cls(RunningStage.TESTING, test_dataset, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_dataset, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index 0af3a95392..d5fb82b66d 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -20,6 +20,7 @@ from flash.core.classification import ClassificationTask from flash.core.data.io.input import DataKeys, Input +from flash.core.data.io.input_transform import InputTransform from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE, _TM_GREATER_EQUAL_0_7_0 from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE @@ -143,6 +144,7 @@ def forward(self, x) -> torch.Tensor: def _process_dataset( self, dataset: Input, + input_transform: InputTransform, batch_size: int, num_workers: int, pin_memory: bool, @@ -161,6 +163,11 @@ def _process_dataset( use_cache=False, ) + if self.input_transform is None: + self.input_transform = input_transform + if self.collate_fn is not None: + self.input_transform.inject_collate_fn(self.collate_fn) + return DataLoader( dataset, batch_size=batch_size, diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py index 3e21d5b56a..6a9da1dcad 100644 --- a/flash/tabular/classification/data.py +++ b/flash/tabular/classification/data.py @@ -46,12 +46,9 @@ def from_data_frame( val_data_frame: Optional[DataFrame] = None, test_data_frame: Optional[DataFrame] = None, predict_data_frame: Optional[DataFrame] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TabularClassificationDataFrameInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TabularClassificationData": @@ -80,14 +77,10 @@ def from_data_frame( val_data_frame: The DataFrame to use when validating. test_data_frame: The DataFrame to use when testing. predict_data_frame: The DataFrame to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -161,23 +154,23 @@ def from_data_frame( """ ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, categorical_fields=categorical_fields, numerical_fields=numerical_fields, target_fields=target_fields, parameters=parameters, ) - train_input = input_cls(RunningStage.TRAINING, train_data_frame, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_data_frame, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_data_frame, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data_frame, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data_frame, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data_frame, **ds_kw), + input_cls(RunningStage.TESTING, test_data_frame, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data_frame, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -192,12 +185,9 @@ def from_csv( val_file: Optional[str] = None, test_file: Optional[str] = None, predict_file: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TabularClassificationCSVInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TabularClassificationData": @@ -226,14 +216,10 @@ def from_csv( val_file: The path to the CSV file to use when validating. test_file: The path to the CSV file to use when testing. predict_file: The path to the CSV file to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -306,22 +292,22 @@ def from_csv( """ ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, categorical_fields=categorical_fields, numerical_fields=numerical_fields, target_fields=target_fields, parameters=parameters, ) - train_input = input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_file, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, **ds_kw), + input_cls(RunningStage.TESTING, test_file, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_file, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py index 439b805eaf..f2e4158eef 100644 --- a/flash/tabular/forecasting/data.py +++ b/flash/tabular/forecasting/data.py @@ -57,11 +57,8 @@ def from_data_frame( val_data_frame: Optional[DataFrame] = None, test_data_frame: Optional[DataFrame] = None, predict_data_frame: Optional[DataFrame] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = TabularForecastingDataFrameInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, data_fetcher: Optional[BaseDataFetcher] = None, val_split: Optional[float] = None, @@ -97,12 +94,8 @@ def from_data_frame( val_data_frame: The pandas DataFrame to use when validating. test_data_frame: The pandas DataFrame to use when testing. predict_data_frame: The pandas DataFrame to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. input_kwargs: Additional keyword arguments to be used when creating the TimeSeriesDataset. @@ -166,8 +159,6 @@ def from_data_frame( """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, time_idx=time_idx, group_ids=group_ids, target=target, @@ -175,14 +166,16 @@ def from_data_frame( **input_kwargs, ) - train_input = input_cls(RunningStage.TRAINING, train_data_frame, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_data_frame, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters return cls( train_input, - input_cls(RunningStage.VALIDATING, val_data_frame, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data_frame, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data_frame, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data_frame, **ds_kw), + input_cls(RunningStage.TESTING, test_data_frame, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data_frame, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, data_fetcher=data_fetcher, val_split=val_split, batch_size=batch_size, diff --git a/flash/tabular/regression/data.py b/flash/tabular/regression/data.py index 065f8fd16f..67acee7559 100644 --- a/flash/tabular/regression/data.py +++ b/flash/tabular/regression/data.py @@ -45,11 +45,8 @@ def from_data_frame( val_data_frame: Optional[DataFrame] = None, test_data_frame: Optional[DataFrame] = None, predict_data_frame: Optional[DataFrame] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = TabularRegressionDataFrameInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TabularRegressionData": @@ -77,12 +74,8 @@ def from_data_frame( val_data_frame: The DataFrame to use when validating. test_data_frame: The DataFrame to use when testing. predict_data_frame: The DataFrame to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -151,22 +144,22 @@ def from_data_frame( >>> del predict_data """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, categorical_fields=categorical_fields, numerical_fields=numerical_fields, target_field=target_field, parameters=parameters, ) - train_input = input_cls(RunningStage.TRAINING, train_data_frame, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_data_frame, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters return cls( train_input, - input_cls(RunningStage.VALIDATING, val_data_frame, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data_frame, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data_frame, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data_frame, **ds_kw), + input_cls(RunningStage.TESTING, test_data_frame, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data_frame, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -181,11 +174,8 @@ def from_csv( val_file: Optional[str] = None, test_file: Optional[str] = None, predict_file: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = TabularRegressionCSVInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TabularRegressionData": @@ -212,12 +202,8 @@ def from_csv( val_file: The path to the CSV file to use when validating. test_file: The path to the CSV file to use when testing. predict_file: The path to the CSV file to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -285,21 +271,21 @@ def from_csv( >>> os.remove("predict_data.csv") """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, categorical_fields=categorical_fields, numerical_fields=numerical_fields, target_field=target_field, parameters=parameters, ) - train_input = input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_file, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters return cls( train_input, - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, **ds_kw), + input_cls(RunningStage.TESTING, test_file, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_file, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/template/classification/data.py b/flash/template/classification/data.py index 38b8892be5..2e2b33adf0 100644 --- a/flash/template/classification/data.py +++ b/flash/template/classification/data.py @@ -134,11 +134,8 @@ def from_numpy( test_data: Optional[Collection[np.ndarray]] = None, test_targets: Optional[Sequence[Any]] = None, predict_data: Optional[Collection[np.ndarray]] = None, - train_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform, input_cls: Type[Input] = TemplateNumpyClassificationInput, + transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TemplateData": @@ -153,46 +150,28 @@ def from_numpy( test_data: The numpy ``Array`` containing the test data. test_targets: The sequence of test targets. predict_data: The numpy ``Array`` containing the predict data. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed data module. """ - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) + ds_kw = dict() - train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw) target_formatter = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls( - RunningStage.VALIDATING, - val_data, - val_targets, - transform=val_transform, - target_formatter=target_formatter, - **ds_kw, - ), - input_cls( - RunningStage.TESTING, - test_data, - test_targets, - transform=test_transform, - target_formatter=target_formatter, - **ds_kw, - ), - input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data, val_targets, target_formatter=target_formatter, **ds_kw), + input_cls(RunningStage.TESTING, test_data, test_targets, target_formatter=target_formatter, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -203,11 +182,8 @@ def from_sklearn( val_bunch: Optional[Bunch] = None, test_bunch: Optional[Bunch] = None, predict_bunch: Optional[Bunch] = None, - train_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform, input_cls: Type[Input] = TemplateSKLearnClassificationInput, + transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TemplateData": @@ -219,43 +195,27 @@ def from_sklearn( val_bunch: The scikit-learn ``Bunch`` containing the validation data. test_bunch: The scikit-learn ``Bunch`` containing the test data. predict_bunch: The scikit-learn ``Bunch`` containing the predict data. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed data module. """ - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) + ds_kw = dict() - train_input = input_cls(RunningStage.TRAINING, train_bunch, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_bunch, **ds_kw) target_formatter = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls( - RunningStage.VALIDATING, - val_bunch, - transform=val_transform, - target_formatter=target_formatter, - **ds_kw, - ), - input_cls( - RunningStage.TESTING, - test_bunch, - transform=test_transform, - target_formatter=target_formatter, - **ds_kw, - ), - input_cls(RunningStage.PREDICTING, predict_bunch, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_bunch, target_formatter=target_formatter, **ds_kw), + input_cls(RunningStage.TESTING, test_bunch, target_formatter=target_formatter, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_bunch, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index ecee9ea2e5..df6f09c30f 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -57,12 +57,9 @@ def from_csv( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: Optional[Dict[str, Callable]] = InputTransform, - val_transform: Optional[Dict[str, Callable]] = InputTransform, - test_transform: Optional[Dict[str, Callable]] = InputTransform, - predict_transform: Optional[Dict[str, Callable]] = InputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TextClassificationCSVInput, + transform: Optional[Dict[str, Callable]] = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TextClassificationData": @@ -82,14 +79,10 @@ def from_csv( val_file: The CSV file to use when validating. test_file: The CSV file to use when testing. predict_file: The CSV file to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -162,18 +155,18 @@ def from_csv( target_formatter=target_formatter, input_key=input_field, target_keys=target_fields, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_file, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, **ds_kw), + input_cls(RunningStage.TESTING, test_file, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_file, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -186,12 +179,9 @@ def from_json( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: Optional[Dict[str, Callable]] = InputTransform, - val_transform: Optional[Dict[str, Callable]] = InputTransform, - test_transform: Optional[Dict[str, Callable]] = InputTransform, - predict_transform: Optional[Dict[str, Callable]] = InputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TextClassificationJSONInput, + transform: Optional[Dict[str, Callable]] = InputTransform, transform_kwargs: Optional[Dict] = None, field: Optional[str] = None, **data_module_kwargs: Any, @@ -212,14 +202,10 @@ def from_json( val_file: The JSON file to use when validating. test_file: The JSON file to use when testing. predict_file: The JSON file to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. field: To specify the field that holds the data in the JSON file. data_module_kwargs: Additional keyword arguments to provide to the @@ -292,18 +278,18 @@ def from_json( input_key=input_field, target_keys=target_fields, field=field, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_file, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, **ds_kw), + input_cls(RunningStage.TESTING, test_file, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_file, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -316,12 +302,9 @@ def from_parquet( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: Optional[Dict[str, Callable]] = InputTransform, - val_transform: Optional[Dict[str, Callable]] = InputTransform, - test_transform: Optional[Dict[str, Callable]] = InputTransform, - predict_transform: Optional[Dict[str, Callable]] = InputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TextClassificationParquetInput, + transform: Optional[Dict[str, Callable]] = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TextClassificationData": @@ -341,14 +324,10 @@ def from_parquet( val_file: The PARQUET file to use when validating. test_file: The PARQUET file to use when testing. predict_file: The PARQUET file to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -421,18 +400,18 @@ def from_parquet( target_formatter=target_formatter, input_key=input_field, target_keys=target_fields, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_file, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, **ds_kw), + input_cls(RunningStage.TESTING, test_file, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_file, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -445,12 +424,9 @@ def from_hf_datasets( val_hf_dataset: Optional[Dataset] = None, test_hf_dataset: Optional[Dataset] = None, predict_hf_dataset: Optional[Dataset] = None, - train_transform: Optional[Dict[str, Callable]] = InputTransform, - val_transform: Optional[Dict[str, Callable]] = InputTransform, - test_transform: Optional[Dict[str, Callable]] = InputTransform, - predict_transform: Optional[Dict[str, Callable]] = InputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TextClassificationInput, + transform: Optional[Dict[str, Callable]] = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TextClassificationData": @@ -470,14 +446,10 @@ def from_hf_datasets( val_hf_dataset: The ``Dataset`` to use when validating. test_hf_dataset: The ``Dataset`` to use when testing. predict_hf_dataset: The ``Dataset`` to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -531,18 +503,18 @@ def from_hf_datasets( target_formatter=target_formatter, input_key=input_field, target_keys=target_fields, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_hf_dataset, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_hf_dataset, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_hf_dataset, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_hf_dataset, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_hf_dataset, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_hf_dataset, **ds_kw), + input_cls(RunningStage.TESTING, test_hf_dataset, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_hf_dataset, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -555,12 +527,9 @@ def from_data_frame( val_data_frame: Optional[DataFrame] = None, test_data_frame: Optional[DataFrame] = None, predict_data_frame: Optional[DataFrame] = None, - train_transform: Optional[Dict[str, Callable]] = InputTransform, - val_transform: Optional[Dict[str, Callable]] = InputTransform, - test_transform: Optional[Dict[str, Callable]] = InputTransform, - predict_transform: Optional[Dict[str, Callable]] = InputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TextClassificationDataFrameInput, + transform: Optional[Dict[str, Callable]] = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TextClassificationData": @@ -581,14 +550,10 @@ def from_data_frame( val_data_frame: The ``DataFrame`` to use when validating. test_data_frame: The ``DataFrame`` to use when testing. predict_data_frame: The ``DataFrame`` to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -642,18 +607,18 @@ def from_data_frame( target_formatter=target_formatter, input_key=input_field, target_keys=target_fields, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_data_frame, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_data_frame, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_data_frame, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data_frame, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data_frame, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data_frame, **ds_kw), + input_cls(RunningStage.TESTING, test_data_frame, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data_frame, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -667,12 +632,9 @@ def from_lists( test_data: Optional[List[str]] = None, test_targets: Optional[Union[List[Any], List[List[Any]]]] = None, predict_data: Optional[List[str]] = None, - train_transform: Optional[Dict[str, Callable]] = InputTransform, - val_transform: Optional[Dict[str, Callable]] = InputTransform, - test_transform: Optional[Dict[str, Callable]] = InputTransform, - predict_transform: Optional[Dict[str, Callable]] = InputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TextClassificationListInput, + transform: Optional[Dict[str, Callable]] = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TextClassificationData": @@ -692,14 +654,10 @@ def from_lists( test_data: The list of text snippets to use when testing. test_targets: The list of targets to use when testing. predict_data: The list of text snippets to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -733,18 +691,18 @@ def from_lists( """ ds_kw = dict( target_formatter=target_formatter, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) - train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data, val_targets, **ds_kw), + input_cls(RunningStage.TESTING, test_data, test_targets, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -761,11 +719,8 @@ def from_labelstudio( val_data_folder: str = None, test_data_folder: str = None, predict_data_folder: str = None, - train_transform: Optional[Dict[str, Callable]] = InputTransform, - val_transform: Optional[Dict[str, Callable]] = InputTransform, - test_transform: Optional[Dict[str, Callable]] = InputTransform, - predict_transform: Optional[Dict[str, Callable]] = InputTransform, input_cls: Type[Input] = LabelStudioTextClassificationInput, + transform: Optional[Dict[str, Callable]] = InputTransform, transform_kwargs: Optional[Dict] = None, val_split: Optional[float] = None, multi_label: Optional[bool] = False, @@ -790,14 +745,9 @@ def from_labelstudio( val_data_folder: path to label studio data folder for validation data test_data_folder: path to label studio data folder for test data predict_data_folder: path to label studio data folder for predict data - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. multi_label: Whether the labels are multi encoded. data_module_kwargs: Additional keyword arguments to use when constructing the datamodule. @@ -821,18 +771,17 @@ def from_labelstudio( multi_label=multi_label, ) - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) + ds_kw = dict() - train_input = input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_data, **ds_kw) ds_kw["parameters"] = getattr(train_input, "parameters", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_data, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data, **ds_kw), + input_cls(RunningStage.TESTING, test_data, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/text/question_answering/data.py b/flash/text/question_answering/data.py index 80b717dfad..43e6158e60 100644 --- a/flash/text/question_answering/data.py +++ b/flash/text/question_answering/data.py @@ -45,11 +45,8 @@ def from_csv( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = QuestionAnsweringCSVInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, question_column_name: str = "question", context_column_name: str = "context", @@ -70,12 +67,8 @@ def from_csv( val_file: The CSV file containing the validation data. test_file: The CSV file containing the testing data. predict_file: The CSV file containing the data to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. question_column_name: The key in the JSON file to recognize the question field. context_column_name: The key in the JSON file to recognize the context field. @@ -171,15 +164,15 @@ def from_csv( question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) return cls( - input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_file, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, **ds_kw), + input_cls(RunningStage.TESTING, test_file, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_file, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -190,11 +183,8 @@ def from_json( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = QuestionAnsweringJSONInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, field: Optional[str] = None, question_column_name: str = "question", @@ -216,12 +206,8 @@ def from_json( val_file: The JSON file containing the validation data. test_file: The JSON file containing the testing data. predict_file: The JSON file containing the data to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. field: The field that holds the data in the JSON file. question_column_name: The key in the JSON file to recognize the question field. @@ -323,15 +309,15 @@ def from_json( question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) return cls( - input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_file, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, **ds_kw), + input_cls(RunningStage.TESTING, test_file, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_file, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -342,11 +328,8 @@ def from_squad_v2( val_file: Optional[str] = None, test_file: Optional[str] = None, predict_file: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = QuestionAnsweringSQuADInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, question_column_name: str = "question", context_column_name: str = "context", @@ -367,12 +350,8 @@ def from_squad_v2( val_file: The JSON file containing the validation data. test_file: The JSON file containing the testing data. predict_file: The JSON file containing the predict data. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. question_column_name: The key in the JSON file to recognize the question field. context_column_name: The key in the JSON file to recognize the context field. @@ -611,15 +590,15 @@ def from_squad_v2( question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) return cls( - input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_file, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, **ds_kw), + input_cls(RunningStage.TESTING, test_file, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_file, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -630,11 +609,8 @@ def from_dicts( val_data: Optional[Dict[str, Any]] = None, test_data: Optional[Dict[str, Any]] = None, predict_data: Optional[Dict[str, Any]] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = QuestionAnsweringDictionaryInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, question_column_name: str = "question", context_column_name: str = "context", @@ -655,12 +631,8 @@ def from_dicts( val_data: The dictionary containing the validation data. test_data: The dictionary containing the testing data. predict_data: The dictionary containing the data to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. question_column_name: The key in the JSON file to recognize the question field. context_column_name: The key in the JSON file to recognize the context field. @@ -731,14 +703,14 @@ def from_dicts( question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) return cls( - input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_data, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_data, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data, **ds_kw), + input_cls(RunningStage.TESTING, test_data, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py index 7ad9c5f8a9..3496d3ef6f 100644 --- a/flash/text/seq2seq/summarization/data.py +++ b/flash/text/seq2seq/summarization/data.py @@ -47,11 +47,8 @@ def from_csv( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqCSVInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SummarizationData": @@ -70,12 +67,8 @@ def from_csv( val_file: The CSV file to use when validating. test_file: The CSV file to use when testing. predict_file: The CSV file to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -142,15 +135,15 @@ def from_csv( ds_kw = dict( input_key=input_field, target_key=target_field, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) return cls( - input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_file, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, **ds_kw), + input_cls(RunningStage.TESTING, test_file, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_file, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -163,11 +156,8 @@ def from_json( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqJSONInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, field: Optional[str] = None, **data_module_kwargs: Any, @@ -187,12 +177,8 @@ def from_json( val_file: The JSON file to use when validating. test_file: The JSON file to use when testing. predict_file: The JSON file to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. field: The field that holds the data in the JSON file. data_module_kwargs: Additional keyword arguments to provide to the @@ -259,15 +245,15 @@ def from_json( input_key=input_field, target_key=target_field, field=field, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) return cls( - input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_file, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, **ds_kw), + input_cls(RunningStage.TESTING, test_file, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_file, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -280,11 +266,8 @@ def from_hf_datasets( val_hf_dataset: Optional[Dataset] = None, test_hf_dataset: Optional[Dataset] = None, predict_hf_dataset: Optional[Dataset] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqInputBase, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SummarizationData": @@ -303,12 +286,8 @@ def from_hf_datasets( val_hf_dataset: The ``Dataset`` to use when validating. test_hf_dataset: The ``Dataset`` to use when testing. predict_hf_dataset: The ``Dataset`` to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -358,15 +337,15 @@ def from_hf_datasets( ds_kw = dict( input_key=input_field, target_key=target_field, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) return cls( - input_cls(RunningStage.TRAINING, train_hf_dataset, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_hf_dataset, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_hf_dataset, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_hf_dataset, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_hf_dataset, **ds_kw), + input_cls(RunningStage.VALIDATING, val_hf_dataset, **ds_kw), + input_cls(RunningStage.TESTING, test_hf_dataset, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_hf_dataset, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -380,11 +359,8 @@ def from_lists( test_data: Optional[List[str]] = None, test_targets: Optional[List[str]] = None, predict_data: Optional[List[str]] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqListInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SummarizationData": @@ -402,12 +378,8 @@ def from_lists( test_data: The list of input text snippets to use when testing. test_targets: The list of target text snippets to use when testing. predict_data: The list of input text snippets to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -436,15 +408,14 @@ def from_lists( Predicting... """ - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) + ds_kw = dict() return cls( - input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data, val_targets, **ds_kw), + input_cls(RunningStage.TESTING, test_data, test_targets, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index 14b6d62b69..2704f03fe2 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -47,11 +47,8 @@ def from_csv( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqCSVInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TranslationData": @@ -70,12 +67,8 @@ def from_csv( val_file: The CSV file to use when validating. test_file: The CSV file to use when testing. predict_file: The CSV file to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -141,15 +134,15 @@ def from_csv( ds_kw = dict( input_key=input_field, target_key=target_field, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) return cls( - input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_file, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, **ds_kw), + input_cls(RunningStage.TESTING, test_file, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_file, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -162,11 +155,8 @@ def from_json( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqJSONInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, field: Optional[str] = None, **data_module_kwargs: Any, @@ -186,12 +176,8 @@ def from_json( val_file: The JSON file to use when validating. test_file: The JSON file to use when testing. predict_file: The JSON file to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. field: The field that holds the data in the JSON file. data_module_kwargs: Additional keyword arguments to provide to the @@ -257,15 +243,15 @@ def from_json( input_key=input_field, target_key=target_field, field=field, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) return cls( - input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_file, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, **ds_kw), + input_cls(RunningStage.TESTING, test_file, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_file, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -278,11 +264,8 @@ def from_hf_datasets( val_hf_dataset: Optional[Dataset] = None, test_hf_dataset: Optional[Dataset] = None, predict_hf_dataset: Optional[Dataset] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqInputBase, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TranslationData": @@ -301,12 +284,8 @@ def from_hf_datasets( val_hf_dataset: The ``Dataset`` to use when validating. test_hf_dataset: The ``Dataset`` to use when testing. predict_hf_dataset: The ``Dataset`` to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -356,15 +335,15 @@ def from_hf_datasets( ds_kw = dict( input_key=input_field, target_key=target_field, - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, ) return cls( - input_cls(RunningStage.TRAINING, train_hf_dataset, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_hf_dataset, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_hf_dataset, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_hf_dataset, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_hf_dataset, **ds_kw), + input_cls(RunningStage.VALIDATING, val_hf_dataset, **ds_kw), + input_cls(RunningStage.TESTING, test_hf_dataset, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_hf_dataset, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -378,11 +357,8 @@ def from_lists( test_data: Optional[List[str]] = None, test_targets: Optional[List[str]] = None, predict_data: Optional[List[str]] = None, - train_transform: INPUT_TRANSFORM_TYPE = InputTransform, - val_transform: INPUT_TRANSFORM_TYPE = InputTransform, - test_transform: INPUT_TRANSFORM_TYPE = InputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqListInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TranslationData": @@ -400,12 +376,8 @@ def from_lists( test_data: The list of input text snippets to use when testing. test_targets: The list of target text snippets to use when testing. predict_data: The list of input text snippets to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -434,15 +406,14 @@ def from_lists( Predicting... """ - ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, - ) + ds_kw = dict() return cls( - input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data, val_targets, **ds_kw), + input_cls(RunningStage.TESTING, test_data, test_targets, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index d74c7aaca6..f60b5039a5 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -84,10 +84,6 @@ def from_files( test_files: Optional[Sequence[str]] = None, test_targets: Optional[Sequence[Any]] = None, predict_files: Optional[Sequence[str]] = None, - train_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, @@ -97,6 +93,7 @@ def from_files( decoder: str = "pyav", input_cls: Type[Input] = VideoClassificationFilesInput, predict_input_cls: Type[Input] = VideoClassificationPathsPredictInput, + transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "VideoClassificationData": @@ -117,11 +114,6 @@ def from_files( test_files: The list of video files to use when testing. test_targets: The list of targets to use when testing. predict_files: The list of video files to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. clip_sampler: The clip sampler to use. One of: ``"uniform"``, ``"random"``, ``"constant_clips_per_video"``. @@ -134,6 +126,7 @@ def from_files( videos. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. predict_input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the prediction data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -181,8 +174,6 @@ def from_files( >>> _ = [os.remove(f"predict_video_{i}.mp4") for i in range(1, 4)] """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, clip_duration=clip_duration, clip_sampler_kwargs=clip_sampler_kwargs, @@ -194,7 +185,6 @@ def from_files( RunningStage.TRAINING, train_files, train_targets, - transform=train_transform, video_sampler=video_sampler, target_formatter=target_formatter, **ds_kw, @@ -207,7 +197,6 @@ def from_files( RunningStage.VALIDATING, val_files, val_targets, - transform=val_transform, video_sampler=video_sampler, target_formatter=target_formatter, **ds_kw, @@ -216,12 +205,13 @@ def from_files( RunningStage.TESTING, test_files, test_targets, - transform=test_transform, video_sampler=video_sampler, target_formatter=target_formatter, **ds_kw, ), - predict_input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw), + predict_input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -232,10 +222,6 @@ def from_folders( val_folder: Optional[str] = None, test_folder: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, @@ -245,6 +231,7 @@ def from_folders( decoder: str = "pyav", input_cls: Type[Input] = VideoClassificationFoldersInput, predict_input_cls: Type[Input] = VideoClassificationPathsPredictInput, + transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "VideoClassificationData": @@ -284,11 +271,6 @@ def from_folders( val_folder: The folder containing videos to use when validating. test_folder: The folder containing videos to use when testing. predict_folder: The folder containing videos to use when predicting. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. clip_sampler: The clip sampler to use. One of: ``"uniform"``, ``"random"``, ``"constant_clips_per_video"``. @@ -301,6 +283,7 @@ def from_folders( videos. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. predict_input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the prediction data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -357,8 +340,6 @@ def from_folders( >>> shutil.rmtree("predict_folder") """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, clip_duration=clip_duration, clip_sampler_kwargs=clip_sampler_kwargs, @@ -369,7 +350,6 @@ def from_folders( train_input = input_cls( RunningStage.TRAINING, train_folder, - transform=train_transform, video_sampler=video_sampler, target_formatter=target_formatter, **ds_kw, @@ -381,7 +361,6 @@ def from_folders( input_cls( RunningStage.VALIDATING, val_folder, - transform=val_transform, video_sampler=video_sampler, target_formatter=target_formatter, **ds_kw, @@ -389,12 +368,13 @@ def from_folders( input_cls( RunningStage.TESTING, test_folder, - transform=test_transform, video_sampler=video_sampler, target_formatter=target_formatter, **ds_kw, ), - predict_input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), + predict_input_cls(RunningStage.PREDICTING, predict_folder, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -415,10 +395,6 @@ def from_data_frame( predict_data_frame: Optional[pd.DataFrame] = None, predict_videos_root: Optional[str] = None, predict_resolver: Optional[Callable[[str, str], str]] = None, - train_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, @@ -428,6 +404,7 @@ def from_data_frame( decoder: str = "pyav", input_cls: Type[Input] = VideoClassificationDataFrameInput, predict_input_cls: Type[Input] = VideoClassificationDataFramePredictInput, + transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "VideoClassificationData": @@ -460,11 +437,6 @@ def from_data_frame( predict_videos_root: The root directory containing predict videos. predict_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a video file path. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. clip_sampler: The clip sampler to use. One of: ``"uniform"``, ``"random"``, ``"constant_clips_per_video"``. @@ -477,6 +449,7 @@ def from_data_frame( videos. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. predict_input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the prediction data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -550,8 +523,6 @@ def from_data_frame( >>> del predict_data_frame """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, clip_duration=clip_duration, clip_sampler_kwargs=clip_sampler_kwargs, @@ -567,7 +538,6 @@ def from_data_frame( train_input = input_cls( RunningStage.TRAINING, *train_data, - transform=train_transform, video_sampler=video_sampler, target_formatter=target_formatter, **ds_kw, @@ -579,7 +549,6 @@ def from_data_frame( input_cls( RunningStage.VALIDATING, *val_data, - transform=val_transform, video_sampler=video_sampler, target_formatter=target_formatter, **ds_kw, @@ -587,12 +556,13 @@ def from_data_frame( input_cls( RunningStage.TESTING, *test_data, - transform=test_transform, video_sampler=video_sampler, target_formatter=target_formatter, **ds_kw, ), - predict_input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), + predict_input_cls(RunningStage.PREDICTING, *predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -613,10 +583,6 @@ def from_csv( predict_file: Optional[str] = None, predict_videos_root: Optional[str] = None, predict_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, - train_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, @@ -626,6 +592,7 @@ def from_csv( decoder: str = "pyav", input_cls: Type[Input] = VideoClassificationCSVInput, predict_input_cls: Type[Input] = VideoClassificationCSVPredictInput, + transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "VideoClassificationData": @@ -658,11 +625,6 @@ def from_csv( predict_videos_root: The root directory containing predict videos. predict_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a video file path. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. clip_sampler: The clip sampler to use. One of: ``"uniform"``, ``"random"``, ``"constant_clips_per_video"``. @@ -675,6 +637,7 @@ def from_csv( videos. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. predict_input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the prediction data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -762,8 +725,6 @@ def from_csv( >>> os.remove("predict_data.csv") """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, clip_duration=clip_duration, clip_sampler_kwargs=clip_sampler_kwargs, @@ -779,7 +740,6 @@ def from_csv( train_input = input_cls( RunningStage.TRAINING, *train_data, - transform=train_transform, video_sampler=video_sampler, target_formatter=target_formatter, **ds_kw, @@ -791,7 +751,6 @@ def from_csv( input_cls( RunningStage.VALIDATING, *val_data, - transform=val_transform, video_sampler=video_sampler, target_formatter=target_formatter, **ds_kw, @@ -799,12 +758,13 @@ def from_csv( input_cls( RunningStage.TESTING, *test_data, - transform=test_transform, video_sampler=video_sampler, target_formatter=target_formatter, **ds_kw, ), - predict_input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), + predict_input_cls(RunningStage.PREDICTING, *predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -816,10 +776,6 @@ def from_fiftyone( val_dataset: Optional[SampleCollection] = None, test_dataset: Optional[SampleCollection] = None, predict_dataset: Optional[SampleCollection] = None, - train_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, @@ -830,6 +786,7 @@ def from_fiftyone( label_field: str = "ground_truth", input_cls: Type[Input] = VideoClassificationFiftyOneInput, predict_input_cls: Type[Input] = VideoClassificationFiftyOnePredictInput, + transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "VideoClassificationData": @@ -848,11 +805,6 @@ def from_fiftyone( test_dataset: The ``SampleCollection`` to use when testing. predict_dataset: The ``SampleCollection`` to use when predicting. label_field: The field in the ``SampleCollection`` objects containing the targets. - train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. - val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. - test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. - predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. clip_sampler: The clip sampler to use. One of: ``"uniform"``, ``"random"``, ``"constant_clips_per_video"``. @@ -865,6 +817,7 @@ def from_fiftyone( videos. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. predict_input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the prediction data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -929,8 +882,6 @@ def from_fiftyone( >>> del predict_dataset """ ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, clip_duration=clip_duration, clip_sampler_kwargs=clip_sampler_kwargs, @@ -941,7 +892,6 @@ def from_fiftyone( train_input = input_cls( RunningStage.TRAINING, train_dataset, - transform=train_transform, video_sampler=video_sampler, label_field=label_field, target_formatter=target_formatter, @@ -954,7 +904,6 @@ def from_fiftyone( input_cls( RunningStage.VALIDATING, val_dataset, - transform=val_transform, video_sampler=video_sampler, label_field=label_field, target_formatter=target_formatter, @@ -963,13 +912,14 @@ def from_fiftyone( input_cls( RunningStage.TESTING, test_dataset, - transform=test_transform, video_sampler=video_sampler, label_field=label_field, target_formatter=target_formatter, **ds_kw, ), - predict_input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), + predict_input_cls(RunningStage.PREDICTING, predict_dataset, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -986,10 +936,6 @@ def from_labelstudio( val_data_folder: str = None, test_data_folder: str = None, predict_data_folder: str = None, - train_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, val_split: Optional[float] = None, multi_label: Optional[bool] = False, clip_sampler: Union[str, "ClipSampler"] = "random", @@ -999,6 +945,7 @@ def from_labelstudio( decode_audio: bool = False, decoder: str = "pyav", input_cls: Type[Input] = LabelStudioVideoClassificationInput, + transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "VideoClassificationData": @@ -1010,30 +957,15 @@ def from_labelstudio( Args: export_json: path to label studio export file - train_export_json: path to label studio export file for train set, - overrides export_json if specified + train_export_json: path to label studio export file for train set. (overrides export_json if specified) val_export_json: path to label studio export file for validation test_export_json: path to label studio export file for test predict_export_json: path to label studio export file for predict data_folder: path to label studio data folder - train_data_folder: path to label studio data folder for train data set, - overrides data_folder if specified + train_data_folder: path to label studio data folder for train data set. (overrides data_folder if specified) val_data_folder: path to label studio data folder for validation data test_data_folder: path to label studio data folder for test data predict_data_folder: path to label studio data folder for predict data - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. - data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the - :class:`~flash.core.data.data_module.DataModule`. - input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.input_transform_cls`` - will be constructed and used. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. multi_label: Whether the label are multi encoded. clip_sampler: Defines how clips should be sampled from each video. @@ -1043,6 +975,9 @@ def from_labelstudio( if necessary, the distributed split. decode_audio: If True, also decode audio from video. decoder: Defines what type of decoder used to decode a video. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to use when constructing the datamodule. Returns: @@ -1073,8 +1008,6 @@ def from_labelstudio( ) ds_kw = dict( - transform_kwargs=transform_kwargs, - input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, clip_duration=clip_duration, clip_sampler_kwargs=clip_sampler_kwargs, @@ -1083,13 +1016,15 @@ def from_labelstudio( decoder=decoder, ) - train_input = input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_data, **ds_kw) ds_kw["parameters"] = getattr(train_input, "parameters", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_data, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data, transform=test_transform, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, val_data, **ds_kw), + input_cls(RunningStage.TESTING, test_data, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) diff --git a/flash/video/classification/input_transform.py b/flash/video/classification/input_transform.py index 5626be5e8e..e0156ae13e 100644 --- a/flash/video/classification/input_transform.py +++ b/flash/video/classification/input_transform.py @@ -46,10 +46,15 @@ class VideoClassificationInputTransform(InputTransform): same_on_frame: bool = False def per_sample_transform(self) -> Callable: - if self.training: - per_sample_transform = [RandomCrop(self.image_size, pad_if_needed=True)] - else: - per_sample_transform = [CenterCrop(self.image_size)] + per_sample_transform = [CenterCrop(self.image_size)] + + return ApplyToKeys( + "video", + Compose([UniformTemporalSubsample(self.temporal_sub_sample), normalize] + per_sample_transform), + ) + + def train_per_sample_transform(self) -> Callable: + per_sample_transform = [RandomCrop(self.image_size, pad_if_needed=True)] return ApplyToKeys( "video", diff --git a/flash_examples/flash_components/custom_data_loading.py b/flash_examples/flash_components/custom_data_loading.py index c95988364b..65bc4c330b 100644 --- a/flash_examples/flash_components/custom_data_loading.py +++ b/flash_examples/flash_components/custom_data_loading.py @@ -38,23 +38,25 @@ # Your loader would take a list of individual class folder and load the images from them # # The folder paths are independent and when loading the order of folder. # # would determine the classification label. # -# Note: This is simple enough to show you the flexibility of the Flash API. # +# NOTE: This is simple enough to show you the flexibility of the Flash API. # ############################################################################################# ############################################################################################# -# Step 1 / 2: Implement a Input # +# Step 1 / 3: Implement a custom Input # # # # An `Input` is a state-aware (c.f training, validating, testing and predicting) # -# dataset. # -# and with specialized hooks (c.f load_data, load_sample) for each of those stages. # +# dataset and with specialized hooks (c.f load_data, load_sample) for each of those stages. # +# # # The hook resolution for the function is done in the following way. # -# If {state}_load_data is implemented then it would be used exclusively for that stage. # -# Otherwise, it would use the load_data function. # +# If {state}_load_data is implemented then it would be used exclusively for that stage. # +# Otherwise, it would use the load_data function. # +# # # If you use Input outside of Flash, the only requirements are to return a Sequence # # from load_data with Input or an Iterable with FlashIterableDataset. # +# # # When using Input with Flash Tasks, the model expects the `load_sample` to return a # -# dictionary with `DataKeys` as its keys (c.f `input`, `target`, metadata) # +# dictionary with `DataKeys` as its keys (c.f `input`, `target`, `metadata`) # # # ############################################################################################# @@ -92,11 +94,22 @@ def predict_load_data(self, predict_folder: str) -> List[Dict[DataKeys, Any]]: ############################################################################################# -# Step 2 / 2: [optional] Implement a InputTransform # +# Step 2 / 3: [optional] Implement a custom InputTransform # +# # +# An `InputTransform` is a state-aware (c.f training, validating, testing and predicting) # +# python dataclass that acts as a callback resolver for each stage of the pipeline with # +# specialized hooks (c.f per_sample_transform, per_sample_transform_on_device, # +# per_batch_transform, per_batch_transform_on_device, collate_fn) for each of those stages. # +# Each of the hooks returns a Callable type that acts on the samples to transform them. # # # -# A `InputTransform` is a state-aware (c.f training, validating, testing and predicting) # -# transform. You would have to implement a `configure_transforms` hook with your transform # # # +# The hook resolution for the function is done in the following way. # +# - If {state}_per_sample_transform is implemented then it would be used exclusively # +# for that stage. Otherwise, it would use the per_sample_transform function. # +# - If {state}_{key}_per_sample_transform is implemented then it would be used # +# exclusively for that stage and the specific key of the sample. Otherwise, it would # +# use the per_sample_transform function. # + ############################################################################################# @@ -105,7 +118,7 @@ class BaseImageInputTransform(InputTransform): image_size: Tuple[int, int] = (224, 224) - def input_per_sample_transform(self) -> Any: + def input_per_sample_transform(self) -> Callable: # this will be used to transform only the input value associated with # the `input` key within each sample. return T.Compose([T.Resize(self.image_size), T.ToTensor()]) @@ -119,107 +132,39 @@ class ImageRandomRotationInputTransform(BaseImageInputTransform): rotation: float = 0 - def input_per_sample_transform(self) -> Any: + def train_input_per_sample_transform(self) -> Callable: + # this will be used to transform only the input value associated with + # the `input` key within each sample of the train batch. + transforms = [T.Resize(self.image_size), T.ToTensor(), T.RandomRotation(self.rotation)] + return T.Compose(transforms) + + def input_per_sample_transform(self) -> Callable: # this will be used to transform only the input value associated with # the `input` key within each sample. transforms = [T.Resize(self.image_size), T.ToTensor()] - if self.training: - transforms += [T.RandomRotation(self.rotation)] return T.Compose(transforms) -# Register your transform within the Flash Dataset registry +# Register your transform within the InputTransform registry of the Flash DataModule # Note: Registries can be shared by multiple dataset. -MultipleFoldersImageInput.register_input_transform("base", BaseImageInputTransform) -MultipleFoldersImageInput.register_input_transform("random_rotation", ImageRandomRotationInputTransform) -MultipleFoldersImageInput.register_input_transform( - "random_90_def_rotation", partial(ImageRandomRotationInputTransform, rotation=90) -) - -train_dataset = MultipleFoldersImageInput( - RunningStage.TRAINING, - TRAIN_FOLDERS, - transform=("random_rotation", {"rotation": 45}), -) - -# Out: -# ImageRandomRotationInputTransform( -# running_stage=train, -# state: {'image_size': (224, 224), 'rotation': 45} -# transform={ -# 'per_sample_transform': Compose( -# ApplyToKeys(keys='input', transform=Compose( -# Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None) -# ToTensor() -# RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0))), -# ApplyToKeys(keys='target', transform=) -# ), -# 'collate': -# } -# ) - -train_dataset = MultipleFoldersImageInput( - RunningStage.TRAINING, - TRAIN_FOLDERS, - transform="random_90_def_rotation", -) - -print(train_dataset.transform) -# Out: -# ImageRandomRotationInputTransform( -# running_stage=train, -# state: {'image_size': (224, 224), 'rotation': 90} -# transform={ -# 'per_sample_transform': Compose( -# ApplyToKeys(keys='input', transform=Compose( -# Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None) -# ToTensor() -# RandomRotation(degrees=[-90.0, 90.0], interpolation=nearest, expand=False, fill=0))), -# ApplyToKeys(keys='target', transform=) -# ), -# 'collate': -# } -# ) - -val_dataset = MultipleFoldersImageInput(RunningStage.VALIDATING, VAL_FOLDERS, transform="base") -print(val_dataset.transform) -# Out: -# ImageRandomRotationInputTransform( -# running_stage=train, -# state: {'image_size': (224, 224), 'rotation': 90} -# transform={ -# 'per_sample_transform': Compose( -# ApplyToKeys(keys='input', transform=Compose( -# Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None) -# ToTensor() -# ), -# ApplyToKeys(keys='target', transform=) -# ), -# 'collate': -# } -# ) - -print(train_dataset[0]) -# Out: -# { -# : , -# : 0, -# : (500, 375) -# } +DataModule.register_input_transform("base", BaseImageInputTransform) +DataModule.register_input_transform("random_rotation", ImageRandomRotationInputTransform) +DataModule.register_input_transform("random_90_def_rotation", partial(ImageRandomRotationInputTransform, rotation=90)) ############################################################################################# -# Step 4 / 5: Create a DataModule # +# Step 3 / 3: Create a DataModule (Part 1) # # # -# The `DataModule` class is a collection of Input and you can pass them directly to # -# its init function. # +# The `DataModule` class is a collection of `Input` for various stages and the # +# `InputTransform` and you can pass them directly to its init function. # # # ############################################################################################# datamodule = DataModule( - train_input=MultipleFoldersImageInput(RunningStage.TRAINING, TRAIN_FOLDERS, transform="random_rotation"), - val_input=MultipleFoldersImageInput(RunningStage.VALIDATING, VAL_FOLDERS, transform="base"), - predict_input=MultipleFoldersImageInput(RunningStage.PREDICTING, PREDICT_FOLDER, transform="base"), + train_input=MultipleFoldersImageInput(RunningStage.TRAINING, TRAIN_FOLDERS), + val_input=MultipleFoldersImageInput(RunningStage.VALIDATING, VAL_FOLDERS), + predict_input=MultipleFoldersImageInput(RunningStage.PREDICTING, PREDICT_FOLDER), + transform="random_rotation", batch_size=2, ) @@ -267,10 +212,10 @@ def input_per_sample_transform(self) -> Any: ############################################################################################# -# Step 5 / 5: Provide your new utility with your DataModule # +# Step 5 / 5: Provide your new utility with your DataModule (Part 2) # # # -# The `DataModule` class is a collection of Input and you can pass them directly to # -# its init function. # +# The `DataModule` class is a collection of `Input` for various stages and the # +# `InputTransform` and you can create an extended utility method. # # # ############################################################################################# @@ -283,18 +228,18 @@ def from_multiple_folders( val_folders: Optional[List[str]] = None, test_folders: Optional[List[str]] = None, predict_folder: Optional[str] = None, - train_transform: Optional[INPUT_TRANSFORM_TYPE] = None, - val_transform: Optional[INPUT_TRANSFORM_TYPE] = None, - test_transform: Optional[INPUT_TRANSFORM_TYPE] = None, - predict_transform: Optional[INPUT_TRANSFORM_TYPE] = None, + transform: Optional[INPUT_TRANSFORM_TYPE] = None, + transform_kwargs: Optional[Dict[str, Any]] = None, **data_module_kwargs: Any, ) -> "ImageClassificationDataModule": return cls( - MultipleFoldersImageInput(RunningStage.TRAINING, train_folders, transform=train_transform), - MultipleFoldersImageInput(RunningStage.VALIDATING, val_folders, transform=val_transform), - MultipleFoldersImageInput(RunningStage.VALIDATING, test_folders, transform=test_transform), - MultipleFoldersImageInput(RunningStage.PREDICTING, predict_folder, transform=predict_transform), + MultipleFoldersImageInput(RunningStage.TRAINING, train_folders), + MultipleFoldersImageInput(RunningStage.VALIDATING, val_folders), + MultipleFoldersImageInput(RunningStage.VALIDATING, test_folders), + MultipleFoldersImageInput(RunningStage.PREDICTING, predict_folder), + transform=transform, + transform_kwargs=transform_kwargs, **data_module_kwargs, ) @@ -304,9 +249,7 @@ def from_multiple_folders( train_folders=TRAIN_FOLDERS, val_folders=VAL_FOLDERS, predict_folder=PREDICT_FOLDER, - train_transform="random_rotation", - val_transform="base", - predict_transform="base", + transform="random_90_def_rotation", batch_size=2, ) diff --git a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py index d335349aea..d1f5dea28a 100644 --- a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py +++ b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py @@ -34,7 +34,7 @@ train_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="train", download=True) val_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="validation", download=True) -train_transform = { +transform = { "per_sample_transform": nn.Sequential( ApplyToKeys( DataKeys.INPUT, @@ -70,7 +70,7 @@ train_targets=torch.from_numpy(train_dataset.y.astype(int)), val_data=val_dataset.x, val_targets=torch.from_numpy(val_dataset.y.astype(int)), - train_transform=train_transform, + transform=transform, ) model = ImageClassifier( diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py index 1da0995886..cfdfa7f916 100644 --- a/tests/audio/classification/test_data.py +++ b/tests/audio/classification/test_data.py @@ -281,8 +281,7 @@ def run(transform: Any = None): dm = AudioClassificationData.from_files( train_files=train_filepaths, train_targets=train_labels, - train_transform=transform, - val_transform=transform, + transform=transform, batch_size=B, num_workers=0, val_split=val_split, @@ -311,11 +310,11 @@ def test_from_folders_only_train(tmpdir): _rand_image().save(train_dir / "b" / "1.png") _rand_image().save(train_dir / "b" / "2.png") - spectrograms_data = AudioClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) + spectrograms_data = AudioClassificationData.from_folders(train_dir, batch_size=1) data = next(iter(spectrograms_data.train_dataloader())) imgs, labels = data["input"], data["target"] - assert imgs.shape == (1, 196, 196, 3) + assert imgs.shape == (1, 3, 128, 128) assert labels.shape == (1,) diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py index 8b138e1297..c9d47af2fe 100644 --- a/tests/core/data/test_callback.py +++ b/tests/core/data/test_callback.py @@ -31,10 +31,12 @@ def test_flash_callback(_, __, tmpdir): callback_mock = mock.MagicMock() inputs = [(torch.rand(1), torch.rand(1))] + transform = InputTransform() dm = DataModule( - DatasetInput(RunningStage.TRAINING, inputs, transform=InputTransform), - DatasetInput(RunningStage.VALIDATING, inputs, transform=InputTransform), - DatasetInput(RunningStage.TESTING, inputs, transform=InputTransform), + DatasetInput(RunningStage.TRAINING, inputs), + DatasetInput(RunningStage.VALIDATING, inputs), + DatasetInput(RunningStage.TESTING, inputs), + transform=transform, batch_size=1, num_workers=0, data_fetcher=callback_mock, @@ -63,10 +65,12 @@ def step(self, batch, batch_idx, metrics): limit_train_batches=1, progress_bar_refresh_rate=0, ) + transform = InputTransform() dm = DataModule( - DatasetInput(RunningStage.TRAINING, inputs, transform=InputTransform), - DatasetInput(RunningStage.VALIDATING, inputs, transform=InputTransform), - DatasetInput(RunningStage.TESTING, inputs, transform=InputTransform), + DatasetInput(RunningStage.TRAINING, inputs), + DatasetInput(RunningStage.VALIDATING, inputs), + DatasetInput(RunningStage.TESTING, inputs), + transform=transform, batch_size=1, num_workers=0, data_fetcher=callback_mock, diff --git a/tests/core/data/test_callbacks.py b/tests/core/data/test_callbacks.py index 486c5de3c6..6ee672c025 100644 --- a/tests/core/data/test_callbacks.py +++ b/tests/core/data/test_callbacks.py @@ -42,10 +42,11 @@ def configure_data_fetcher(): @classmethod def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_data: Any) -> "CustomDataModule": return cls( - Input(RunningStage.TRAINING, train_data, transform=InputTransform), - Input(RunningStage.VALIDATING, val_data, transform=InputTransform), - Input(RunningStage.TESTING, test_data, transform=InputTransform), - Input(RunningStage.PREDICTING, predict_data, transform=InputTransform), + Input(RunningStage.TRAINING, train_data), + Input(RunningStage.VALIDATING, val_data), + Input(RunningStage.TESTING, test_data), + Input(RunningStage.PREDICTING, predict_data), + transform=InputTransform(), batch_size=5, ) diff --git a/tests/core/data/test_data_module.py b/tests/core/data/test_data_module.py index 695d42f358..f6a40bab29 100644 --- a/tests/core/data/test_data_module.py +++ b/tests/core/data/test_data_module.py @@ -55,33 +55,31 @@ def fn(x): return fn - def per_batch_transform_on_device(self) -> Callable: - if self.training: - return train_fn - elif self.validating: - return val_fn - elif self.testing: - return test_fn - elif self.predicting: - return predict_fn - - train_dataset = Input(RunningStage.TRAINING, range(10), transform=TestTransform) - assert train_dataset.transform._running_stage == RunningStage.TRAINING + def train_per_batch_transform_on_device(self) -> Callable: + return train_fn + + def val_per_batch_transform_on_device(self) -> Callable: + return val_fn + + def test_per_batch_transform_on_device(self) -> Callable: + return test_fn + + def predict_per_batch_transform_on_device(self) -> Callable: + return predict_fn + + transform = TestTransform() + assert transform._transform is not None + + train_dataset = Input(RunningStage.TRAINING, range(10)) assert train_dataset.running_stage == RunningStage.TRAINING - transform = TestTransform(RunningStage.VALIDATING) - assert transform._running_stage == RunningStage.VALIDATING - val_dataset = Input(RunningStage.VALIDATING, range(10), transform=transform) + val_dataset = Input(RunningStage.VALIDATING, range(10)) assert val_dataset.running_stage == RunningStage.VALIDATING - transform = TestTransform(RunningStage.TESTING) - assert transform._running_stage == RunningStage.TESTING - test_dataset = Input(RunningStage.TESTING, range(10), transform=transform) + test_dataset = Input(RunningStage.TESTING, range(10)) assert test_dataset.running_stage == RunningStage.TESTING - transform = TestTransform(RunningStage.PREDICTING) - assert transform._running_stage == RunningStage.PREDICTING - predict_dataset = Input(RunningStage.PREDICTING, range(10), transform=transform) + predict_dataset = Input(RunningStage.PREDICTING, range(10)) assert predict_dataset.running_stage == RunningStage.PREDICTING dm = DataModule( @@ -89,6 +87,7 @@ def per_batch_transform_on_device(self) -> Callable: val_input=val_dataset, test_input=test_dataset, predict_input=predict_dataset, + transform=transform, batch_size=2, ) assert len(dm.train_dataloader()) == 5 @@ -140,9 +139,10 @@ def on_fit_end(self) -> None: trainer.test(model, datamodule=dm) trainer.predict(model, datamodule=dm) - input = Input(RunningStage.TRAINING, transform=TestTransform) - dm = DataModule(train_input=input, batch_size=1) - assert isinstance(dm._train_input.transform, TestTransform) + transform = TestTransform() + input = Input(RunningStage.TRAINING) + dm = DataModule(train_input=input, batch_size=1, transform=transform) + assert isinstance(dm.input_transform, TestTransform) class RandomDataset(Dataset): def __init__(self, size: int, length: int): @@ -155,6 +155,13 @@ def __getitem__(self, index): def __len__(self): return self.len + def _add_hundred(x): + if isinstance(x, Dict): + x["input"] += 100 + else: + x += 100 + return x + class TrainInputTransform(InputTransform): def _add_one(self, x): if isinstance(x, Dict): @@ -166,44 +173,37 @@ def _add_one(self, x): def per_sample_transform(self) -> Callable: return self._add_one - def _add_hundred(x): - if isinstance(x, Dict): - x["input"] += 100 - else: - x += 100 - return x + def val_per_sample_transform(self) -> Callable: + return _add_hundred dm = DataModule( - train_input=DatasetInput(RunningStage.TRAINING, RandomDataset(64, 32), transform=TrainInputTransform), - val_input=DatasetInput(RunningStage.TRAINING, RandomDataset(64, 32), transform=_add_hundred), - test_input=DatasetInput(RunningStage.TRAINING, RandomDataset(64, 32)), + train_input=DatasetInput(RunningStage.TRAINING, RandomDataset(64, 32)), + val_input=DatasetInput(RunningStage.VALIDATING, RandomDataset(64, 32)), + test_input=DatasetInput(RunningStage.TESTING, RandomDataset(64, 32)), batch_size=3, + transform=TrainInputTransform(), ) batch = next(iter(dm.train_dataloader())) assert batch["input"][0][0] == 2 batch = next(iter(dm.val_dataloader())) assert batch["input"][0][0] == 101 batch = next(iter(dm.test_dataloader())) - assert batch["input"][0][0] == 1 + assert batch["input"][0][0] == 2 class TestInput(Input): def train_load_data(self, _): - assert self.training return [(0, 1, 2, 3), (0, 1, 2, 3)] def val_load_data(self, _): - assert self.validating self.val_load_sample_called = False return list(range(5)) def val_load_sample(self, sample): - assert self.validating self.val_load_sample_called = True return {"a": sample, "b": sample + 1} def test_load_data(self, _): - assert self.testing return [[torch.rand(1), torch.rand(1)], [torch.rand(1), torch.rand(1)]] @@ -218,8 +218,6 @@ class TestInputTransform(InputTransform): test_per_sample_transform_called = False def _train_per_sample_transform(self, sample): - assert self.training - assert self.current_fn == "per_sample_transform" self.train_per_sample_transform_called = True return sample + (5,) @@ -227,8 +225,6 @@ def train_per_sample_transform(self): return self._train_per_sample_transform def _train_collate(self, samples): - assert self.training - assert self.current_fn == "collate" self.train_collate_called = True return torch.tensor([list(s) for s in samples]) @@ -236,8 +232,6 @@ def train_collate(self): return self._train_collate def _train_per_batch_transform_on_device(self, batch): - assert self.training - assert self.current_fn == "per_batch_transform_on_device" self.train_per_batch_transform_on_device_called = True assert torch.equal(batch, torch.tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]])) @@ -245,8 +239,6 @@ def train_per_batch_transform_on_device(self): return self._train_per_batch_transform_on_device def _val_per_sample_transform(self, sample): - assert self.validating - assert self.current_fn == "per_sample_transform" self.val_per_sample_transform_called = True return sample @@ -254,8 +246,6 @@ def val_per_sample_transform(self): return self._val_per_sample_transform def _val_collate(self, samples): - assert self.validating - assert self.current_fn == "collate" self.val_collate_called = True _count = samples[0]["a"] assert samples == [{"a": _count, "b": _count + 1}, {"a": _count + 1, "b": _count + 2}] @@ -265,8 +255,6 @@ def val_collate(self): return self._val_collate def _val_per_batch_transform_on_device(self, batch): - assert self.validating - assert self.current_fn == "per_batch_transform_on_device" self.val_per_batch_transform_on_device_called = True if isinstance(batch, list): batch = batch[0] @@ -278,8 +266,6 @@ def val_per_batch_transform_on_device(self): return self._val_per_batch_transform_on_device def _test_per_sample_transform(self, sample): - assert self.testing - assert self.current_fn == "per_sample_transform" self.test_per_sample_transform_called = True return sample @@ -312,10 +298,12 @@ def test_step(self, batch, batch_idx): def test_transformations(tmpdir): + transform = TestInputTransform() datamodule = DataModule( - TestInput(RunningStage.TRAINING, [1], transform=TestInputTransform), - TestInput(RunningStage.VALIDATING, [1], transform=TestInputTransform), - TestInput(RunningStage.TESTING, [1], transform=TestInputTransform), + TestInput(RunningStage.TRAINING, [1]), + TestInput(RunningStage.VALIDATING, [1]), + TestInput(RunningStage.TESTING, [1]), + transform=transform, batch_size=2, num_workers=0, ) @@ -329,9 +317,10 @@ def test_transformations(tmpdir): batch = next(iter(datamodule.val_dataloader())) datamodule = DataModule( - TestInput(RunningStage.TRAINING, [1], transform=TestInputTransform2), - TestInput(RunningStage.VALIDATING, [1], transform=TestInputTransform2), - TestInput(RunningStage.TESTING, [1], transform=TestInputTransform2), + TestInput(RunningStage.TRAINING, [1]), + TestInput(RunningStage.VALIDATING, [1]), + TestInput(RunningStage.TESTING, [1]), + transform=TestInputTransform2, batch_size=2, num_workers=0, ) @@ -351,13 +340,13 @@ def test_transformations(tmpdir): trainer.fit(model, datamodule=datamodule) trainer.test(model, datamodule=datamodule) - assert datamodule.train_dataset.transform.train_per_sample_transform_called - assert datamodule.train_dataset.transform.train_collate_called - assert datamodule.train_dataset.transform.train_per_batch_transform_on_device_called - assert datamodule.train_dataset.transform.train_per_sample_transform_called - assert datamodule.val_dataset.transform.val_collate_called - assert datamodule.val_dataset.transform.val_per_batch_transform_on_device_called - assert datamodule.test_dataset.transform.test_per_sample_transform_called + assert datamodule.input_transform.train_per_sample_transform_called + assert datamodule.input_transform.train_collate_called + assert datamodule.input_transform.train_per_batch_transform_on_device_called + assert datamodule.input_transform.train_per_sample_transform_called + assert datamodule.input_transform.val_collate_called + assert datamodule.input_transform.val_per_batch_transform_on_device_called + assert datamodule.input_transform.test_per_sample_transform_called @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @@ -401,9 +390,11 @@ def validation_step(self, batch, batch_idx): assert torch.max(batch) <= 1.0 assert torch.min(batch) >= 0.0 + transform = ImageClassificationInputTransform() datamodule = DataModule( - ImageInput(RunningStage.TRAINING, [1], transform=ImageClassificationInputTransform), - ImageInput(RunningStage.VALIDATING, [1], transform=ImageClassificationInputTransform), + ImageInput(RunningStage.TRAINING, [1]), + ImageInput(RunningStage.VALIDATING, [1]), + transform=transform, batch_size=2, num_workers=0, ) diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index d1d720783e..90ed44cac5 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -16,7 +16,6 @@ from flash.core.data.data_pipeline import DataPipeline from flash.core.data.io.input_transform import InputTransform -from flash.core.utilities.stages import RunningStage def test_is_overridden_recursive(tmpdir): @@ -31,7 +30,7 @@ def collate(self): def val_collate(self): return self.custom_transform - input_transform = TestInputTransform(RunningStage.TRAINING) + input_transform = TestInputTransform() assert DataPipeline._is_overridden_recursive("collate", input_transform, InputTransform, prefix="val") assert DataPipeline._is_overridden_recursive("collate", input_transform, InputTransform, prefix="train") assert not DataPipeline._is_overridden_recursive( diff --git a/tests/core/data/test_input_transform.py b/tests/core/data/test_input_transform.py index 41d5840085..ff82e58ca8 100644 --- a/tests/core/data/test_input_transform.py +++ b/tests/core/data/test_input_transform.py @@ -37,14 +37,16 @@ def input_per_batch_transform(self) -> Callable: MisconfigurationException, match="Only one of per_batch_transform or input_per_batch_transform can be overridden", ): - MyTransform(running_stage=RunningStage.TRAINING) + transform = MyTransform() + transform._populate_transforms_for_stage(RunningStage.TRAINING) class MyTransform(InputTransform): def input_per_batch_transform(self) -> Callable: return None with pytest.raises(MisconfigurationException, match="The hook input_per_batch_transform should return a function."): - MyTransform(running_stage=RunningStage.TRAINING) + transform = MyTransform() + transform._populate_transforms_for_stage(RunningStage.TRAINING) class MyTransform(InputTransform): def target_per_batch_transform(self) -> Callable: @@ -53,11 +55,14 @@ def target_per_batch_transform(self) -> Callable: def input_per_batch_transform(self) -> Callable: return fn - transform = MyTransform(running_stage=RunningStage.TRAINING) - assert list(transform._transform.keys()) == ["per_batch_transform", "collate"] - assert isinstance(transform._transform["per_batch_transform"], ApplyToKeys) - assert transform._transform["per_batch_transform"].keys == ["input"] - assert transform._transform["collate"] == default_collate + transform = MyTransform() + for stage in [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]: + transform._populate_transforms_for_stage(stage) + transforms = transform._transform[stage].transforms + assert list(transforms.keys()) == ["per_batch_transform", "collate"] + assert isinstance(transforms["per_batch_transform"], ApplyToKeys) + assert transforms["per_batch_transform"].keys == ["input"] + assert transforms["collate"] == default_collate class MyTransform(InputTransform): def train_per_batch_transform(self) -> Callable: @@ -69,13 +74,19 @@ def target_per_batch_transform(self) -> Callable: def input_per_batch_transform(self) -> Callable: return self.input_per_batch_transform - transform = MyTransform(running_stage=RunningStage.TRAINING) - assert list(transform._transform.keys()) == ["per_batch_transform", "collate"] - assert transform._transform["per_batch_transform"] == transform.train_per_batch_transform + transform = MyTransform() - transform = MyTransform(running_stage=RunningStage.VALIDATING) - assert isinstance(transform._transform["per_batch_transform"], Compose) - assert len(transform._transform["per_batch_transform"].transforms) == 2 + # Tests for RunningStage.TRAINING + transform._populate_transforms_for_stage(RunningStage.TRAINING) + train_transforms = transform._transform[RunningStage.TRAINING].transforms + assert list(train_transforms.keys()) == ["per_batch_transform", "collate"] + assert train_transforms["per_batch_transform"] == transform.train_per_batch_transform + + # Tests for RunningStage.VALIDATING + transform._populate_transforms_for_stage(RunningStage.VALIDATING) + val_transforms = transform._transform[RunningStage.VALIDATING].transforms + assert isinstance(val_transforms["per_batch_transform"], Compose) + assert len(val_transforms["per_batch_transform"].transforms) == 2 class MyTransform(InputTransform): def train_per_batch_transform(self) -> Callable: @@ -91,7 +102,8 @@ def input_per_batch_transform(self) -> Callable: MisconfigurationException, match="Only one of train_per_batch_transform or train_target_per_batch_transform can be overridden.", ): - MyTransform(running_stage=RunningStage.TRAINING) + transform = MyTransform() + transform._populate_transforms_for_stage(RunningStage.TRAINING) class MyTransform(InputTransform): def per_batch_transform(self) -> Callable: @@ -109,24 +121,33 @@ def train_collate(self) -> Callable: def collate(self) -> Callable: return self.collate - transform = MyTransform(running_stage=RunningStage.TRAINING) - assert list(transform._transform.keys()) == ["per_batch_transform", "collate"] - assert isinstance(transform._transform["per_batch_transform"], Compose) - assert len(transform._transform["per_batch_transform"].transforms) == 2 - assert transform._transform["collate"] == transform.train_collate - - transform = MyTransform(running_stage=RunningStage.VALIDATING) - assert list(transform._transform.keys()) == ["per_batch_transform", "collate"] - assert transform._transform["per_batch_transform"] == transform.train_per_batch_transform - assert transform._transform["collate"] == transform.collate - - transform = LambdaInputTransform(RunningStage.TRAINING, transform=fn) - assert list(transform._transform.keys()) == ["per_sample_transform", "collate"] - assert transform._transform["per_sample_transform"] == fn + transform = MyTransform() + + # Tests for RunningStage.TRAINING + transform._populate_transforms_for_stage(RunningStage.TRAINING) + train_transforms = transform._transform[RunningStage.TRAINING].transforms + assert list(train_transforms.keys()) == ["per_batch_transform", "collate"] + assert isinstance(train_transforms["per_batch_transform"], Compose) + assert len(train_transforms["per_batch_transform"].transforms) == 2 + assert train_transforms["collate"] == transform.train_collate + + # Tests for RunningStage.VALIDATING + transform._populate_transforms_for_stage(RunningStage.VALIDATING) + val_transforms = transform._transform[RunningStage.VALIDATING].transforms + assert list(val_transforms.keys()) == ["per_batch_transform", "collate"] + assert val_transforms["per_batch_transform"] == transform.train_per_batch_transform + assert val_transforms["collate"] == transform.collate + + transform = LambdaInputTransform(transform=fn) + for stage in [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]: + transform._populate_transforms_for_stage(stage) + transforms = transform._transform[stage].transforms + assert list(transforms.keys()) == ["per_sample_transform", "collate"] + assert transforms["per_sample_transform"] == fn class MyTransform(InputTransform): - def __init__(self, value: int, running_stage: RunningStage): - super().__init__(running_stage) + def __init__(self, value: int): + super().__init__() self.value = value def input_per_batch_transform(self) -> Callable: @@ -135,19 +156,19 @@ def input_per_batch_transform(self) -> Callable: return super().input_per_batch_transform with pytest.raises(AttributeError, match="__init__"): - MyTransform(1, running_stage=RunningStage.TRAINING) + MyTransform(1) class MyTransform(InputTransform): - def __init__(self, value: int, running_stage: RunningStage): + def __init__(self, value: int): self.value = value - super().__init__(running_stage) + super().__init__() def input_per_batch_transform(self) -> Callable: if self.value > 0: return self.input_per_batch_transform return super().input_per_batch_transform - MyTransform(1, running_stage=RunningStage.TRAINING) + MyTransform(1) class CustomInputTransform(InputTransform): @@ -193,9 +214,9 @@ def test_check_transforms(): input_transform = CustomInputTransform - input_transform(RunningStage.TRAINING) - with pytest.raises(MisconfigurationException, match="are mutually exclusive"): - input_transform(RunningStage.VALIDATING) + # input_transform._populate_transforms_for_stage(RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="are mutually exclusive"): - input_transform(RunningStage.TESTING) - input_transform(RunningStage.PREDICTING) + input_transform() + # with pytest.raises(MisconfigurationException, match="are mutually exclusive"): + # input_transform._populate_transforms_for_stage(RunningStage.TESTING) + # input_transform._populate_transforms_for_stage(RunningStage.PREDICTING) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 76dddeddfa..6a236aa6aa 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -35,6 +35,7 @@ from flash.audio import SpeechRecognition from flash.core.adapter import Adapter from flash.core.classification import ClassificationTask +from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.utilities.imports import ( _AUDIO_TESTING, @@ -194,7 +195,7 @@ def test_classification_task_trainer_predict(tmpdir): task = ClassificationTask(model) ds = PredictDummyDataset(10) batch_size = 6 - predict_dl = task.process_predict_dataset(ds, batch_size=batch_size) + predict_dl = task.process_predict_dataset(ds, input_transform=InputTransform(), batch_size=batch_size) trainer = pl.Trainer(default_root_dir=tmpdir) predictions = trainer.predict(task, predict_dl) assert len(list(chain.from_iterable(predictions))) == 10 diff --git a/tests/graph/classification/test_data.py b/tests/graph/classification/test_data.py index 9205573358..a143e3d2ed 100644 --- a/tests/graph/classification/test_data.py +++ b/tests/graph/classification/test_data.py @@ -86,10 +86,7 @@ def per_sample_transform(self): val_dataset=val_dataset, test_dataset=test_dataset, predict_dataset=predict_dataset, - train_transform=TestInputTransform, - val_transform=TestInputTransform, - test_transform=TestInputTransform, - predict_transform=TestInputTransform, + transform=TestInputTransform, batch_size=2, ) assert dm is not None diff --git a/tests/graph/classification/test_model.py b/tests/graph/classification/test_model.py index 6c0bde8b6d..3a09d68089 100644 --- a/tests/graph/classification/test_model.py +++ b/tests/graph/classification/test_model.py @@ -40,7 +40,8 @@ def test_train(tmpdir): tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) datamodule = DataModule( - GraphClassificationDatasetInput(RunningStage.TRAINING, tudataset, transform=GraphClassificationInputTransform), + GraphClassificationDatasetInput(RunningStage.TRAINING, tudataset), + transform=GraphClassificationInputTransform, batch_size=4, ) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) @@ -53,9 +54,8 @@ def test_val(tmpdir): tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) datamodule = DataModule( - val_input=GraphClassificationDatasetInput( - RunningStage.VALIDATING, tudataset, transform=GraphClassificationInputTransform - ), + val_input=GraphClassificationDatasetInput(RunningStage.VALIDATING, tudataset), + transform=GraphClassificationInputTransform, batch_size=4, ) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) @@ -68,9 +68,8 @@ def test_test(tmpdir): tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) datamodule = DataModule( - test_input=GraphClassificationDatasetInput( - RunningStage.TESTING, tudataset, transform=GraphClassificationInputTransform - ), + test_input=GraphClassificationDatasetInput(RunningStage.TESTING, tudataset), + transform=GraphClassificationInputTransform, batch_size=4, ) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) @@ -83,9 +82,8 @@ def test_predict_dataset(tmpdir): tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) datamodule = DataModule( - predict_input=GraphClassificationDatasetInput( - RunningStage.TESTING, tudataset, transform=GraphClassificationInputTransform - ), + predict_input=GraphClassificationDatasetInput(RunningStage.TESTING, tudataset), + transform=GraphClassificationInputTransform, batch_size=4, ) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) diff --git a/tests/graph/embedding/test_model.py b/tests/graph/embedding/test_model.py index f6d415ac7b..a91c556fc3 100644 --- a/tests/graph/embedding/test_model.py +++ b/tests/graph/embedding/test_model.py @@ -39,11 +39,10 @@ def test_not_trainable(tmpdir): tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphEmbedder(GraphClassifier(num_features=1, num_classes=1).backbone) datamodule = DataModule( - GraphClassificationDatasetInput(RunningStage.TRAINING, tudataset, transform=GraphClassificationInputTransform), - GraphClassificationDatasetInput( - RunningStage.VALIDATING, tudataset, transform=GraphClassificationInputTransform - ), - GraphClassificationDatasetInput(RunningStage.TESTING, tudataset, transform=GraphClassificationInputTransform), + GraphClassificationDatasetInput(RunningStage.TRAINING, tudataset), + GraphClassificationDatasetInput(RunningStage.VALIDATING, tudataset), + GraphClassificationDatasetInput(RunningStage.TESTING, tudataset), + transform=GraphClassificationInputTransform, batch_size=4, ) trainer = Trainer(default_root_dir=tmpdir, num_sanity_val_steps=0) @@ -65,9 +64,8 @@ def test_predict_dataset(tmpdir): GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes).backbone ) datamodule = DataModule( - predict_input=GraphClassificationDatasetInput( - RunningStage.PREDICTING, tudataset, transform=GraphClassificationInputTransform - ), + predict_input=GraphClassificationDatasetInput(RunningStage.PREDICTING, tudataset), + transform=GraphClassificationInputTransform, batch_size=4, ) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 29affc003c..90e46e82c9 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -284,8 +284,7 @@ def run(transform: Any = None): dm = ImageClassificationData.from_files( train_files=train_filepaths, train_targets=train_labels, - train_transform=transform, - val_transform=transform, + transform=transform, batch_size=B, num_workers=0, val_split=val_split, @@ -311,7 +310,7 @@ def test_from_folders_only_train(tmpdir): _rand_image().save(train_dir / "b" / "1.png") _rand_image().save(train_dir / "b" / "2.png") - img_data = ImageClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) + img_data = ImageClassificationData.from_folders(train_dir, batch_size=1) data = img_data.train_dataset[0] imgs, labels = data["input"], data["target"] @@ -646,7 +645,7 @@ def per_batch_transform(self): train_file=single_target_csv, batch_size=2, num_workers=0, - train_transform=MyTransform, + transform=MyTransform, ) batch = next(iter(img_data.train_dataloader())) diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index db46114b0c..e699aa97ec 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -21,9 +21,13 @@ from torch.utils.data import Dataset from flash.__main__ import main +from flash.core.data.callback import BaseDataFetcher from flash.core.data.io.input import DataKeys +from flash.core.data.io.input_transform import create_worker_input_transform_processor +from flash.core.integrations.icevision.transforms import IceVisionInputTransform from flash.core.trainer import Trainer from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.stages import RunningStage from flash.image import ObjectDetector @@ -75,7 +79,16 @@ def test_init(): batch_size = 2 ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) - dl = model.process_predict_dataset(ds, batch_size=batch_size) + input_transform = IceVisionInputTransform() + predict_collate_fn = create_worker_input_transform_processor( + RunningStage.PREDICTING, input_transform, [BaseDataFetcher()] + ) + dl = model.process_predict_dataset( + dataset=ds, + input_transform=input_transform, + batch_size=batch_size, + collate_fn=predict_collate_fn, + ) data = next(iter(dl)) out = model.forward(data[DataKeys.INPUT]) @@ -93,8 +106,21 @@ def test_init(): def test_training(tmpdir, head): model = ObjectDetector(num_classes=2, head=head, pretrained=False) ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) + + input_transform = IceVisionInputTransform() + data_fetcher = BaseDataFetcher() + train_collate_fn = create_worker_input_transform_processor(RunningStage.TRAINING, input_transform, [data_fetcher]) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - dl = model.process_train_dataset(ds, trainer, 2, 0, False, None) + dl = model.process_train_dataset( + dataset=ds, + trainer=trainer, + input_transform=input_transform, + batch_size=2, + num_workers=0, + pin_memory=False, + collate_fn=train_collate_fn, + ) trainer.fit(model, dl) @@ -142,10 +168,32 @@ def test_cli(): def test_predict(tmpdir, head): model = ObjectDetector(num_classes=2, head=head, pretrained=False) ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) + + input_transform = IceVisionInputTransform() + data_fetcher = BaseDataFetcher() + train_collate_fn = create_worker_input_transform_processor(RunningStage.TRAINING, input_transform, [data_fetcher]) + predict_collate_fn = create_worker_input_transform_processor( + RunningStage.PREDICTING, input_transform, [data_fetcher] + ) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - dl = model.process_train_dataset(ds, trainer, 2, 0, False, None) + dl = model.process_train_dataset( + dataset=ds, + trainer=trainer, + input_transform=input_transform, + batch_size=2, + num_workers=0, + pin_memory=False, + collate_fn=train_collate_fn, + ) trainer.fit(model, dl) - dl = model.process_predict_dataset(ds, batch_size=2) + + dl = model.process_predict_dataset( + dataset=ds, + input_transform=input_transform, + batch_size=2, + collate_fn=predict_collate_fn, + ) predictions = trainer.predict(model, dl, output="preds") assert len(predictions[0][0]["bboxes"]) > 0 model.predict_kwargs = {"detection_threshold": 2} diff --git a/tests/image/embedding/utils.py b/tests/image/embedding/utils.py index 909757b0c4..72a0485c0b 100644 --- a/tests/image/embedding/utils.py +++ b/tests/image/embedding/utils.py @@ -57,10 +57,7 @@ def collate(self) -> Callable: datamodule = ImageClassificationData.from_datasets( train_dataset=FakeData(), - train_transform=SSLInputTransform, - val_transform=SSLInputTransform, - test_transform=SSLInputTransform, - predict_transform=SSLInputTransform, + transform=SSLInputTransform, transform_kwargs=transform_kwargs, batch_size=batch_size, )