diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index a42e37e096..acb3761d4a 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -19,16 +19,22 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input +from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.image.classification.input import ImageClassificationFilesInput, ImageClassificationFolderInput from flash.image.data import ImageNumpyInput, ImageTensorInput from flash.image.style_transfer.input_transform import StyleTransferInputTransform -__all__ = ["StyleTransferInputTransform", "StyleTransferData"] +# Skip doctests if requirements aren't available +if not _IMAGE_AVAILABLE: + __doctest_skip__ = ["StyleTransferData", "StyleTransferData.*"] class StyleTransferData(DataModule): + """The ``StyleTransferData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of + classmethods for loading data for image style transfer.""" + input_transform_cls = StyleTransferInputTransform @classmethod @@ -42,6 +48,60 @@ def from_files( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any ) -> "StyleTransferData": + """Load the :class:`~flash.image.style_transfer.data.StyleTransferData` from lists of image files. + + The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, + ``.tiff``, ``.webp``, and ``.npy``. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_files: The list of image files to use when training. + predict_files: The list of image files to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.style_transfer.data.StyleTransferData`. + + Examples + ________ + + .. testsetup:: + + >>> from PIL import Image + >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) + >>> _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)] + >>> _ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)] + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.image import StyleTransfer, StyleTransferData + >>> datamodule = StyleTransferData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> model = StyleTransfer() + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import os + >>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)] + >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] + """ ds_kw = dict( data_pipeline_state=DataPipelineState(), @@ -66,6 +126,73 @@ def from_folders( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any ) -> "StyleTransferData": + """Load the :class:`~flash.image.style_transfer.data.StyleTransferData` from folders containing images. + + The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, + ``.tiff``, ``.webp``, and ``.npy``. + Here's the required folder structure: + + .. code-block:: + + train_folder + ├── image_1.png + ├── image_2.png + ├── image_3.png + ... + + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_folder: The folder containing images to use when training. + predict_folder: The folder containing images to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.style_transfer.data.StyleTransferData`. + + Examples + ________ + + .. testsetup:: + + >>> import os + >>> from PIL import Image + >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) + >>> os.makedirs("train_folder", exist_ok=True) + >>> os.makedirs("predict_folder", exist_ok=True) + >>> _ = [rand_image.save(os.path.join("train_folder", f"image_{i}.png")) for i in range(1, 4)] + >>> _ = [rand_image.save(os.path.join("predict_folder", f"predict_image_{i}.png")) for i in range(1, 4)] + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.image import StyleTransfer, StyleTransferData + >>> datamodule = StyleTransferData.from_folders( + ... train_folder="train_folder", + ... predict_folder="predict_folder", + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> model = StyleTransfer() + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import shutil + >>> shutil.rmtree("train_folder") + >>> shutil.rmtree("predict_folder") + """ ds_kw = dict( data_pipeline_state=DataPipelineState(), @@ -90,6 +217,47 @@ def from_numpy( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any ) -> "StyleTransferData": + """Load the :class:`~flash.image.style_transfer.data.StyleTransferData` from numpy arrays (or lists of + arrays). + + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_data: The numpy array or list of arrays to use when training. + predict_data: The numpy array or list of arrays to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.style_transfer.data.StyleTransferData`. + + Examples + ________ + + .. doctest:: + + >>> import numpy as np + >>> from flash import Trainer + >>> from flash.image import StyleTransfer, StyleTransferData + >>> datamodule = StyleTransferData.from_numpy( + ... train_data=[np.random.rand(3, 64, 64), np.random.rand(3, 64, 64), np.random.rand(3, 64, 64)], + ... predict_data=[np.random.rand(3, 64, 64)], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> model = StyleTransfer() + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + """ ds_kw = dict( data_pipeline_state=DataPipelineState(), @@ -114,6 +282,48 @@ def from_tensors( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any ) -> "StyleTransferData": + """Load the :class:`~flash.image.style_transfer.data.StyleTransferData` from torch tensors (or lists of + tensors). + + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_data: The torch tensor or list of tensors to use when training. + predict_data: The torch tensor or list of tensors to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.style_transfer.data.StyleTransferData`. + + Examples + ________ + + .. doctest:: + + >>> import torch + >>> from flash import Trainer + >>> from flash.image import StyleTransfer, StyleTransferData + >>> datamodule = StyleTransferData.from_tensors( + ... train_data=[torch.rand(3, 64, 64), torch.rand(3, 64, 64), torch.rand(3, 64, 64)], + ... predict_data=[torch.rand(3, 64, 64)], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> model = StyleTransfer() + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + """ + ds_kw = dict( data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs,