Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add ability to override the TargetFormatter in classification tasks (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Feb 15, 2022
1 parent 4c62482 commit 1b1b939
Show file tree
Hide file tree
Showing 10 changed files with 345 additions and 55 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for passing the `Output` object (or a string e.g. `"labels"`) to the `flash.Trainer.predict` method ([#1157](https://github.com/PyTorchLightning/lightning-flash/pull/1157))

- Added support for passing the `TargetFormatter` object to `from_*` methods for classification to override target handling ([#1171](https://github.com/PyTorchLightning/lightning-flash/pull/1171))

### Changed

- Changed `Wav2Vec2Processor` to `AutoProcessor` and seperate it from backbone [optional] ([#1075](https://github.com/PyTorchLightning/lightning-flash/pull/1075))
Expand Down
177 changes: 177 additions & 0 deletions docs/source/general/classification_targets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Formatting Classification Targets

This guide details the different target formats supported by classification tasks in Flash.
By default, the target format and any additional metadata (``labels``, ``num_classes``, ``multi_label``) will be inferred from your training data.
You can override this behaviour by passing your own :class:`~flash.core.data.utilities.classification.TargetFormatter` using the ``target_formatter`` argument.

.. testsetup:: targets

Expand Down Expand Up @@ -45,6 +46,28 @@ Here's an example:
>>> datamodule.multi_label
False

Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.SingleNumericTargetFormatter` to override the behaviour.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import SingleNumericTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[0, 1, 0],
... target_formatter=SingleNumericTargetFormatter(labels=["dog", "cat", "rabbit"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
False

Labels
______

Expand All @@ -70,6 +93,28 @@ Here's an example:
>>> datamodule.multi_label
False

Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.SingleLabelTargetFormatter` to override the behaviour.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "dog", "cat"],
... target_formatter=SingleLabelTargetFormatter(labels=["dog", "cat", "rabbit"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
False

One-hot Binaries
________________

Expand All @@ -95,6 +140,28 @@ Here's an example:
>>> datamodule.multi_label
False

Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.SingleBinaryTargetFormatter` to override the behaviour.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import SingleBinaryTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[1, 0], [0, 1], [1, 0]],
... target_formatter=SingleLabelTargetFormatter(labels=["dog", "cat"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['dog', 'cat']
>>> datamodule.multi_label
False

Multi Label
###########

Expand Down Expand Up @@ -125,6 +192,28 @@ Here's an example:
>>> datamodule.multi_label
True

Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.MultiNumericTargetFormatter` to override the behaviour.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import MultiNumericTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[0], [0, 1], [1, 2]],
... target_formatter=MultiNumericTargetFormatter(labels=["dog", "cat", "rabbit"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True

Labels
______

Expand All @@ -150,6 +239,28 @@ Here's an example:
>>> datamodule.multi_label
True

Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.MultiLabelTargetFormatter` to override the behaviour.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import MultiLabelTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[["cat"], ["cat", "dog"], ["dog", "rabbit"]],
... target_formatter=MultiLabelTargetFormatter(labels=["dog", "cat", "rabbit"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True

Comma Delimited
_______________

Expand All @@ -175,6 +286,28 @@ Here's an example:
>>> datamodule.multi_label
True

Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.CommaDelimitedMultiLabelTargetFormatter` to override the behaviour.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import CommaDelimitedMultiLabelTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "cat,dog", "dog,rabbit"],
... target_formatter=CommaDelimitedMultiLabelTargetFormatter(labels=["dog", "cat", "rabbit"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True

Space Delimited
_______________

Expand All @@ -200,6 +333,28 @@ Here's an example:
>>> datamodule.multi_label
True

Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.SpaceDelimitedTargetFormatter` to override the behaviour.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import SpaceDelimitedTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "cat dog", "dog rabbit"],
... target_formatter=SpaceDelimitedTargetFormatter(labels=["dog", "cat", "rabbit"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True

Multi-hot Binaries
__________________

Expand All @@ -224,3 +379,25 @@ Here's an example:
True
>>> datamodule.multi_label
True

Alternatively, you can provide a :class:`~flash.core.data.utilities.classification.MultiBinaryTargetFormatter` to override the behaviour.
Here's an example:

.. doctest:: targets

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import MultiBinaryTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[1, 0, 0], [1, 1, 0], [0, 1, 1]],
... target_formatter=MultiBinaryTargetFormatter(labels=["dog", "cat", "rabbit"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True
Loading

0 comments on commit 1b1b939

Please sign in to comment.