Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Update formatting targets guide (#1165)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Feb 14, 2022
1 parent 381aa37 commit defbace
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where backbones for the `ObjectDetector`, `KeypointDetector`, and `InstanceSegmentation` tasks were not always frozen correctly when finetuning ([#1163](https://github.com/PyTorchLightning/lightning-flash/pull/1163))

- Fixed a bug where `DataModule.multi_label` would sometimes be `None` when it had been inferred to be `False` ([#1165](https://github.com/PyTorchLightning/lightning-flash/pull/1165))

### Removed

- Removed the `Seq2SeqData` base class (use `TranslationData` or `SummarizationData` directly) ([#1128](https://github.com/PyTorchLightning/lightning-flash/pull/1128))
Expand Down
221 changes: 220 additions & 1 deletion docs/source/general/classification_targets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,223 @@
Formatting Classification Targets
*********************************

.. note:: The contents of this page are currently being updated. Stay tuned!
This guide details the different target formats supported by classification tasks in Flash.
By default, the target format and any additional metadata (``labels``, ``num_classes``, ``multi_label``) will be inferred from your training data.

.. testsetup:: targets

import numpy as np
from PIL import Image

rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
_ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)]

Single Label
############

Classification targets are described as single label (``DataModule.multi_label = False``) if each data sample corresponds to a single class.

Class Indexes
_____________

Targets formatted as class indexes are represented by a single number, e.g. ``train_targets = [0, 1, 0]``.
No ``labels`` will be inferred.
The inferred ``num_classes`` is the maximum index plus one (we assume that class indexes are zero-based).
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[0, 1, 0],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels is None
True
>>> datamodule.multi_label
False

Labels
______

Targets formatted as labels are represented by a single string, e.g. ``train_targets = ["cat", "dog", "cat"]``.
The inferred ``labels`` will be the unique labels in the train targets sorted alphanumerically.
The inferred ``num_classes`` is the number of labels.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "dog", "cat"],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> datamodule.multi_label
False

One-hot Binaries
________________

Targets formatted as one-hot binaries are represented by a binary list with a single index (the target class index) set to ``1``, e.g. ``train_targets = [[1, 0], [0, 1], [1, 0]]``.
No ``labels`` will be inferred.
The inferred ``num_classes`` is the length of the binary list.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[1, 0], [0, 1], [1, 0]],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels is None
True
>>> datamodule.multi_label
False

Multi Label
###########

Classification targets are described as multi label (``DataModule.multi_label = True``) if each data sample corresponds to zero or more (and perhaps many) classes.

Class Indexes
_____________

Targets formatted as multi label class indexes are represented by a list of class indexes, e.g. ``train_targets = [[0], [0, 1], [1, 2]]``.
No ``labels`` will be inferred.
The inferred ``num_classes`` is the maximum target value plus one (we assume that targets are zero-based).
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[0], [0, 1], [1, 2]],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels is None
True
>>> datamodule.multi_label
True

Labels
______

Targets formatted as multi label are represented by a list of strings, e.g. ``train_targets = [["cat"], ["cat", "dog"], ["dog", "rabbit"]]``.
The inferred ``labels`` will be the unique labels in the train targets sorted alphanumerically.
The inferred ``num_classes`` is the number of labels.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[["cat"], ["cat", "dog"], ["dog", "rabbit"]],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['cat', 'dog', 'rabbit']
>>> datamodule.multi_label
True

Comma Delimited
_______________

Targets formatted as comma delimited mutli label are given as comma delimited strings, e.g. ``train_targets = ["cat", "cat,dog", "dog,rabbit"]``.
The inferred ``labels`` will be the unique labels in the train targets sorted alphanumerically.
The inferred ``num_classes`` is the number of labels.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "cat,dog", "dog,rabbit"],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['cat', 'dog', 'rabbit']
>>> datamodule.multi_label
True

Space Delimited
_______________

Targets formatted as space delimited mutli label are given as space delimited strings, e.g. ``train_targets = ["cat", "cat dog", "dog rabbit"]``.
The inferred ``labels`` will be the unique labels in the train targets sorted alphanumerically.
The inferred ``num_classes`` is the number of labels.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "cat dog", "dog rabbit"],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['cat', 'dog', 'rabbit']
>>> datamodule.multi_label
True

