diff --git a/docs/source/custom_task.rst b/docs/source/custom_task.rst index 93250e98a5..81c641d343 100644 --- a/docs/source/custom_task.rst +++ b/docs/source/custom_task.rst @@ -12,44 +12,49 @@ which is stored as numpy arrays. .. note:: Find the complete tutorial example at - `flash_examples/custom_task.py `_. + `flash_examples/custom_task.py `_. 1. Imports ---------- -.. testcode:: python +.. code:: python - from typing import Any, List, Tuple + from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import torch from pytorch_lightning import seed_everything from sklearn import datasets - from sklearn.model_selection import train_test_split - from torch import nn + from torch import nn, Tensor import flash - from flash.data.auto_dataset import AutoDataset - from flash.data.process import Postprocess, Preprocess + from flash.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources + from flash.data.process import Preprocess + from flash.data.transforms import ApplyToKeys # set the random seeds. seed_everything(42) + ND = np.ndarray + 2. The Task: Linear regression ------------------------------- -Here we create a basic linear regression task by subclassing -:class:`~flash.core.model.Task`. For the majority of tasks, you will likely only need to -override the ``__init__`` and ``forward`` methods. +Here we create a basic linear regression task by subclassing :class:`~flash.core.model.Task`. For the majority of tasks, +you will likely need to override the ``__init__``, ``forward``, and the ``{train,val,test,predict}_step`` methods. The +``__init__`` should be overridden to configure the model and any additional arguments to be passed to the base +:class:`~flash.core.model.Task`. ``forward`` may need to be overridden to apply the model forward pass to the inputs. +It's best practice in flash for the data to be provide as a dictionary which maps string keys to their values. The +``{train,val,test,predict}_step`` methods need to be overridden to extract the data from the input dictionary. -.. testcode:: +Example:: class RegressionTask(flash.Task): - def __init__(self, num_inputs, learning_rate=0.001, metrics=None): + def __init__(self, num_inputs, learning_rate=0.2, metrics=None): # what kind of model do we want? model = nn.Linear(num_inputs, 1) @@ -57,7 +62,7 @@ override the ``__init__`` and ``forward`` methods. loss_fn = torch.nn.functional.mse_loss # what optimizer to do we want? - optimizer = torch.optim.SGD + optimizer = torch.optim.Adam super().__init__( model=model, @@ -67,6 +72,31 @@ override the ``__init__`` and ``forward`` methods. learning_rate=learning_rate, ) + def training_step(self, batch: Any, batch_idx: int) -> Any: + return super().training_step( + (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]), + batch_idx, + ) + + def validation_step(self, batch: Any, batch_idx: int) -> None: + return super().validation_step( + (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]), + batch_idx, + ) + + def test_step(self, batch: Any, batch_idx: int) -> None: + return super().test_step( + (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]), + batch_idx, + ) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return super().predict_step( + batch[DefaultDataKeys.INPUT], + batch_idx, + dataloader_idx, + ) + def forward(self, x): # we don't actually need to override this method for this example return self.model(x) @@ -77,19 +107,16 @@ override the ``__init__`` and ``forward`` methods. Registries are Flash internal key-value database to store a mapping between a name and a function. In simple words, they are just advanced dictionary storing a function from a key string. They are useful to store list of backbones and make them available for a :class:`~flash.core.model.Task`. - Check out to learn more :ref:`registry`. + Check out :ref:`registry` to learn more. Where is the training step? ~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Most models can be trained simply by passing the output of ``forward`` -to the supplied ``loss_fn``, and then passing the resulting loss to the -supplied ``optimizer``. If you need a more custom configuration, you can -override ``step`` (which is called for training, validation, and -testing) or override ``training_step``, ``validation_step``, and -``test_step`` individually. These methods behave identically to PyTorch -Lightning’s +Most models can be trained simply by passing the output of ``forward`` to the supplied ``loss_fn``, and then passing the +resulting loss to the supplied ``optimizer``. If you need a more custom configuration, you can override ``step`` (which +is called for training, validation, and testing) or override ``training_step``, ``validation_step``, and ``test_step`` +individually. These methods behave identically to PyTorch Lightning’s `methods `__. Here is the pseudo code behind :class:`~flash.core.model.Task` step. @@ -107,75 +134,35 @@ Example:: return output -3.a The DataModule API +3.a The DataSource API ---------------------- -Now that we have defined our ``RegressionTask``, we need to load our data. -We will define a custom ``NumpyDataModule`` class subclassing :class:`~flash.data.data_module.DataModule`. -This ``NumpyDataModule`` class will provide a ``from_xy_dataset`` helper ``classmethod`` to instantiate -:class:`~flash.data.data_module.DataModule` from x, y numpy arrays. - -Here is how it would look: +Now that we have defined our ``RegressionTask``, we need to load our data. We will define a custom ``NumpyDataSource`` +which extends :class:`~flash.data.data_source.DataSource`. The ``NumpyDataSource`` contains a ``load_data`` and +``predict_load_data`` methods which handle the loading of a sequence of dictionaries from the input numpy arrays. When +loading the train data (``if self.training:``), the ``NumpyDataSource`` sets the ``num_inputs`` attribute of the +optional ``dataset`` argument. Any attributes that are set on the optional ``dataset`` argument will also be set on the +generated ``dataset``. Example:: - x, y = ... - preprocess = ... - datamodule = NumpyDataModule.from_xy_dataset(x, y, preprocess) - -Here is the ``NumpyDataModule`` implementation: - -Example:: - - from flash import DataModule - from flash.data.process import Preprocess - import numpy as np - - ND = np.ndarray - - class NumpyDataModule(DataModule): - - @classmethod - def from_xy_dataset( - cls, - x: ND, - y: ND, - preprocess: Preprocess = None, - batch_size: int = 64, - num_workers: int = 0 - ): - - preprocess = preprocess or NumpyPreprocess() - - x_train, x_test, y_train, y_test = train_test_split( - x, y, test_size=.20, random_state=0) - - # Make sure to call ``from_load_data_inputs``. - # The ``train_load_data_input`` value will be given to ``Preprocess`` - # ``train_load_data`` function. - dm = cls.from_load_data_inputs( - train_load_data_input=(x_train, y_train), - test_load_data_input=(x_test, y_test), - preprocess=preprocess, # DON'T FORGET TO PROVIDE THE PREPROCESS - batch_size=batch_size, - num_workers=num_workers - ) - # Some metatada can be accessed from ``train_ds`` directly. - dm.num_inputs = dm.train_dataset.num_inputs - return dm + class NumpyDataSource(DataSource[Tuple[ND, ND]]): + def load_data(self, data: Tuple[ND, ND], dataset: Optional[Any] = None) -> List[Dict[str, Any]]: + if self.training: + dataset.num_inputs = data[0].shape[1] + return [{DefaultDataKeys.INPUT: x, DefaultDataKeys.TARGET: y} for x, y in zip(*data)] -.. note:: + def predict_load_data(self, data: ND) -> List[Dict[str, Any]]: + return [{DefaultDataKeys.INPUT: x} for x in data] - The :class:`~flash.data.data_module.DataModule` provides a ``from_load_data_inputs`` helper function. This function will take care - of connecting the provided :class:`~flash.data.process.Preprocess` with the :class:`~flash.data.data_module.DataModule`. - Make sure to instantiate your :class:`~flash.data.data_module.DataModule` with this helper if you rely on :class:`~flash.data.process.Preprocess` - objects. 3.b The Preprocess API ---------------------- -A :class:`~flash.data.process.Preprocess` object provides a series of hooks that can be overridden with custom data processing logic. +Now that we have a :class:`~flash.data.data_source.DataSource` implementation, we can define our +:class:`~flash.data.process.Preprocess`. The :class:`~flash.data.process.Preprocess` object provides a series of hooks +that can be overridden with custom data processing logic and to which transforms can be attached. It allows the user much more granular control over their data processing flow. .. note:: @@ -183,37 +170,107 @@ It allows the user much more granular control over their data processing flow. Why introduce :class:`~flash.data.process.Preprocess` ? The :class:`~flash.data.process.Preprocess` object reduces the engineering overhead to make inference on raw data or - to deploy the model in production environnement compared to traditional + to deploy the model in production environnement compared to a traditional `Dataset `_. - You can override ``predict_{hook_name}`` hooks to handle data processing logic specific for inference. + You can override ``predict_{hook_name}`` hooks or the ``default_predict_transforms`` to handle data processing logic + specific for inference. -Example:: +The recommended way to define a custom :class:`~flash.data.process.Preprocess` is as follows: - import torch - from torch import Tensor - import numpy as np +- Define an ``__init__`` which accepts transform arguments. +- Pass these arguments through to ``super().__init__`` and specify the ``data_sources`` and the ``default_data_source``. + - ``data_sources`` gives the :class:`~flash.data.data_source.DataSource` objects that work with your :class:`~flash.data.process.Preprocess` as a mapping from data source name to :class:`~flash.data.data_source.DataSource`. The data source name can be any string, but for our purposes we can use ``NUMPY`` from :class:`~flash.data.data_source.DefaultDataSources`. + - ``default_data_source`` is the name of the data source to use by default when predicting. +- Override the ``get_state_dict`` and ``load_state_dict`` methods. These methods are used to save and load your :class:`~flash.data.process.Preprocess` from a checkpoint. +- Override the ``default_{train,val,test,predict}_transforms`` methods to specify the default transforms to use in each stage (these will be used if the transforms passed in the ``__init__`` are ``None``). + - Transforms are given as a mapping from hook name to callable transforms. You should use :class:`~flash.data.transforms.ApplyToKeys` to apply each transform only to specific keys in the data dictionary. - ND = np.ndarray +Example:: class NumpyPreprocess(Preprocess): - def load_data(self, data: Tuple[ND, ND], dataset: AutoDataset) -> List[Tuple[ND, float]]: - if self.training: - dataset.num_inputs = data[0].shape[1] - return [(x, y) for x, y in zip(*data)] + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + ): + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={DefaultDataSources.NUMPY: NumpyDataSource()}, + default_data_source=DefaultDataSources.NUMPY, + ) + + @staticmethod + def to_float(x: Tensor): + return x.float() + + @staticmethod + def format_targets(x: Tensor): + return x.unsqueeze(0) + + @property + def to_tensor(self) -> Dict[str, Callable]: + return { + "to_tensor_transform": nn.Sequential( + ApplyToKeys( + DefaultDataKeys.INPUT, + torch.from_numpy, + self.to_float, + ), + ApplyToKeys( + DefaultDataKeys.TARGET, + torch.as_tensor, + self.to_float, + self.format_targets, + ), + ), + } + + @property + def default_train_transforms(self) -> Optional[Dict[str, Callable]]: + return self.to_tensor + + @property + def default_val_transforms(self) -> Optional[Dict[str, Callable]]: + return self.to_tensor + + @property + def default_test_transforms(self) -> Optional[Dict[str, Callable]]: + return self.to_tensor + + @property + def default_predict_transforms(self) -> Optional[Dict[str, Callable]]: + return self.to_tensor + + def get_state_dict(self) -> Dict[str, Any]: + return self.transforms + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(*state_dict) + + +3.c The DataModule API +---------------------- + +Now that we have a :class:`~flash.data.process.Preprocess` which knows about the +:class:`~flash.data.data_source.DataSource` objects it supports, we just need to create a +:class:`~flash.data.data_module.DataModule` which has a reference to the ``preprocess_cls`` we want it to use. For any +data source whose name is in :class:`~flash.data.data_source.DefaultDataSources`, there is a standard +``DataModule.from_*`` method that provides the expected inputs. So in this case, there is the +:meth:`~flash.data.data_module.DataModule.from_numpy` that will use our numpy data source. - def to_tensor_transform(self, sample: Any) -> Tuple[Tensor, Tensor]: - x, y = sample - x = torch.from_numpy(x).float() - y = torch.tensor(y, dtype=torch.float) - return x, y +Example:: - def predict_load_data(self, data: ND) -> ND: - return data + class NumpyDataModule(flash.DataModule): - def predict_to_tensor_transform(self, sample: ND) -> ND: - return torch.from_numpy(sample).float() + preprocess_cls = NumpyPreprocess You now have a new customized Flash Task! Congratulations ! @@ -232,10 +289,10 @@ supplying the task itself, and the associated data: .. code:: python x, y = datasets.load_diabetes(return_X_y=True) - datamodule = NumpyDataModule.from_xy_dataset(x, y) - model = RegressionTask(num_inputs=datamodule.num_inputs) + datamodule = NumpyDataModule.from_numpy(x, y) + model = RegressionTask(num_inputs=datamodule.train_dataset.num_inputs) - trainer = flash.Trainer(max_epochs=1000) + trainer = flash.Trainer(max_epochs=20, progress_bar_refresh_rate=20) trainer.fit(model, datamodule=datamodule) @@ -248,8 +305,8 @@ few examples from the test set of our data: .. code:: python predict_data = torch.tensor([ - [ 0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403], - [-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072], + [ 0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403], + [-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072], [ 0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072], [-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384], [-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094]] @@ -257,4 +314,4 @@ few examples from the test set of our data: predictions = model.predict(predict_data) print(predictions) - #out: [tensor([14.7190]), tensor([14.7100]), tensor([14.7288]), tensor([14.6685]), tensor([14.6687])] + # out: [tensor([188.9760]), tensor([196.1777]), tensor([161.3590]), tensor([130.7312]), tensor([149.0340])] diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst index fec10ec9d6..a3d46a794a 100644 --- a/docs/source/general/data.rst +++ b/docs/source/general/data.rst @@ -185,31 +185,6 @@ Example:: # Set ``preprocess_cls`` with your custom ``preprocess``. preprocess_cls = ImageClassificationPreprocess - @classmethod - def from_folders( - cls, - train_folder: Optional[str], - val_folder: Optional[str], - test_folder: Optional[str], - predict_folder: Optional[str], - preprocess: Optional[Preprocess] = None, - **kwargs - ): - - # Set a custom ``Preprocess`` if none was provided - preprocess = preprocess or cls.preprocess_cls() - - # {stage}_load_data_input will be given to your - # ``Preprocess`` ``{stage}_load_data`` function. - return cls.from_load_data_inputs( - train_load_data_input=train_folder, - val_load_data_input=val_folder, - test_load_data_input=test_folder, - predict_load_data_input=predict_folder, - preprocess=preprocess, # DON'T FORGET TO PASS THE CREATED PREPROCESS - **kwargs, - ) - 3. The Preprocess __________________ @@ -218,9 +193,12 @@ Finally, implement your custom ``ImageClassificationPreprocess``. Example:: + from typing import Any, Callable, Dict, Optional, Tuple, Union import os import numpy as np + from flash.data.data_source import DefaultDataSources from flash.data.process import Preprocess + from flash.vision.data import ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource from PIL import Image import torchvision.transforms as T from torch import Tensor @@ -231,29 +209,32 @@ Example:: to_tensor = T.ToTensor() - def load_data(self, folder: str, dataset: AutoDataset) -> Iterable: - # The AutoDataset is optional but can be useful to save some metadata. - - # metadata contains the image path and its corresponding label with the following structure: - # [(image_path_1, label_1), ... (image_path_n, label_n)]. - metadata = make_dataset(folder) - - # for the train ``AutoDataset``, we want to store the ``num_classes``. - if self.training: - dataset.num_classes = len(np.unique([m[1] for m in metadata])) - - return metadata + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + ): + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.PATHS: ImagePathsDataSource(), + DefaultDataSources.NUMPY: ImageNumpyDataSource(), + DefaultDataSources.TENSOR: ImageTensorDataSource(), + }, + default_data_source=DefaultDataSources.PATHS, + ) - def predict_load_data(self, predict_folder: str) -> Iterable: - # This returns [image_path_1, ... image_path_m]. - return os.listdir(folder) + def get_state_dict(self) -> Dict[str, Any]: + return {**self.transforms} - def load_sample(self, sample: Union[str, Tuple[str, int]]) -> Tuple[Image, int] - if self.predicting: - return Image.open(image_path) - else: - image_path, label = sample - return Image.open(image_path), label + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) def to_tensor_transform( self, @@ -285,6 +266,14 @@ __________ .. autoclass:: flash.data.data_source.DataSource :members: +.. autoclass:: flash.data.data_source.DefaultDataSources + :members: + :undoc-members: + +.. autoclass:: flash.data.data_source.DefaultDataKeys + :members: + :undoc-members: + ---------- diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 191385900d..a20f8a8c39 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -27,21 +27,20 @@ class BaseAutoDataset(Generic[DATA_TYPE]): - - DATASET_KEY = "dataset" - """This class is used to encapsulate a Preprocess Object ``load_data`` and ``load_sample`` functions. ``load_data`` - will be called within the ``__init__`` function of the AutoDataset if ``running_stage`` is provided and - ``load_sample`` within ``__getitem__``. + """The ``BaseAutoDataset`` class wraps the output of a call to :meth:`~flash.data.data_source.DataSource.load_data` + and a :class:`~fash.data.data_source.DataSource` and provides the ``_call_load_sample`` method to call + :meth:`~flash.data.data_source.DataSource.load_sample` with the correct + :class:`~flash.data.utils.CurrentRunningStageFuncContext` for the current ``running_stage``. Inheriting classes are + responsible for extracting samples from ``data`` to be given to ``_call_load_sample``. Args: - - data: The output of a call to :meth:`~flash.data.data_source.load_data`. - + data: The output of a call to :meth:`~flash.data.data_source.DataSource.load_data`. data_source: The :class:`~flash.data.data_source.DataSource` which has the ``load_sample`` method. - running_stage: The current running stage. """ + DATASET_KEY = "dataset" + def __init__( self, data: DATA_TYPE, @@ -93,6 +92,8 @@ def _call_load_sample(self, sample: Any) -> Any: class AutoDataset(BaseAutoDataset[Sequence], Dataset): + """The ``AutoDataset`` is a ``BaseAutoDataset`` and a :class:`~torch.utils.data.Dataset`. The `data` argument + must be a ``Sequence`` (it must have a length).""" def __getitem__(self, index: int) -> Any: return self._call_load_sample(self.data[index]) @@ -102,6 +103,8 @@ def __len__(self) -> int: class IterableAutoDataset(BaseAutoDataset[Iterable], IterableDataset): + """The ``IterableAutoDataset`` is a ``BaseAutoDataset`` and a :class:`~torch.utils.data.IterableDataset`. The `data` + argument must be an ``Iterable``.""" def __iter__(self): self.data_iter = iter(self.data) diff --git a/flash/data/callback.py b/flash/data/callback.py index 1221046a31..c303ecab4b 100644 --- a/flash/data/callback.py +++ b/flash/data/callback.py @@ -81,7 +81,16 @@ class BaseDataFetcher(FlashCallback): from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule + from flash.data.data_source import DataSource + from flash.data.process import Preprocess + class CustomPreprocess(Preprocess): + + def __init__(**kwargs): + super().__init__( + data_sources = {"inputs": DataSource()}, + **kwargs, + ) class PrintData(BaseDataFetcher): @@ -90,6 +99,8 @@ def print(self): class CustomDataModule(DataModule): + preprocess_cls = CustomPreprocess + @staticmethod def configure_data_fetcher(): return PrintData() @@ -100,17 +111,16 @@ def from_inputs( train_data: Any, val_data: Any, test_data: Any, - predict_data: Any) -> "CustomDataModule": - - preprocess = CustomPreprocess() - - return cls.from_load_data_inputs( - train_load_data_input=train_data, - val_load_data_input=val_data, - test_load_data_input=test_data, - predict_load_data_input=predict_data, - preprocess=preprocess, - batch_size=5) + predict_data: Any, + ) -> "CustomDataModule": + return cls.from_data_source( + "inputs", + train_data=train_data, + val_data=val_data, + test_data=test_data, + predict_data=predict_data, + batch_size=5, + ) dm = CustomDataModule.from_inputs(range(5), range(5), range(5), range(5)) data_fetcher = dm.data_fetcher diff --git a/flash/data/data_module.py b/flash/data/data_module.py index e36af6fa9b..e5ac91e585 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -33,18 +33,31 @@ class DataModule(pl.LightningDataModule): - """Basic DataModule class for all Flash tasks + """A basic DataModule class for all Flash tasks. This class includes references to a + :class:`~flash.data.data_source.DataSource`, :class:`~flash.data.process.Preprocess`, + :class:`~flash.data.process.Postprocess`, and a :class:`~flash.data.callback.BaseDataFetcher`. Args: train_dataset: Dataset for training. Defaults to None. val_dataset: Dataset for validating model performance during training. Defaults to None. test_dataset: Dataset to test model performance. Defaults to None. predict_dataset: Dataset for predicting. Defaults to None. - num_workers: The number of workers to use for parallelized loading. Defaults to None. + data_source: The :class:`~flash.data.data_source.DataSource` that was used to create the datasets. + preprocess: The :class:`~flash.data.process.Preprocess` to use when constructing the + :class:`~flash.data.data_pipeline.DataPipeline`. If ``None``, a + :class:`~flash.data.process.DefaultPreprocess` will be used. + postprocess: The :class:`~flash.data.process.Postprocess` to use when constructing the + :class:`~flash.data.data_pipeline.DataPipeline`. If ``None``, a plain + :class:`~flash.data.process.Postprocess` will be used. + data_fetcher: The :class:`~flash.data.callback.BaseDataFetcher` to attach to the + :class:`~flash.data.process.Preprocess`. If ``None``, the output from + :meth:`~flash.data.data_module.DataModule.configure_data_fetcher` will be used. + val_split: An optional float which gives the relative amount of the training dataset to use for the validation + dataset. batch_size: The batch size to be used by the DataLoader. Defaults to 1. num_workers: The number of workers to use for parallelized loading. Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. + or 0 for Windows or Darwin platform. """ preprocess_cls = DefaultPreprocess @@ -346,13 +359,61 @@ def from_data_source( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: BaseDataFetcher = None, + data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs: Any, ) -> 'DataModule': + """Creates a :class:`~flash.data.data_module.DataModule` object from the given inputs to + :meth:`~flash.data.data_source.DataSource.load_data` (``train_data``, ``val_data``, ``test_data``, + ``predict_data``). The data source will be resolved from the instantiated + :class:`~flash.data.process.Preprocess` using :meth:`~flash.data.process.Preprocess.data_source_of_name`. + + Args: + data_source: The name of the data source to use for the + :meth:`~flash.data.data_source.DataSource.load_data`. + train_data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use when creating the train + dataset. + val_data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use when creating the + validation dataset. + test_data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use when creating the test + dataset. + predict_data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use when creating the + predict dataset. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.data.data_module.DataModule`. + preprocess: The :class:`~flash.data.data.Preprocess` to pass to the + :class:`~flash.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be constructed + and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = DataModule.from_data_source( + DefaultDataSources.PATHS, + train_data="train_folder", + train_transform={ + "to_tensor_transform": torch.as_tensor, + }, + ) + """ preprocess = preprocess or cls.preprocess_cls( train_transform, val_transform, @@ -394,13 +455,53 @@ def from_folders( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: BaseDataFetcher = None, + data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs: Any, ) -> 'DataModule': + """Creates a :class:`~flash.data.data_module.DataModule` object from the given folders using the + :class:`~flash.data.data_source.DataSource` of name :attr:`~flash.data.data_source.DefaultDataSources.PATHS` + from the passed or constructed :class:`~flash.data.process.Preprocess`. + + Args: + train_folder: The folder containing the train data. + val_folder: The folder containing the validation data. + test_folder: The folder containing the test data. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.data.data_module.DataModule`. + preprocess: The :class:`~flash.data.data.Preprocess` to pass to the + :class:`~flash.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be constructed + and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = DataModule.from_folders( + train_folder="train_folder", + train_transform={ + "to_tensor_transform": torch.as_tensor, + }, + ) + """ return cls.from_data_source( DefaultDataSources.PATHS, train_folder, @@ -433,13 +534,57 @@ def from_files( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: BaseDataFetcher = None, + data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs: Any, ) -> 'DataModule': + """Creates a :class:`~flash.data.data_module.DataModule` object from the given sequences of files using the + :class:`~flash.data.data_source.DataSource` of name :attr:`~flash.data.data_source.DefaultDataSources.PATHS` + from the passed or constructed :class:`~flash.data.process.Preprocess`. + + Args: + train_files: A sequence of files to use as the train inputs. + train_targets: A sequence of targets (one per train file) to use as the train targets. + val_files: A sequence of files to use as the validation inputs. + val_targets: A sequence of targets (one per validation file) to use as the validation targets. + test_files: A sequence of files to use as the test inputs. + test_targets: A sequence of targets (one per test file) to use as the test targets. + predict_files: A sequence of files to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.data.data_module.DataModule`. + preprocess: The :class:`~flash.data.data.Preprocess` to pass to the + :class:`~flash.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be constructed + and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = DataModule.from_files( + train_files=["image_1.png", "image_2.png", "image_3.png"], + train_targets=[1, 0, 1], + train_transform={ + "to_tensor_transform": torch.as_tensor, + }, + ) + """ return cls.from_data_source( DefaultDataSources.PATHS, (train_files, train_targets), @@ -472,13 +617,57 @@ def from_tensors( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: BaseDataFetcher = None, + data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs: Any, ) -> 'DataModule': + """Creates a :class:`~flash.data.data_module.DataModule` object from the given tensors using the + :class:`~flash.data.data_source.DataSource` of name :attr:`~flash.data.data_source.DefaultDataSources.TENSOR` + from the passed or constructed :class:`~flash.data.process.Preprocess`. + + Args: + train_data: A tensor or collection of tensors to use as the train inputs. + train_targets: A sequence of targets (one per train input) to use as the train targets. + val_data: A tensor or collection of tensors to use as the validation inputs. + val_targets: A sequence of targets (one per validation input) to use as the validation targets. + test_data: A tensor or collection of tensors to use as the test inputs. + test_targets: A sequence of targets (one per test input) to use as the test targets. + predict_data: A tensor or collection of tensors to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.data.data_module.DataModule`. + preprocess: The :class:`~flash.data.data.Preprocess` to pass to the + :class:`~flash.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be constructed + and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = DataModule.from_tensors( + train_files=torch.rand(3, 128), + train_targets=[1, 0, 1], + train_transform={ + "to_tensor_transform": torch.as_tensor, + }, + ) + """ return cls.from_data_source( DefaultDataSources.TENSOR, (train_data, train_targets), @@ -511,13 +700,57 @@ def from_numpy( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: BaseDataFetcher = None, + data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs: Any, ) -> 'DataModule': + """Creates a :class:`~flash.data.data_module.DataModule` object from the given numpy array using the + :class:`~flash.data.data_source.DataSource` of name :attr:`~flash.data.data_source.DefaultDataSources.NUMPY` + from the passed or constructed :class:`~flash.data.process.Preprocess`. + + Args: + train_data: A numpy array to use as the train inputs. + train_targets: A sequence of targets (one per train input) to use as the train targets. + val_data: A numpy array to use as the validation inputs. + val_targets: A sequence of targets (one per validation input) to use as the validation targets. + test_data: A numpy array to use as the test inputs. + test_targets: A sequence of targets (one per test input) to use as the test targets. + predict_data: A numpy array to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.data.data_module.DataModule`. + preprocess: The :class:`~flash.data.data.Preprocess` to pass to the + :class:`~flash.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be constructed + and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = DataModule.from_numpy( + train_files=np.random.rand(3, 128), + train_targets=[1, 0, 1], + train_transform={ + "to_tensor_transform": torch.as_tensor, + }, + ) + """ return cls.from_data_source( DefaultDataSources.NUMPY, (train_data, train_targets), @@ -549,13 +782,57 @@ def from_json( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: BaseDataFetcher = None, + data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs: Any, ) -> 'DataModule': + """Creates a :class:`~flash.data.data_module.DataModule` object from the given JSON files using the + :class:`~flash.data.data_source.DataSource` of name :attr:`~flash.data.data_source.DefaultDataSources.JSON` + from the passed or constructed :class:`~flash.data.process.Preprocess`. + + Args: + input_fields: The field or fields in the JSON objects to use for the input. + target_fields: The field or fields in the JSON objects to use for the target. + train_file: The JSON file containing the training data. + val_file: The JSON file containing the validation data. + test_file: The JSON file containing the testing data. + predict_file: The JSON file containing the data to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.data.data_module.DataModule`. + preprocess: The :class:`~flash.data.data.Preprocess` to pass to the + :class:`~flash.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be constructed + and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = DataModule.from_json( + "input", + "target", + train_file="train_data.json", + train_transform={ + "to_tensor_transform": torch.as_tensor, + }, + ) + """ return cls.from_data_source( DefaultDataSources.JSON, (train_file, input_fields, target_fields), @@ -587,13 +864,57 @@ def from_csv( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: BaseDataFetcher = None, + data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs: Any, ) -> 'DataModule': + """Creates a :class:`~flash.data.data_module.DataModule` object from the given CSV files using the + :class:`~flash.data.data_source.DataSource` of name :attr:`~flash.data.data_source.DefaultDataSources.CSV` + from the passed or constructed :class:`~flash.data.process.Preprocess`. + + Args: + input_fields: The field or fields (columns) in the CSV file to use for the input. + target_fields: The field or fields (columns) in the CSV file to use for the target. + train_file: The CSV file containing the training data. + val_file: The CSV file containing the validation data. + test_file: The CSV file containing the testing data. + predict_file: The CSV file containing the data to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.data.data_module.DataModule`. + preprocess: The :class:`~flash.data.data.Preprocess` to pass to the + :class:`~flash.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be constructed + and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = DataModule.from_csv( + "input", + "target", + train_file="train_data.csv", + train_transform={ + "to_tensor_transform": torch.as_tensor, + }, + ) + """ return cls.from_data_source( DefaultDataSources.CSV, (train_file, input_fields, target_fields), diff --git a/flash/data/data_source.py b/flash/data/data_source.py index e637eab923..4238dbb514 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -39,11 +39,41 @@ def has_len(data: Union[Sequence[Any], Iterable[Any]]) -> bool: @dataclass(unsafe_hash=True, frozen=True) class LabelsState(ProcessState): + """ A :class:`~flash.data.properties.ProcessState` containing ``labels``, a mapping from class index to label. """ labels: Optional[Sequence[str]] +class DefaultDataSources(LightningEnum): + """The ``DefaultDataSources`` enum contains the data source names used by all of the default ``from_*`` methods in + :class:`~flash.data.data_module.DataModule`.""" + + PATHS = "paths" + NUMPY = "numpy" + TENSOR = "tensor" + CSV = "csv" + JSON = "json" + + # TODO: Create a FlashEnum class??? + def __hash__(self) -> int: + return hash(self.value) + + +class DefaultDataKeys(LightningEnum): + """The ``DefaultDataKeys`` enum contains the keys that are used by built-in data sources to refer to inputs and + targets.""" + + INPUT = "input" + TARGET = "target" + + # TODO: Create a FlashEnum class??? + def __hash__(self) -> int: + return hash(self.value) + + class MockDataset: + """The ``MockDataset`` catches any metadata that is attached through ``__setattr__``. This is passed to + :meth:`~flash.data.data_source.DataSource.load_data` so that attributes can be set on the generated data set.""" def __init__(self): self.metadata = {} @@ -51,32 +81,69 @@ def __init__(self): def __setattr__(self, key, value): if key != 'metadata': self.metadata[key] = value - else: - object.__setattr__(self, key, value) + object.__setattr__(self, key, value) DATA_TYPE = TypeVar("DATA_TYPE") class DataSource(Generic[DATA_TYPE], Properties, Module): + """The ``DataSource`` class encapsulates two hooks: ``load_data`` and ``load_sample``. The + :meth:`~flash.data.data_source.DataSource.to_datasets` method can then be used to automatically construct data sets + from the hooks.""" def load_data(self, data: DATA_TYPE, dataset: Optional[Any] = None) -> Union[Sequence[Mapping[str, Any]], Iterable[Mapping[str, Any]]]: - """Loads entire data from Dataset. The input ``data`` can be anything, but you need to return a Mapping. + """Given the ``data`` argument, the ``load_data`` hook produces a sequence or iterable of samples or + sample metadata. The ``data`` argument can be anything, but this method should return a sequence or iterable of + mappings from string (e.g. "input", "target", "bbox", etc.) to data (e.g. a target value) or metadata (e.g. a + filename). Where possible, any heavy data loading should be performed in + :meth:`~flash.data.data_source.DataSource.load_sample`. If the output is an iterable rather than a sequence + (that is, it doesn't have length) then the generated dataset will be an ``IterableDataset``. + + Args: + data: The data required to load the sequence or iterable of samples or sample metadata. + dataset: Overriding methods can optionally include the dataset argument. Any attributes set on the dataset + (e.g. ``num_classes``) will also be set on the generated dataset. + + Returns: + A sequence or iterable of samples or sample metadata to be used as inputs to + :meth:`~flash.data.data_source.DataSource.load_sample`. Example:: # data: "." - # output: [("./cat/1.png", 1), ..., ("./dog/10.png", 0)] + # output: [{"input": "./cat/1.png", "target": 1}, ..., {"input": "./dog/10.png", "target": 0}] - output: Mapping = load_data(data) + output: Sequence[Mapping[str, Any]] = load_data(data) """ return data def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: - """Loads single sample from dataset""" + """Given an element from the output of a call to :meth:`~flash.data.data_source.DataSource.load_data`, this hook + should load a single data sample. The keys and values in the ``sample`` argument will be same as the keys and + values in the outputs of :meth:`~flash.data.data_source.DataSource.load_data`. + + Args: + sample: An element (sample or sample metadata) from the output of a call to + :meth:`~flash.data.data_source.DataSource.load_data`. + dataset: Overriding methods can optionally include the dataset argument. Any attributes set on the dataset + (e.g. ``num_classes``) will also be set on the generated dataset. + + Returns: + The loaded sample as a mapping with string keys (e.g. "input", "target") that can be processed by the + :meth:`~flash.data.process.Preprocess.pre_tensor_transform`. + + Example:: + + # sample: {"input": "./cat/1.png", "target": 1} + # output: {"input": PIL.Image, "target": 1} + + output: Mapping[str, Any] = load_sample(sample) + + """ return sample def to_datasets( @@ -86,6 +153,25 @@ def to_datasets( test_data: Optional[DATA_TYPE] = None, predict_data: Optional[DATA_TYPE] = None, ) -> Tuple[Optional[BaseAutoDataset], ...]: + """Construct data sets (of type :class:`~flash.data.auto_dataset.BaseAutoDataset`) from this data source by + calling :meth:`~flash.data.data_source.DataSource.load_data` with each of the ``*_data`` arguments. If an + argument is given as ``None`` then no dataset will be created for that stage (``train``, ``val``, ``test``, + ``predict``). + + Args: + train_data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use to create the train + dataset. + val_data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use to create the validation + dataset. + test_data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use to create the test + dataset. + predict_data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use to create the + predict dataset. + + Returns: + A tuple of ``train_dataset``, ``val_dataset``, ``test_dataset``, ``predict_dataset``. If any ``*_data`` + argument is not passed to this method then the corresponding ``*_dataset`` will be ``None``. + """ train_dataset = self.generate_dataset(train_data, RunningStage.TRAINING) val_dataset = self.generate_dataset(val_data, RunningStage.VALIDATING) test_dataset = self.generate_dataset(test_data, RunningStage.TESTING) @@ -97,6 +183,16 @@ def generate_dataset( data: Optional[DATA_TYPE], running_stage: RunningStage, ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: + """Generate a single dataset with the given input to :meth:`~flash.data.data_source.DataSource.load_data` for + the given ``running_stage``. + + Args: + data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use to create the dataset. + running_stage: The running_stage for this dataset. + + Returns: + The constructed :class:`~flash.data.auto_dataset.BaseAutoDataset`. + """ is_none = data is None if isinstance(data, Sequence): @@ -129,29 +225,6 @@ def generate_dataset( return dataset -class DefaultDataSources(LightningEnum): - - PATHS = "paths" - NUMPY = "numpy" - TENSOR = "tensor" - CSV = "csv" - JSON = "json" - - # TODO: Create a FlashEnum class??? - def __hash__(self) -> int: - return hash(self.value) - - -class DefaultDataKeys(LightningEnum): - - INPUT = "input" - TARGET = "target" - - # TODO: Create a FlashEnum class??? - def __hash__(self) -> int: - return hash(self.value) - - SEQUENCE_DATA_TYPE = TypeVar("SEQUENCE_DATA_TYPE") @@ -159,6 +232,14 @@ class SequenceDataSource( Generic[SEQUENCE_DATA_TYPE], DataSource[Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]]], ): + """The ``SequenceDataSource`` implements default behaviours for data sources which expect the input to + :meth:`~flash.data.data_source.DataSource.load_data` to be a sequence of tuples (``(input, target)`` where target + can be ``None``). + + Args: + labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the + :class:`~flash.data.data_source.LabelsState`. + """ def __init__(self, labels: Optional[Sequence[str]] = None): super().__init__() @@ -186,17 +267,25 @@ def predict_load_data(self, data: Sequence[SEQUENCE_DATA_TYPE]) -> Sequence[Mapp return [{DefaultDataKeys.INPUT: input} for input in data] -class PathsDataSource(SequenceDataSource): # TODO: Sort out the typing here +class PathsDataSource(SequenceDataSource): + """The ``PathsDataSource`` implements default behaviours for data sources which expect the input to + :meth:`~flash.data.data_source.DataSource.load_data` to be either a directory with a subdirectory for each class or + a tuple containing list of files and corresponding list of targets. - def __init__(self, extensions: Optional[Tuple[str, ...]] = None): - super().__init__() + Args: + extensions: The file extensions supported by this data source (e.g. ``(".jpg", ".png")``). + labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the + :class:`~flash.data.data_source.LabelsState`. + """ + + def __init__(self, extensions: Optional[Tuple[str, ...]] = None, labels: Optional[Sequence[str]] = None): + super().__init__(labels=labels) self.extensions = extensions @staticmethod def find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: - """ - Finds the class folders in a dataset. Ensures that no class is a subdirectory of another. + """Finds the class folders in a dataset. Ensures that no class is a subdirectory of another. Args: dir: Root directory path. @@ -257,8 +346,10 @@ def predict_load_data(self, class TensorDataSource(SequenceDataSource[torch.Tensor]): - """""" # TODO: Some docstring here + """The ``TensorDataSource`` is a ``SequenceDataSource`` which expects the input to + :meth:`~flash.data.data_source.DataSource.load_data` to be a sequence of ``torch.Tensor`` objects.""" class NumpyDataSource(SequenceDataSource[np.ndarray]): - """""" # TODO: Some docstring here + """The ``NumpyDataSource`` is a ``SequenceDataSource`` which expects the input to + :meth:`~flash.data.data_source.DataSource.load_data` to be a sequence of ``np.ndarray`` objects.""" diff --git a/flash/data/process.py b/flash/data/process.py index 050847dfa0..4b3ed3df53 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -49,41 +49,13 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): pass -DATA_SOURCE_TYPE = TypeVar("DATA_SOURCE_TYPE") - - class Preprocess(BasePreprocess, Properties, Module): - """ - The :class:`~flash.data.process.Preprocess` encapsulates - all the data processing and loading logic that should run before the data is passed to the model. - - It is particularly relevant when you want to provide an end to end implementation which works - with 4 different stages: ``train``, ``validation``, ``test``, and inference (``predict``). - - You can override any of the preprocessing hooks to provide custom functionality. - All hooks default to no-op (except the collate which is PyTorch default - `collate `_) + """The :class:`~flash.data.process.Preprocess` encapsulates all the data processing logic that should run before + the data is passed to the model. It is particularly useful when you want to provide an end to end implementation + which works with 4 different stages: ``train``, ``validation``, ``test``, and inference (``predict``). The :class:`~flash.data.process.Preprocess` supports the following hooks: - - ``load_data``: Function to receiving some metadata to generate a Mapping from. - Example:: - - * Input: Receive a folder path: - - * Action: Walk the folder path to find image paths and their associated labels. - - * Output: Return a list of image paths and their associated labels. - - - ``load_sample``: Function to load a sample from metadata sample. - Example:: - - * Input: Receive an image path and its label. - - * Action: Load a PIL Image from received image_path. - - * Output: Return the PIL Image and its label. - - ``pre_tensor_transform``: Performs transforms on a single data sample. Example:: @@ -142,65 +114,61 @@ class Preprocess(BasePreprocess, Properties, Module): * Output: Return a normalized augmented batch of images and their labels. - .. note:: - - By default, each hook will be no-op execpt the collate which is PyTorch default - `collate `_. - To customize them, just override the hooks and ``Flash`` will take care of calling them at the right moment. - .. note:: The ``per_sample_transform_on_device`` and ``per_batch_transform`` are mutually exclusive as it will impact performances. - To change the processing behavior only on specific stages, - you can prefix all the above hooks adding ``train``, ``val``, ``test`` or ``predict``. + Data processing can be configured by overriding hooks or through transforms. The preprocess transforms are given as + a mapping from hook names to callables. Default transforms can be configured by overriding the + `default_{train,val,test,predict}_transforms` methods. These can then be overridden by the user with the + `{train,val,test,predict}_transform` arguments to the ``Preprocess``. All of the hooks can be used in the transform + mappings, with the exception of ``collate``. + + Example:: + + class CustomPreprocess(Preprocess): - For example, is useful to encapsulate ``predict`` logic as labels aren't availabled at inference time. + def default_train_transforms() -> Mapping[str, Callable]: + return { + "pre_tensor_transform": transforms.RandomHorizontalFlip(), + "to_tensor_transform": transforms.ToTensor(), + } + + When overriding hooks for particular stages, you can prefix with ``train``, ``val``, ``test`` or ``predict``. For + example, you can achieve the same as the above example by implementing ```train_pre_tensor_transform`` and + ``train_to_tensor_transform``. Example:: class CustomPreprocess(Preprocess): - def predict_load_data(cls, data: Any, dataset: Optional[Any] = None) -> Mapping: - # logic for predict data only. + def train_pre_tensor_transform(self, sample: PIL.Image) -> PIL.Image: + return transforms.RandomHorizontalFlip()(sample) - Each hook is aware of the Trainer ``running stage`` through booleans as follow. + def train_to_tensor_transform(self, sample: PIL.Image) -> torch.Tensor: + return transforms.ToTensor()(sample) - This is useful to adapt a hook internals for a stage without duplicating code. + Each hook is aware of the Trainer ``running stage`` through booleans. These are useful for adapting functionality + for a stage without duplicating code. Example:: class CustomPreprocess(Preprocess): - def load_data(cls, data: Any, dataset: Optional[Any] = None) -> Mapping: + def pre_tensor_transform(self, sample: PIL.Image) -> PIL.Image: if self.training: - # logic for train + # logic for training elif self.validating: - # logic from validation + # logic for validation elif self.testing: - # logic for test + # logic for testing elif self.predicting: - # logic for predict - - .. note:: - - It is possible to wrap a ``Dataset`` within a :meth:`~flash.data.process.Preprocess.load_data` function. - However, we don't recommend to do as such as it is better to rely entirely on the hooks. - - Example:: - - from torchvision import datasets - - class CustomPreprocess(Preprocess): - - def load_data(cls, path_to_data: str) -> Iterable: - - return datasets.MNIST(path_to_data, download=True, transform=transforms.ToTensor()) + # logic for predicting """ @@ -210,7 +178,7 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_sources: Optional[Dict[str, 'DataSource']] = None, + data_sources: Optional[Dict[str, DataSource]] = None, default_data_source: Optional[str] = None, ): super().__init__() @@ -416,7 +384,7 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: """ return self.current_transform(batch) - def data_source_of_name(self, data_source_name: str) -> Optional[DATA_SOURCE_TYPE]: + def data_source_of_name(self, data_source_name: str) -> Optional[DataSource]: if data_source_name == "default": data_source_name = self._default_data_source data_sources = self._data_sources diff --git a/flash/data/transforms.py b/flash/data/transforms.py index 67b1229ad4..689ef2d0c9 100644 --- a/flash/data/transforms.py +++ b/flash/data/transforms.py @@ -19,6 +19,14 @@ class ApplyToKeys(nn.Sequential): + """The ``ApplyToKeys`` class is an ``nn.Sequential`` which applies the given transforms to the given keys from the + input. When a single key is given, a single value will be passed to the transforms. When multiple keys are given, + the corresponding values will be passed to the transforms as a list. + + Args: + keys: The key (``str``) or sequence of keys (``Sequence[str]``) to extract and forward to the transforms. + args: The transforms, passed to the ``nn.Sequential`` super constructor. + """ def __init__(self, keys: Union[str, Sequence[str]], *args): super().__init__(*[convert_to_modules(arg) for arg in args]) @@ -47,7 +55,11 @@ def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: class KorniaParallelTransforms(nn.Sequential): """The ``KorniaParallelTransforms`` class is an ``nn.Sequential`` which will apply the given transforms to each input (to ``.forward``) in parallel, whilst sharing the random state (``._params``). This should be used when - multiple elements need to be augmented in the same way (e.g. an image and corresponding segmentation mask).""" + multiple elements need to be augmented in the same way (e.g. an image and corresponding segmentation mask). + + Args: + args: The transforms, passed to the ``nn.Sequential`` super constructor. + """ def __init__(self, *args): super().__init__(*[convert_to_modules(arg) for arg in args]) diff --git a/flash/tabular/classification/data/data.py b/flash/tabular/classification/data/data.py index 1b4ad6b9bd..915575f3f5 100644 --- a/flash/tabular/classification/data/data.py +++ b/flash/tabular/classification/data/data.py @@ -19,6 +19,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from flash.core.classification import LabelsState +from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources from flash.data.process import Preprocess @@ -203,7 +204,7 @@ def emb_sizes(self) -> list: return list(zip(num_classes, emb_dims)) @staticmethod - def _sanetize_cols(cat_cols: Optional[List], num_cols: Optional[List]): + def _sanetize_cols(cat_cols: Optional[Union[str, List[str]]], num_cols: Optional[Union[str, List[str]]]): if cat_cols is None and num_cols is None: raise RuntimeError('Both `cat_cols` and `num_cols` are None!') @@ -252,95 +253,182 @@ def compute_state( @classmethod def from_data_frame( cls, - categorical_cols: List, - numerical_cols: List, - target_col: str, - train_data_frame: DataFrame, + categorical_fields: Optional[Union[str, List[str]]], + numerical_fields: Optional[Union[str, List[str]]], + target_fields: Optional[str] = None, + train_data_frame: Optional[DataFrame] = None, val_data_frame: Optional[DataFrame] = None, test_data_frame: Optional[DataFrame] = None, predict_data_frame: Optional[DataFrame] = None, - is_regression: bool = False, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, - val_split: float = None, - batch_size: int = 8, + val_split: Optional[float] = None, + batch_size: int = 4, num_workers: Optional[int] = None, + is_regression: bool = False, + **preprocess_kwargs: Any, ): - """Creates a TabularData object from pandas DataFrames. + """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames. Args: - train_df: Train data DataFrame. - target_col: The column containing the class id. - categorical_cols: The list of categorical columns. - numerical_cols: The list of numerical columns. - val_df: Validation data DataFrame. - test_df: Test data DataFrame. - batch_size: The batchsize to use for parallel loading. Defaults to 64. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - val_split: Float between 0 and 1 to create a validation dataset from train dataset. - preprocess: Preprocess to be used within this DataModule DataPipeline. + categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs. + numerical_fields: The field or fields (columns) in the CSV file containing numerical inputs. + target_fields: The field or fields (columns) in the CSV file to use for the target. + train_data_frame: The pandas ``DataFrame`` containing the training data. + val_data_frame: The pandas ``DataFrame`` containing the validation data. + test_data_frame: The pandas ``DataFrame`` containing the testing data. + predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.data.data_module.DataModule`. + preprocess: The :class:`~flash.data.data.Preprocess` to pass to the + :class:`~flash.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be constructed + and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + is_regression: If ``True``, targets will be formatted as floating point. If ``False``, targets will be + formatted as integers. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. Returns: - TabularData: The constructed data module. + The constructed data module. Examples:: - text_data = TextClassificationData.from_files("train.csv", label_field="class", text_field="sentence") + data_module = TabularData.from_data_frame( + "categorical_input", + "numerical_input", + "target", + train_data_frame=train_data, + ) """ - categorical_cols, numerical_cols = cls._sanetize_cols(categorical_cols, numerical_cols) + categorical_fields, numerical_fields = cls._sanetize_cols(categorical_fields, numerical_fields) + + if not isinstance(categorical_fields, list): + categorical_fields = [categorical_fields] + + if not isinstance(numerical_fields, list): + numerical_fields = [numerical_fields] mean, std, classes, codes, target_codes = cls.compute_state( train_data_frame, val_data_frame, test_data_frame, predict_data_frame, - target_col, - numerical_cols, - categorical_cols, + target_fields, + numerical_fields, + categorical_fields, ) return cls.from_data_source( - data_source="data_frame", - train_data=train_data_frame, - val_data=val_data_frame, - test_data=test_data_frame, - predict_data=predict_data_frame, + "data_frame", + train_data_frame, + val_data_frame, + test_data_frame, + predict_data_frame, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, preprocess=preprocess, val_split=val_split, batch_size=batch_size, num_workers=num_workers, - cat_cols=categorical_cols, - num_cols=numerical_cols, - target_col=target_col, + cat_cols=categorical_fields, + num_cols=numerical_fields, + target_col=target_fields, mean=mean, std=std, codes=codes, target_codes=target_codes, classes=classes, is_regression=is_regression, + **preprocess_kwargs, ) @classmethod def from_csv( cls, - categorical_fields: Union[str, List[str]], - numerical_fields: Union[str, List[str]], - target_field: Optional[str] = None, + categorical_fields: Optional[Union[str, List[str]]], + numerical_fields: Optional[Union[str, List[str]]], + target_fields: Optional[str] = None, train_file: Optional[str] = None, val_file: Optional[str] = None, test_file: Optional[str] = None, predict_file: Optional[str] = None, - is_regression: bool = False, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, + is_regression: bool = False, + **preprocess_kwargs: Any, ) -> 'DataModule': + """Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files. + + Args: + categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs. + numerical_fields: The field or fields (columns) in the CSV file containing numerical inputs. + target_fields: The field or fields (columns) in the CSV file to use for the target. + train_file: The CSV file containing the training data. + val_file: The CSV file containing the validation data. + test_file: The CSV file containing the testing data. + predict_file: The CSV file containing the data to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.data.data_module.DataModule`. + preprocess: The :class:`~flash.data.data.Preprocess` to pass to the + :class:`~flash.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be constructed + and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + is_regression: If ``True``, targets will be formatted as floating point. If ``False``, targets will be + formatted as integers. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = TabularData.from_csv( + "categorical_input", + "numerical_input", + "target", + train_file="train_data.csv", + ) + """ return cls.from_data_frame( categorical_fields, numerical_fields, - target_field, + target_fields, train_data_frame=pd.read_csv(train_file) if train_file is not None else None, val_data_frame=pd.read_csv(val_file) if val_file is not None else None, test_data_frame=pd.read_csv(test_file) if test_file is not None else None, diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 79d0fca863..5b7987dbff 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -18,7 +18,6 @@ from PIL import Image from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch.utils.data._utils.collate import default_collate from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py index 528a74a99d..3fec5536de 100644 --- a/flash/vision/detection/data.py +++ b/flash/vision/detection/data.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple -from torch.nn import Module from torchvision.datasets.folder import default_loader +from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources from flash.data.process import Preprocess @@ -146,28 +146,69 @@ def from_coco( cls, train_folder: Optional[str] = None, train_ann_file: Optional[str] = None, - train_transform: Optional[Dict[str, Callable]] = None, val_folder: Optional[str] = None, val_ann_file: Optional[str] = None, - val_transform: Optional[Dict[str, Callable]] = None, test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - preprocess: Preprocess = None, - val_split: Optional[float] = None, + **preprocess_kwargs: Any, ): + """Creates a :class:`~flash.vision.detection.data.ObjectDetectionData` object from the given data + folders and corresponding target folders. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.data.data_module.DataModule`. + preprocess: The :class:`~flash.data.data.Preprocess` to pass to the + :class:`~flash.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be constructed + and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = SemanticSegmentationData.from_coco( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ return cls.from_data_source( - data_source="coco", - train_data=(train_folder, train_ann_file) if train_folder else None, - val_data=(val_folder, val_ann_file) if val_folder else None, - test_data=(test_folder, test_ann_file) if test_folder else None, + "coco", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, + data_fetcher=data_fetcher, preprocess=preprocess, val_split=val_split, batch_size=batch_size, num_workers=num_workers, + **preprocess_kwargs, ) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index d674205786..b479b4cfa3 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -196,13 +196,56 @@ def from_folders( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: BaseDataFetcher = None, + data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs: Any, ) -> 'DataModule': + """Creates a :class:`~flash.vision.segmentation.data.SemanticSegmentationData` object from the given data + folders and corresponding target folders. + + Args: + train_folder: The folder containing the train data. + train_target_folder: The folder containing the train targets (targets must have the same file name as their + corresponding inputs). + val_folder: The folder containing the validation data. + val_target_folder: The folder containing the validation targets (targets must have the same file name as + their corresponding inputs). + test_folder: The folder containing the test data. + test_target_folder: The folder containing the test targets (targets must have the same file name as their + corresponding inputs). + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.data.data_module.DataModule`. + preprocess: The :class:`~flash.data.data.Preprocess` to pass to the + :class:`~flash.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be constructed + and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = SemanticSegmentationData.from_folders( + train_folder="train_folder", + train_target_folder="train_masks", + ) + """ return cls.from_data_source( DefaultDataSources.PATHS, (train_folder, train_target_folder), diff --git a/flash_examples/custom_task.py b/flash_examples/custom_task.py index 8fc9c3de88..0522ab009b 100644 --- a/flash_examples/custom_task.py +++ b/flash_examples/custom_task.py @@ -1,16 +1,15 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import torch from pytorch_lightning import seed_everything from sklearn import datasets -from sklearn.model_selection import train_test_split from torch import nn, Tensor import flash -from flash.data.auto_dataset import AutoDataset -from flash.data.data_source import DataSource +from flash.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources from flash.data.process import Preprocess +from flash.data.transforms import ApplyToKeys seed_everything(42) @@ -19,7 +18,7 @@ class RegressionTask(flash.Task): - def __init__(self, num_inputs, learning_rate=0.001, metrics=None): + def __init__(self, num_inputs, learning_rate=0.2, metrics=None): # what kind of model do we want? model = nn.Linear(num_inputs, 1) @@ -27,7 +26,7 @@ def __init__(self, num_inputs, learning_rate=0.001, metrics=None): loss_fn = torch.nn.functional.mse_loss # what optimizer to do we want? - optimizer = torch.optim.SGD + optimizer = torch.optim.Adam super().__init__( model=model, @@ -37,80 +36,135 @@ def __init__(self, num_inputs, learning_rate=0.001, metrics=None): learning_rate=learning_rate, ) + def training_step(self, batch: Any, batch_idx: int) -> Any: + return super().training_step( + (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]), + batch_idx, + ) + + def validation_step(self, batch: Any, batch_idx: int) -> None: + return super().validation_step( + (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]), + batch_idx, + ) + + def test_step(self, batch: Any, batch_idx: int) -> None: + return super().test_step( + (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]), + batch_idx, + ) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return super().predict_step( + batch[DefaultDataKeys.INPUT], + batch_idx, + dataloader_idx, + ) + def forward(self, x): # we don't actually need to override this method for this example return self.model(x) -class NumpyDataSource(DataSource): +class NumpyDataSource(DataSource[Tuple[ND, ND]]): - def load_data(self, data: Tuple[ND, ND], dataset: AutoDataset) -> List[Tuple[ND, float]]: + def load_data(self, data: Tuple[ND, ND], dataset: Optional[Any] = None) -> List[Dict[str, Any]]: if self.training: dataset.num_inputs = data[0].shape[1] - return [(x, y) for x, y in zip(*data)] + return [{DefaultDataKeys.INPUT: x, DefaultDataKeys.TARGET: y} for x, y in zip(*data)] - def predict_load_data(self, data: ND) -> ND: - return data + def predict_load_data(self, data: ND) -> List[Dict[str, Any]]: + return [{DefaultDataKeys.INPUT: x} for x in data] class NumpyPreprocess(Preprocess): - def __init__(self): - super().__init__(data_sources={"numpy": NumpyDataSource()}, default_data_source="numpy") - - def to_tensor_transform(self, sample: Any) -> Tuple[Tensor, Tensor]: - x, y = sample - x = torch.from_numpy(x).float() - y = torch.tensor(y, dtype=torch.float) - return x, y + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + ): + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={DefaultDataSources.NUMPY: NumpyDataSource()}, + default_data_source=DefaultDataSources.NUMPY, + ) - def predict_to_tensor_transform(self, sample: ND) -> ND: - return torch.from_numpy(sample).float() + @staticmethod + def to_float(x: Tensor): + return x.float() + + @staticmethod + def format_targets(x: Tensor): + return x.unsqueeze(0) + + @property + def to_tensor(self) -> Dict[str, Callable]: + return { + "to_tensor_transform": nn.Sequential( + ApplyToKeys( + DefaultDataKeys.INPUT, + torch.from_numpy, + self.to_float, + ), + ApplyToKeys( + DefaultDataKeys.TARGET, + torch.as_tensor, + self.to_float, + self.format_targets, + ), + ), + } + + @property + def default_train_transforms(self) -> Optional[Dict[str, Callable]]: + return self.to_tensor + + @property + def default_val_transforms(self) -> Optional[Dict[str, Callable]]: + return self.to_tensor + + @property + def default_test_transforms(self) -> Optional[Dict[str, Callable]]: + return self.to_tensor + + @property + def default_predict_transforms(self) -> Optional[Dict[str, Callable]]: + return self.to_tensor def get_state_dict(self) -> Dict[str, Any]: - return {} + return self.transforms @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): - return cls() + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(*state_dict) class NumpyDataModule(flash.DataModule): - @classmethod - def from_dataset(cls, x: ND, y: ND, preprocess: Preprocess, batch_size: int = 64, num_workers: int = 0): - - preprocess = preprocess - - x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0) - - dm = cls.from_data_source( - "numpy", - train_data=(x_train, y_train), - test_data=(x_test, y_test), - preprocess=preprocess, - batch_size=batch_size, - num_workers=num_workers, - ) - dm.num_inputs = dm.train_dataset.num_inputs - return dm + preprocess_cls = NumpyPreprocess x, y = datasets.load_diabetes(return_X_y=True) -datamodule = NumpyDataModule.from_dataset(x, y, NumpyPreprocess()) -model = RegressionTask(num_inputs=datamodule.num_inputs) +datamodule = NumpyDataModule.from_numpy(x, y) +model = RegressionTask(num_inputs=datamodule.train_dataset.num_inputs) -trainer = flash.Trainer(max_epochs=10, progress_bar_refresh_rate=20) +trainer = flash.Trainer(max_epochs=20, progress_bar_refresh_rate=20) trainer.fit(model, datamodule=datamodule) -predict_data = np.array([[0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403], - [-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072], - [0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072], - [-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384], - [-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094]]) +predict_data = np.array([ + [0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403], + [-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072], + [0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072], + [-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384], + [-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094], +]) predictions = model.predict(predict_data) -# out: This prediction: tensor([14.7288]) is above the threshold: 14.72 - print(predictions) -# out: [tensor([14.7190]), tensor([14.7100]), tensor([14.7288]), tensor([14.6685]), tensor([14.6687])] +# out: [tensor([188.9760]), tensor([196.1777]), tensor([161.3590]), tensor([130.7312]), tensor([149.0340])] diff --git a/flash_examples/finetuning/tabular_classification.py b/flash_examples/finetuning/tabular_classification.py index ad8a949455..5c60c6f29e 100644 --- a/flash_examples/finetuning/tabular_classification.py +++ b/flash_examples/finetuning/tabular_classification.py @@ -23,8 +23,8 @@ # 2. Load the data datamodule = TabularData.from_csv( ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], - ["Fare"], - target_field="Survived", + "Fare", + target_fields="Survived", train_file="./data/titanic/titanic.csv", test_file="./data/titanic/test.csv", val_split=0.25, diff --git a/flash_examples/generic_task.py b/flash_examples/generic_task.py index e82d6d0ee2..6d4acdc03e 100644 --- a/flash_examples/generic_task.py +++ b/flash_examples/generic_task.py @@ -32,7 +32,6 @@ nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 10), - nn.Softmax(), ) # 3. Load a dataset diff --git a/flash_examples/predict/image_classification.py b/flash_examples/predict/image_classification.py index fe697b2963..df201e3157 100644 --- a/flash_examples/predict/image_classification.py +++ b/flash_examples/predict/image_classification.py @@ -30,10 +30,7 @@ print(predictions) # 3b. Or generate predictions with a whole folder! -datamodule = ImageClassificationData.from_folders( - predict_folder="data/hymenoptera_data/predict/", - preprocess=model.preprocess, -) +datamodule = ImageClassificationData.from_folders(predict_folder="data/hymenoptera_data/predict/") predictions = Trainer().predict(model, datamodule=datamodule) print(predictions) diff --git a/flash_notebooks/tabular_classification.ipynb b/flash_notebooks/tabular_classification.ipynb index 748dacd4c7..30dd77c437 100644 --- a/flash_notebooks/tabular_classification.ipynb +++ b/flash_notebooks/tabular_classification.ipynb @@ -97,7 +97,7 @@ "datamodule = TabularData.from_csv(\n", " [\"Sex\", \"Age\", \"SibSp\", \"Parch\", \"Ticket\", \"Cabin\", \"Embarked\"],\n", " [\"Fare\"],\n", - " target_field=\"Survived\",\n", + " target_fields=\"Survived\",\n", " train_file=\"./data/titanic/titanic.csv\",\n", " test_file=\"./data/titanic/test.csv\",\n", " val_split=0.25,\n", diff --git a/tests/tabular/data/test_data.py b/tests/tabular/data/test_data.py index 1a0d1e1574..b6524ab84d 100644 --- a/tests/tabular/data/test_data.py +++ b/tests/tabular/data/test_data.py @@ -87,9 +87,9 @@ def test_tabular_data(tmpdir): val_data_frame = TEST_DF_2.copy() test_data_frame = TEST_DF_2.copy() dm = TabularData.from_data_frame( - categorical_cols=["category"], - numerical_cols=["scalar_b", "scalar_b"], - target_col="label", + categorical_fields=["category"], + numerical_fields=["scalar_b", "scalar_b"], + target_fields="label", train_data_frame=train_data_frame, val_data_frame=val_data_frame, test_data_frame=test_data_frame, @@ -114,9 +114,9 @@ def test_categorical_target(tmpdir): df["label"] = df["label"].astype(str) dm = TabularData.from_data_frame( - categorical_cols=["category"], - numerical_cols=["scalar_b", "scalar_b"], - target_col="label", + categorical_fields=["category"], + numerical_fields=["scalar_b", "scalar_b"], + target_fields="label", train_data_frame=train_data_frame, val_data_frame=val_data_frame, test_data_frame=test_data_frame, @@ -137,9 +137,9 @@ def test_from_data_frame(tmpdir): val_data_frame = TEST_DF_2.copy() test_data_frame = TEST_DF_2.copy() dm = TabularData.from_data_frame( - categorical_cols=["category"], - numerical_cols=["scalar_b", "scalar_b"], - target_col="label", + categorical_fields=["category"], + numerical_fields=["scalar_b", "scalar_b"], + target_fields="label", train_data_frame=train_data_frame, val_data_frame=val_data_frame, test_data_frame=test_data_frame, @@ -165,7 +165,7 @@ def test_from_csv(tmpdir): dm = TabularData.from_csv( categorical_fields=["category"], numerical_fields=["scalar_b", "scalar_b"], - target_field="label", + target_fields="label", train_file=str(train_csv), val_file=str(val_csv), test_file=str(test_csv), @@ -185,9 +185,9 @@ def test_empty_inputs(): train_data_frame = TEST_DF_1.copy() with pytest.raises(RuntimeError): TabularData.from_data_frame( - numerical_cols=None, - categorical_cols=None, - target_col="label", + numerical_fields=None, + categorical_fields=None, + target_fields="label", train_data_frame=train_data_frame, num_workers=0, batch_size=1, diff --git a/tests/tabular/test_data_model_integration.py b/tests/tabular/test_data_model_integration.py index 6dcec9b6a8..a15082e7f8 100644 --- a/tests/tabular/test_data_model_integration.py +++ b/tests/tabular/test_data_model_integration.py @@ -32,9 +32,9 @@ def test_classification(tmpdir): val_data_frame = TEST_DF_1.copy() test_data_frame = TEST_DF_1.copy() data = TabularData.from_data_frame( - categorical_cols=["category"], - numerical_cols=["scalar_a", "scalar_b"], - target_col="label", + categorical_fields=["category"], + numerical_fields=["scalar_a", "scalar_b"], + target_fields="label", train_data_frame=train_data_frame, val_data_frame=val_data_frame, test_data_frame=test_data_frame,