Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[feat] add from datasets (#304)
Browse files Browse the repository at this point in the history
* update

* update

* Update flash/core/data/data_module.py

Co-authored-by: Ethan Harris <[email protected]>

* Update flash/core/data/data_module.py

Co-authored-by: Ethan Harris <[email protected]>

* update

Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
tchaton and ethanwharris authored May 17, 2021
1 parent 8eef52f commit 094fad0
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 5 deletions.
79 changes: 78 additions & 1 deletion flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from flash.core.data.base_viz import BaseVisualization
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_pipeline import DataPipeline, DefaultPreprocess, Postprocess, Preprocess
from flash.core.data.data_source import DataSource, DefaultDataSources
from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataSources
from flash.core.data.splits import SplitDataset
from flash.core.data.utils import _STAGES_PREFIX

Expand Down Expand Up @@ -947,3 +947,80 @@ def from_csv(
num_workers=num_workers,
**preprocess_kwargs,
)

@classmethod
def from_datasets(
cls,
train_dataset: Optional[Dataset] = None,
val_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
predict_dataset: Optional[Dataset] = None,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given datasets using the
:class:`~flash.core.data.data_source.DataSource`
of name :attr:`~flash.core.data.data_source.DefaultDataSources.DATASET`
from the passed or constructed :class:`~flash.core.data.process.Preprocess`.
Args:
train_dataset: Dataset used during training.
val_dataset: Dataset used during validating.
test_dataset: Dataset used during testing.
predict_dataset: Dataset used during predicting.
train_transform: The dictionary of transforms to use during training which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
val_transform: The dictionary of transforms to use during validation which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
test_transform: The dictionary of transforms to use during testing which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
predict_transform: The dictionary of transforms to use during predicting which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
:class:`~flash.core.data.data_module.DataModule`.
preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
:class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
will be constructed and used.
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Returns:
The constructed data module.
Examples::
data_module = DataModule.from_datasets(
train_dataset=train_dataset,
train_transform={
"to_tensor_transform": torch.as_tensor,
},
)
"""
return cls.from_data_source(
DefaultDataSources.DATASET,
train_dataset,
val_dataset,
test_dataset,
predict_dataset,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_fetcher=data_fetcher,
preprocess=preprocess,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
**preprocess_kwargs,
)
19 changes: 19 additions & 0 deletions flash/core/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.enums import LightningEnum
from torch.nn import Module
from torch.utils.data.dataset import Dataset

from flash.core.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset
from flash.core.data.properties import ProcessState, Properties
Expand Down Expand Up @@ -143,6 +144,7 @@ class DefaultDataSources(LightningEnum):
TENSORS = "tensors"
CSV = "csv"
JSON = "json"
DATASET = "dataset"

# TODO: Create a FlashEnum class???
def __hash__(self) -> int:
Expand Down Expand Up @@ -321,6 +323,23 @@ def generate_dataset(
SEQUENCE_DATA_TYPE = TypeVar("SEQUENCE_DATA_TYPE")


class DatasetDataSource(DataSource):

def load_data(self, dataset: Dataset, auto_dataset: AutoDataset) -> Dataset:
if self.training:
# store a sample to infer the shape
parameters = signature(self.load_sample).parameters
if len(parameters) > 1 and AutoDataset.DATASET_KEY in parameters:
auto_dataset.sample = self.load_sample(dataset[0], self)
else:
auto_dataset.sample = self.load_sample(dataset[0])
return dataset

def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any]) -> Any:
# wrap everything within `.INPUT`.
return {DefaultDataKeys.INPUT: sample}


class SequenceDataSource(
Generic[SEQUENCE_DATA_TYPE],
DataSource[Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]]],
Expand Down
5 changes: 4 additions & 1 deletion flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from flash.core.data.batch import default_uncollate
from flash.core.data.callback import FlashCallback
from flash.core.data.data_source import DataSource
from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataSources
from flash.core.data.properties import Properties
from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext

Expand Down Expand Up @@ -217,6 +217,9 @@ def __init__(
self._test_transform = convert_to_modules(self.test_transform)
self._predict_transform = convert_to_modules(self.predict_transform)

if DefaultDataSources.DATASET not in data_sources:
data_sources[DefaultDataSources.DATASET] = DatasetDataSource()

self._data_sources = data_sources
self._default_data_source = default_data_source

Expand Down
9 changes: 8 additions & 1 deletion tests/data/test_auto_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

from flash.core.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset
from flash.core.data.callback import FlashCallback
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipeline
from flash.core.data.data_source import DataSource
from flash.core.data.data_source import DataSource, DefaultDataKeys
from flash.core.data.process import Preprocess


Expand Down Expand Up @@ -190,3 +191,9 @@ def test_preprocessing_data_source_with_running_stage(with_dataset):
else:
assert data_source.train_load_sample_count == len(dataset)
assert data_source.train_load_data_count == 1


def test_dataset_data_source():

dm = DataModule.from_datasets(range(10), range(10))
assert dm.train_dataset.sample == {DefaultDataKeys.INPUT: 0}
4 changes: 2 additions & 2 deletions tests/data/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ def test_available_data_sources():

assert DefaultDataSources.TENSORS in preprocess.available_data_sources()
assert "test" in preprocess.available_data_sources()
assert len(preprocess.available_data_sources()) == 2
assert len(preprocess.available_data_sources()) == 3

data_module = DataModule(preprocess=preprocess)

assert DefaultDataSources.TENSORS in data_module.available_data_sources()
assert "test" in data_module.available_data_sources()
assert len(data_module.available_data_sources()) == 2
assert len(data_module.available_data_sources()) == 3

0 comments on commit 094fad0

Please sign in to comment.