Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Docstrings for StyleTransferData (#1100)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Jan 5, 2022
1 parent 0b8d27d commit 7f45fdf
Showing 1 changed file with 211 additions and 1 deletion.
212 changes: 211 additions & 1 deletion flash/image/style_transfer/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,22 @@
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input import Input
from flash.core.utilities.imports import _IMAGE_AVAILABLE
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.image.classification.input import ImageClassificationFilesInput, ImageClassificationFolderInput
from flash.image.data import ImageNumpyInput, ImageTensorInput
from flash.image.style_transfer.input_transform import StyleTransferInputTransform

__all__ = ["StyleTransferInputTransform", "StyleTransferData"]
# Skip doctests if requirements aren't available
if not _IMAGE_AVAILABLE:
__doctest_skip__ = ["StyleTransferData", "StyleTransferData.*"]


class StyleTransferData(DataModule):
"""The ``StyleTransferData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of
classmethods for loading data for image style transfer."""

input_transform_cls = StyleTransferInputTransform

@classmethod
Expand All @@ -42,6 +48,60 @@ def from_files(
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any
) -> "StyleTransferData":
"""Load the :class:`~flash.image.style_transfer.data.StyleTransferData` from lists of image files.
The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``,
``.tiff``, ``.webp``, and ``.npy``.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
train_files: The list of image files to use when training.
predict_files: The list of image files to use when predicting.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.image.style_transfer.data.StyleTransferData`.
Examples
________
.. testsetup::
>>> from PIL import Image
>>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
>>> _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)]
>>> _ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)]
.. doctest::
>>> from flash import Trainer
>>> from flash.image import StyleTransfer, StyleTransferData
>>> datamodule = StyleTransferData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> model = StyleTransfer()
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
.. testcleanup::
>>> import os
>>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)]
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
Expand All @@ -66,6 +126,73 @@ def from_folders(
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any
) -> "StyleTransferData":
"""Load the :class:`~flash.image.style_transfer.data.StyleTransferData` from folders containing images.
The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``,
``.tiff``, ``.webp``, and ``.npy``.
Here's the required folder structure:
.. code-block::
train_folder
├── image_1.png
├── image_2.png
├── image_3.png
...
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
train_folder: The folder containing images to use when training.
predict_folder: The folder containing images to use when predicting.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.image.style_transfer.data.StyleTransferData`.
Examples
________
.. testsetup::
>>> import os
>>> from PIL import Image
>>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
>>> os.makedirs("train_folder", exist_ok=True)
>>> os.makedirs("predict_folder", exist_ok=True)
>>> _ = [rand_image.save(os.path.join("train_folder", f"image_{i}.png")) for i in range(1, 4)]
>>> _ = [rand_image.save(os.path.join("predict_folder", f"predict_image_{i}.png")) for i in range(1, 4)]
.. doctest::
>>> from flash import Trainer
>>> from flash.image import StyleTransfer, StyleTransferData
>>> datamodule = StyleTransferData.from_folders(
... train_folder="train_folder",
... predict_folder="predict_folder",
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> model = StyleTransfer()
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
.. testcleanup::
>>> import shutil
>>> shutil.rmtree("train_folder")
>>> shutil.rmtree("predict_folder")
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
Expand All @@ -90,6 +217,47 @@ def from_numpy(
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any
) -> "StyleTransferData":
"""Load the :class:`~flash.image.style_transfer.data.StyleTransferData` from numpy arrays (or lists of
arrays).
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
train_data: The numpy array or list of arrays to use when training.
predict_data: The numpy array or list of arrays to use when predicting.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.image.style_transfer.data.StyleTransferData`.
Examples
________
.. doctest::
>>> import numpy as np
>>> from flash import Trainer
>>> from flash.image import StyleTransfer, StyleTransferData
>>> datamodule = StyleTransferData.from_numpy(
... train_data=[np.random.rand(3, 64, 64), np.random.rand(3, 64, 64), np.random.rand(3, 64, 64)],
... predict_data=[np.random.rand(3, 64, 64)],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> model = StyleTransfer()
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
Expand All @@ -114,6 +282,48 @@ def from_tensors(
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any
) -> "StyleTransferData":
"""Load the :class:`~flash.image.style_transfer.data.StyleTransferData` from torch tensors (or lists of
tensors).
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
train_data: The torch tensor or list of tensors to use when training.
predict_data: The torch tensor or list of tensors to use when predicting.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.image.style_transfer.data.StyleTransferData`.
Examples
________
.. doctest::
>>> import torch
>>> from flash import Trainer
>>> from flash.image import StyleTransfer, StyleTransferData
>>> datamodule = StyleTransferData.from_tensors(
... train_data=[torch.rand(3, 64, 64), torch.rand(3, 64, 64), torch.rand(3, 64, 64)],
... predict_data=[torch.rand(3, 64, 64)],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> model = StyleTransfer()
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
Expand Down

0 comments on commit 7f45fdf

Please sign in to comment.