Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Data Pipeline V2: Rename all Output implementations to end in Output #1011

Merged
merged 18 commits into from
Nov 30, 2021
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed the `SpeechRecognition` task to use `AutoModelForCTC` rather than just `Wav2Vec2ForCTC` ([#874](https://github.com/PyTorchLightning/lightning-flash/pull/874))

- Added `Output` suffix to `Preds`, `FiftyOneDetectionLabels`, `SegmentationLabels`, `FiftyOneDetectionLabels`, `DetectionLabels`, `Classes`, `FiftyOneLabels`, `Labels`, `Logits`, `Probabilities` ([#1011](https://github.com/PyTorchLightning/lightning-flash/pull/1011))

### Deprecated

- Deprecated `flash.core.data.process.Serializer` in favour of `flash.core.data.io.output.Output` ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927))
Expand Down
10 changes: 5 additions & 5 deletions docs/source/api/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ _________________________
:nosignatures:
:template: classtemplate.rst

~flash.core.classification.Classes
~flash.core.classification.ClassesOutput
~flash.core.classification.ClassificationOutput
~flash.core.classification.ClassificationTask
~flash.core.classification.FiftyOneLabels
~flash.core.classification.Labels
~flash.core.classification.Logits
~flash.core.classification.FiftyOneLabelsOutput
~flash.core.classification.LabelsOutput
~flash.core.classification.LogitsOutput
~flash.core.classification.PredsClassificationOutput
~flash.core.classification.Probabilities
~flash.core.classification.ProbabilitiesOutput

flash.core.finetuning
_____________________
Expand Down
6 changes: 3 additions & 3 deletions docs/source/api/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ ________________

detection.data.FiftyOneParser
detection.data.ObjectDetectionFiftyOneInput
detection.output.FiftyOneDetectionLabels
detection.output.FiftyOneDetectionLabelsOutput
detection.data.ObjectDetectionInputTransform

Keypoint Detection
Expand Down Expand Up @@ -103,8 +103,8 @@ ____________
segmentation.data.SemanticSegmentationFiftyOneInput
segmentation.data.SemanticSegmentationDeserializer
segmentation.model.SemanticSegmentationOutputTransform
segmentation.output.FiftyOneSegmentationLabels
segmentation.output.SegmentationLabels
segmentation.output.FiftyOneSegmentationLabelsOutput
segmentation.output.SegmentationLabelsOutput

.. autosummary::
:toctree: generated/
Expand Down
4 changes: 2 additions & 2 deletions docs/source/common/finetuning_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Here's an example of finetuning.
from pytorch_lightning import seed_everything

import flash
from flash.core.classification import Labels
from flash.core.classification import LabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

Expand Down Expand Up @@ -56,7 +56,7 @@ Once you've finetuned, use the model to predict:
.. testcode:: finetune

# Output predictions as labels, automatically inferred from the training data in part 2.
model.output = Labels()
model.output = LabelsOutput()

predictions = model.predict(
[
Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/training_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Here's an example:
from pytorch_lightning import seed_everything

import flash
from flash.core.classification import Labels
from flash.core.classification import LabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

Expand Down
4 changes: 2 additions & 2 deletions docs/source/general/predictions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ reference below).

.. code-block:: python

from flash.core.classification import Probabilities
from flash.core.classification import ProbabilitiesOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassifier

Expand All @@ -78,7 +78,7 @@ reference below).
)

# 3. Attach the Output
model.output = Probabilities()
model.output = ProbabilitiesOutput()

# 4. Predict whether the image contains an ant or a bee
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
Expand Down
6 changes: 3 additions & 3 deletions docs/source/integrations/fiftyone.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ You can visualize predictions for classification, object detection, and
semantic segmentation tasks. Doing so is as easy as updating your model to use
one of the following outputs:

* :class:`FiftyOneLabels(return_filepath=True)<flash.core.classification.FiftyOneLabels>`
* :class:`FiftyOneSegmentationLabels(return_filepath=True)<flash.image.segmentation.output.FiftyOneSegmentationLabels>`
* :class:`FiftyOneDetectionLabels(return_filepath=True)<flash.image.detection.output.FiftyOneDetectionLabels>`
* :class:`FiftyOneLabelsOutput(return_filepath=True)<flash.core.classification.FiftyOneLabelsOutput>`
* :class:`FiftyOneSegmentationLabelsOutput(return_filepath=True)<flash.image.segmentation.output.FiftyOneSegmentationLabelsOutput>`
* :class:`FiftyOneDetectionLabelsOutput(return_filepath=True)<flash.image.detection.output.FiftyOneDetectionLabelsOutput>`

The :func:`~flash.core.integrations.fiftyone.visualize` function then lets you visualize
your predictions in the
Expand Down
8 changes: 4 additions & 4 deletions docs/source/template/optional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ Specifically, it should include any formatting and transforms that should always
If you want to support different use cases that require different prediction formats, you should add some :class:`~flash.core.data.io.output.Output` implementations in an ``output.py`` file.

Some good examples are in `flash/core/classification.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/core/classification.py>`_.
Here's the :class:`~flash.core.classification.Classes` :class:`~flash.core.data.io.output.Output`:
Here's the :class:`~flash.core.classification.ClassesOutput` :class:`~flash.core.data.io.output.Output`:

.. literalinclude:: ../../../flash/core/classification.py
:language: python
:pyobject: Classes
:pyobject: ClassesOutput

Alternatively, here's the :class:`~flash.core.classification.Logits` :class:`~flash.core.data.io.output.Output`:
Alternatively, here's the :class:`~flash.core.classification.LogitsOutput` :class:`~flash.core.data.io.output.Output`:

.. literalinclude:: ../../../flash/core/classification.py
:language: python
:pyobject: Logits
:pyobject: LogitsOutput

Take a look at :ref:`predictions` to learn more.

Expand Down
33 changes: 25 additions & 8 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
import torch.nn.functional as F
import torchmetrics
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn

from flash.core.adapter import AdapterTask
from flash.core.data.io.classification_input import ClassificationState
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
output=output or Classes(multi_label=multi_label),
output=output or ClassesOutput(multi_label=multi_label),
**kwargs,
)

