From 8c53cd641136086fbdc3ffc63224f72a7ec573fc Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 Jan 2022 16:57:17 +0000 Subject: [PATCH] Docstrings for `SemanticSegmentationData` (#1101) --- CHANGELOG.md | 2 + flash/image/segmentation/data.py | 318 +++++++++++++++++++++++++++++- flash/image/segmentation/input.py | 9 +- 3 files changed, 318 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 21fbb8ab07..0479946c38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where `SpeechRecognitionData.from_datasets` did not work as expected ([#1097](https://github.com/PyTorchLightning/lightning-flash/pull/1097)) +- Fixed a bug where loading data for prediction with `SemanticSegmentationData.from_folders` raised an error ([#1101](https://github.com/PyTorchLightning/lightning-flash/pull/1101)) + ### Removed ## [0.6.0] - 2021-13-12 diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 60a8477577..9324620b40 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -21,7 +21,7 @@ from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, lazy_import from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.image.segmentation.input import ( @@ -42,9 +42,14 @@ else: fo = None +# Skip doctests if requirements aren't available +if not _IMAGE_AVAILABLE: + __doctest_skip__ = ["SemanticSegmentationData", "SemanticSegmentationData.*"] + class SemanticSegmentationData(DataModule): - """Data module for semantic segmentation tasks.""" + """The ``SemanticSegmentationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of + classmethods for loading data for semantic segmentation.""" input_transforms_registry = FlashRegistry("input_transforms") input_transform_cls = SemanticSegmentationInputTransform @@ -73,6 +78,78 @@ def from_files( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SemanticSegmentationData": + """Load the :class:`~flash.image.segmentation.data.SemanticSegmentationData` from lists of input files and + corresponding lists of mask files. + + The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, + ``.tiff``, ``.webp``, and ``.npy``. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_files: The list of image files to use when training. + train_targets: The list of mask files to use when training. + val_files: The list of image files to use when validating. + val_targets: The list of mask files to use when validating. + test_files: The list of image files to use when testing. + test_targets: The list of mask files to use when testing. + predict_files: The list of image files to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + num_classes: The number of segmentation classes. + labels_map: An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. + If not provided, a random mapping will be used. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.segmentation.data.SemanticSegmentationData`. + + Examples + ________ + + .. testsetup:: + + >>> from PIL import Image + >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) + >>> rand_mask= Image.fromarray(np.random.randint(0, 10, (64, 64), dtype="uint8")) + >>> _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)] + >>> _ = [rand_mask.save(f"mask_{i}.png") for i in range(1, 4)] + >>> _ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)] + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.image import SemanticSegmentation, SemanticSegmentationData + >>> datamodule = SemanticSegmentationData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... train_targets=["mask_1.png", "mask_2.png", "mask_3.png"], + ... predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"], + ... transform_kwargs=dict(image_size=(128, 128)), + ... num_classes=10, + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 10 + >>> model = SemanticSegmentation(backbone="resnet18", num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import os + >>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)] + >>> _ = [os.remove(f"mask_{i}.png") for i in range(1, 4)] + >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] + """ ds_kw = dict( data_pipeline_state=DataPipelineState(), @@ -110,6 +187,117 @@ def from_folders( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SemanticSegmentationData": + """Load the :class:`~flash.image.segmentation.data.SemanticSegmentationData` from folders containing image + files and folders containing mask files. + + The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, + ``.tiff``, ``.webp``, and ``.npy``. + For train, test, and validation data, the folders are expected to contain the images with a corresponding target + folder which contains the mask in a file of the same name. + For example, if your ``train_images`` folder (passed to the ``train_folder`` argument) looks like this: + + .. code-block:: + + train_images + ├── image_1.png + ├── image_2.png + ├── image_3.png + ... + + your ``train_masks`` folder (passed to the ``train_target_folder`` argument) would need to look like this: + + .. code-block:: + + train_masks + ├── image_1.png + ├── image_2.png + ├── image_3.png + ... + + For prediction, the folder is expected to contain the files for inference, like this: + + .. code-block:: + + predict_folder + ├── predict_image_1.png + ├── predict_image_2.png + ├── predict_image_3.png + ... + + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_folder: The folder containing images to use when training. + train_target_folder: The folder containing masks to use when training (files should have the same name as + the files in the ``train_folder``). + val_folder: The folder containing images to use when validating. + val_target_folder: The folder containing masks to use when validating (files should have the same name as + the files in the ``train_folder``). + test_folder: The folder containing images to use when testing. + test_target_folder: The folder containing masks to use when testing (files should have the same name as + the files in the ``train_folder``). + predict_folder: The folder containing images to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + num_classes: The number of segmentation classes. + labels_map: An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. + If not provided, a random mapping will be used. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.segmentation.data.SemanticSegmentationData`. + + Examples + ________ + + .. testsetup:: + + >>> import os + >>> from PIL import Image + >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) + >>> rand_mask = Image.fromarray(np.random.randint(0, 10, (64, 64), dtype="uint8")) + >>> os.makedirs("train_images", exist_ok=True) + >>> os.makedirs("train_masks", exist_ok=True) + >>> os.makedirs("predict_folder", exist_ok=True) + >>> _ = [rand_image.save(os.path.join("train_images", f"image_{i}.png")) for i in range(1, 4)] + >>> _ = [rand_mask.save(os.path.join("train_masks", f"image_{i}.png")) for i in range(1, 4)] + >>> _ = [rand_image.save(os.path.join("predict_folder", f"predict_image_{i}.png")) for i in range(1, 4)] + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.image import SemanticSegmentation, SemanticSegmentationData + >>> datamodule = SemanticSegmentationData.from_folders( + ... train_folder="train_images", + ... train_target_folder="train_masks", + ... predict_folder="predict_folder", + ... transform_kwargs=dict(image_size=(128, 128)), + ... num_classes=10, + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 10 + >>> model = SemanticSegmentation(backbone="resnet18", num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import shutil + >>> shutil.rmtree("train_images") + >>> shutil.rmtree("train_masks") + >>> shutil.rmtree("predict_folder") + """ ds_kw = dict( data_pipeline_state=DataPipelineState(), @@ -133,9 +321,9 @@ def from_numpy( train_data: Optional[Collection[np.ndarray]] = None, train_targets: Optional[Collection[np.ndarray]] = None, val_data: Optional[Collection[np.ndarray]] = None, - val_targets: Optional[Sequence[np.ndarray]] = None, + val_targets: Optional[Collection[np.ndarray]] = None, test_data: Optional[Collection[np.ndarray]] = None, - test_targets: Optional[Sequence[np.ndarray]] = None, + test_targets: Optional[Collection[np.ndarray]] = None, predict_data: Optional[Collection[np.ndarray]] = None, train_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, val_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, @@ -147,6 +335,65 @@ def from_numpy( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SemanticSegmentationData": + """Load the :class:`~flash.image.segmentation.data.SemanticSegmentationData` from numpy arrays containing + images (or lists of arrays) and corresponding numpy arrays containing masks (or lists of arrays). + + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_data: The numpy array or list of arrays containing images to use when training. + train_targets: The numpy array or list of arrays containing masks to use when training. + val_data: The numpy array or list of arrays containing images to use when validating. + val_targets: The numpy array or list of arrays containing masks to use when validating. + test_data: The numpy array or list of arrays containing images to use when testing. + test_targets: The numpy array or list of arrays containing masks to use when testing. + predict_data: The numpy array or list of arrays to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + num_classes: The number of segmentation classes. + labels_map: An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. + If not provided, a random mapping will be used. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.segmentation.data.SemanticSegmentationData`. + + Examples + ________ + + .. doctest:: + + >>> import numpy as np + >>> from flash import Trainer + >>> from flash.image import SemanticSegmentation, SemanticSegmentationData + >>> datamodule = SemanticSegmentationData.from_numpy( + ... train_data=[np.random.rand(3, 64, 64), np.random.rand(3, 64, 64), np.random.rand(3, 64, 64)], + ... train_targets=[ + ... np.random.randint(0, 10, (1, 64, 64), dtype="uint8"), + ... np.random.randint(0, 10, (1, 64, 64), dtype="uint8"), + ... np.random.randint(0, 10, (1, 64, 64), dtype="uint8"), + ... ], + ... predict_data=[np.random.rand(3, 64, 64)], + ... transform_kwargs=dict(image_size=(128, 128)), + ... num_classes=10, + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 10 + >>> model = SemanticSegmentation(backbone="resnet18", num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + """ ds_kw = dict( data_pipeline_state=DataPipelineState(), @@ -170,9 +417,9 @@ def from_tensors( train_data: Optional[Collection[torch.Tensor]] = None, train_targets: Optional[Collection[torch.Tensor]] = None, val_data: Optional[Collection[torch.Tensor]] = None, - val_targets: Optional[Sequence[torch.Tensor]] = None, + val_targets: Optional[Collection[torch.Tensor]] = None, test_data: Optional[Collection[torch.Tensor]] = None, - test_targets: Optional[Sequence[torch.Tensor]] = None, + test_targets: Optional[Collection[torch.Tensor]] = None, predict_data: Optional[Collection[torch.Tensor]] = None, train_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, val_transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, @@ -184,6 +431,65 @@ def from_tensors( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SemanticSegmentationData": + """Load the :class:`~flash.image.segmentation.data.SemanticSegmentationData` from torch tensors containing + images (or lists of tensors) and corresponding torch tensors containing masks (or lists of tensors). + + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_data: The torch tensor or list of tensors containing images to use when training. + train_targets: The torch tensor or list of tensors containing masks to use when training. + val_data: The torch tensor or list of tensors containing images to use when validating. + val_targets: The torch tensor or list of tensors containing masks to use when validating. + test_data: The torch tensor or list of tensors containing images to use when testing. + test_targets: The torch tensor or list of tensors containing masks to use when testing. + predict_data: The torch tensor or list of tensors to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + num_classes: The number of segmentation classes. + labels_map: An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. + If not provided, a random mapping will be used. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.segmentation.data.SemanticSegmentationData`. + + Examples + ________ + + .. doctest:: + + >>> import torch + >>> from flash import Trainer + >>> from flash.image import SemanticSegmentation, SemanticSegmentationData + >>> datamodule = SemanticSegmentationData.from_tensors( + ... train_data=[torch.rand(3, 64, 64), torch.rand(3, 64, 64), torch.rand(3, 64, 64)], + ... train_targets=[ + ... torch.randint(10, (1, 64, 64)), + ... torch.randint(10, (1, 64, 64)), + ... torch.randint(10, (1, 64, 64)), + ... ], + ... predict_data=[torch.rand(3, 64, 64)], + ... transform_kwargs=dict(image_size=(128, 128)), + ... num_classes=10, + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 10 + >>> model = SemanticSegmentation(backbone="resnet18", num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + """ ds_kw = dict( data_pipeline_state=DataPipelineState(), diff --git a/flash/image/segmentation/input.py b/flash/image/segmentation/input.py index 1201597a0b..cc5c0bc64e 100644 --- a/flash/image/segmentation/input.py +++ b/flash/image/segmentation/input.py @@ -20,10 +20,9 @@ from flash.core.data.io.input import DataKeys, ImageLabelsMap, Input from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE from flash.core.data.utilities.samples import to_samples -from flash.core.data.utils import image_default_loader from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TORCHVISION_AVAILABLE, lazy_import -from flash.image.data import ImageDeserializer, IMG_EXTENSIONS +from flash.image.data import image_loader, ImageDeserializer, IMG_EXTENSIONS from flash.image.segmentation.output import SegmentationLabelsOutput SampleCollection = None @@ -101,12 +100,12 @@ def load_data( if mask_files is None: files = filter_valid_files(files, valid_extensions=IMG_EXTENSIONS) else: - files, masks = filter_valid_files(files, mask_files, valid_extensions=IMG_EXTENSIONS) + files, mask_files = filter_valid_files(files, mask_files, valid_extensions=IMG_EXTENSIONS) return to_samples(files, mask_files) def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: filepath = sample[DataKeys.INPUT] - sample[DataKeys.INPUT] = FT.to_tensor(image_default_loader(filepath)) + sample[DataKeys.INPUT] = FT.to_tensor(image_loader(filepath)) if DataKeys.TARGET in sample: sample[DataKeys.TARGET] = torchvision.io.read_image(sample[DataKeys.TARGET])[0] sample = super().load_sample(sample) @@ -141,7 +140,7 @@ def load_data( files.sort() mask_files.sort() return super().load_data(files, mask_files) - return super().load_data(files) + return super().load_data([os.path.join(folder, file) for file in files]) class SemanticSegmentationFiftyOneInput(SemanticSegmentationFilesInput):