Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[3/N] Data sources - docs #272

Merged
merged 26 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 160 additions & 103 deletions docs/source/custom_task.rst

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ __________
.. autoclass:: flash.data.data_source.DataSource
:members:

.. autoclass:: flash.data.data_source.DefaultDataSources
:members:
:undoc-members:

.. autoclass:: flash.data.data_source.DefaultDataKeys
:members:
:undoc-members:


----------

Expand Down
21 changes: 12 additions & 9 deletions flash/data/auto_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,20 @@


class BaseAutoDataset(Generic[DATA_TYPE]):

DATASET_KEY = "dataset"
"""This class is used to encapsulate a Preprocess Object ``load_data`` and ``load_sample`` functions. ``load_data``
will be called within the ``__init__`` function of the AutoDataset if ``running_stage`` is provided and
``load_sample`` within ``__getitem__``.
"""The ``BaseAutoDataset`` class wraps the output of a call to :meth:`~flash.data.data_source.DataSource.load_data`
and a :class:`~fash.data.data_source.DataSource` and provides the ``_call_load_sample`` method to call
:meth:`~flash.data.data_source.DataSource.load_sample` with the correct
:class:`~flash.data.utils.CurrentRunningStageFuncContext` for the current ``running_stage``. Inheriting classes are
responsible for extracting samples from ``data`` to be given to ``_call_load_sample``.

Args:

data: The output of a call to :meth:`~flash.data.data_source.load_data`.

data: The output of a call to :meth:`~flash.data.data_source.DataSource.load_data`.
data_source: The :class:`~flash.data.data_source.DataSource` which has the ``load_sample`` method.

running_stage: The current running stage.
"""

DATASET_KEY = "dataset"

def __init__(
self,
data: DATA_TYPE,
Expand Down Expand Up @@ -93,6 +92,8 @@ def _call_load_sample(self, sample: Any) -> Any:


class AutoDataset(BaseAutoDataset[Sequence], Dataset):
"""The ``AutoDataset`` is a ``BaseAutoDataset`` and a :class:`~torch.utils.data.Dataset`. The `data` argument
must be a ``Sequence`` (it must have a length)."""

def __getitem__(self, index: int) -> Any:
return self._call_load_sample(self.data[index])
Expand All @@ -102,6 +103,8 @@ def __len__(self) -> int:


class IterableAutoDataset(BaseAutoDataset[Iterable], IterableDataset):
"""The ``IterableAutoDataset`` is a ``BaseAutoDataset`` and a :class:`~torch.utils.data.IterableDataset`. The `data`
argument must be an ``Iterable``."""

def __iter__(self):
self.data_iter = iter(self.data)
Expand Down
341 changes: 331 additions & 10 deletions flash/data/data_module.py

Large diffs are not rendered by default.

163 changes: 127 additions & 36 deletions flash/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,44 +39,111 @@ def has_len(data: Union[Sequence[Any], Iterable[Any]]) -> bool:

@dataclass(unsafe_hash=True, frozen=True)
class LabelsState(ProcessState):
""" A :class:`~flash.data.properties.ProcessState` containing ``labels``, a mapping from class index to label. """

labels: Optional[Sequence[str]]


class DefaultDataSources(LightningEnum):
"""The ``DefaultDataSources`` enum contains the data source names used by all of the default ``from_*`` methods in
:class:`~flash.data.data_module.DataModule`."""

PATHS = "paths"
NUMPY = "numpy"
TENSOR = "tensor"
CSV = "csv"
JSON = "json"

# TODO: Create a FlashEnum class???
def __hash__(self) -> int:
return hash(self.value)


class DefaultDataKeys(LightningEnum):
"""The ``DefaultDataKeys`` enum contains the keys that are used by built-in data sources to refer to inputs and
targets."""

INPUT = "input"
TARGET = "target"

# TODO: Create a FlashEnum class???
def __hash__(self) -> int:
return hash(self.value)


class MockDataset:
"""The ``MockDataset`` catches any metadata that is attached through ``__setattr__``. This is passed to
:meth:`~flash.data.data_source.DataSource.load_data` so that attributes can be set on the generated data set."""

def __init__(self):
self.metadata = {}

def __setattr__(self, key, value):
if key != 'metadata':
self.metadata[key] = value
else:
object.__setattr__(self, key, value)
object.__setattr__(self, key, value)


DATA_TYPE = TypeVar("DATA_TYPE")


class DataSource(Generic[DATA_TYPE], Properties, Module):
"""The ``DataSource`` class encapsulates two hooks: ``load_data`` and ``load_sample``. The
:meth:`~flash.data.data_source.DataSource.to_datasets` method can then be used to automatically construct data sets
from the hooks."""

def load_data(self,
data: DATA_TYPE,
dataset: Optional[Any] = None) -> Union[Sequence[Mapping[str, Any]], Iterable[Mapping[str, Any]]]:
"""Loads entire data from Dataset. The input ``data`` can be anything, but you need to return a Mapping.
"""Given the ``data`` argument, the ``load_data`` hook produces a sequence or iterable of samples or
sample metadata. The ``data`` argument can be anything, but this method should return a sequence or iterable of
mappings from string (e.g. "input", "target", "bbox", etc.) to data (e.g. a target value) or metadata (e.g. a
filename). Where possible, any heavy data loading should be performed in
:meth:`~flash.data.data_source.DataSource.load_sample`. If the output is an iterable rather than a sequence
(that is, it doesn't have length) then the generated dataset will be an ``IterableDataset``.

Args:
data: The data required to load the sequence or iterable of samples or sample metadata.
dataset: Overriding methods can optionally include the dataset argument. Any attributes set on the dataset
(e.g. ``num_classes``) will also be set on the generated dataset.

Returns:
A sequence or iterable of samples or sample metadata to be used as inputs to
:meth:`~flash.data.data_source.DataSource.load_sample`.

Example::

# data: "."
# output: [("./cat/1.png", 1), ..., ("./dog/10.png", 0)]
# output: [{"input": "./cat/1.png", "target": 1}, ..., {"input": "./dog/10.png", "target": 0}]

output: Mapping = load_data(data)
output: Sequence[Mapping[str, Any]] = load_data(data)

"""
return data