Expand All @@ -102,7 +102,7 @@ def __init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
output=output or Classes(multi_label=multi_label),
output=output or ClassesOutput(multi_label=multi_label),
**kwargs,
)

Expand Down Expand Up @@ -137,14 +137,14 @@ def transform(self, sample: Any) -> Any:
return sample


class Logits(PredsClassificationOutput):
class LogitsOutput(PredsClassificationOutput):
"""A :class:`.Output` which simply converts the model outputs (assumed to be logits) to a list."""

def transform(self, sample: Any) -> Any:
return super().transform(sample).tolist()


class Probabilities(PredsClassificationOutput):
class ProbabilitiesOutput(PredsClassificationOutput):
"""A :class:`.Output` which applies a softmax to the model outputs (assumed to be logits) and converts to a
list."""

Expand All @@ -155,7 +155,7 @@ def transform(self, sample: Any) -> Any:
return torch.softmax(sample, -1).tolist()


class Classes(PredsClassificationOutput):
class ClassesOutput(PredsClassificationOutput):
"""A :class:`.Output` which applies an argmax to the model outputs (either logits or probabilities) and
converts to a list.

Expand All @@ -181,7 +181,7 @@ def transform(self, sample: Any) -> Union[int, List[int]]:
return torch.argmax(sample, -1).tolist()


class Labels(Classes):
class LabelsOutput(ClassesOutput):
"""A :class:`.Output` which converts the model outputs (either logits or probabilities) to the label of the
argmax classification.

Expand Down Expand Up @@ -219,7 +219,7 @@ def transform(self, sample: Any) -> Union[int, List[int], str, List[str]]:
return classes


class FiftyOneLabels(ClassificationOutput):
class FiftyOneLabelsOutput(ClassificationOutput):
"""A :class:`.Output` which converts the model outputs to FiftyOne classification format.

Args:
Expand Down Expand Up @@ -339,3 +339,20 @@ def transform(
filepath = sample[DataKeys.METADATA]["filepath"]
return {"filepath": filepath, "predictions": fo_predictions}
return fo_predictions


class Labels(LabelsOutput):
def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False, threshold: float = 0.5):
rank_zero_deprecation(
"`Labels` was deprecated in v0.6.0 and will be removed in v0.7.0." "Please use `LabelsOutput` instead."
)
super().__init__(labels=labels, multi_label=multi_label, threshold=threshold)


class Probabilities(ProbabilitiesOutput):
def __init__(self, multi_label: bool = False):
rank_zero_deprecation(
"`Probabilities` was deprecated in v0.6.0 and will be removed in v0.7.0."
"Please use `ProbabilitiesOutput` instead."
)
super().__init__(multi_label=multi_label)
2 changes: 1 addition & 1 deletion flash/core/data/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flash.core.data.io.output import Output


class Preds(Output):
class PredsOutput(Output):
"""A :class:`~flash.core.data.io.output.Output` which returns the "preds" from the model outputs."""

def transform(self, sample: Any) -> Union[int, List[int]]:
Expand Down
4 changes: 2 additions & 2 deletions flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn

from flash.core.classification import ClassificationAdapterTask, Labels
from flash.core.classification import ClassificationAdapterTask, LabelsOutput
from flash.core.registry import FlashRegistry
from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE
from flash.image.classification.adapters import TRAINING_STRATEGIES
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
optimizer=optimizer,
lr_scheduler=lr_scheduler,
multi_label=multi_label,
output=output or Labels(multi_label=multi_label),
output=output or LabelsOutput(multi_label=multi_label),
)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Dict, List, Optional

from flash.core.adapter import AdapterTask
from flash.core.data.output import Preds
from flash.core.data.output import PredsOutput
from flash.core.registry import FlashRegistry
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE
from flash.image.detection.backbones import OBJECT_DETECTION_HEADS
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
learning_rate=learning_rate,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
output=output or Preds(),
output=output or PredsOutput(),
)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions flash/image/detection/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
fo = None


class FiftyOneDetectionLabels(Output):
class FiftyOneDetectionLabelsOutput(Output):
"""A :class:`.Output` which converts model outputs to FiftyOne detection format.

Args:
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(

def transform(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]]:
if DataKeys.METADATA not in sample:
raise ValueError("sample requires DefaultDataKeys.METADATA to use a FiftyOneDetectionLabels output.")
raise ValueError("sample requires DefaultDataKeys.METADATA to use a FiftyOneDetectionLabelsOutput output.")

labels = None
if self._labels is not None:
Expand Down
4 changes: 2 additions & 2 deletions flash/image/face_detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import fastface as ff


class DetectionLabels(Output):
class DetectionLabelsOutput(Output):
"""A :class:`.Output` which extracts predictions from sample dict."""

def transform(self, sample: Any) -> Dict[str, Any]:
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(
learning_rate=learning_rate,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
output=output or DetectionLabels(),
output=output or DetectionLabelsOutput(),
input_transform=input_transform or FaceDetectionInputTransform(),
)

Expand Down
4 changes: 2 additions & 2 deletions flash/image/instance_segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from flash.core.adapter import AdapterTask
from flash.core.data.data_pipeline import DataPipeline
from flash.core.data.output import Preds
from flash.core.data.output import PredsOutput
from flash.core.registry import FlashRegistry
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE
from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
learning_rate=learning_rate,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
output=output or Preds(),
output=output or PredsOutput(),
)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions flash/image/keypoint_detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Dict, List, Optional

from flash.core.adapter import AdapterTask
from flash.core.data.output import Preds
from flash.core.data.output import PredsOutput
from flash.core.registry import FlashRegistry
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE
from flash.image.keypoint_detection.backbones import KEYPOINT_DETECTION_HEADS
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
learning_rate=learning_rate,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
output=output or Preds(),
output=output or PredsOutput(),
)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
Expand Down
12 changes: 6 additions & 6 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from flash.core.utilities.stages import RunningStage
from flash.image.data import ImageDeserializer, IMG_EXTENSIONS
from flash.image.segmentation.output import SegmentationLabels
from flash.image.segmentation.output import SegmentationLabelsOutput
from flash.image.segmentation.transforms import default_transforms, predict_default_transforms, train_default_transforms

SampleCollection = None
Expand Down Expand Up @@ -244,7 +244,7 @@ def __init__(
self.image_size = image_size
self.num_classes = num_classes
if num_classes:
labels_map = labels_map or SegmentationLabels.create_random_labels_map(num_classes)
labels_map = labels_map or SegmentationLabelsOutput.create_random_labels_map(num_classes)

super().__init__(
train_transform=train_transform,
Expand Down Expand Up @@ -329,9 +329,9 @@ def from_input(

num_classes = input_transform_kwargs["num_classes"]

labels_map = getattr(input_transform_kwargs, "labels_map", None) or SegmentationLabels.create_random_labels_map(
num_classes
)
labels_map = getattr(
input_transform_kwargs, "labels_map", None
) or SegmentationLabelsOutput.create_random_labels_map(num_classes)

data_fetcher = data_fetcher or cls.configure_data_fetcher(labels_map)

Expand Down Expand Up @@ -494,7 +494,7 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str)
raise TypeError(f"Unknown data type. Got: {type(data)}.")
# convert images and labels to numpy and stack horizontally
image_vis: np.ndarray = self._to_numpy(image.byte())
label_tmp: torch.Tensor = SegmentationLabels.labels_to_image(label.squeeze().byte(), self.labels_map)
label_tmp: torch.Tensor = SegmentationLabelsOutput.labels_to_image(label.squeeze().byte(), self.labels_map)
label_vis: np.ndarray = self._to_numpy(label_tmp)
img_vis = np.hstack((image_vis, label_vis))
# send to visualiser
Expand Down
4 changes: 2 additions & 2 deletions flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS
from flash.image.segmentation.output import SegmentationLabels
from flash.image.segmentation.output import SegmentationLabelsOutput

if _KORNIA_AVAILABLE:
import kornia as K
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(
lr_scheduler=lr_scheduler,
metrics=metrics,
learning_rate=learning_rate,
output=output or SegmentationLabels(),
output=output or SegmentationLabelsOutput(),
output_transform=output_transform or self.output_transform_cls(),
)

Expand Down
Loading