Multi-hot Binaries
__________________

Targets formatted as one-hot binaries are represented by a binary list with a zero or more indices (the target class indices) set to ``1``, e.g. ``train_targets = [[1, 0, 0], [1, 1, 0], [0, 1, 1]]``.
No ``labels`` will be inferred.
The inferred ``num_classes`` is the length of the binary list.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[1, 0, 0], [1, 1, 0], [0, 1, 1]],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels is None
True
>>> datamodule.multi_label
True
22 changes: 10 additions & 12 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,29 +481,27 @@ def show_predict_batch(self, hooks_names: Union[str, List[str]] = "load_sample",
stage_name: str = _STAGES_PREFIX[RunningStage.PREDICTING]
self._show_batch(stage_name, hooks_names, reset=reset)

def _get_property(self, property_name: str) -> Optional[Any]:
train = getattr(self.train_dataset, property_name, None)
val = getattr(self.val_dataset, property_name, None)
test = getattr(self.test_dataset, property_name, None)
filtered = list(filter(lambda x: x is not None, [train, val, test]))
return filtered[0] if len(filtered) > 0 else None

@property
def num_classes(self) -> Optional[int]:
"""Property that returns the number of classes of the datamodule if a multiclass task."""
n_cls_train = getattr(self.train_dataset, "num_classes", None)
n_cls_val = getattr(self.val_dataset, "num_classes", None)
n_cls_test = getattr(self.test_dataset, "num_classes", None)
return n_cls_train or n_cls_val or n_cls_test
return self._get_property("num_classes")

@property
def labels(self) -> Optional[int]:
"""Property that returns the labels if this ``DataModule`` contains classification data."""
n_cls_train = getattr(self.train_dataset, "labels", None)
n_cls_val = getattr(self.val_dataset, "labels", None)
n_cls_test = getattr(self.test_dataset, "labels", None)
return n_cls_train or n_cls_val or n_cls_test
return self._get_property("labels")

@property
def multi_label(self) -> Optional[bool]:
"""Property that returns ``True`` if this ``DataModule`` contains multi-label data."""
multi_label_train = getattr(self.train_dataset, "multi_label", None)
multi_label_val = getattr(self.val_dataset, "multi_label", None)
multi_label_test = getattr(self.test_dataset, "multi_label", None)
return multi_label_train or multi_label_val or multi_label_test
return self._get_property("multi_label")

@property
def inputs(self) -> Optional[Union[Input, List[InputBase]]]:
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/utilities/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def _get_target_formatter_type(target: Any) -> Type[TargetFormatter]:
MultiBinaryTargetFormatter: [MultiNumericTargetFormatter],
SingleBinaryTargetFormatter: [MultiBinaryTargetFormatter, MultiNumericTargetFormatter],
SingleLabelTargetFormatter: [CommaDelimitedMultiLabelTargetFormatter, SpaceDelimitedTargetFormatter],
SingleNumericTargetFormatter: [MultiNumericTargetFormatter],
SingleNumericTargetFormatter: [SingleBinaryTargetFormatter, MultiNumericTargetFormatter],
}


Expand Down
2 changes: 1 addition & 1 deletion tests/core/data/utilities/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
3,
),
# Ambiguous
Case([[0], [1, 2], [2, 0]], [[1, 0, 0], [0, 1, 1], [1, 0, 1]], MultiNumericTargetFormatter, None, 3),
Case([[0], [0, 1], [1, 2]], [[1, 0, 0], [1, 1, 0], [0, 1, 1]], MultiNumericTargetFormatter, None, 3),
Case([[1, 0, 0], [0, 1, 1], [1, 0, 1]], [[1, 0, 0], [0, 1, 1], [1, 0, 1]], MultiBinaryTargetFormatter, None, 3),
Case(
[["blue"], ["green", "red"], ["red", "blue"]],
Expand Down

0 comments on commit defbace

Please sign in to comment.