def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any:
"""Loads single sample from dataset"""
"""Given an element from the output of a call to :meth:`~flash.data.data_source.DataSource.load_data`, this hook
should load a single data sample. The keys and values in the ``sample`` argument will be same as the keys and
values in the outputs of :meth:`~flash.data.data_source.DataSource.load_data`.

Args:
sample: An element (sample or sample metadata) from the output of a call to
:meth:`~flash.data.data_source.DataSource.load_data`.
dataset: Overriding methods can optionally include the dataset argument. Any attributes set on the dataset
(e.g. ``num_classes``) will also be set on the generated dataset.

Returns:
The loaded sample as a mapping with string keys (e.g. "input", "target") that can be processed by the
:meth:`~flash.data.process.Preprocess.pre_tensor_transform`.

Example::

# sample: {"input": "./cat/1.png", "target": 1}
# output: {"input": PIL.Image, "target": 1}

output: Mapping[str, Any] = load_sample(sample)

"""
return sample

def to_datasets(
Expand All @@ -86,6 +153,25 @@ def to_datasets(
test_data: Optional[DATA_TYPE] = None,
predict_data: Optional[DATA_TYPE] = None,
) -> Tuple[Optional[BaseAutoDataset], ...]:
"""Construct data sets (of type :class:`~flash.data.auto_dataset.BaseAutoDataset`) from this data source by
calling :meth:`~flash.data.data_source.DataSource.load_data` with each of the ``*_data`` arguments. If an
argument is given as ``None`` then no dataset will be created for that stage (``train``, ``val``, ``test``,
``predict``).

Args:
train_data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use to create the train
dataset.
val_data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use to create the validation
dataset.
test_data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use to create the test
dataset.
predict_data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use to create the
predict dataset.

Returns:
A tuple of ``train_dataset``, ``val_dataset``, ``test_dataset``, ``predict_dataset``. If any ``*_data``
argument is not passed to this method then the corresponding ``*_dataset`` will be ``None``.
"""
train_dataset = self.generate_dataset(train_data, RunningStage.TRAINING)
val_dataset = self.generate_dataset(val_data, RunningStage.VALIDATING)
test_dataset = self.generate_dataset(test_data, RunningStage.TESTING)
Expand All @@ -97,6 +183,16 @@ def generate_dataset(
data: Optional[DATA_TYPE],
running_stage: RunningStage,
) -> Optional[Union[AutoDataset, IterableAutoDataset]]:
"""Generate a single dataset with the given input to :meth:`~flash.data.data_source.DataSource.load_data` for
the given ``running_stage``.

Args:
data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use to create the dataset.
running_stage: The running_stage for this dataset.

Returns:
The constructed :class:`~flash.data.auto_dataset.BaseAutoDataset`.
"""
is_none = data is None

if isinstance(data, Sequence):
Expand Down Expand Up @@ -129,36 +225,21 @@ def generate_dataset(
return dataset


class DefaultDataSources(LightningEnum):

PATHS = "paths"
NUMPY = "numpy"
TENSOR = "tensor"
CSV = "csv"
JSON = "json"

# TODO: Create a FlashEnum class???
def __hash__(self) -> int:
return hash(self.value)


class DefaultDataKeys(LightningEnum):

INPUT = "input"
TARGET = "target"

# TODO: Create a FlashEnum class???
def __hash__(self) -> int:
return hash(self.value)


SEQUENCE_DATA_TYPE = TypeVar("SEQUENCE_DATA_TYPE")


class SequenceDataSource(
Generic[SEQUENCE_DATA_TYPE],
DataSource[Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]]],
):
"""The ``SequenceDataSource`` implements default behaviours for data sources which expect the input to
:meth:`~flash.data.data_source.DataSource.load_data` to be a sequence of tuples (``(input, target)`` where target
can be ``None``).

Args:
labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the
:class:`~flash.data.data_source.LabelsState`.
"""

