From defbaceb0440fe999c9cf11123e163b971f09de5 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 14 Feb 2022 12:27:33 +0000 Subject: [PATCH] Update formatting targets guide (#1165) --- CHANGELOG.md | 2 + .../source/general/classification_targets.rst | 221 +++++++++++++++++- flash/core/data/data_module.py | 22 +- flash/core/data/utilities/classification.py | 2 +- .../data/utilities/test_classification.py | 2 +- 5 files changed, 234 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 793070c1b1..f213a32b41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,6 +80,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where backbones for the `ObjectDetector`, `KeypointDetector`, and `InstanceSegmentation` tasks were not always frozen correctly when finetuning ([#1163](https://github.com/PyTorchLightning/lightning-flash/pull/1163)) +- Fixed a bug where `DataModule.multi_label` would sometimes be `None` when it had been inferred to be `False` ([#1165](https://github.com/PyTorchLightning/lightning-flash/pull/1165)) + ### Removed - Removed the `Seq2SeqData` base class (use `TranslationData` or `SummarizationData` directly) ([#1128](https://github.com/PyTorchLightning/lightning-flash/pull/1128)) diff --git a/docs/source/general/classification_targets.rst b/docs/source/general/classification_targets.rst index 2315dac0d1..afa4f66f8f 100644 --- a/docs/source/general/classification_targets.rst +++ b/docs/source/general/classification_targets.rst @@ -4,4 +4,223 @@ Formatting Classification Targets ********************************* -.. note:: The contents of this page are currently being updated. Stay tuned! +This guide details the different target formats supported by classification tasks in Flash. +By default, the target format and any additional metadata (``labels``, ``num_classes``, ``multi_label``) will be inferred from your training data. + +.. testsetup:: targets + + import numpy as np + from PIL import Image + + rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) + _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)] + +Single Label +############ + +Classification targets are described as single label (``DataModule.multi_label = False``) if each data sample corresponds to a single class. + +Class Indexes +_____________ + +Targets formatted as class indexes are represented by a single number, e.g. ``train_targets = [0, 1, 0]``. +No ``labels`` will be inferred. +The inferred ``num_classes`` is the maximum index plus one (we assume that class indexes are zero-based). +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=[0, 1, 0], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 2 + >>> datamodule.labels is None + True + >>> datamodule.multi_label + False + +Labels +______ + +Targets formatted as labels are represented by a single string, e.g. ``train_targets = ["cat", "dog", "cat"]``. +The inferred ``labels`` will be the unique labels in the train targets sorted alphanumerically. +The inferred ``num_classes`` is the number of labels. +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=["cat", "dog", "cat"], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 2 + >>> datamodule.labels + ['cat', 'dog'] + >>> datamodule.multi_label + False + +One-hot Binaries +________________ + +Targets formatted as one-hot binaries are represented by a binary list with a single index (the target class index) set to ``1``, e.g. ``train_targets = [[1, 0], [0, 1], [1, 0]]``. +No ``labels`` will be inferred. +The inferred ``num_classes`` is the length of the binary list. +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=[[1, 0], [0, 1], [1, 0]], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 2 + >>> datamodule.labels is None + True + >>> datamodule.multi_label + False + +Multi Label +########### + +Classification targets are described as multi label (``DataModule.multi_label = True``) if each data sample corresponds to zero or more (and perhaps many) classes. + +Class Indexes +_____________ + +Targets formatted as multi label class indexes are represented by a list of class indexes, e.g. ``train_targets = [[0], [0, 1], [1, 2]]``. +No ``labels`` will be inferred. +The inferred ``num_classes`` is the maximum target value plus one (we assume that targets are zero-based). +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=[[0], [0, 1], [1, 2]], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels is None + True + >>> datamodule.multi_label + True + +Labels +______ + +Targets formatted as multi label are represented by a list of strings, e.g. ``train_targets = [["cat"], ["cat", "dog"], ["dog", "rabbit"]]``. +The inferred ``labels`` will be the unique labels in the train targets sorted alphanumerically. +The inferred ``num_classes`` is the number of labels. +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=[["cat"], ["cat", "dog"], ["dog", "rabbit"]], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['cat', 'dog', 'rabbit'] + >>> datamodule.multi_label + True + +Comma Delimited +_______________ + +Targets formatted as comma delimited mutli label are given as comma delimited strings, e.g. ``train_targets = ["cat", "cat,dog", "dog,rabbit"]``. +The inferred ``labels`` will be the unique labels in the train targets sorted alphanumerically. +The inferred ``num_classes`` is the number of labels. +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=["cat", "cat,dog", "dog,rabbit"], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['cat', 'dog', 'rabbit'] + >>> datamodule.multi_label + True + +Space Delimited +_______________ + +Targets formatted as space delimited mutli label are given as space delimited strings, e.g. ``train_targets = ["cat", "cat dog", "dog rabbit"]``. +The inferred ``labels`` will be the unique labels in the train targets sorted alphanumerically. +The inferred ``num_classes`` is the number of labels. +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=["cat", "cat dog", "dog rabbit"], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['cat', 'dog', 'rabbit'] + >>> datamodule.multi_label + True + +Multi-hot Binaries +__________________ + +Targets formatted as one-hot binaries are represented by a binary list with a zero or more indices (the target class indices) set to ``1``, e.g. ``train_targets = [[1, 0, 0], [1, 1, 0], [0, 1, 1]]``. +No ``labels`` will be inferred. +The inferred ``num_classes`` is the length of the binary list. +Here's an example: + +.. doctest:: targets + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> datamodule = ImageClassificationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=[[1, 0, 0], [1, 1, 0], [0, 1, 1]], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels is None + True + >>> datamodule.multi_label + True diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 8d3792edaf..9056805542 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -481,29 +481,27 @@ def show_predict_batch(self, hooks_names: Union[str, List[str]] = "load_sample", stage_name: str = _STAGES_PREFIX[RunningStage.PREDICTING] self._show_batch(stage_name, hooks_names, reset=reset) + def _get_property(self, property_name: str) -> Optional[Any]: + train = getattr(self.train_dataset, property_name, None) + val = getattr(self.val_dataset, property_name, None) + test = getattr(self.test_dataset, property_name, None) + filtered = list(filter(lambda x: x is not None, [train, val, test])) + return filtered[0] if len(filtered) > 0 else None + @property def num_classes(self) -> Optional[int]: """Property that returns the number of classes of the datamodule if a multiclass task.""" - n_cls_train = getattr(self.train_dataset, "num_classes", None) - n_cls_val = getattr(self.val_dataset, "num_classes", None) - n_cls_test = getattr(self.test_dataset, "num_classes", None) - return n_cls_train or n_cls_val or n_cls_test + return self._get_property("num_classes") @property def labels(self) -> Optional[int]: """Property that returns the labels if this ``DataModule`` contains classification data.""" - n_cls_train = getattr(self.train_dataset, "labels", None) - n_cls_val = getattr(self.val_dataset, "labels", None) - n_cls_test = getattr(self.test_dataset, "labels", None) - return n_cls_train or n_cls_val or n_cls_test + return self._get_property("labels") @property def multi_label(self) -> Optional[bool]: """Property that returns ``True`` if this ``DataModule`` contains multi-label data.""" - multi_label_train = getattr(self.train_dataset, "multi_label", None) - multi_label_val = getattr(self.val_dataset, "multi_label", None) - multi_label_test = getattr(self.test_dataset, "multi_label", None) - return multi_label_train or multi_label_val or multi_label_test + return self._get_property("multi_label") @property def inputs(self) -> Optional[Union[Input, List[InputBase]]]: diff --git a/flash/core/data/utilities/classification.py b/flash/core/data/utilities/classification.py index d991c5f5f5..e8c264e8c2 100644 --- a/flash/core/data/utilities/classification.py +++ b/flash/core/data/utilities/classification.py @@ -345,7 +345,7 @@ def _get_target_formatter_type(target: Any) -> Type[TargetFormatter]: MultiBinaryTargetFormatter: [MultiNumericTargetFormatter], SingleBinaryTargetFormatter: [MultiBinaryTargetFormatter, MultiNumericTargetFormatter], SingleLabelTargetFormatter: [CommaDelimitedMultiLabelTargetFormatter, SpaceDelimitedTargetFormatter], - SingleNumericTargetFormatter: [MultiNumericTargetFormatter], + SingleNumericTargetFormatter: [SingleBinaryTargetFormatter, MultiNumericTargetFormatter], } diff --git a/tests/core/data/utilities/test_classification.py b/tests/core/data/utilities/test_classification.py index 79609ffe19..bb03489dd9 100644 --- a/tests/core/data/utilities/test_classification.py +++ b/tests/core/data/utilities/test_classification.py @@ -62,7 +62,7 @@ 3, ), # Ambiguous - Case([[0], [1, 2], [2, 0]], [[1, 0, 0], [0, 1, 1], [1, 0, 1]], MultiNumericTargetFormatter, None, 3), + Case([[0], [0, 1], [1, 2]], [[1, 0, 0], [1, 1, 0], [0, 1, 1]], MultiNumericTargetFormatter, None, 3), Case([[1, 0, 0], [0, 1, 1], [1, 0, 1]], [[1, 0, 0], [0, 1, 1], [1, 0, 1]], MultiBinaryTargetFormatter, None, 3), Case( [["blue"], ["green", "red"], ["red", "blue"]],