From 1b1b9391f2cd5b8ba7edaf4a0464cc11351c2461 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 15 Feb 2022 14:11:39 +0000 Subject: [PATCH] Add ability to override the TargetFormatter in classification tasks (#1171) --- CHANGELOG.md | 2 + .../source/general/classification_targets.rst | 177 ++++++++++++++++++ flash/audio/classification/data.py | 49 +++-- flash/core/data/data_module.py | 4 +- flash/core/data/utilities/classification.py | 5 +- flash/image/classification/data.py | 64 ++++--- flash/image/detection/data.py | 13 +- flash/tabular/classification/data.py | 11 +- flash/text/classification/data.py | 31 ++- flash/video/classification/data.py | 44 ++++- 10 files changed, 345 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c573622c2..a837619ebd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for passing the `Output` object (or a string e.g. `"labels"`) to the `flash.Trainer.predict` method ([#1157](https://github.com/PyTorchLightning/lightning-flash/pull/1157)) +- Added support for passing the `TargetFormatter` object to `from_*` methods for classification to override target handling ([#1171](https://github.com/PyTorchLightning/lightning-flash/pull/1171)) + ### Changed - Changed `Wav2Vec2Processor` to `AutoProcessor` and seperate it from backbone [optional] ([#1075](https://github.com/PyTorchLightning/lightning-flash/pull/1075)) diff --git a/docs/source/general/classification_targets.rst b/docs/source/general/classification_targets.rst index afa4f66f8f..bc58d2d09d 100644 --- a/docs/source/general/classification_targets.rst +++ b/docs/source/general/classification_targets.rst @@ -6,6 +6,7 @@ Formatting Classification Targets This guide details the different target formats supported by classification tasks in Flash. By default, the target format and any additional metadata (``labels``, ``num_classes``, ``multi_label``) will be inferred from your training data. +You can override this behaviour by passing your own :class:`~flash.core.data.utilities.classification.TargetFormatter` using the ``target_formatter`` argument. .. testsetup:: targets @@ -45,6 +46,28 @@ Here's an example: >>> datamodule.multi_label False +Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.SingleNumericTargetFormatter` to override the behaviour. +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> from flash.core.data.utilities.classification import SingleNumericTargetFormatter + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=[0, 1, 0], + ... target_formatter=SingleNumericTargetFormatter(labels=["dog", "cat", "rabbit"]), + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['dog', 'cat', 'rabbit'] + >>> datamodule.multi_label + False + Labels ______ @@ -70,6 +93,28 @@ Here's an example: >>> datamodule.multi_label False +Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.SingleLabelTargetFormatter` to override the behaviour. +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=["cat", "dog", "cat"], + ... target_formatter=SingleLabelTargetFormatter(labels=["dog", "cat", "rabbit"]), + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['dog', 'cat', 'rabbit'] + >>> datamodule.multi_label + False + One-hot Binaries ________________ @@ -95,6 +140,28 @@ Here's an example: >>> datamodule.multi_label False +Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.SingleBinaryTargetFormatter` to override the behaviour. +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> from flash.core.data.utilities.classification import SingleBinaryTargetFormatter + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=[[1, 0], [0, 1], [1, 0]], + ... target_formatter=SingleLabelTargetFormatter(labels=["dog", "cat"]), + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 2 + >>> datamodule.labels + ['dog', 'cat'] + >>> datamodule.multi_label + False + Multi Label ########### @@ -125,6 +192,28 @@ Here's an example: >>> datamodule.multi_label True +Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.MultiNumericTargetFormatter` to override the behaviour. +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> from flash.core.data.utilities.classification import MultiNumericTargetFormatter + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=[[0], [0, 1], [1, 2]], + ... target_formatter=MultiNumericTargetFormatter(labels=["dog", "cat", "rabbit"]), + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['dog', 'cat', 'rabbit'] + >>> datamodule.multi_label + True + Labels ______ @@ -150,6 +239,28 @@ Here's an example: >>> datamodule.multi_label True +Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.MultiLabelTargetFormatter` to override the behaviour. +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> from flash.core.data.utilities.classification import MultiLabelTargetFormatter + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=[["cat"], ["cat", "dog"], ["dog", "rabbit"]], + ... target_formatter=MultiLabelTargetFormatter(labels=["dog", "cat", "rabbit"]), + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['dog', 'cat', 'rabbit'] + >>> datamodule.multi_label + True + Comma Delimited _______________ @@ -175,6 +286,28 @@ Here's an example: >>> datamodule.multi_label True +Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.CommaDelimitedMultiLabelTargetFormatter` to override the behaviour. +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> from flash.core.data.utilities.classification import CommaDelimitedMultiLabelTargetFormatter + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=["cat", "cat,dog", "dog,rabbit"], + ... target_formatter=CommaDelimitedMultiLabelTargetFormatter(labels=["dog", "cat", "rabbit"]), + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['dog', 'cat', 'rabbit'] + >>> datamodule.multi_label + True + Space Delimited _______________ @@ -200,6 +333,28 @@ Here's an example: >>> datamodule.multi_label True +Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.SpaceDelimitedTargetFormatter` to override the behaviour. +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> from flash.core.data.utilities.classification import SpaceDelimitedTargetFormatter + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=["cat", "cat dog", "dog rabbit"], + ... target_formatter=SpaceDelimitedTargetFormatter(labels=["dog", "cat", "rabbit"]), + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['dog', 'cat', 'rabbit'] + >>> datamodule.multi_label + True + Multi-hot Binaries __________________ @@ -224,3 +379,25 @@ Here's an example: True >>> datamodule.multi_label True + +Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.MultiBinaryTargetFormatter` to override the behaviour. +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> from flash.core.data.utilities.classification import MultiBinaryTargetFormatter + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=[[1, 0, 0], [1, 1, 0], [0, 1, 1]], + ... target_formatter=MultiBinaryTargetFormatter(labels=["dog", "cat", "rabbit"]), + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['dog', 'cat', 'rabbit'] + >>> datamodule.multi_label + True diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py index c9920b4f2a..e520410b14 100644 --- a/flash/audio/classification/data.py +++ b/flash/audio/classification/data.py @@ -30,6 +30,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE +from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.paths import PATH_TYPE from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _AUDIO_TESTING @@ -62,6 +63,7 @@ def from_files( val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationFilesInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -88,11 +90,13 @@ def from_files( val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.audio.classification.data.AudioClassificationData`. @@ -142,6 +146,7 @@ def from_files( """ ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -168,6 +173,7 @@ def from_folders( val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationFolderInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -213,11 +219,13 @@ def from_folders( val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.image.classification.data.ImageClassificationData`. @@ -270,6 +278,7 @@ def from_folders( """ ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -299,6 +308,7 @@ def from_numpy( val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationNumpyInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -323,11 +333,13 @@ def from_numpy( val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.audio.classification.data.AudioClassificationData`. @@ -361,6 +373,7 @@ def from_numpy( """ ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -390,6 +403,7 @@ def from_tensors( val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationTensorInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -414,11 +428,13 @@ def from_tensors( val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.audio.classification.data.AudioClassificationData`. @@ -452,6 +468,7 @@ def from_tensors( """ ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -488,6 +505,7 @@ def from_data_frame( val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationDataFrameInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -526,11 +544,13 @@ def from_data_frame( val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.image.classification.data.ImageClassificationData`. @@ -602,6 +622,7 @@ def from_data_frame( """ ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -643,6 +664,7 @@ def from_csv( val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationCSVInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -681,11 +703,13 @@ def from_csv( val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.audio.classification.data.AudioClassificationData`. @@ -767,6 +791,7 @@ def from_csv( """ ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 2c0dbeb400..f2e3fd00b6 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -65,10 +65,8 @@ class DataModule(pl.LightningDataModule): :meth:`~flash.core.data.data_module.DataModule.configure_data_fetcher` will be used. val_split: An optional float which gives the relative amount of the training dataset to use for the validation dataset. - batch_size: The batch size to be used by the DataLoader. Defaults to 1. + batch_size: The batch size to be used by the DataLoader. num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Windows or Darwin platform. sampler: A sampler following the :class:`~torch.utils.data.sampler.Sampler` type. Will be passed to the DataLoader for the training dataset. Defaults to None. """ diff --git a/flash/core/data/utilities/classification.py b/flash/core/data/utilities/classification.py index e8c264e8c2..0bd4660f2c 100644 --- a/flash/core/data/utilities/classification.py +++ b/flash/core/data/utilities/classification.py @@ -72,6 +72,9 @@ class TargetFormatter: labels: Optional[List[str]] = None num_classes: Optional[int] = None + def __post_init__(self): + self.num_classes = len(self.labels) if self.labels is not None else self.num_classes + def __call__(self, target: Any) -> Any: return self.format(target) @@ -132,7 +135,7 @@ class SingleLabelTargetFormatter(TargetFormatter): binary: ClassVar[Optional[bool]] = False def __post_init__(self): - self.num_classes = len(self.labels) if self.num_classes is None else self.num_classes + super().__post_init__() self.label_to_idx = {label: idx for idx, label in enumerate(self.labels)} def format(self, target: Any) -> Any: diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 59d33bcfa4..9ec5f81196 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -23,6 +23,7 @@ from flash.core.data.data_module import DataModule, DatasetInput from flash.core.data.io.input import DataKeys, Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE +from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.paths import PATH_TYPE from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioImageClassificationInput from flash.core.registry import FlashRegistry @@ -93,6 +94,7 @@ def from_files( val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ImageClassificationFilesInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -119,11 +121,13 @@ def from_files( val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.image.classification.data.ImageClassificationData`. @@ -166,8 +170,8 @@ def from_files( >>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)] >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] """ - ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -194,6 +198,7 @@ def from_folders( val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ImageClassificationFolderInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -239,11 +244,13 @@ def from_folders( val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.image.classification.data.ImageClassificationData`. @@ -291,8 +298,8 @@ def from_folders( >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") """ - ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -322,6 +329,7 @@ def from_numpy( val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ImageClassificationNumpyInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -346,11 +354,13 @@ def from_numpy( val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.image.classification.data.ImageClassificationData`. @@ -381,8 +391,8 @@ def from_numpy( >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... """ - ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -412,6 +422,7 @@ def from_tensors( val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ImageClassificationTensorInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -436,11 +447,13 @@ def from_tensors( val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.image.classification.data.ImageClassificationData`. @@ -471,8 +484,8 @@ def from_tensors( >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... """ - ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -509,6 +522,7 @@ def from_data_frame( val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ImageClassificationDataFrameInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -547,11 +561,13 @@ def from_data_frame( val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.image.classification.data.ImageClassificationData`. @@ -614,8 +630,8 @@ def from_data_frame( >>> del train_data_frame >>> del predict_data_frame """ - ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -657,6 +673,7 @@ def from_csv( val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ImageClassificationCSVInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -695,11 +712,13 @@ def from_csv( val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.image.classification.data.ImageClassificationData`. @@ -776,8 +795,8 @@ def from_csv( >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") """ - ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -811,6 +830,7 @@ def from_fiftyone( val_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ImageClassificationFiftyOneInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, @@ -835,11 +855,13 @@ def from_fiftyone( val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.image.classification.data.ImageClassificationData`. @@ -899,8 +921,8 @@ def from_fiftyone( >>> del train_dataset >>> del predict_dataset """ - ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 14472c1493..0376b1752c 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -16,6 +16,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input +from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.sort import sorted_alphanumeric from flash.core.integrations.icevision.data import IceVisionInput from flash.core.integrations.icevision.transforms import IceVisionInputTransform @@ -66,6 +67,7 @@ def from_files( val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = ObjectDetectionFilesInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -97,11 +99,13 @@ def from_files( val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when - predicting. + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.image.detection.data.ObjectDetectionData`. @@ -151,7 +155,10 @@ def from_files( >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] """ - ds_kw = dict(transform_kwargs=transform_kwargs) + ds_kw = dict( + target_formatter=target_formatter, + transform_kwargs=transform_kwargs, + ) train_input = input_cls( RunningStage.TRAINING, diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py index 518b45f35b..3e21d5b56a 100644 --- a/flash/tabular/classification/data.py +++ b/flash/tabular/classification/data.py @@ -15,6 +15,7 @@ from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform +from flash.core.data.utilities.classification import TargetFormatter from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING from flash.core.utilities.stages import RunningStage from flash.tabular.classification.input import TabularClassificationCSVInput, TabularClassificationDataFrameInput @@ -49,6 +50,7 @@ def from_data_frame( val_transform: INPUT_TRANSFORM_TYPE = InputTransform, test_transform: INPUT_TRANSFORM_TYPE = InputTransform, predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TabularClassificationDataFrameInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -83,10 +85,12 @@ def from_data_frame( test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the - :class:`~flash.core.data.data_module.DataModule` constructor. + :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.tabular.classification.data.TabularClassificationData`. @@ -156,6 +160,7 @@ def from_data_frame( >>> del predict_data """ ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, categorical_fields=categorical_fields, @@ -191,6 +196,7 @@ def from_csv( val_transform: INPUT_TRANSFORM_TYPE = InputTransform, test_transform: INPUT_TRANSFORM_TYPE = InputTransform, predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TabularClassificationCSVInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -225,6 +231,8 @@ def from_csv( test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the @@ -297,6 +305,7 @@ def from_csv( >>> os.remove("predict_data.csv") """ ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, categorical_fields=categorical_fields, diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index ec842edf83..ecee9ea2e5 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -18,6 +18,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input from flash.core.data.io.input_transform import InputTransform +from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.paths import PATH_TYPE from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioTextClassificationInput from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING @@ -60,6 +61,7 @@ def from_csv( val_transform: Optional[Dict[str, Callable]] = InputTransform, test_transform: Optional[Dict[str, Callable]] = InputTransform, predict_transform: Optional[Dict[str, Callable]] = InputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TextClassificationCSVInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -85,6 +87,8 @@ def from_csv( test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the @@ -154,8 +158,8 @@ def from_csv( >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") """ - ds_kw = dict( + target_formatter=target_formatter, input_key=input_field, target_keys=target_fields, transform_kwargs=transform_kwargs, @@ -186,6 +190,7 @@ def from_json( val_transform: Optional[Dict[str, Callable]] = InputTransform, test_transform: Optional[Dict[str, Callable]] = InputTransform, predict_transform: Optional[Dict[str, Callable]] = InputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TextClassificationJSONInput, transform_kwargs: Optional[Dict] = None, field: Optional[str] = None, @@ -212,6 +217,8 @@ def from_json( test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. field: To specify the field that holds the data in the JSON file. @@ -280,8 +287,8 @@ def from_json( >>> os.remove("train_data.json") >>> os.remove("predict_data.json") """ - ds_kw = dict( + target_formatter=target_formatter, input_key=input_field, target_keys=target_fields, field=field, @@ -313,6 +320,7 @@ def from_parquet( val_transform: Optional[Dict[str, Callable]] = InputTransform, test_transform: Optional[Dict[str, Callable]] = InputTransform, predict_transform: Optional[Dict[str, Callable]] = InputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TextClassificationParquetInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -338,6 +346,8 @@ def from_parquet( test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the @@ -407,8 +417,8 @@ def from_parquet( >>> os.remove("train_data.parquet") >>> os.remove("predict_data.parquet") """ - ds_kw = dict( + target_formatter=target_formatter, input_key=input_field, target_keys=target_fields, transform_kwargs=transform_kwargs, @@ -439,6 +449,7 @@ def from_hf_datasets( val_transform: Optional[Dict[str, Callable]] = InputTransform, test_transform: Optional[Dict[str, Callable]] = InputTransform, predict_transform: Optional[Dict[str, Callable]] = InputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TextClassificationInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -464,6 +475,8 @@ def from_hf_datasets( test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the @@ -514,8 +527,8 @@ def from_hf_datasets( >>> del train_data >>> del predict_data """ - ds_kw = dict( + target_formatter=target_formatter, input_key=input_field, target_keys=target_fields, transform_kwargs=transform_kwargs, @@ -546,6 +559,7 @@ def from_data_frame( val_transform: Optional[Dict[str, Callable]] = InputTransform, test_transform: Optional[Dict[str, Callable]] = InputTransform, predict_transform: Optional[Dict[str, Callable]] = InputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TextClassificationDataFrameInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -572,6 +586,8 @@ def from_data_frame( test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the @@ -622,8 +638,8 @@ def from_data_frame( >>> del train_data >>> del predict_data """ - ds_kw = dict( + target_formatter=target_formatter, input_key=input_field, target_keys=target_fields, transform_kwargs=transform_kwargs, @@ -655,6 +671,7 @@ def from_lists( val_transform: Optional[Dict[str, Callable]] = InputTransform, test_transform: Optional[Dict[str, Callable]] = InputTransform, predict_transform: Optional[Dict[str, Callable]] = InputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TextClassificationListInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, @@ -680,6 +697,8 @@ def from_lists( test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the @@ -712,8 +731,8 @@ def from_lists( >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... """ - ds_kw = dict( + target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index d15c3be800..d74c7aaca6 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -20,6 +20,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE +from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.paths import PATH_TYPE from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioVideoClassificationInput from flash.core.utilities.imports import ( @@ -87,6 +88,7 @@ def from_files( val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, @@ -120,6 +122,8 @@ def from_files( test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. clip_sampler: The clip sampler to use. One of: ``"uniform"``, ``"random"``, ``"constant_clips_per_video"``. clip_duration: The duration of clips to sample. clip_sampler_kwargs: Additional keyword arguments to use when constructing the clip sampler. @@ -176,7 +180,6 @@ def from_files( >>> _ = [os.remove(f"video_{i}.mp4") for i in range(1, 4)] >>> _ = [os.remove(f"predict_video_{i}.mp4") for i in range(1, 4)] """ - ds_kw = dict( transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, @@ -193,6 +196,7 @@ def from_files( train_targets, transform=train_transform, video_sampler=video_sampler, + target_formatter=target_formatter, **ds_kw, ) target_formatter = getattr(train_input, "target_formatter", None) @@ -232,6 +236,7 @@ def from_folders( val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, @@ -284,6 +289,8 @@ def from_folders( test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. clip_sampler: The clip sampler to use. One of: ``"uniform"``, ``"random"``, ``"constant_clips_per_video"``. clip_duration: The duration of clips to sample. clip_sampler_kwargs: Additional keyword arguments to use when constructing the clip sampler. @@ -349,7 +356,6 @@ def from_folders( >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") """ - ds_kw = dict( transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, @@ -361,7 +367,12 @@ def from_folders( ) train_input = input_cls( - RunningStage.TRAINING, train_folder, transform=train_transform, video_sampler=video_sampler, **ds_kw + RunningStage.TRAINING, + train_folder, + transform=train_transform, + video_sampler=video_sampler, + target_formatter=target_formatter, + **ds_kw, ) target_formatter = getattr(train_input, "target_formatter", None) @@ -408,6 +419,7 @@ def from_data_frame( val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, @@ -453,6 +465,8 @@ def from_data_frame( test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. clip_sampler: The clip sampler to use. One of: ``"uniform"``, ``"random"``, ``"constant_clips_per_video"``. clip_duration: The duration of clips to sample. clip_sampler_kwargs: Additional keyword arguments to use when constructing the clip sampler. @@ -535,7 +549,6 @@ def from_data_frame( >>> del train_data_frame >>> del predict_data_frame """ - ds_kw = dict( transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, @@ -552,7 +565,12 @@ def from_data_frame( predict_data = (predict_data_frame, input_field, predict_videos_root, predict_resolver) train_input = input_cls( - RunningStage.TRAINING, *train_data, transform=train_transform, video_sampler=video_sampler, **ds_kw + RunningStage.TRAINING, + *train_data, + transform=train_transform, + video_sampler=video_sampler, + target_formatter=target_formatter, + **ds_kw, ) target_formatter = getattr(train_input, "target_formatter", None) @@ -599,6 +617,7 @@ def from_csv( val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, @@ -644,6 +663,8 @@ def from_csv( test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. clip_sampler: The clip sampler to use. One of: ``"uniform"``, ``"random"``, ``"constant_clips_per_video"``. clip_duration: The duration of clips to sample. clip_sampler_kwargs: Additional keyword arguments to use when constructing the clip sampler. @@ -740,7 +761,6 @@ def from_csv( >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") """ - ds_kw = dict( transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, @@ -757,7 +777,12 @@ def from_csv( predict_data = (predict_file, input_field, predict_videos_root, predict_resolver) train_input = input_cls( - RunningStage.TRAINING, *train_data, transform=train_transform, video_sampler=video_sampler, **ds_kw + RunningStage.TRAINING, + *train_data, + transform=train_transform, + video_sampler=video_sampler, + target_formatter=target_formatter, + **ds_kw, ) target_formatter = getattr(train_input, "target_formatter", None) @@ -795,6 +820,7 @@ def from_fiftyone( val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, @@ -827,6 +853,8 @@ def from_fiftyone( test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. clip_sampler: The clip sampler to use. One of: ``"uniform"``, ``"random"``, ``"constant_clips_per_video"``. clip_duration: The duration of clips to sample. clip_sampler_kwargs: Additional keyword arguments to use when constructing the clip sampler. @@ -900,7 +928,6 @@ def from_fiftyone( >>> del train_dataset >>> del predict_dataset """ - ds_kw = dict( transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, @@ -917,6 +944,7 @@ def from_fiftyone( transform=train_transform, video_sampler=video_sampler, label_field=label_field, + target_formatter=target_formatter, **ds_kw, ) target_formatter = getattr(train_input, "target_formatter", None)