Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Data Pipeline V2: Rename all Output implementations to end in Output #1011

Merged
merged 18 commits into from
Nov 30, 2021
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed the `SpeechRecognition` task to use `AutoModelForCTC` rather than just `Wav2Vec2ForCTC` ([#874](https://github.com/PyTorchLightning/lightning-flash/pull/874))

- Added `Output` suffix to `Preds`, `FiftyOneDetectionLabels`, `SegmentationLabels`, `FiftyOneDetectionLabels`, `DetectionLabels` ([#1011](https://github.com/PyTorchLightning/lightning-flash/pull/1011))

### Deprecated

- Deprecated `flash.core.data.process.Serializer` in favour of `flash.core.data.io.output.Output` ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927))
Expand Down
6 changes: 3 additions & 3 deletions docs/source/api/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ ________________

detection.data.FiftyOneParser
detection.data.ObjectDetectionFiftyOneInput
detection.output.FiftyOneDetectionLabels
detection.output.FiftyOneDetectionLabelsOutput
detection.data.ObjectDetectionInputTransform

Keypoint Detection
Expand Down Expand Up @@ -102,8 +102,8 @@ ____________
segmentation.data.SemanticSegmentationFiftyOneInput
segmentation.data.SemanticSegmentationDeserializer
segmentation.model.SemanticSegmentationOutputTransform
segmentation.output.FiftyOneSegmentationLabels
segmentation.output.SegmentationLabels
segmentation.output.FiftyOneSegmentationLabelsOutput
segmentation.output.SegmentationLabelsOutput

.. autosummary::
:toctree: generated/
Expand Down
4 changes: 2 additions & 2 deletions docs/source/integrations/fiftyone.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ semantic segmentation tasks. Doing so is as easy as updating your model to use
one of the following outputs:

* :class:`FiftyOneLabels(return_filepath=True)<flash.core.classification.FiftyOneLabels>`
* :class:`FiftyOneSegmentationLabels(return_filepath=True)<flash.image.segmentation.output.FiftyOneSegmentationLabels>`
* :class:`FiftyOneDetectionLabels(return_filepath=True)<flash.image.detection.output.FiftyOneDetectionLabels>`
* :class:`FiftyOneSegmentationLabelsOutput(return_filepath=True)<flash.image.segmentation.output.FiftyOneSegmentationLabelsOutput>`
* :class:`FiftyOneDetectionLabelsOutput(return_filepath=True)<flash.image.detection.output.FiftyOneDetectionLabelsOutput>`

The :func:`~flash.core.integrations.fiftyone.visualize` function then lets you visualize
your predictions in the
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flash.core.data.io.output import Output


class Preds(Output):
class PredsOutput(Output):
"""A :class:`~flash.core.data.io.output.Output` which returns the "preds" from the model outputs."""

def transform(self, sample: Any) -> Union[int, List[int]]:
Expand Down
4 changes: 2 additions & 2 deletions flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Dict, List, Optional

from flash.core.adapter import AdapterTask
from flash.core.data.output import Preds
from flash.core.data.output import PredsOutput
from flash.core.registry import FlashRegistry
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE
from flash.image.detection.backbones import OBJECT_DETECTION_HEADS
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
learning_rate=learning_rate,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
output=output or Preds(),
output=output or PredsOutput(),
)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions flash/image/detection/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
fo = None


class FiftyOneDetectionLabels(Output):
class FiftyOneDetectionLabelsOutput(Output):
"""A :class:`.Output` which converts model outputs to FiftyOne detection format.

Args:
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(

def transform(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]]:
if DataKeys.METADATA not in sample:
raise ValueError("sample requires DefaultDataKeys.METADATA to use a FiftyOneDetectionLabels output.")
raise ValueError("sample requires DefaultDataKeys.METADATA to use a FiftyOneDetectionLabelsOutput output.")

labels = None
if self._labels is not None:
Expand Down
4 changes: 2 additions & 2 deletions flash/image/face_detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import fastface as ff


class DetectionLabels(Output):
class DetectionLabelsOutput(Output):
"""A :class:`.Output` which extracts predictions from sample dict."""

def transform(self, sample: Any) -> Dict[str, Any]:
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(
learning_rate=learning_rate,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
output=output or DetectionLabels(),
output=output or DetectionLabelsOutput(),
input_transform=input_transform or FaceDetectionInputTransform(),
)

Expand Down
4 changes: 2 additions & 2 deletions flash/image/instance_segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from flash.core.adapter import AdapterTask
from flash.core.data.data_pipeline import DataPipeline
from flash.core.data.output import Preds
from flash.core.data.output import PredsOutput
from flash.core.registry import FlashRegistry
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE
from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
learning_rate=learning_rate,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
output=output or Preds(),
output=output or PredsOutput(),
)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions flash/image/keypoint_detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Dict, List, Optional

from flash.core.adapter import AdapterTask
from flash.core.data.output import Preds
from flash.core.data.output import PredsOutput
from flash.core.registry import FlashRegistry
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE
from flash.image.keypoint_detection.backbones import KEYPOINT_DETECTION_HEADS
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
learning_rate=learning_rate,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
output=output or Preds(),
output=output or PredsOutput(),
)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
Expand Down
12 changes: 6 additions & 6 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from flash.core.utilities.stages import RunningStage
from flash.image.data import ImageDeserializer, IMG_EXTENSIONS
from flash.image.segmentation.output import SegmentationLabels
from flash.image.segmentation.output import SegmentationLabelsOutput
from flash.image.segmentation.transforms import default_transforms, predict_default_transforms, train_default_transforms

SampleCollection = None
Expand Down Expand Up @@ -244,7 +244,7 @@ def __init__(
self.image_size = image_size
self.num_classes = num_classes
if num_classes:
labels_map = labels_map or SegmentationLabels.create_random_labels_map(num_classes)
labels_map = labels_map or SegmentationLabelsOutput.create_random_labels_map(num_classes)

super().__init__(
train_transform=train_transform,
Expand Down Expand Up @@ -329,9 +329,9 @@ def from_input(

num_classes = input_transform_kwargs["num_classes"]

labels_map = getattr(input_transform_kwargs, "labels_map", None) or SegmentationLabels.create_random_labels_map(
num_classes
)
labels_map = getattr(
input_transform_kwargs, "labels_map", None
) or SegmentationLabelsOutput.create_random_labels_map(num_classes)

data_fetcher = data_fetcher or cls.configure_data_fetcher(labels_map)

Expand Down Expand Up @@ -494,7 +494,7 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str)
raise TypeError(f"Unknown data type. Got: {type(data)}.")
# convert images and labels to numpy and stack horizontally
image_vis: np.ndarray = self._to_numpy(image.byte())
label_tmp: torch.Tensor = SegmentationLabels.labels_to_image(label.squeeze().byte(), self.labels_map)
label_tmp: torch.Tensor = SegmentationLabelsOutput.labels_to_image(label.squeeze().byte(), self.labels_map)
label_vis: np.ndarray = self._to_numpy(label_tmp)
img_vis = np.hstack((image_vis, label_vis))
# send to visualiser
Expand Down
4 changes: 2 additions & 2 deletions flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS
from flash.image.segmentation.output import SegmentationLabels
from flash.image.segmentation.output import SegmentationLabelsOutput

if _KORNIA_AVAILABLE:
import kornia as K
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(
lr_scheduler=lr_scheduler,
metrics=metrics,
learning_rate=learning_rate,
output=output or SegmentationLabels(),
output=output or SegmentationLabelsOutput(),
output_transform=output_transform or self.output_transform_cls(),
)

Expand Down
4 changes: 2 additions & 2 deletions flash/image/segmentation/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
K = None


class SegmentationLabels(Output):
class SegmentationLabelsOutput(Output):
"""A :class:`.Output` which converts the model outputs to the label of the argmax classification per pixel in
the image for semantic segmentation tasks.

Expand Down Expand Up @@ -100,7 +100,7 @@ def transform(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor:
return labels.tolist()


class FiftyOneSegmentationLabels(SegmentationLabels):
class FiftyOneSegmentationLabelsOutput(SegmentationLabelsOutput):
"""A :class:`.Output` which converts the model outputs to FiftyOne segmentation format.

Args:
Expand Down
4 changes: 2 additions & 2 deletions flash_examples/integrations/fiftyone/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flash.core.integrations.fiftyone import visualize
from flash.core.utilities.imports import example_requires
from flash.image import ObjectDetectionData, ObjectDetector
from flash.image.detection.output import FiftyOneDetectionLabels
from flash.image.detection.output import FiftyOneDetectionLabelsOutput

example_requires("image")

Expand All @@ -42,7 +42,7 @@
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Set the output and get some predictions
model.output = FiftyOneDetectionLabels(return_filepath=True) # output FiftyOne format
model.output = FiftyOneDetectionLabelsOutput(return_filepath=True) # output FiftyOne format
predictions = trainer.predict(model, datamodule=datamodule)
predictions = list(chain.from_iterable(predictions)) # flatten batches

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.image import SemanticSegmentation
from flash.image.segmentation.output import SegmentationLabels
from flash.image.segmentation.output import SegmentationLabelsOutput

model = SemanticSegmentation.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.6.0/semantic_segmentation_model.pt"
)
model.output = SegmentationLabels(visualize=False)
model.output = SegmentationLabelsOutput(visualize=False)
model.serve()
14 changes: 7 additions & 7 deletions tests/image/detection/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,24 @@

from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE
from flash.image.detection.output import FiftyOneDetectionLabels
from flash.image.detection.output import FiftyOneDetectionLabelsOutput


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing")
class TestFiftyOneDetectionLabels:
class TestFiftyOneDetectionLabelsOutput:
@staticmethod
def test_smoke():
serial = FiftyOneDetectionLabels()
serial = FiftyOneDetectionLabelsOutput()
assert serial is not None

@staticmethod
def test_serialize_fiftyone():
labels = ["class_1", "class_2", "class_3"]
serial = FiftyOneDetectionLabels()
filepath_serial = FiftyOneDetectionLabels(return_filepath=True)
threshold_serial = FiftyOneDetectionLabels(threshold=0.9)
labels_serial = FiftyOneDetectionLabels(labels=labels)
serial = FiftyOneDetectionLabelsOutput()
filepath_serial = FiftyOneDetectionLabelsOutput(return_filepath=True)
threshold_serial = FiftyOneDetectionLabelsOutput(threshold=0.9)
labels_serial = FiftyOneDetectionLabelsOutput(labels=labels)

sample = {
DataKeys.PREDS: {
Expand Down
14 changes: 7 additions & 7 deletions tests/image/segmentation/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@

from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE
from flash.image.segmentation.output import FiftyOneSegmentationLabels, SegmentationLabels
from flash.image.segmentation.output import FiftyOneSegmentationLabelsOutput, SegmentationLabelsOutput
from tests.helpers.utils import _IMAGE_TESTING


class TestSemanticSegmentationLabels:
class TestSemanticSegmentationLabelsOutput:
@pytest.mark.skipif(not _IMAGE_TESTING, "image libraries aren't installed.")
@staticmethod
def test_smoke():
serial = SegmentationLabels()
serial = SegmentationLabelsOutput()
assert serial is not None
assert serial.labels_map is None
assert serial.visualize is False

@pytest.mark.skipif(not _IMAGE_TESTING, "image libraries aren't installed.")
@staticmethod
def test_exception():
serial = SegmentationLabels()
serial = SegmentationLabelsOutput()

with pytest.raises(Exception):
sample = torch.zeros(1, 5, 2, 3)
Expand All @@ -45,7 +45,7 @@ def test_exception():
@pytest.mark.skipif(not _IMAGE_TESTING, "image libraries aren't installed.")
@staticmethod
def test_serialize():
serial = SegmentationLabels()
serial = SegmentationLabelsOutput()

sample = torch.zeros(5, 2, 3)
sample[1, 1, 2] = 1 # add peak in class 2
Expand All @@ -59,8 +59,8 @@ def test_serialize():
@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing")
@staticmethod
def test_serialize_fiftyone():
serial = FiftyOneSegmentationLabels()
filepath_serial = FiftyOneSegmentationLabels(return_filepath=True)
serial = FiftyOneSegmentationLabelsOutput()
filepath_serial = FiftyOneSegmentationLabelsOutput(return_filepath=True)

preds = torch.zeros(5, 2, 3)
preds[1, 1, 2] = 1 # add peak in class 2
Expand Down