Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Improve data docs (#355)
Browse files Browse the repository at this point in the history
* Fix pytorch link

* typo

* Fix section title

* typo

* typo

* Ignore data dir

* Update docs

* Update docs/source/general/data.rst
  • Loading branch information
akihironitta authored Jul 8, 2021
1 parent 8077f7b commit 74dbee8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 30 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ data_folder
*.zip
flash_notebooks/*.py
flash_notebooks/data
/data
MNIST*
titanic
hymenoptera_data
Expand Down
79 changes: 49 additions & 30 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,28 @@ Here are common terms you need to be familiar with:

* - Term
- Definition
* - :class:`~flash.core.data.process.Deserializer`
- The :class:`~flash.core.data.process.Deserializer` provides a single :meth:`~flash.core.data.process.Deserializer.deserialize` method.
* - :class:`~flash.core.data.data_module.DataModule`
- The :class:`~flash.core.data.data_module.DataModule` contains the datasets, transforms and dataloaders.
* - :class:`~flash.core.data.data_pipeline.DataPipeline`
- The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage: :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, and :class:`~flash.core.data.process.Serializer` objects.
- The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.Deserializer`, :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, and :class:`~flash.core.data.process.Serializer` objects.
* - :class:`~flash.core.data.data_source.DataSource`
- The :class:`~flash.core.data.data_source.DataSource` provides :meth:`~flash.core.data.data_source.DataSource.load_data` and :meth:`~flash.core.data.data_source.DataSource.load_sample` hooks for creating data sets from metadata (such as folder names).
* - :class:`~flash.core.data.process.Preprocess`
- The :class:`~flash.core.data.process.Preprocess` provides a simple hook-based API to encapsulate your pre-processing logic.
These hooks (such as :meth:`~flash.core.data.process.Preprocess.pre_tensor_transform`) enable transformations to be applied to your data at every point along the pipeline (including on the device).
The :class:`~flash.core.data.data_pipeline.DataPipeline` contains a system to call the right hooks when needed.
The :class:`~flash.core.data.process.Preprocess` hooks can be either overriden directly or provided as a dictionary of transforms (mapping hook name to callable transform).
The :class:`~flash.core.data.process.Preprocess` hooks can be either overridden directly or provided as a dictionary of transforms (mapping hook name to callable transform).
* - :class:`~flash.core.data.process.Postprocess`
- The :class:`~flash.core.data.process.Postprocess` provides a simple hook-based API to encapsulate your post-processing logic.
The :class:`~flash.core.data.process.Postprocess` hooks cover from model outputs to predictions export.
* - :class:`~flash.core.data.process.Serializer`
- The :class:`~flash.core.data.process.Serializer` provides a single ``serialize`` method that is used to convert model outputs (after the :class:`~flash.core.data.process.Postprocess`) to the desired output format during prediction.
- The :class:`~flash.core.data.process.Serializer` provides a single :meth:`~flash.core.data.process.Serializer.serialize` method that is used to convert model outputs (after the :class:`~flash.core.data.process.Postprocess`) to the desired output format during prediction.


*******************************************
How to use out-of-the-box flashdatamodules
How to use out-of-the-box Flash DataModules
*******************************************

Flash provides several DataModules with helpers functions.
Expand All @@ -49,14 +52,14 @@ Check out the :ref:`image_classification` section (or the sections for any of ou
Data Processing
***************

Currently, it is common practice to implement a :class:`pytorch.utils.data.Dataset`
and provide it to a :class:`pytorch.utils.data.DataLoader`.
Currently, it is common practice to implement a :class:`torch.utils.data.Dataset`
and provide it to a :class:`torch.utils.data.DataLoader`.
However, after model training, it requires a lot of engineering overhead to make inference on raw data and deploy the model in production environment.
Usually, extra processing logic should be added to bridge the gap between training data and raw data.

The :class:`~flash.core.data.data_source.DataSource` class can be used to generate data sets from multiple sources (e.g. folders, numpy, etc.), that can then all be transformed in the same way.
The :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess` classes can be used to manage the preprocessing and postprocessing transforms.
The :class:`~flash.core.data.process.Serializer` class provides the logic for converting :class:`~flash.core.data.process.Postprocess` outputs to the desired predict format (e.g. classes, labels, probabilites, etc.).
The :class:`~flash.core.data.process.Serializer` class provides the logic for converting :class:`~flash.core.data.process.Postprocess` outputs to the desired predict format (e.g. classes, labels, probabilities, etc.).

By providing a series of hooks that can be overridden with custom data processing logic (or just targeted with transforms),
Flash gives the user much more granular control over their data processing flow.
Expand All @@ -75,15 +78,14 @@ hooks by adding ``train``, ``val``, ``test`` or ``predict``.
Check out :class:`~flash.core.data.process.Preprocess` for some examples.

*************************************
How to customize existing datamodules
How to customize existing DataModules
*************************************

Any Flash :class:`~flash.core.data.data_module.DataModule` can be created directly from datasets using the :meth:`~flash.core.data.data_module.DataModule.from_datasets` like this:

.. code-block:: python
from flash import Trainer
from flash.core.data.data_module import DataModule
from flash import DataModule, Trainer
data_module = DataModule.from_datasets(train_dataset=MyDataset())
trainer = Trainer()
Expand All @@ -95,6 +97,10 @@ In each ``from_*`` method, the :class:`~flash.core.data.data_module.DataModule`
Flash :class:`~flash.core.data.auto_dataset.AutoDataset` instances are created from the :class:`~flash.core.data.data_source.DataSource` for train, val, test, and predict.
The :class:`~flash.core.data.data_module.DataModule` populates the ``DataLoader`` for each stage with the corresponding :class:`~flash.core.data.auto_dataset.AutoDataset`.

**************************************
Customize preprocessing of DataModules
**************************************

The :class:`~flash.core.data.process.Preprocess` contains the processing logic related to a given task.
Each :class:`~flash.core.data.process.Preprocess` provides some default transforms through the :meth:`~flash.core.data.process.Preprocess.default_transforms` method.
Users can easily override these by providing their own transforms to the :class:`~flash.core.data.data_module.DataModule`.
Expand Down Expand Up @@ -139,16 +145,16 @@ Alternatively, the user may directly override the hooks for their needs like thi
)
******************************
Custom Preprocess + Datamodule
******************************
*****************************************
Create your own Preprocess and DataModule
*****************************************

The example below shows a very simple ``ImageClassificationPreprocess`` with a single ``ImageClassificationFoldersDataSource`` and an ``ImageClassificationDataModule``.

1. User-Facing API design
_________________________

Designing an easy to use API is key. This is the first and most important step.
Designing an easy-to-use API is key. This is the first and most important step.
We want the ``ImageClassificationDataModule`` to generate a dataset from folders of images arranged in this way.

Example::
Expand Down Expand Up @@ -194,15 +200,21 @@ Here's the full ``ImageClassificationFoldersDataSource``:
def load_data(self, folder: str, dataset: Any) -> Iterable:
# The dataset is optional but can be useful to save some metadata.
# metadata contains the image path and its corresponding label with the following structure:
# `metadata` contains the image path and its corresponding label
# with the following structure:
# [(image_path_1, label_1), ... (image_path_n, label_n)].
metadata = make_dataset(folder)
# for the train ``AutoDataset``, we want to store the ``num_classes``.
# for the train `AutoDataset`, we want to store the `num_classes`.
if self.training:
dataset.num_classes = len(np.unique([m[1] for m in metadata]))
return [{DefaultDataKeys.INPUT: file, DefaultDataKeys.TARGET: target} for file, target in metadata]
return [
{
DefaultDataKeys.INPUT: file,
DefaultDataKeys.TARGET: target,
} for file, target in metadata
]
def predict_load_data(self, predict_folder: str) -> Iterable:
# This returns [image_path_1, ... image_path_m].
Expand All @@ -226,7 +238,7 @@ Next, implement your custom ``ImageClassificationPreprocess`` with some default
from flash.core.data.process import Preprocess
import torchvision.transforms.functional as T
# Subclass ``Preprocess``
# Subclass `Preprocess`
class ImageClassificationPreprocess(Preprocess):
def __init__(
Expand Down Expand Up @@ -268,11 +280,11 @@ All we need to do is attach our :class:`~flash.core.data.process.Preprocess` cla

.. code-block:: python
from flash.core.data.data_module import DataModule
from flash import DataModule
class ImageClassificationDataModule(DataModule):
# Set ``preprocess_cls`` with your custom ``preprocess``.
# Set `preprocess_cls` with your custom `Preprocess`.
preprocess_cls = ImageClassificationPreprocess
Expand All @@ -283,24 +295,27 @@ How it works behind the scenes
DataSource
__________

.. note:: The ``load_data`` and ``load_sample`` will be used to generate an AutoDataset object.
.. note::
The :meth:`~flash.core.data.data_source.DataSource.load_data` and
:meth:`~flash.core.data.data_source.DataSource.load_sample` will be used to generate an
:class:`~flash.core.data.auto_dataset.AutoDataset` object.

Here is the ``AutoDataset`` pseudo-code.
Here is the :class:`~flash.core.data.auto_dataset.AutoDataset` pseudo-code.

Example::
.. code-block:: python
class AutoDataset
class AutoDataset:
def __init__(
self,
data: List[Any], # The result of a call to DataSource.load_data
data: List[Any], # output of `DataSource.load_data`
data_source: DataSource,
running_stage: RunningStage,
) -> None:
):
self.data = data
self.data_source = data_source
def __getitem__(self, index):
def __getitem__(self, index: int):
return self.data_source.load_sample(self.data[index])
def __len__(self):
Expand All @@ -311,8 +326,12 @@ __________

.. note::

The ``pre_tensor_transform``, ``to_tensor_transform``, ``post_tensor_transform``, ``collate``,
``per_batch_transform`` are injected as the ``collate_fn`` function of the DataLoader.
The :meth:`~flash.core.data.process.Preprocess.pre_tensor_transform`,
:meth:`~flash.core.data.process.Preprocess.to_tensor_transform`,
:meth:`~flash.core.data.process.Preprocess.post_tensor_transform`,
:meth:`~flash.core.data.process.Preprocess.collate`,
:meth:`~flash.core.data.process.Preprocess.per_batch_transform` are injected as the
:paramref:`torch.utils.data.DataLoader.collate_fn` function of the DataLoader.

Here is the pseudo code using the preprocess hooks name.
Flash takes care of calling the right hooks for each stage.
Expand Down Expand Up @@ -385,7 +404,7 @@ Here is the pseudo-code:

Example::

# This will be wrapped into a :class:`~flash.core.data.batch._Preprocessor`
# This will be wrapped into a :class:`~flash.core.data.batch._Postprocessor`
def uncollate_fn(batch: Any) -> Any:

batch = per_batch_transform(batch)
Expand Down
4 changes: 4 additions & 0 deletions flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,10 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool):


class Postprocess(Properties):
"""
The :class:`~flash.core.data.process.Postprocess` encapsulates all the data processing logic that should run after
the model.
"""

def __init__(self, save_path: Optional[str] = None):
super().__init__()
Expand Down

0 comments on commit 74dbee8

Please sign in to comment.