def __init__(self, labels: Optional[Sequence[str]] = None):
super().__init__()
Expand Down Expand Up @@ -186,17 +267,25 @@ def predict_load_data(self, data: Sequence[SEQUENCE_DATA_TYPE]) -> Sequence[Mapp
return [{DefaultDataKeys.INPUT: input} for input in data]


class PathsDataSource(SequenceDataSource): # TODO: Sort out the typing here
class PathsDataSource(SequenceDataSource):
"""The ``PathsDataSource`` implements default behaviours for data sources which expect the input to
:meth:`~flash.data.data_source.DataSource.load_data` to be either a directory with a subdirectory for each class or
a tuple containing list of files and corresponding list of targets.

def __init__(self, extensions: Optional[Tuple[str, ...]] = None):
super().__init__()
Args:
extensions: The file extensions supported by this data source (e.g. ``(".jpg", ".png")``).
labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the
:class:`~flash.data.data_source.LabelsState`.
"""

def __init__(self, extensions: Optional[Tuple[str, ...]] = None, labels: Optional[Sequence[str]] = None):
super().__init__(labels=labels)

self.extensions = extensions

@staticmethod
def find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]:
"""
Finds the class folders in a dataset. Ensures that no class is a subdirectory of another.
"""Finds the class folders in a dataset. Ensures that no class is a subdirectory of another.

Args:
dir: Root directory path.
Expand Down Expand Up @@ -257,8 +346,10 @@ def predict_load_data(self,


class TensorDataSource(SequenceDataSource[torch.Tensor]):
"""""" # TODO: Some docstring here
"""The ``TensorDataSource`` is a ``SequenceDataSource`` which expects the input to
:meth:`~flash.data.data_source.DataSource.load_data` to be a sequence of ``torch.Tensor`` objects."""


class NumpyDataSource(SequenceDataSource[np.ndarray]):
"""""" # TODO: Some docstring here
"""The ``NumpyDataSource`` is a ``SequenceDataSource`` which expects the input to
:meth:`~flash.data.data_source.DataSource.load_data` to be a sequence of ``np.ndarray`` objects."""
Loading