diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst index 1eb89405b7..fec10ec9d6 100644 --- a/docs/source/general/data.rst +++ b/docs/source/general/data.rst @@ -21,6 +21,8 @@ Here are common terms you need to be familiar with: - The :class:`~flash.data.data_module.DataModule` contains the dataset, transforms and dataloaders. * - :class:`~flash.data.data_pipeline.DataPipeline` - The :class:`~flash.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects. + * - :class:`~flash.data.data_source.DataSource` + - The :class:`~flash.data.data_source.DataSource` provides a hook-based API for creating data sets. * - :class:`~flash.data.process.Preprocess` - The :class:`~flash.data.process.Preprocess` provides a simple hook-based API to encapsulate your pre-processing logic. The :class:`~flash.data.process.Preprocess` provides multiple hooks such as :meth:`~flash.data.process.Preprocess.load_data` @@ -275,6 +277,17 @@ Example:: API reference ************* +.. _data_source: + +DataSource +__________ + +.. autoclass:: flash.data.data_source.DataSource + :members: + + +---------- + .. _preprocess: Preprocess @@ -325,7 +338,6 @@ __________ .. autoclass:: flash.data.data_module.DataModule :members: - from_load_data_inputs, train_dataset, val_dataset, test_dataset, diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index ac12aea2cf..54f841ad5f 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -183,8 +183,4 @@ ImageClassificationData .. autoclass:: flash.vision.ImageClassificationData -.. automethod:: flash.vision.ImageClassificationData.from_filepaths - -.. automethod:: flash.vision.ImageClassificationData.from_folders - .. autoclass:: flash.vision.ImageClassificationPreprocess diff --git a/docs/source/reference/tabular_classification.rst b/docs/source/reference/tabular_classification.rst index e54356c751..9812bab90b 100644 --- a/docs/source/reference/tabular_classification.rst +++ b/docs/source/reference/tabular_classification.rst @@ -165,4 +165,4 @@ TabularData .. automethod:: flash.tabular.TabularData.from_csv -.. automethod:: flash.tabular.TabularData.from_df +.. automethod:: flash.tabular.TabularData.from_data_frame diff --git a/docs/source/reference/video_classification.rst b/docs/source/reference/video_classification.rst index e088a556ea..6b7d3c08d1 100644 --- a/docs/source/reference/video_classification.rst +++ b/docs/source/reference/video_classification.rst @@ -152,5 +152,3 @@ VideoClassificationData ----------------------- .. autoclass:: flash.video.VideoClassificationData - -.. automethod:: flash.video.VideoClassificationData.from_paths diff --git a/flash/core/classification.py b/flash/core/classification.py index 5fb983983c..b85a529b3a 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Any, Callable, List, Mapping, Optional, Sequence, Union import torch @@ -20,7 +19,8 @@ from pytorch_lightning.utilities import rank_zero_warn from flash.core.model import Task -from flash.data.process import ProcessState, Serializer +from flash.data.data_source import LabelsState +from flash.data.process import Serializer def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -28,12 +28,6 @@ def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch. return F.binary_cross_entropy_with_logits(x, y.float()) -@dataclass(unsafe_hash=True, frozen=True) -class ClassificationState(ProcessState): - - labels: Optional[List[str]] - - class ClassificationTask(Task): def __init__( @@ -130,7 +124,7 @@ class Labels(Classes): Args: labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not - provided, will attempt to get them from the :class:`.ClassificationState`. + provided, will attempt to get them from the :class:`.LabelsState`. multi_label: If true, treats outputs as multi label logits. @@ -141,13 +135,16 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False super().__init__(multi_label=multi_label, threshold=threshold) self._labels = labels + if labels is not None: + self.set_state(LabelsState(labels)) + def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: labels = None if self._labels is not None: labels = self._labels else: - state = self.get_state(ClassificationState) + state = self.get_state(LabelsState) if state is not None: labels = state.labels @@ -158,7 +155,5 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: return [labels[cls] for cls in classes] return labels[classes] else: - rank_zero_warn( - "No ClassificationState was found, this serializer will act as a Classes serializer.", UserWarning - ) + rank_zero_warn("No LabelsState was found, this serializer will act as a Classes serializer.", UserWarning) return classes diff --git a/flash/core/model.py b/flash/core/model.py index 39aa32095e..6c453ae0bf 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import inspect from importlib import import_module from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union -import inspect + import torch import torchmetrics from pytorch_lightning import LightningModule from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.optim.lr_scheduler import _LRScheduler @@ -29,7 +29,8 @@ from flash.core.registry import FlashRegistry from flash.core.schedulers import _SCHEDULERS_REGISTRY from flash.core.utils import get_callable_dict -from flash.data.data_pipeline import DataPipeline +from flash.data.data_pipeline import DataPipeline, DataPipelineState +from flash.data.data_source import DataSource, DefaultDataSources from flash.data.process import Postprocess, Preprocess, Serializer, SerializerMapping @@ -103,6 +104,9 @@ def __init__( self._postprocess: Optional[Postprocess] = postprocess self._serializer: Optional[Serializer] = None + # TODO: create enum values to define what are the exact states + self._data_pipeline_state: Optional[DataPipelineState] = None + # Explicitly set the serializer to call the setter self.serializer = serializer @@ -154,6 +158,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None: def predict( self, x: Any, + data_source: Optional[str] = None, data_pipeline: Optional[DataPipeline] = None, ) -> Any: """ @@ -169,9 +174,9 @@ def predict( """ running_stage = RunningStage.PREDICTING - data_pipeline = self.build_data_pipeline(data_pipeline) + data_pipeline = self.build_data_pipeline(data_source or "default", data_pipeline) - x = [x for x in data_pipeline._generate_auto_dataset(x, running_stage)] + x = [x for x in data_pipeline.data_source.generate_dataset(x, running_stage)] x = data_pipeline.worker_preprocessor(running_stage)(x) # switch to self.device when #7188 merge in Lightning x = self.transfer_batch_to_device(x, next(self.parameters()).device) @@ -252,7 +257,11 @@ def serializer(self, serializer: Union[Serializer, Mapping[str, Serializer]]): serializer = SerializerMapping(serializer) self._serializer = serializer - def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]: + def build_data_pipeline( + self, + data_source: Optional[str] = None, + data_pipeline: Optional[DataPipeline] = None, + ) -> Optional[DataPipeline]: """Build a :class:`.DataPipeline` incorporating available :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects. These will be overridden in the following resolution order (lowest priority first): @@ -269,10 +278,11 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O Returns: The fully resolved :class:`.DataPipeline`. """ - preprocess, postprocess, serializer = None, None, None + old_data_source, preprocess, postprocess, serializer = None, None, None, None # Datamodule if self.datamodule is not None and getattr(self.datamodule, 'data_pipeline', None) is not None: + old_data_source = getattr(self.datamodule.data_pipeline, 'data_source', None) preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.datamodule.data_pipeline, '_serializer', None) @@ -280,9 +290,14 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O elif self.trainer is not None and hasattr( self.trainer, 'datamodule' ) and getattr(self.trainer.datamodule, 'data_pipeline', None) is not None: + old_data_source = getattr(self.trainer.datamodule.data_pipeline, 'data_source', None) preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None) + else: + # TODO: we should log with low severity level that we use defaults to create + # `preprocess`, `postprocess` and `serializer`. + pass # Defaults / task attributes preprocess, postprocess, serializer = Task._resolve( @@ -305,8 +320,16 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O getattr(data_pipeline, '_serializer', None), ) - data_pipeline = DataPipeline(preprocess, postprocess, serializer) - data_pipeline.initialize() + data_source = data_source or old_data_source + + if isinstance(data_source, str): + if preprocess is None: + data_source = DataSource() # TODO: warn the user that we are not using the specified data source + else: + data_source = preprocess.data_source_of_name(data_source) + + data_pipeline = DataPipeline(data_source, preprocess, postprocess, serializer) + self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) return data_pipeline @property @@ -376,12 +399,16 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # https://pytorch.org/docs/stable/notes/serialization.html if self.data_pipeline is not None and 'data_pipeline' not in checkpoint: checkpoint['data_pipeline'] = self.data_pipeline + if self._data_pipeline_state is not None and '_data_pipeline_state' not in checkpoint: + checkpoint['_data_pipeline_state'] = self._data_pipeline_state super().on_save_checkpoint(checkpoint) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: super().on_load_checkpoint(checkpoint) if 'data_pipeline' in checkpoint: self.data_pipeline = checkpoint['data_pipeline'] + if '_data_pipeline_state' in checkpoint: + self._data_pipeline_state = checkpoint['_data_pipeline_state'] @classmethod def available_backbones(cls) -> List[str]: diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 2ba6dd92f4..191385900d 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -12,167 +12,100 @@ # See the License for the specific language governing permissions and # limitations under the License. from inspect import signature -from typing import Any, Callable, Iterable, Iterator, Optional, TYPE_CHECKING +from typing import Any, Callable, Generic, Iterable, Optional, Sequence, TYPE_CHECKING, TypeVar -import torch from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.utilities.warning_utils import rank_zero_warn from torch.utils.data import Dataset, IterableDataset -from flash.data.callback import ControlFlow -from flash.data.process import Preprocess -from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, CurrentRunningStageFuncContext +from flash.data.utils import CurrentRunningStageFuncContext if TYPE_CHECKING: from flash.data.data_pipeline import DataPipeline + from flash.data.data_source import DataSource +DATA_TYPE = TypeVar('DATA_TYPE') -class BaseAutoDataset: + +class BaseAutoDataset(Generic[DATA_TYPE]): DATASET_KEY = "dataset" - """ - This class is used to encapsulate a Preprocess Object ``load_data`` and ``load_sample`` functions. - ``load_data`` will be called within the ``__init__`` function of the AutoDataset if ``running_stage`` - is provided and ``load_sample`` within ``__getitem__``. + """This class is used to encapsulate a Preprocess Object ``load_data`` and ``load_sample`` functions. ``load_data`` + will be called within the ``__init__`` function of the AutoDataset if ``running_stage`` is provided and + ``load_sample`` within ``__getitem__``. + + Args: + + data: The output of a call to :meth:`~flash.data.data_source.load_data`. + + data_source: The :class:`~flash.data.data_source.DataSource` which has the ``load_sample`` method. + + running_stage: The current running stage. """ def __init__( self, - data: Any, - load_data: Optional[Callable] = None, - load_sample: Optional[Callable] = None, - data_pipeline: Optional['DataPipeline'] = None, - running_stage: Optional[RunningStage] = None + data: DATA_TYPE, + data_source: 'DataSource', + running_stage: RunningStage, ) -> None: super().__init__() - if load_data or load_sample: - if data_pipeline: - rank_zero_warn( - "``datapipeline`` is specified but load_sample and/or load_data are also specified. " - "Won't use datapipeline" - ) - # initial states - self._load_data_called = False - self._running_stage = None - self.data = data - self.data_pipeline = data_pipeline - self.load_data = load_data - self.load_sample = load_sample + self.data_source = data_source - # trigger the setup only if `running_stage` is provided + self._running_stage = None self.running_stage = running_stage @property - def running_stage(self) -> Optional[RunningStage]: + def running_stage(self) -> RunningStage: return self._running_stage @running_stage.setter def running_stage(self, running_stage: RunningStage) -> None: - if self._running_stage != running_stage or (not self._running_stage): - self._running_stage = running_stage - self._load_data_context = CurrentRunningStageFuncContext(self._running_stage, "load_data", self.preprocess) - self._load_sample_context = CurrentRunningStageFuncContext( - self._running_stage, "load_sample", self.preprocess - ) - self._setup(running_stage) + from flash.data.data_pipeline import DataPipeline # noqa F811 + from flash.data.data_source import DataSource # noqa F811 # TODO: something better than this - @property - def preprocess(self) -> Optional[Preprocess]: - if self.data_pipeline is not None: - return self.data_pipeline._preprocess_pipeline + self._running_stage = running_stage - @property - def control_flow_callback(self) -> Optional[ControlFlow]: - preprocess = self.preprocess - if preprocess is not None: - return ControlFlow(preprocess.callbacks) - - def _call_load_data(self, data: Any) -> Iterable: - parameters = signature(self.load_data).parameters - if len(parameters) > 1 and self.DATASET_KEY in parameters: - return self.load_data(data, self) - else: - return self.load_data(data) + self._load_sample_context = CurrentRunningStageFuncContext(self.running_stage, "load_sample", self.data_source) - def _call_load_sample(self, sample: Any) -> Any: - parameters = signature(self.load_sample).parameters - if len(parameters) > 1 and self.DATASET_KEY in parameters: - return self.load_sample(sample, self) - else: - return self.load_sample(sample) - - def _setup(self, stage: Optional[RunningStage]) -> None: - assert not stage or _STAGES_PREFIX[stage] in _STAGES_PREFIX_VALUES - previous_load_data = self.load_data.__code__ if self.load_data else None - - if self._running_stage and self.data_pipeline and (not self.load_data or not self.load_sample) and stage: - self.load_data = getattr( - self.preprocess, - self.data_pipeline._resolve_function_hierarchy('load_data', self.preprocess, stage, Preprocess) + self.load_sample: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( + self.data_source, + DataPipeline._resolve_function_hierarchy( + 'load_sample', + self.data_source, + self.running_stage, + DataSource, ) - self.load_sample = getattr( - self.preprocess, - self.data_pipeline._resolve_function_hierarchy('load_sample', self.preprocess, stage, Preprocess) - ) - if self.load_data and (previous_load_data != self.load_data.__code__ or not self._load_data_called): - if previous_load_data: - rank_zero_warn( - "The load_data function of the Autogenerated Dataset changed. " - "This is not expected! Preloading Data again to ensure compatibility. This may take some time." - ) - self.setup() - self._load_data_called = True - - def setup(self): - raise NotImplementedError + ) + def _call_load_sample(self, sample: Any) -> Any: + if self.load_sample: + if isinstance(sample, dict): + sample = dict(**sample) + with self._load_sample_context: + parameters = signature(self.load_sample).parameters + if len(parameters) > 1 and self.DATASET_KEY in parameters: + sample = self.load_sample(sample, self) + else: + sample = self.load_sample(sample) + return sample -class AutoDataset(BaseAutoDataset, Dataset): - def setup(self): - with self._load_data_context: - self.preprocessed_data = self._call_load_data(self.data) +class AutoDataset(BaseAutoDataset[Sequence], Dataset): def __getitem__(self, index: int) -> Any: - if not self.load_sample and not self.load_data: - raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") - if self.load_sample: - with self._load_sample_context: - data: Any = self._call_load_sample(self.preprocessed_data[index]) - if self.control_flow_callback: - self.control_flow_callback.on_load_sample(data, self.running_stage) - return data - return self.preprocessed_data[index] + return self._call_load_sample(self.data[index]) def __len__(self) -> int: - if not self.load_sample and not self.load_data: - raise RuntimeError("`__len__` for `load_sample` and `load_data` could not be inferred.") - return len(self.preprocessed_data) - + return len(self.data) -class IterableAutoDataset(BaseAutoDataset, IterableDataset): - def setup(self): - with self._load_data_context: - self.dataset = self._call_load_data(self.data) - self.dataset_iter = None +class IterableAutoDataset(BaseAutoDataset[Iterable], IterableDataset): def __iter__(self): - self.dataset_iter = iter(self.dataset) + self.data_iter = iter(self.data) return self def __next__(self) -> Any: - if not self.load_sample and not self.load_data: - raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") - - data = next(self.dataset_iter) - - if self.load_sample: - with self._load_sample_context: - data: Any = self._call_load_sample(data) - if self.control_flow_callback: - self.control_flow_callback.on_load_sample(data, self.running_stage) - return data - return data + return self._call_load_sample(next(self.data_iter)) diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py index c05cc93dcc..3ad1506257 100644 --- a/flash/data/base_viz.py +++ b/flash/data/base_viz.py @@ -5,7 +5,7 @@ from flash.core.utils import _is_overriden from flash.data.callback import BaseDataFetcher -from flash.data.utils import _PREPROCESS_FUNCS +from flash.data.utils import _CALLBACK_FUNCS class BaseVisualization(BaseDataFetcher): @@ -103,7 +103,7 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage, func_names_li Override this function when you want to visualize a composition. """ # filter out the functions to visualise - func_names_set: Set[str] = set(func_names_list) & set(_PREPROCESS_FUNCS) + func_names_set: Set[str] = set(func_names_list) & set(_CALLBACK_FUNCS) if len(func_names_set) == 0: raise MisconfigurationException(f"Invalid function names: {func_names_list}.") diff --git a/flash/data/batch.py b/flash/data/batch.py index ea6ce1e9ca..739f4704ea 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -57,6 +57,8 @@ def __init__( self._post_tensor_transform_context = CurrentFuncContext("post_tensor_transform", preprocess) def forward(self, sample: Any) -> Any: + self.callback.on_load_sample(sample, self.stage) + with self._current_stage_context: with self._pre_tensor_transform_context: sample = self.pre_tensor_transform(sample) diff --git a/flash/data/callback.py b/flash/data/callback.py index a479a6e59e..1221046a31 100644 --- a/flash/data/callback.py +++ b/flash/data/callback.py @@ -190,9 +190,6 @@ def enable(self): yield self.enabled = False - def attach_to_datamodule(self, datamodule) -> None: - datamodule.data_fetcher = self - def attach_to_preprocess(self, preprocess: 'flash.data.process.Preprocess') -> None: preprocess.add_callbacks([self]) self._preprocess = preprocess diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 874bcd8132..f64c25284a 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -13,22 +13,21 @@ # limitations under the License. import os import platform -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import pytorch_lightning as pl import torch -from datasets.splits import SplitInfo from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch.nn import Module from torch.utils.data import DataLoader, Dataset from torch.utils.data.dataset import IterableDataset, Subset -from flash.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset +from flash.data.auto_dataset import BaseAutoDataset, IterableAutoDataset from flash.data.base_viz import BaseVisualization from flash.data.callback import BaseDataFetcher from flash.data.data_pipeline import DataPipeline, DefaultPreprocess, Postprocess, Preprocess +from flash.data.data_source import DataSource, DefaultDataSources from flash.data.splits import SplitDataset from flash.data.utils import _STAGES_PREFIX @@ -57,16 +56,34 @@ def __init__( val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, predict_dataset: Optional[Dataset] = None, + data_source: Optional[DataSource] = None, + preprocess: Optional[Preprocess] = None, + postprocess: Optional[Postprocess] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + val_split: Optional[float] = None, batch_size: int = 1, - num_workers: Optional[int] = 0, + num_workers: Optional[int] = None, ) -> None: super().__init__() + + self._data_source: DataSource = data_source + self._preprocess: Optional[Preprocess] = preprocess + self._postprocess: Optional[Postprocess] = postprocess + self._viz: Optional[BaseVisualization] = None + self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher() + + # TODO: Preprocess can change + self.data_fetcher.attach_to_preprocess(self.preprocess) + self._train_ds = train_dataset self._val_ds = val_dataset self._test_ds = test_dataset self._predict_ds = predict_dataset + if self._train_ds is not None and (val_split is not None and self._val_ds is None): + self._train_ds, self._val_ds = self._split_train_val(self._train_ds, val_split) + if self._train_ds: self.train_dataloader = self._train_dataloader @@ -89,12 +106,6 @@ def __init__( num_workers = os.cpu_count() self.num_workers = num_workers - self._preprocess: Optional[Preprocess] = None - self._postprocess: Optional[Postprocess] = None - self._viz: Optional[BaseVisualization] = None - self._data_fetcher: Optional[BaseDataFetcher] = None - - # this may also trigger data preloading self.set_running_stages() @property @@ -141,7 +152,7 @@ def data_fetcher(self) -> BaseDataFetcher: def data_fetcher(self, data_fetcher: BaseDataFetcher) -> None: self._data_fetcher = data_fetcher - def _reset_iterator(self, stage: RunningStage) -> Iterable[Any]: + def _reset_iterator(self, stage: str) -> Iterable[Any]: iter_name = f"_{stage}_iter" # num_workers has to be set to 0 to work properly num_workers = self.num_workers @@ -152,7 +163,7 @@ def _reset_iterator(self, stage: RunningStage) -> Iterable[Any]: setattr(self, iter_name, iterator) return iterator - def _show_batch(self, stage: RunningStage, func_names: Union[str, List[str]], reset: bool = True) -> None: + def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool = True) -> None: """ This function is used to handle transforms profiling for batch visualization. """ @@ -278,11 +289,6 @@ def _predict_dataloader(self) -> DataLoader: collate_fn=self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING) ) - def generate_auto_dataset(self, *args, **kwargs): - if all(a is None for a in args) and len(kwargs) == 0: - return None - return self.data_pipeline._generate_auto_dataset(*args, **kwargs) - @property def num_classes(self) -> Optional[int]: return ( @@ -290,6 +296,10 @@ def num_classes(self) -> Optional[int]: or getattr(self.test_dataset, "num_classes", None) ) + @property + def data_source(self) -> Optional[DataSource]: + return self._data_source + @property def preprocess(self) -> Preprocess: return self._preprocess or self.preprocess_cls() @@ -300,55 +310,11 @@ def postprocess(self) -> Postprocess: @property def data_pipeline(self) -> DataPipeline: - return DataPipeline(self.preprocess, self.postprocess) + return DataPipeline(self.data_source, self.preprocess, self.postprocess) @staticmethod - def _check_transforms(transform: Dict[str, Union[Module, Callable]]) -> Dict[str, Union[Module, Callable]]: - if not isinstance(transform, dict): - raise MisconfigurationException( - "Transform should be a dict. Here are the available keys " - f"for your transforms: {DataPipeline.PREPROCESS_FUNCS}." - ) - return transform - - @classmethod - def autogenerate_dataset( - cls, - data: Any, - running_stage: RunningStage, - whole_data_load_fn: Optional[Callable] = None, - per_sample_load_fn: Optional[Callable] = None, - data_pipeline: Optional[DataPipeline] = None, - use_iterable_auto_dataset: bool = False, - ) -> BaseAutoDataset: - """ - This function is used to generate an ``BaseAutoDataset`` from a ``DataPipeline`` if provided - or from the provided ``whole_data_load_fn``, ``per_sample_load_fn`` functions directly - """ - - preprocess = getattr(data_pipeline, '_preprocess_pipeline', None) - - if whole_data_load_fn is None: - whole_data_load_fn = getattr( - preprocess, - DataPipeline._resolve_function_hierarchy('load_data', preprocess, running_stage, Preprocess) - ) - - if per_sample_load_fn is None: - per_sample_load_fn = getattr( - preprocess, - DataPipeline._resolve_function_hierarchy('load_sample', preprocess, running_stage, Preprocess) - ) - if use_iterable_auto_dataset: - return IterableAutoDataset( - data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage - ) - return BaseAutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage) - - @classmethod def _split_train_val( - cls, - train_dataset: Union[AutoDataset, IterableAutoDataset], + train_dataset: Dataset, val_split: float, ) -> Tuple[Any, Any]: @@ -357,7 +323,7 @@ def _split_train_val( if isinstance(train_dataset, IterableAutoDataset): raise MisconfigurationException( - "`val_split` should be `None` when the dataset is built with an IterativeDataset." + "`val_split` should be `None` when the dataset is built with an IterableDataset." ) train_num_samples = len(train_dataset) @@ -367,113 +333,279 @@ def _split_train_val( return SplitDataset(train_dataset, train_indices), SplitDataset(train_dataset, val_indices) @classmethod - def _generate_dataset_if_possible( + def from_data_source( cls, - data: Optional[Any], - running_stage: RunningStage, - whole_data_load_fn: Optional[Callable] = None, - per_sample_load_fn: Optional[Callable] = None, - data_pipeline: Optional[DataPipeline] = None, - use_iterable_auto_dataset: bool = False, - ) -> Optional[BaseAutoDataset]: - if data is None: - return - - if data_pipeline: - return data_pipeline._generate_auto_dataset( - data, - running_stage=running_stage, - use_iterable_auto_dataset=use_iterable_auto_dataset, - ) + data_source: str, + train_data: Any = None, + val_data: Any = None, + test_data: Any = None, + predict_data: Any = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ) -> 'DataModule': + preprocess = preprocess or cls.preprocess_cls( + train_transform, + val_transform, + test_transform, + predict_transform, + **preprocess_kwargs, + ) - return cls.autogenerate_dataset( - data, - running_stage, - whole_data_load_fn, - per_sample_load_fn, - data_pipeline, - use_iterable_auto_dataset=use_iterable_auto_dataset, + data_source = preprocess.data_source_of_name(data_source) + + train_dataset, val_dataset, test_dataset, predict_dataset = data_source.to_datasets( + train_data, + val_data, + test_data, + predict_data, + ) + + return cls( + train_dataset, + val_dataset, + test_dataset, + predict_dataset, + data_source=data_source, + preprocess=preprocess, + data_fetcher=data_fetcher, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, ) @classmethod - def from_load_data_inputs( + def from_folders( cls, - train_load_data_input: Optional[Any] = None, - val_load_data_input: Optional[Any] = None, - test_load_data_input: Optional[Any] = None, - predict_load_data_input: Optional[Any] = None, + train_folder: Optional[str] = None, + val_folder: Optional[str] = None, + test_folder: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, - use_iterable_auto_dataset: bool = False, - seed: int = 42, val_split: Optional[float] = None, - **kwargs, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, ) -> 'DataModule': - """ - This functions is an helper to generate a ``DataModule`` from a ``DataPipeline``. - - Args: - cls: ``DataModule`` subclass - train_load_data_input: Data to be received by the ``train_load_data`` function - from this :class:`~flash.data.process.Preprocess` - val_load_data_input: Data to be received by the ``val_load_data`` function - from this :class:`~flash.data.process.Preprocess` - test_load_data_input: Data to be received by the ``test_load_data`` function - from this :class:`~flash.data.process.Preprocess` - predict_load_data_input: Data to be received by the ``predict_load_data`` function - from this :class:`~flash.data.process.Preprocess` - kwargs: Any extra arguments to instantiate the provided ``DataModule`` - """ - # trick to get data_pipeline from empty DataModule - if preprocess or postprocess: - data_pipeline = DataPipeline( - preprocess or cls(**kwargs).preprocess, - postprocess or cls(**kwargs).postprocess, - ) - else: - data_pipeline = cls(**kwargs).data_pipeline - - data_fetcher: BaseDataFetcher = data_fetcher or cls.configure_data_fetcher() - - data_fetcher.attach_to_preprocess(data_pipeline._preprocess_pipeline) - - train_dataset = cls._generate_dataset_if_possible( - train_load_data_input, - running_stage=RunningStage.TRAINING, - data_pipeline=data_pipeline, - use_iterable_auto_dataset=use_iterable_auto_dataset, + return cls.from_data_source( + DefaultDataSources.PATHS, + train_folder, + val_folder, + test_folder, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, ) - val_dataset = cls._generate_dataset_if_possible( - val_load_data_input, - running_stage=RunningStage.VALIDATING, - data_pipeline=data_pipeline, - use_iterable_auto_dataset=use_iterable_auto_dataset, + + @classmethod + def from_files( + cls, + train_files: Optional[Sequence[str]] = None, + train_targets: Optional[Sequence[Any]] = None, + val_files: Optional[Sequence[str]] = None, + val_targets: Optional[Sequence[Any]] = None, + test_files: Optional[Sequence[str]] = None, + test_targets: Optional[Sequence[Any]] = None, + predict_files: Optional[Sequence[str]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ) -> 'DataModule': + return cls.from_data_source( + DefaultDataSources.PATHS, + (train_files, train_targets), + (val_files, val_targets), + (test_files, test_targets), + predict_files, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, ) - test_dataset = cls._generate_dataset_if_possible( - test_load_data_input, - running_stage=RunningStage.TESTING, - data_pipeline=data_pipeline, - use_iterable_auto_dataset=use_iterable_auto_dataset, + + @classmethod + def from_tensors( + cls, + train_data: Optional[Collection[torch.Tensor]] = None, + train_targets: Optional[Collection[Any]] = None, + val_data: Optional[Collection[torch.Tensor]] = None, + val_targets: Optional[Sequence[Any]] = None, + test_data: Optional[Collection[torch.Tensor]] = None, + test_targets: Optional[Sequence[Any]] = None, + predict_data: Optional[Collection[torch.Tensor]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ) -> 'DataModule': + return cls.from_data_source( + DefaultDataSources.TENSOR, + (train_data, train_targets), + (val_data, val_targets), + (test_data, test_targets), + predict_data, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, ) - predict_dataset = cls._generate_dataset_if_possible( - predict_load_data_input, - running_stage=RunningStage.PREDICTING, - data_pipeline=data_pipeline, - use_iterable_auto_dataset=use_iterable_auto_dataset, + + @classmethod + def from_numpy( + cls, + train_data: Optional[Collection[np.ndarray]] = None, + train_targets: Optional[Collection[Any]] = None, + val_data: Optional[Collection[np.ndarray]] = None, + val_targets: Optional[Sequence[Any]] = None, + test_data: Optional[Collection[np.ndarray]] = None, + test_targets: Optional[Sequence[Any]] = None, + predict_data: Optional[Collection[np.ndarray]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ) -> 'DataModule': + return cls.from_data_source( + DefaultDataSources.NUMPY, + (train_data, train_targets), + (val_data, val_targets), + (test_data, test_targets), + predict_data, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, ) - if train_dataset is not None and (val_split is not None and val_dataset is None): - train_dataset, val_dataset = cls._split_train_val(train_dataset, val_split) + @classmethod + def from_json( + cls, + input_fields: Union[str, Sequence[str]], + target_fields: Optional[Union[str, Sequence[str]]] = None, + train_file: Optional[str] = None, + val_file: Optional[str] = None, + test_file: Optional[str] = None, + predict_file: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ) -> 'DataModule': + return cls.from_data_source( + DefaultDataSources.JSON, + (train_file, input_fields, target_fields), + (val_file, input_fields, target_fields), + (test_file, input_fields, target_fields), + (predict_file, input_fields, target_fields), + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) - datamodule = cls( - train_dataset=train_dataset, - val_dataset=val_dataset, - test_dataset=test_dataset, - predict_dataset=predict_dataset, - **kwargs + @classmethod + def from_csv( + cls, + input_fields: Union[str, Sequence[str]], + target_fields: Optional[Union[str, Sequence[str]]] = None, + train_file: Optional[str] = None, + val_file: Optional[str] = None, + test_file: Optional[str] = None, + predict_file: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ) -> 'DataModule': + return cls.from_data_source( + DefaultDataSources.CSV, + (train_file, input_fields, target_fields), + (val_file, input_fields, target_fields), + (test_file, input_fields, target_fields), + (predict_file, input_fields, target_fields), + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, ) - datamodule._preprocess = data_pipeline._preprocess_pipeline - datamodule._postprocess = data_pipeline._postprocess_pipeline - data_fetcher.attach_to_datamodule(datamodule) - return datamodule diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index baeebaa760..07ab9bab50 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -14,7 +14,7 @@ import functools import inspect import weakref -from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING import torch from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader @@ -22,11 +22,13 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader, IterableDataset -from torch.utils.data._utils.collate import default_collate, default_convert +from torch.utils.data._utils.collate import default_collate -from flash.data.auto_dataset import AutoDataset, IterableAutoDataset +from flash.data.auto_dataset import IterableAutoDataset from flash.data.batch import _PostProcessor, _PreProcessor, _Sequential -from flash.data.process import DefaultPreprocess, Postprocess, Preprocess, ProcessState, Serializer +from flash.data.data_source import DataSource +from flash.data.process import DefaultPreprocess, Postprocess, Preprocess, Serializer +from flash.data.properties import ProcessState from flash.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX if TYPE_CHECKING: @@ -88,10 +90,13 @@ class CustomPostprocess(Postprocess): def __init__( self, + data_source: Optional[DataSource] = None, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None, serializer: Optional[Serializer] = None, ) -> None: + self.data_source = data_source + self._preprocess_pipeline = preprocess or DefaultPreprocess() self._postprocess_pipeline = postprocess or Postprocess() @@ -99,15 +104,19 @@ def __init__( self._running_stage = None - def initialize(self): + def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) -> DataPipelineState: """Creates the :class:`.DataPipelineState` and gives the reference to the: :class:`.Preprocess`, :class:`.Postprocess`, and :class:`.Serializer`. Once this has been called, any attempt to add new state will give a warning.""" - data_pipeline_state = DataPipelineState() + data_pipeline_state = data_pipeline_state or DataPipelineState() + data_pipeline_state._initialized = False + if self.data_source is not None: + self.data_source.attach_data_pipeline_state(data_pipeline_state) self._preprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) self._postprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) self._serializer.attach_data_pipeline_state(data_pipeline_state) - data_pipeline_state._initialized = True + data_pipeline_state._initialized = True # TODO: Not sure we need this + return data_pipeline_state @staticmethod def _is_overriden(method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None) -> bool: @@ -506,46 +515,6 @@ def _detach_postprocess_from_model(model: 'Task'): # if any other pipeline is attached which may rely on this! model.predict_step = model.predict_step._original - def _generate_callable_auto_dataset( - self, data: Union[Iterable, Any], running_stage: RunningStage = None - ) -> Callable: - - def fn(): - return self._generate_auto_dataset(data, running_stage=running_stage) - - return fn - - def _generate_auto_dataset( - self, - data: Union[Iterable, Any], - running_stage: RunningStage = None, - use_iterable_auto_dataset: bool = False - ) -> Union[AutoDataset, IterableAutoDataset]: - if use_iterable_auto_dataset: - return IterableAutoDataset(data, data_pipeline=self, running_stage=running_stage) - return AutoDataset(data=data, data_pipeline=self, running_stage=running_stage) - - def to_dataloader( - self, data: Union[Iterable, Any], auto_collate: Optional[bool] = None, **loader_kwargs - ) -> DataLoader: - if 'collate_fn' in loader_kwargs: - if auto_collate: - raise MisconfigurationException('auto_collate and collate_fn are mutually exclusive') - - else: - if auto_collate is None: - auto_collate = True - - collate_fn = self.worker_collate_fn - - if collate_fn: - loader_kwargs['collate_fn'] = collate_fn - - else: - loader_kwargs['collate_fn'] = default_collate if auto_collate else default_convert - - return DataLoader(self._generate_auto_dataset(data), **loader_kwargs) - def __str__(self) -> str: preprocess: Preprocess = self._preprocess_pipeline postprocess: Postprocess = self._postprocess_pipeline diff --git a/flash/data/data_source.py b/flash/data/data_source.py new file mode 100644 index 0000000000..e637eab923 --- /dev/null +++ b/flash/data/data_source.py @@ -0,0 +1,264 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import typing +from dataclasses import dataclass +from inspect import signature +from typing import Any, Callable, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union + +import numpy as np +import torch +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.enums import LightningEnum +from torch.nn import Module +from torchvision.datasets.folder import has_file_allowed_extension, make_dataset + +from flash.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset +from flash.data.properties import ProcessState, Properties +from flash.data.utils import CurrentRunningStageFuncContext + + +def has_len(data: Union[Sequence[Any], Iterable[Any]]) -> bool: + try: + len(data) + return True + except (TypeError, NotImplementedError): + return False + + +@dataclass(unsafe_hash=True, frozen=True) +class LabelsState(ProcessState): + + labels: Optional[Sequence[str]] + + +class MockDataset: + + def __init__(self): + self.metadata = {} + + def __setattr__(self, key, value): + if key != 'metadata': + self.metadata[key] = value + else: + object.__setattr__(self, key, value) + + +DATA_TYPE = TypeVar("DATA_TYPE") + + +class DataSource(Generic[DATA_TYPE], Properties, Module): + + def load_data(self, + data: DATA_TYPE, + dataset: Optional[Any] = None) -> Union[Sequence[Mapping[str, Any]], Iterable[Mapping[str, Any]]]: + """Loads entire data from Dataset. The input ``data`` can be anything, but you need to return a Mapping. + + Example:: + + # data: "." + # output: [("./cat/1.png", 1), ..., ("./dog/10.png", 0)] + + output: Mapping = load_data(data) + + """ + return data + + def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: + """Loads single sample from dataset""" + return sample + + def to_datasets( + self, + train_data: Optional[DATA_TYPE] = None, + val_data: Optional[DATA_TYPE] = None, + test_data: Optional[DATA_TYPE] = None, + predict_data: Optional[DATA_TYPE] = None, + ) -> Tuple[Optional[BaseAutoDataset], ...]: + train_dataset = self.generate_dataset(train_data, RunningStage.TRAINING) + val_dataset = self.generate_dataset(val_data, RunningStage.VALIDATING) + test_dataset = self.generate_dataset(test_data, RunningStage.TESTING) + predict_dataset = self.generate_dataset(predict_data, RunningStage.PREDICTING) + return train_dataset, val_dataset, test_dataset, predict_dataset + + def generate_dataset( + self, + data: Optional[DATA_TYPE], + running_stage: RunningStage, + ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: + is_none = data is None + + if isinstance(data, Sequence): + is_none = data[0] is None + + if not is_none: + from flash.data.data_pipeline import DataPipeline + + mock_dataset = typing.cast(AutoDataset, MockDataset()) + with CurrentRunningStageFuncContext(running_stage, "load_data", self): + load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( + self, DataPipeline._resolve_function_hierarchy( + "load_data", + self, + running_stage, + DataSource, + ) + ) + parameters = signature(load_data).parameters + if len(parameters) > 1 and "dataset" in parameters: # TODO: This was DATASET_KEY before + data = load_data(data, mock_dataset) + else: + data = load_data(data) + + if has_len(data): + dataset = AutoDataset(data, self, running_stage) + else: + dataset = IterableAutoDataset(data, self, running_stage) + dataset.__dict__.update(mock_dataset.metadata) + return dataset + + +class DefaultDataSources(LightningEnum): + + PATHS = "paths" + NUMPY = "numpy" + TENSOR = "tensor" + CSV = "csv" + JSON = "json" + + # TODO: Create a FlashEnum class??? + def __hash__(self) -> int: + return hash(self.value) + + +class DefaultDataKeys(LightningEnum): + + INPUT = "input" + TARGET = "target" + + # TODO: Create a FlashEnum class??? + def __hash__(self) -> int: + return hash(self.value) + + +SEQUENCE_DATA_TYPE = TypeVar("SEQUENCE_DATA_TYPE") + + +class SequenceDataSource( + Generic[SEQUENCE_DATA_TYPE], + DataSource[Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]]], +): + + def __init__(self, labels: Optional[Sequence[str]] = None): + super().__init__() + + self.labels = labels + + if self.labels is not None: + self.set_state(LabelsState(self.labels)) + + def load_data( + self, + data: Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]], + dataset: Optional[Any] = None, + ) -> Sequence[Mapping[str, Any]]: + # TODO: Bring back the code to work out how many classes there are + inputs, targets = data + if targets is None: + return self.predict_load_data(data) + return [{ + DefaultDataKeys.INPUT: input, + DefaultDataKeys.TARGET: target + } for input, target in zip(inputs, targets)] + + def predict_load_data(self, data: Sequence[SEQUENCE_DATA_TYPE]) -> Sequence[Mapping[str, Any]]: + return [{DefaultDataKeys.INPUT: input} for input in data] + + +class PathsDataSource(SequenceDataSource): # TODO: Sort out the typing here + + def __init__(self, extensions: Optional[Tuple[str, ...]] = None): + super().__init__() + + self.extensions = extensions + + @staticmethod + def find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: + """ + Finds the class folders in a dataset. Ensures that no class is a subdirectory of another. + + Args: + dir: Root directory path. + + Returns: + tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. + """ + classes = [d.name for d in os.scandir(dir) if d.is_dir()] + classes.sort() + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + return classes, class_to_idx + + @staticmethod + def isdir(data: Union[str, Tuple[List[str], List[Any]]]) -> bool: + try: + return os.path.isdir(data) + except TypeError: + # data is not path-like (e.g. it may be a list of paths) + return False + + def load_data(self, + data: Union[str, Tuple[List[str], List[Any]]], + dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: + if self.isdir(data): + classes, class_to_idx = self.find_classes(data) + if not classes: + return self.predict_load_data(data) + else: + self.set_state(LabelsState(classes)) + + if dataset is not None: + dataset.num_classes = len(classes) + + data = make_dataset(data, class_to_idx, extensions=self.extensions) + return [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data] + return list( + filter( + lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions), + super().load_data(data, dataset), + ) + ) + + def predict_load_data(self, + data: Union[str, List[str]], + dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: + if self.isdir(data): + data = [os.path.join(data, file) for file in os.listdir(data)] + + if not isinstance(data, list): + data = [data] + + return list( + filter( + lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions), + super().predict_load_data(data), + ) + ) + + +class TensorDataSource(SequenceDataSource[torch.Tensor]): + """""" # TODO: Some docstring here + + +class NumpyDataSource(SequenceDataSource[np.ndarray]): + """""" # TODO: Some docstring here diff --git a/flash/data/process.py b/flash/data/process.py index c8d232cccf..050847dfa0 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -13,8 +13,7 @@ # limitations under the License. import os from abc import ABC, abstractclassmethod, abstractmethod -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, TYPE_CHECKING, TypeVar +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, TypeVar, Union import torch from pytorch_lightning.trainer.states import RunningStage @@ -25,110 +24,14 @@ from flash.data.batch import default_uncollate from flash.data.callback import FlashCallback +from flash.data.data_source import DataSource +from flash.data.properties import Properties from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules if TYPE_CHECKING: from flash.data.data_pipeline import DataPipelineState -@dataclass(unsafe_hash=True, frozen=True) -class ProcessState: - """ - Base class for all process states - """ - pass - - -STATE_TYPE = TypeVar('STATE_TYPE', bound=ProcessState) - - -class Properties: - - def __init__(self): - super().__init__() - - self._running_stage: Optional[RunningStage] = None - self._current_fn: Optional[str] = None - self._data_pipeline_state: Optional['DataPipelineState'] = None - self._state: Dict[Type[ProcessState], ProcessState] = {} - - def get_state(self, state_type: Type[STATE_TYPE]) -> Optional[STATE_TYPE]: - if self._data_pipeline_state is not None: - return self._data_pipeline_state.get_state(state_type) - else: - return None - - def set_state(self, state: ProcessState): - self._state[type(state)] = state - if self._data_pipeline_state is not None: - self._data_pipeline_state.set_state(state) - - def attach_data_pipeline_state(self, data_pipeline_state: 'DataPipelineState'): - self._data_pipeline_state = data_pipeline_state - for state in self._state.values(): - self._data_pipeline_state.set_state(state) - - @property - def current_fn(self) -> Optional[str]: - return self._current_fn - - @current_fn.setter - def current_fn(self, current_fn: str): - self._current_fn = current_fn - - @property - def running_stage(self) -> Optional[RunningStage]: - return self._running_stage - - @running_stage.setter - def running_stage(self, running_stage: RunningStage): - self._running_stage = running_stage - - @property - def training(self) -> bool: - return self._running_stage == RunningStage.TRAINING - - @training.setter - def training(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.TRAINING - elif self.training: - self._running_stage = None - - @property - def testing(self) -> bool: - return self._running_stage == RunningStage.TESTING - - @testing.setter - def testing(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.TESTING - elif self.testing: - self._running_stage = None - - @property - def predicting(self) -> bool: - return self._running_stage == RunningStage.PREDICTING - - @predicting.setter - def predicting(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.PREDICTING - elif self.predicting: - self._running_stage = None - - @property - def validating(self) -> bool: - return self._running_stage == RunningStage.VALIDATING - - @validating.setter - def validating(self, val: bool) -> None: - if val: - self._running_stage = RunningStage.VALIDATING - elif self.validating: - self._running_stage = None - - class BasePreprocess(ABC): @abstractmethod @@ -146,6 +49,9 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): pass +DATA_SOURCE_TYPE = TypeVar("DATA_SOURCE_TYPE") + + class Preprocess(BasePreprocess, Properties, Module): """ The :class:`~flash.data.process.Preprocess` encapsulates @@ -304,9 +210,17 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, + data_sources: Optional[Dict[str, 'DataSource']] = None, + default_data_source: Optional[str] = None, ): super().__init__() + # resolve the default transforms + train_transform = train_transform or self.default_train_transforms + val_transform = val_transform or self.default_val_transforms + test_transform = test_transform or self.default_test_transforms + predict_transform = predict_transform or self.default_predict_transforms + # used to keep track of provided transforms self._train_collate_in_worker_from_transform: Optional[bool] = None self._val_collate_in_worker_from_transform: Optional[bool] = None @@ -314,17 +228,44 @@ def __init__( self._test_collate_in_worker_from_transform: Optional[bool] = None # store the transform before conversion to modules. - self._train_transform = self._check_transforms(train_transform, RunningStage.TRAINING) - self._val_transform = self._check_transforms(val_transform, RunningStage.VALIDATING) - self._test_transform = self._check_transforms(test_transform, RunningStage.TESTING) - self._predict_transform = self._check_transforms(predict_transform, RunningStage.PREDICTING) + self.train_transform = self._check_transforms(train_transform, RunningStage.TRAINING) + self.val_transform = self._check_transforms(val_transform, RunningStage.VALIDATING) + self.test_transform = self._check_transforms(test_transform, RunningStage.TESTING) + self.predict_transform = self._check_transforms(predict_transform, RunningStage.PREDICTING) + + self._train_transform = convert_to_modules(self.train_transform) + self._val_transform = convert_to_modules(self.val_transform) + self._test_transform = convert_to_modules(self.test_transform) + self._predict_transform = convert_to_modules(self.predict_transform) + + self._data_sources = data_sources + self._default_data_source = default_data_source + self._callbacks: List[FlashCallback] = [] - self.train_transform = convert_to_modules(self._train_transform) - self.val_transform = convert_to_modules(self._val_transform) - self.test_transform = convert_to_modules(self._test_transform) - self.predict_transform = convert_to_modules(self._predict_transform) + @property + def default_train_transforms(self) -> Optional[Dict[str, Callable]]: + return None - self._callbacks: List[FlashCallback] = [] + @property + def default_val_transforms(self) -> Optional[Dict[str, Callable]]: + return None + + @property + def default_test_transforms(self) -> Optional[Dict[str, Callable]]: + return None + + @property + def default_predict_transforms(self) -> Optional[Dict[str, Callable]]: + return None + + @property + def transforms(self) -> Dict[str, Optional[Dict[str, Callable]]]: + return { + "train_transform": self.train_transform, + "val_transform": self.val_transform, + "test_transform": self.test_transform, + "predict_transform": self.predict_transform, + } def _save_to_state_dict(self, destination, prefix, keep_vars): preprocess_state_dict = self.get_state_dict() @@ -399,14 +340,14 @@ def _get_transform(self, transform: Dict[str, Callable]) -> Callable: @property def current_transform(self) -> Callable: - if self.training and self.train_transform: - return self._get_transform(self.train_transform) - elif self.validating and self.val_transform: - return self._get_transform(self.val_transform) - elif self.testing and self.test_transform: - return self._get_transform(self.test_transform) - elif self.predicting and self.predict_transform: - return self._get_transform(self.predict_transform) + if self.training and self._train_transform: + return self._get_transform(self._train_transform) + elif self.validating and self._val_transform: + return self._get_transform(self._val_transform) + elif self.testing and self._test_transform: + return self._get_transform(self._test_transform) + elif self.predicting and self._predict_transform: + return self._get_transform(self._predict_transform) else: return self._identity @@ -424,36 +365,17 @@ def add_callbacks(self, callbacks: List['FlashCallback']): _callbacks = [c for c in callbacks if c not in self._callbacks] self._callbacks.extend(_callbacks) - @classmethod - def load_data(cls, data: Any, dataset: Optional[Any] = None) -> Mapping: - """Loads entire data from Dataset. The input ``data`` can be anything, but you need to return a Mapping. - - Example:: - - # data: "." - # output: [("./cat/1.png", 1), ..., ("./dog/10.png", 0)] - - output: Mapping = load_data(data) - - """ - return data - - @classmethod - def load_sample(cls, sample: Any, dataset: Optional[Any] = None) -> Any: - """Loads single sample from dataset""" - return sample - def pre_tensor_transform(self, sample: Any) -> Any: """Transforms to apply on a single object.""" - return sample + return self.current_transform(sample) def to_tensor_transform(self, sample: Any) -> Tensor: """Transforms to convert single object to a tensor.""" - return sample + return self.current_transform(sample) def post_tensor_transform(self, sample: Tensor) -> Tensor: """Transforms to apply on a tensor.""" - return sample + return self.current_transform(sample) def per_batch_transform(self, batch: Any) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). @@ -463,7 +385,7 @@ def per_batch_transform(self, batch: Any) -> Any: This option is mutually exclusive with :meth:`per_sample_transform_on_device`, since if both are specified, uncollation has to be applied. """ - return batch + return self.current_transform(batch) def collate(self, samples: Sequence) -> Any: return default_collate(samples) @@ -481,7 +403,7 @@ def per_sample_transform_on_device(self, sample: Any) -> Any: This function won't be called within the dataloader workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). """ - return sample + return self.current_transform(sample) def per_batch_transform_on_device(self, batch: Any) -> Any: """ @@ -492,13 +414,40 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: This function won't be called within the dataloader workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). """ - return batch + return self.current_transform(batch) + + def data_source_of_name(self, data_source_name: str) -> Optional[DATA_SOURCE_TYPE]: + if data_source_name == "default": + data_source_name = self._default_data_source + data_sources = self._data_sources + if data_source_name in data_sources: + return data_sources[data_source_name] + return None class DefaultPreprocess(Preprocess): + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_sources: Optional[Dict[str, 'DataSource']] = None, + default_data_source: Optional[str] = None, + ): + from flash.data.data_source import DataSource + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources=data_sources or {"default": DataSource()}, + default_data_source=default_data_source or "default", + ) + def get_state_dict(self) -> Dict[str, Any]: - return {} + return {**self.transforms} @classmethod def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): diff --git a/flash/data/properties.py b/flash/data/properties.py new file mode 100644 index 0000000000..2a2934a3d9 --- /dev/null +++ b/flash/data/properties.py @@ -0,0 +1,120 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Dict, Optional, Type, TYPE_CHECKING, TypeVar + +from pytorch_lightning.trainer.states import RunningStage + +if TYPE_CHECKING: + from flash.data.data_pipeline import DataPipelineState + + +@dataclass(unsafe_hash=True, frozen=True) +class ProcessState: + """ + Base class for all process states + """ + pass + + +STATE_TYPE = TypeVar('STATE_TYPE', bound=ProcessState) + + +class Properties: + + def __init__(self): + super().__init__() + + self._running_stage: Optional[RunningStage] = None + self._current_fn: Optional[str] = None + self._data_pipeline_state: Optional['DataPipelineState'] = None + self._state: Dict[Type[ProcessState], ProcessState] = {} + + def get_state(self, state_type: Type[STATE_TYPE]) -> Optional[STATE_TYPE]: + if state_type in self._state: + return self._state[state_type] + if self._data_pipeline_state is not None: + return self._data_pipeline_state.get_state(state_type) + else: + return None + + def set_state(self, state: ProcessState): + self._state[type(state)] = state + if self._data_pipeline_state is not None: + self._data_pipeline_state.set_state(state) + + def attach_data_pipeline_state(self, data_pipeline_state: 'DataPipelineState'): + self._data_pipeline_state = data_pipeline_state + for state in self._state.values(): + self._data_pipeline_state.set_state(state) + + @property + def current_fn(self) -> Optional[str]: + return self._current_fn + + @current_fn.setter + def current_fn(self, current_fn: str): + self._current_fn = current_fn + + @property + def running_stage(self) -> Optional[RunningStage]: + return self._running_stage + + @running_stage.setter + def running_stage(self, running_stage: RunningStage): + self._running_stage = running_stage + + @property + def training(self) -> bool: + return self._running_stage == RunningStage.TRAINING + + @training.setter + def training(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TRAINING + elif self.training: + self._running_stage = None + + @property + def testing(self) -> bool: + return self._running_stage == RunningStage.TESTING + + @testing.setter + def testing(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TESTING + elif self.testing: + self._running_stage = None + + @property + def predicting(self) -> bool: + return self._running_stage == RunningStage.PREDICTING + + @predicting.setter + def predicting(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.PREDICTING + elif self.predicting: + self._running_stage = None + + @property + def validating(self) -> bool: + return self._running_stage == RunningStage.VALIDATING + + @validating.setter + def validating(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.VALIDATING + elif self.validating: + self._running_stage = None diff --git a/flash/data/transforms.py b/flash/data/transforms.py new file mode 100644 index 0000000000..0a26224791 --- /dev/null +++ b/flash/data/transforms.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Mapping, Sequence, Union + +from torch import nn + +from flash.data.utils import convert_to_modules + + +class ApplyToKeys(nn.Sequential): + + def __init__(self, keys: Union[str, Sequence[str]], *args): + super().__init__(*[convert_to_modules(arg) for arg in args]) + if isinstance(keys, str): + keys = [keys] + self.keys = keys + + def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: + inputs = [x[key] for key in filter(lambda key: key in x, self.keys)] + if len(inputs) > 0: + outputs = super().forward(*inputs) + if not isinstance(outputs, tuple): + outputs = (outputs, ) + + result = {} + result.update(x) + for i, key in enumerate(self.keys): + result[key] = outputs[i] + return result + return x diff --git a/flash/data/utils.py b/flash/data/utils.py index 48bac51a93..9a329beb78 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -32,9 +32,12 @@ } _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} -_PREPROCESS_FUNCS: Set[str] = { +_DATASOURCE_FUNCS: Set[str] = { "load_data", "load_sample", +} + +_PREPROCESS_FUNCS: Set[str] = { "pre_tensor_transform", "to_tensor_transform", "post_tensor_transform", @@ -44,6 +47,11 @@ "collate", } +_CALLBACK_FUNCS: Set[str] = { + "load_sample", + *_PREPROCESS_FUNCS, +} + _POSTPROCESS_FUNCS: Set[str] = { "per_batch_transform", "uncollate", diff --git a/flash/tabular/classification/data/data.py b/flash/tabular/classification/data/data.py index ec40abc82d..1b4ad6b9bd 100644 --- a/flash/tabular/classification/data/data.py +++ b/flash/tabular/classification/data/data.py @@ -11,18 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd from pandas.core.frame import DataFrame from pytorch_lightning.utilities.exceptions import MisconfigurationException -from sklearn.model_selection import train_test_split -from torch.utils.data import Dataset -from flash.core.classification import ClassificationState -from flash.data.auto_dataset import AutoDataset +from flash.core.classification import LabelsState from flash.data.data_module import DataModule +from flash.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources from flash.data.process import Preprocess from flash.tabular.classification.data.dataset import ( _compute_normalization, @@ -33,23 +31,21 @@ ) -class TabularPreprocess(Preprocess): +class TabularDataFrameDataSource(DataSource[DataFrame]): def __init__( self, - cat_cols: List[str], - num_cols: List[str], - target_col: str, - mean: DataFrame, - std: DataFrame, - codes: Dict[str, Any], - target_codes: Optional[Dict[str, Any]], - classes: List[str], - num_classes: int, - is_regression: bool, + cat_cols: Optional[List[str]] = None, + num_cols: Optional[List[str]] = None, + target_col: Optional[str] = None, + mean: Optional[DataFrame] = None, + std: Optional[DataFrame] = None, + codes: Optional[Dict[str, Any]] = None, + target_codes: Optional[Dict[str, Any]] = None, + classes: Optional[List[str]] = None, + is_regression: bool = True, ): super().__init__() - self.set_state(ClassificationState(classes)) self.cat_cols = cat_cols self.num_cols = num_cols @@ -58,28 +54,16 @@ def __init__( self.std = std self.codes = codes self.target_codes = target_codes - self.num_classes = num_classes self.is_regression = is_regression - def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: - return { - "cat_cols": self.cat_cols, - "num_cols": self.num_cols, - "target_col": self.target_col, - "mean": self.mean, - "std": self.std, - "codes": self.codes, - "target_codes": self.target_codes, - "classes": self.num_classes, - "num_classes": self.num_classes, - "is_regression": self.is_regression, - } + self.set_state(LabelsState(classes)) + self.num_classes = len(classes) - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> 'Preprocess': - return cls(**state_dict) - - def common_load_data(self, df: DataFrame, dataset: AutoDataset): + def common_load_data( + self, + df: DataFrame, + dataset: Optional[Any] = None, + ): # impute_data # compute train dataset stats dfs = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, @@ -87,181 +71,126 @@ def common_load_data(self, df: DataFrame, dataset: AutoDataset): df = dfs[0] - dataset.num_samples = len(df) + if dataset is not None: + dataset.num_samples = len(df) + cat_vars = _to_cat_vars_numpy(df, self.cat_cols) num_vars = _to_num_vars_numpy(df, self.num_cols) - cat_vars = np.stack(cat_vars, 1) if len(cat_vars) else np.zeros((len(self), 0)) - num_vars = np.stack(num_vars, 1) if len(num_vars) else np.zeros((len(self), 0)) + cat_vars = np.stack(cat_vars, 1) # if len(cat_vars) else np.zeros((len(self), 0)) + num_vars = np.stack(num_vars, 1) # if len(num_vars) else np.zeros((len(self), 0)) return df, cat_vars, num_vars - def load_data(self, df: DataFrame, dataset: AutoDataset): - df, cat_vars, num_vars = self.common_load_data(df, dataset) + def load_data(self, data: DataFrame, dataset: Optional[Any] = None): + df, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) target = df[self.target_col].to_numpy().astype(np.float32 if self.is_regression else np.int64) - return [((c, n), t) for c, n, t in zip(cat_vars, num_vars, target)] + return [{ + DefaultDataKeys.INPUT: (c, n), + DefaultDataKeys.TARGET: t + } for c, n, t in zip(cat_vars, num_vars, target)] - def predict_load_data(self, sample: Union[str, DataFrame], dataset: AutoDataset): - df = pd.read_csv(sample) if isinstance(sample, str) else sample - _, cat_vars, num_vars = self.common_load_data(df, dataset) - return list(zip(cat_vars, num_vars)) + def predict_load_data(self, data: DataFrame, dataset: Optional[Any] = None): + _, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) + return [{DefaultDataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)] - @classmethod - def from_data( - cls, - train_df: DataFrame, - val_df: Optional[DataFrame], - test_df: Optional[DataFrame], - predict_df: Optional[DataFrame], - target_col: str, - num_cols: List[str], - cat_cols: List[str], - is_regression: bool, - ) -> 'TabularPreprocess': - if train_df is None: - raise MisconfigurationException("train_df is required to instantiate the TabularPreprocess") +class TabularCSVDataSource(TabularDataFrameDataSource): - dfs = [train_df] + def load_data(self, data: str, dataset: Optional[Any] = None): + return super().load_data(pd.read_csv(data), dataset=dataset) - if val_df is not None: - dfs += [val_df] + def predict_load_data(self, data: str, dataset: Optional[Any] = None): + return super().predict_load_data(pd.read_csv(data), dataset=dataset) - if test_df is not None: - dfs += [test_df] - if predict_df is not None: - dfs += [predict_df] +class TabularPreprocess(Preprocess): - mean, std = _compute_normalization(dfs[0], num_cols) - classes = list(dfs[0][target_col].unique()) - num_classes = len(classes) - if dfs[0][target_col].dtype == object: - # if the target_col is a category, not an int - target_codes = _generate_codes(dfs, [target_col]) - else: - target_codes = None - codes = _generate_codes(dfs, cat_cols) + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + cat_cols: Optional[List[str]] = None, + num_cols: Optional[List[str]] = None, + target_col: Optional[str] = None, + mean: Optional[DataFrame] = None, + std: Optional[DataFrame] = None, + codes: Optional[Dict[str, Any]] = None, + target_codes: Optional[Dict[str, Any]] = None, + classes: Optional[List[str]] = None, + is_regression: bool = True, + ): + self.cat_cols = cat_cols + self.num_cols = num_cols + self.target_col = target_col + self.mean = mean + self.std = std + self.codes = codes + self.target_codes = target_codes + self.classes = classes + self.is_regression = is_regression - return cls( - cat_cols, - num_cols, - target_col, - mean, - std, - codes, - target_codes, - classes, - num_classes, - is_regression, + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.CSV: TabularCSVDataSource( + cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression + ), + "data_frame": TabularDataFrameDataSource( + cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression + ), + }, + default_data_source=DefaultDataSources.CSV, ) + def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: + return { + **self.transforms, + "cat_cols": self.cat_cols, + "num_cols": self.num_cols, + "target_col": self.target_col, + "mean": self.mean, + "std": self.std, + "codes": self.codes, + "target_codes": self.target_codes, + "classes": self.classes, + "is_regression": self.is_regression, + } + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> 'Preprocess': + return cls(**state_dict) + class TabularData(DataModule): """Data module for tabular tasks""" preprocess_cls = TabularPreprocess - def __init__( - self, - train_dataset: Optional[Dataset] = None, - val_dataset: Optional[Dataset] = None, - test_dataset: Optional[Dataset] = None, - predict_dataset: Optional[Dataset] = None, - batch_size: int = 1, - num_workers: Optional[int] = 0, - ) -> None: - super().__init__( - train_dataset, - val_dataset, - test_dataset, - predict_dataset, - batch_size=batch_size, - num_workers=num_workers, - ) - - self._preprocess: Optional[Preprocess] = None - @property def codes(self) -> Dict[str, str]: - return self._preprocess.codes + return self._data_source.codes @property def num_classes(self) -> int: - return self._preprocess.num_classes + return self._data_source.num_classes @property def cat_cols(self) -> Optional[List[str]]: - return self._preprocess.cat_cols + return self._data_source.cat_cols @property def num_cols(self) -> Optional[List[str]]: - return self._preprocess.num_cols + return self._data_source.num_cols @property def num_features(self) -> int: return len(self.cat_cols) + len(self.num_cols) - @classmethod - def from_csv( - cls, - target_col: str, - train_csv: Optional[str] = None, - categorical_cols: Optional[List] = None, - numerical_cols: Optional[List] = None, - val_csv: Optional[str] = None, - test_csv: Optional[str] = None, - predict_csv: Optional[str] = None, - batch_size: int = 8, - num_workers: Optional[int] = None, - val_size: Optional[float] = None, - test_size: Optional[float] = None, - preprocess: Optional[Preprocess] = None, - **pandas_kwargs, - ): - """Creates a TextClassificationData object from pandas DataFrames. - - Args: - train_csv: Train data csv file. - target_col: The column containing the class id. - categorical_cols: The list of categorical columns. - numerical_cols: The list of numerical columns. - val_csv: Validation data csv file. - test_csv: Test data csv file. - batch_size: The batchsize to use for parallel loading. Defaults to 64. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - val_size: Float between 0 and 1 to create a validation dataset from train dataset. - test_size: Float between 0 and 1 to create a test dataset from train validation. - preprocess: Preprocess to be used within this DataModule DataPipeline. - - Returns: - TabularData: The constructed data module. - - Examples:: - - text_data = TabularData.from_files("train.csv", label_field="class", text_field="sentence") - """ - train_df = pd.read_csv(train_csv, **pandas_kwargs) - val_df = pd.read_csv(val_csv, **pandas_kwargs) if val_csv else None - test_df = pd.read_csv(test_csv, **pandas_kwargs) if test_csv else None - predict_df = pd.read_csv(predict_csv, **pandas_kwargs) if predict_csv else None - - return cls.from_df( - train_df, - target_col, - categorical_cols, - numerical_cols, - val_df, - test_df, - predict_df, - batch_size, - num_workers, - val_size, - test_size, - preprocess=preprocess, - ) - @property def emb_sizes(self) -> list: """Recommended embedding sizes.""" @@ -273,25 +202,6 @@ def emb_sizes(self) -> list: emb_dims = [max(int(n**0.25), 16) for n in num_classes] return list(zip(num_classes, emb_dims)) - @staticmethod - def _split_dataframe( - train_df: DataFrame, - val_df: Optional[DataFrame] = None, - test_df: Optional[DataFrame] = None, - val_size: float = None, - test_size: float = None, - ): - if val_df is None and isinstance(val_size, float) and isinstance(test_size, float): - assert 0 < val_size < 1 - assert 0 < test_size < 1 - train_df, val_df = train_test_split(train_df, test_size=(val_size + test_size)) - - if test_df is None and isinstance(test_size, float): - assert 0 < test_size < 1 - val_df, test_df = train_test_split(val_df, test_size=test_size) - - return train_df, val_df, test_df - @staticmethod def _sanetize_cols(cat_cols: Optional[List], num_cols: Optional[List]): if cat_cols is None and num_cols is None: @@ -300,21 +210,60 @@ def _sanetize_cols(cat_cols: Optional[List], num_cols: Optional[List]): return cat_cols or [], num_cols or [] @classmethod - def from_df( + def compute_state( cls, - train_df: DataFrame, + train_data_frame: DataFrame, + val_data_frame: Optional[DataFrame], + test_data_frame: Optional[DataFrame], + predict_data_frame: Optional[DataFrame], target_col: str, - categorical_cols: Optional[List] = None, - numerical_cols: Optional[List] = None, - val_df: Optional[DataFrame] = None, - test_df: Optional[DataFrame] = None, - predict_df: Optional[DataFrame] = None, - batch_size: int = 8, - num_workers: Optional[int] = None, - val_size: float = None, - test_size: float = None, + num_cols: List[str], + cat_cols: List[str], + ) -> Tuple[float, float, List[str], Dict[str, Any], Dict[str, Any]]: + + if train_data_frame is None: + raise MisconfigurationException( + "train_data_frame is required to instantiate the TabularDataFrameDataSource" + ) + + data_frames = [train_data_frame] + + if val_data_frame is not None: + data_frames += [val_data_frame] + + if test_data_frame is not None: + data_frames += [test_data_frame] + + if predict_data_frame is not None: + data_frames += [predict_data_frame] + + mean, std = _compute_normalization(data_frames[0], num_cols) + classes = list(data_frames[0][target_col].unique()) + + if data_frames[0][target_col].dtype == object: + # if the target_col is a category, not an int + target_codes = _generate_codes(data_frames, [target_col]) + else: + target_codes = None + codes = _generate_codes(data_frames, cat_cols) + + return mean, std, classes, codes, target_codes + + @classmethod + def from_data_frame( + cls, + categorical_cols: List, + numerical_cols: List, + target_col: str, + train_data_frame: DataFrame, + val_data_frame: Optional[DataFrame] = None, + test_data_frame: Optional[DataFrame] = None, + predict_data_frame: Optional[DataFrame] = None, is_regression: bool = False, preprocess: Optional[Preprocess] = None, + val_split: float = None, + batch_size: int = 8, + num_workers: Optional[int] = None, ): """Creates a TabularData object from pandas DataFrames. @@ -329,8 +278,7 @@ def from_df( num_workers: The number of workers to use for parallelized loading. Defaults to None which equals the number of available CPU threads, or 0 for Darwin platform. - val_size: Float between 0 and 1 to create a validation dataset from train dataset. - test_size: Float between 0 and 1 to create a test dataset from train validation. + val_split: Float between 0 and 1 to create a validation dataset from train dataset. preprocess: Preprocess to be used within this DataModule DataPipeline. Returns: @@ -342,25 +290,64 @@ def from_df( """ categorical_cols, numerical_cols = cls._sanetize_cols(categorical_cols, numerical_cols) - train_df, val_df, test_df = cls._split_dataframe(train_df, val_df, test_df, val_size, test_size) - - preprocess = preprocess or cls.preprocess_cls.from_data( - train_df, - val_df, - test_df, - predict_df, + mean, std, classes, codes, target_codes = cls.compute_state( + train_data_frame, + val_data_frame, + test_data_frame, + predict_data_frame, target_col, numerical_cols, categorical_cols, - is_regression, ) - return cls.from_load_data_inputs( - train_load_data_input=train_df, - val_load_data_input=val_df, - test_load_data_input=test_df, - predict_load_data_input=predict_df, + return cls.from_data_source( + data_source="data_frame", + train_data=train_data_frame, + val_data=val_data_frame, + test_data=test_data_frame, + predict_data=predict_data_frame, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + cat_cols=categorical_cols, + num_cols=numerical_cols, + target_col=target_col, + mean=mean, + std=std, + codes=codes, + target_codes=target_codes, + classes=classes, + is_regression=is_regression, + ) + + @classmethod + def from_csv( + cls, + categorical_fields: Union[str, List[str]], + numerical_fields: Union[str, List[str]], + target_field: Optional[str] = None, + train_file: Optional[str] = None, + val_file: Optional[str] = None, + test_file: Optional[str] = None, + predict_file: Optional[str] = None, + is_regression: bool = False, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + ) -> 'DataModule': + return cls.from_data_frame( + categorical_fields, + numerical_fields, + target_field, + train_data_frame=pd.read_csv(train_file) if train_file is not None else None, + val_data_frame=pd.read_csv(val_file) if val_file is not None else None, + test_data_frame=pd.read_csv(test_file) if test_file is not None else None, + predict_data_frame=pd.read_csv(predict_file) if predict_file is not None else None, + is_regression=is_regression, + preprocess=preprocess, + val_split=val_split, batch_size=batch_size, num_workers=num_workers, - preprocess=preprocess ) diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 7e399aaaa2..6fde330784 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -18,6 +18,7 @@ from torchmetrics import Metric from flash.core.classification import ClassificationTask +from flash.data.data_source import DefaultDataKeys from flash.data.process import Serializer from flash.utils.imports import _TABNET_AVAILABLE @@ -80,9 +81,22 @@ def __init__( def forward(self, x_in) -> torch.Tensor: # TabNet takes single input, x_in is composed of (categorical, numerical) x = torch.cat([x for x in x_in if x.numel()], dim=1) - return F.softmax(self.model(x)[0], -1) + return self.model(x)[0] + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch = (batch[DefaultDataKeys.INPUT]) return self(batch) @classmethod diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 7982ab7af0..7f867fb76c 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -13,138 +13,62 @@ # limitations under the License. import os from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union from datasets import DatasetDict, load_dataset -from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor from transformers import AutoTokenizer, default_data_collator from transformers.modeling_outputs import SequenceClassifierOutput -from flash.core.classification import ClassificationState from flash.data.auto_dataset import AutoDataset from flash.data.data_module import DataModule +from flash.data.data_source import DataSource, DefaultDataSources, LabelsState from flash.data.process import Postprocess, Preprocess -class TextClassificationPreprocess(Preprocess): - - def __init__( - self, - input: str, - backbone: str, - max_length: int, - target: str, - filetype: str, - train_file: Optional[str] = None, - label_to_class_mapping: Optional[Dict[str, int]] = None, - ): - """ - This class contains the preprocessing logic for text classification - - Args: - # tokenizer: Hugging Face Tokenizer. # TODO: Add back a tokenizer argument and make backbone optional? - input: The field storing the text to be classified. - max_length: Maximum number of tokens within a single sentence. - target: The field storing the class id of the associated text. - filetype: .csv or .json format type. - label_to_class_mapping: Dictionary mapping target labels to class indexes. - """ +class TextDataSource(DataSource): + def __init__(self, backbone: str, max_length: int = 128): super().__init__() - if label_to_class_mapping is None: - if train_file is not None: - label_to_class_mapping = self.get_label_to_class_mapping(train_file, target, filetype) - else: - raise MisconfigurationException( - "Either ``label_to_class_mapping`` or ``train_file`` needs to be provided" - ) - - self.backbone = backbone self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) - self.input = input - self.filetype = filetype self.max_length = max_length - self.label_to_class_mapping = label_to_class_mapping - self.target = target - - self._tokenize_fn = partial( - self._tokenize_fn, - tokenizer=self.tokenizer, - input=self.input, - max_length=self.max_length, - truncation=True, - padding="max_length" - ) - - class_to_label_mapping = ['CLASS_UNKNOWN'] * (max(self.label_to_class_mapping.values()) + 1) - for label, cls in self.label_to_class_mapping.items(): - class_to_label_mapping[cls] = label - self.set_state(ClassificationState(class_to_label_mapping)) - - def get_state_dict(self) -> Dict[str, Any]: - return { - "input": self.input, - "backbone": self.backbone, - "max_length": self.max_length, - "target": self.target, - "filetype": self.filetype, - "label_to_class_mapping": self.label_to_class_mapping, - } - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): - return cls(**state_dict) - def per_batch_transform(self, batch: Any) -> Any: - if "labels" not in batch: - # todo: understand why an extra dimension has been added. - if batch["input_ids"].dim() == 3: - batch["input_ids"] = batch["input_ids"].squeeze(0) - return batch - - @staticmethod def _tokenize_fn( + self, ex: Union[Dict[str, str], str], - tokenizer=None, - input: str = None, - max_length: int = None, - **kwargs + input: Optional[str] = None, ) -> Callable: """This function is used to tokenize sentences using the provided tokenizer.""" if isinstance(ex, dict): ex = ex[input] - return tokenizer(ex, max_length=max_length, **kwargs) + return self.tokenizer(ex, max_length=self.max_length, truncation=True, padding="max_length") - def collate(self, samples: Any) -> Tensor: - """Override to convert a set of samples to a batch""" - if isinstance(samples, dict): - samples = [samples] - return default_data_collator(samples) - - def _transform_label(self, ex: Dict[str, str]): - ex[self.target] = self.label_to_class_mapping[ex[self.target]] + def _transform_label(self, label_to_class_mapping: Dict[str, int], target: str, ex: Dict[str, Union[int, str]]): + ex[target] = label_to_class_mapping[ex[target]] return ex - @staticmethod - def get_label_to_class_mapping(file: str, target: str, filetype: str) -> Dict[str, int]: - data_files = {'train': file} - dataset_dict = load_dataset(filetype, data_files=data_files) - label_to_class_mapping = {v: k for k, v in enumerate(list(sorted(list(set(dataset_dict['train'][target])))))} - return label_to_class_mapping + +class TextFileDataSource(TextDataSource): + + def __init__(self, filetype: str, backbone: str, max_length: int = 128): + super().__init__(backbone, max_length=max_length) + + self.filetype = filetype def load_data( self, - filepath: str, - dataset: AutoDataset, + data: Tuple[str, Union[str, List[str]], Union[str, List[str]]], + dataset: Optional[Any] = None, columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"), - use_full: bool = True - ): + use_full: bool = True, + ) -> Union[Sequence[Mapping[str, Any]]]: + csv_file, input, target = data + data_files = {} - stage = dataset.running_stage.value - data_files[stage] = str(filepath) + stage = self.running_stage.value + data_files[stage] = str(csv_file) # FLASH_TESTING is set in the CI to run faster. if use_full and os.getenv("FLASH_TESTING", "0") == "0": @@ -155,37 +79,112 @@ def load_data( stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] }) - dataset_dict = dataset_dict.map(self._tokenize_fn, batched=True) + if self.training: + labels = list(sorted(list(set(dataset_dict[stage][target])))) + dataset.num_classes = len(labels) + self.set_state(LabelsState(labels)) + + labels = self.get_state(LabelsState) # convert labels to ids - if not self.predicting: - dataset_dict = dataset_dict.map(self._transform_label) + # if not self.predicting: + if labels is not None: + labels = labels.labels + label_to_class_mapping = {v: k for k, v in enumerate(labels)} + dataset_dict = dataset_dict.map(partial(self._transform_label, label_to_class_mapping, target)) - dataset_dict = dataset_dict.map(self._tokenize_fn, batched=True) + dataset_dict = dataset_dict.map(partial(self._tokenize_fn, input=input), batched=True) # Hugging Face models expect target to be named ``labels``. - if not self.predicting and self.target != "labels": - dataset_dict.rename_column_(self.target, "labels") + if not self.predicting and target != "labels": + dataset_dict.rename_column_(target, "labels") dataset_dict.set_format("torch", columns=columns) - if not self.predicting: - dataset.num_classes = len(self.label_to_class_mapping) - return dataset_dict[stage] - def predict_load_data(self, sample: Any, dataset: AutoDataset): - if isinstance(sample, str) and os.path.isfile(sample) and sample.endswith(".csv"): - return self.load_data(sample, dataset, columns=["input_ids", "attention_mask"]) - else: - if isinstance(sample, str): - sample = [sample] + def predict_load_data(self, data: Any, dataset: AutoDataset): + return self.load_data(data, dataset, columns=["input_ids", "attention_mask"]) + + +class TextCSVDataSource(TextFileDataSource): + + def __init__(self, backbone: str, max_length: int = 128): + super().__init__("csv", backbone, max_length=max_length) + + +class TextJSONDataSource(TextFileDataSource): + + def __init__(self, backbone: str, max_length: int = 128): + super().__init__("json", backbone, max_length=max_length) + + +class TextSentencesDataSource(TextDataSource): - if isinstance(sample, list) and all(isinstance(s, str) for s in sample): - return [self._tokenize_fn(s) for s in sample] + def __init__(self, backbone: str, max_length: int = 128): + super().__init__(backbone, max_length=max_length) - else: - raise MisconfigurationException("Currently, we support only list of sentences") + def load_data( + self, + data: Union[str, List[str]], + dataset: Optional[Any] = None, + ) -> Union[Sequence[Mapping[str, Any]]]: + + if isinstance(data, str): + data = [data] + return [self._tokenize_fn(s, ) for s in data] + + +class TextClassificationPreprocess(Preprocess): + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + backbone: str = "prajjwal1/bert-tiny", + max_length: int = 128, + ): + self.backbone = backbone + self.max_length = max_length + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.CSV: TextCSVDataSource(self.backbone, max_length=max_length), + DefaultDataSources.JSON: TextJSONDataSource(self.backbone, max_length=max_length), + "sentences": TextSentencesDataSource(self.backbone, max_length=max_length), + }, + default_data_source="sentences", + ) + + def get_state_dict(self) -> Dict[str, Any]: + return { + **self.transforms, + "backbone": self.backbone, + "max_length": self.max_length, + } + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): + return cls(**state_dict) + + def per_batch_transform(self, batch: Any) -> Any: + if "labels" not in batch: + # todo: understand why an extra dimension has been added. + if batch["input_ids"].dim() == 3: + batch["input_ids"] = batch["input_ids"].squeeze(0) + return batch + + def collate(self, samples: Any) -> Tensor: + """Override to convert a set of samples to a batch""" + if isinstance(samples, dict): + samples = [samples] + return default_data_collator(samples) class TextClassificationPostProcess(Postprocess): @@ -201,118 +200,3 @@ class TextClassificationData(DataModule): preprocess_cls = TextClassificationPreprocess postprocess_cls = TextClassificationPostProcess - - @property - def num_classes(self) -> int: - return len(self._preprocess.label_to_class_mapping) - - @classmethod - def from_files( - cls, - train_file: Optional[str], - input: Optional[str] = 'input', - target: Optional[str] = 'labels', - filetype: str = "csv", - backbone: str = "prajjwal1/bert-tiny", - val_file: Optional[str] = None, - test_file: Optional[str] = None, - predict_file: Optional[str] = None, - max_length: int = 128, - label_to_class_mapping: Optional[dict] = None, - batch_size: int = 16, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, - ) -> 'TextClassificationData': - """Creates a TextClassificationData object from files. - - Args: - train_file: Path to training data. - input: The field storing the text to be classified. - target: The field storing the class id of the associated text. - filetype: .csv or .json - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - val_file: Path to validation data. - test_file: Path to test data. - batch_size: the batchsize to use for parallel loading. Defaults to 64. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - - Returns: - TextClassificationData: The constructed data module. - - Examples:: - - train_df = pd.read_csv("train_data.csv") - tab_data = TabularData.from_df(train_df, target="fraud", - num_cols=["account_value"], - cat_cols=["account_type"]) - - """ - preprocess = preprocess or cls.preprocess_cls( - input, - backbone, - max_length, - target, - filetype, - train_file, - label_to_class_mapping, - ) - - postprocess = postprocess or cls.postprocess_cls() - - return cls.from_load_data_inputs( - train_load_data_input=train_file, - val_load_data_input=val_file, - test_load_data_input=test_file, - predict_load_data_input=predict_file, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, - ) - - @classmethod - def from_file( - cls, - predict_file: str, - input: str, - backbone="bert-base-cased", - filetype="csv", - max_length: int = 128, - label_to_class_mapping: Optional[dict] = None, - batch_size: int = 16, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, - ) -> 'TextClassificationData': - """Creates a TextClassificationData object from files. - - Args: - - predict_file: Path to training data. - input: The field storing the text to be classified. - filetype: .csv or .json - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - batch_size: the batchsize to use for parallel loading. Defaults to 64. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - """ - return cls.from_files( - None, - input=input, - target=None, - filetype=filetype, - backbone=backbone, - val_file=None, - test_file=None, - predict_file=predict_file, - max_length=max_length, - label_to_class_mapping=label_to_class_mapping, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, - ) diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index f317c4fade..f7968ee4a7 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -18,90 +18,73 @@ import datasets import torch from datasets import DatasetDict, load_dataset -from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor from transformers import AutoTokenizer, default_data_collator from flash.data.data_module import DataModule -from flash.data.process import Postprocess, Preprocess +from flash.data.data_source import DataSource, DefaultDataSources +from flash.data.process import Preprocess -class Seq2SeqPreprocess(Preprocess): +class Seq2SeqDataSource(DataSource): def __init__( self, backbone: str, - input: str, - filetype: str, - target: Optional[str] = None, max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'longest', - use_fast: bool = True, + padding: Union[str, bool] = 'max_length' ): super().__init__() - self.backbone = backbone - self.use_fast = use_fast - self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=use_fast) - self.input = input - self.filetype = filetype - self.target = target - self.max_target_length = max_target_length + + self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) self.max_source_length = max_source_length self.max_target_length = max_target_length self.padding = padding - self._tokenize_fn_wrapped = partial( - self._tokenize_fn, - tokenizer=self.tokenizer, - input=self.input, - target=self.target, - max_source_length=self.max_source_length, + def _tokenize_fn( + self, + ex: Union[Dict[str, str], str], + input: Optional[str] = None, + target: Optional[str] = None, + ) -> Callable: + if isinstance(ex, dict): + ex_input = ex[input] + ex_target = ex[target] if target else None + else: + ex_input = ex + ex_target = None + + return self.tokenizer.prepare_seq2seq_batch( + src_texts=ex_input, + tgt_texts=ex_target, + max_length=self.max_source_length, max_target_length=self.max_target_length, - padding=self.padding + padding=self.padding, ) - def get_state_dict(self) -> Dict[str, Any]: - return { - "backbone": self.backbone, - "use_fast": self.use_fast, - "input": self.input, - "filetype": self.filetype, - "target": self.target, - "max_source_length": self.max_source_length, - "max_target_length": self.max_target_length, - "padding": self.padding, - } - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): - return cls(**state_dict) +class Seq2SeqFileDataSource(Seq2SeqDataSource): - @staticmethod - def _tokenize_fn( - ex, - tokenizer, - input: str, - target: Optional[str], - max_source_length: int, - max_target_length: int, - padding: Union[str, bool], - ) -> Callable: - output = tokenizer.prepare_seq2seq_batch( - src_texts=ex[input], - tgt_texts=ex[target] if target else None, - max_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - ) - return output + def __init__( + self, + filetype: str, + backbone: str, + max_source_length: int = 128, + max_target_length: int = 128, + padding: Union[str, bool] = 'max_length', + ): + super().__init__(backbone, max_source_length, max_target_length, padding) + + self.filetype = filetype def load_data( self, - file: str, - use_full: bool = True, + data: Any, + use_full: bool = False, columns: List[str] = ["input_ids", "attention_mask", "labels"] ) -> 'datasets.Dataset': + file, input, target = data data_files = {} stage = self._running_stage.value data_files[stage] = str(file) @@ -118,155 +101,128 @@ def load_data( except AssertionError: dataset_dict = load_dataset(self.filetype, data_files=data_files) - dataset_dict = dataset_dict.map(self._tokenize_fn_wrapped, batched=True) + dataset_dict = dataset_dict.map(partial(self._tokenize_fn, input=input, target=target), batched=True) dataset_dict.set_format(columns=columns) return dataset_dict[stage] - def predict_load_data(self, sample: Any) -> Union['datasets.Dataset', List[Dict[str, torch.Tensor]]]: - if isinstance(sample, str) and os.path.isfile(sample) and sample.endswith(".csv"): - return self.load_data(sample, use_full=True, columns=["input_ids", "attention_mask"]) - else: - if isinstance(sample, (list, tuple)) and len(sample) > 0 and all(isinstance(s, str) for s in sample): - return [self._tokenize_fn_wrapped({self.input: s, self.target: None}) for s in sample] - else: - raise MisconfigurationException("Currently, we support only list of sentences") + def predict_load_data(self, data: Any) -> Union['datasets.Dataset', List[Dict[str, torch.Tensor]]]: + return self.load_data(data, use_full=False, columns=["input_ids", "attention_mask"]) - def collate(self, samples: Any) -> Tensor: - """Override to convert a set of samples to a batch""" - return default_data_collator(samples) - -class Seq2SeqPostprocess(Postprocess): +class Seq2SeqCSVDataSource(Seq2SeqFileDataSource): def __init__( self, backbone: str, - use_fast: bool = True, + max_source_length: int = 128, + max_target_length: int = 128, + padding: Union[str, bool] = 'max_length', ): - super().__init__() - self.backbone = backbone - self.use_fast = use_fast - self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=use_fast) - - def uncollate(self, generated_tokens: Any) -> Any: - pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) - pred_str = [str.strip(s) for s in pred_str] - return pred_str - + super().__init__( + "csv", + backbone, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + ) -class Seq2SeqData(DataModule): - """Data module for Seq2Seq tasks.""" - preprocess_cls = Seq2SeqPreprocess - postprocess_cls = Seq2SeqPostprocess +class Seq2SeqJSONDataSource(Seq2SeqFileDataSource): - @classmethod - def from_files( - cls, - train_file: Optional[str], - input: str = 'input', - target: Optional[str] = None, - filetype: str = "csv", - backbone: str = "sshleifer/tiny-mbart", - use_fast: bool = True, - val_file: Optional[str] = None, - test_file: Optional[str] = None, - predict_file: Optional[str] = None, + def __init__( + self, + backbone: str, max_source_length: int = 128, max_target_length: int = 128, padding: Union[str, bool] = 'max_length', - batch_size: int = 32, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, ): - """Creates a Seq2SeqData object from files. - Args: - train_file: Path to training data. - input: The field storing the source translation text. - target: The field storing the target translation text. - filetype: ``csv`` or ``json`` File - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - val_file: Path to validation data. - test_file: Path to test data. - max_source_length: Maximum length of the source text. Any text longer will be truncated. - max_target_length: Maximum length of the target text. Any text longer will be truncated. - padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: The batchsize to use for parallel loading. Defaults to 32. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - Returns: - Seq2SeqData: The constructed data module. - Examples:: - train_df = pd.read_csv("train_data.csv") - tab_data = TabularData.from_df(train_df, - target="fraud", - num_cols=["account_value"], - cat_cols=["account_type"]) - """ - preprocess = preprocess or cls.preprocess_cls( - backbone, input, filetype, target, max_source_length, max_target_length, padding, use_fast=use_fast + super().__init__( + "json", + backbone, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, ) - postprocess = postprocess or cls.postprocess_cls(backbone, use_fast=use_fast) - - return cls.from_load_data_inputs( - train_load_data_input=train_file, - val_load_data_input=val_file, - test_load_data_input=test_file, - predict_load_data_input=predict_file, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, - ) - @classmethod - def from_file( - cls, - predict_file: str, - input: str = 'input', - target: Optional[str] = None, +class Seq2SeqSentencesDataSource(Seq2SeqDataSource): + + def load_data( + self, + data: Union[str, List[str]], + dataset: Optional[Any] = None, + ) -> List[Any]: + + if isinstance(data, str): + data = [data] + return [self._tokenize_fn(s) for s in data] + + +class Seq2SeqPreprocess(Preprocess): + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, backbone: str = "sshleifer/tiny-mbart", - filetype: str = "csv", max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length', - batch_size: int = 32, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, + padding: Union[str, bool] = 'max_length' ): - """Creates a TextClassificationData object from files. - Args: - predict_file: Path to prediction input file. - input: The field storing the source translation text. - target: The field storing the target translation text. - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - filetype: Csv or json. - max_source_length: Maximum length of the source text. Any text longer will be truncated. - max_target_length: Maximum length of the target text. Any text longer will be truncated. - padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: The batchsize to use for parallel loading. Defaults to 32. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - Returns: - Seq2SeqData: The constructed data module. - """ - return cls.from_files( - train_file=None, - input=input, - target=target, - filetype=filetype, - backbone=backbone, - predict_file=predict_file, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, + self.backbone = backbone + self.max_target_length = max_target_length + self.max_source_length = max_source_length + self.padding = padding + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.CSV: Seq2SeqCSVDataSource( + self.backbone, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + ), + DefaultDataSources.JSON: Seq2SeqJSONDataSource( + self.backbone, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + ), + "sentences": Seq2SeqSentencesDataSource( + self.backbone, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + ), + }, + default_data_source="sentences", ) + + def get_state_dict(self) -> Dict[str, Any]: + return { + **self.transforms, + "backbone": self.backbone, + "max_source_length": self.max_source_length, + "max_target_length": self.max_target_length, + "padding": self.padding, + } + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): + return cls(**state_dict) + + def collate(self, samples: Any) -> Tensor: + """Override to convert a set of samples to a batch""" + return default_data_collator(samples) + + +class Seq2SeqData(DataModule): + """Data module for Seq2Seq tasks.""" + + preprocess_cls = Seq2SeqPreprocess diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index 8971584bde..3caec065ca 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -120,7 +120,7 @@ def _initialize_model_specific_parameters(self): @property def tokenizer(self) -> PreTrainedTokenizerBase: - return self.data_pipeline._preprocess_pipeline.tokenizer + return self.data_pipeline.data_source.tokenizer def tokenize_labels(self, labels: Tensor) -> List[str]: label_str = self.tokenizer.batch_decode(labels, skip_special_tokens=True) diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py index 1fab5a30a4..791c98a32f 100644 --- a/flash/text/seq2seq/summarization/data.py +++ b/flash/text/seq2seq/summarization/data.py @@ -11,130 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any -from flash.data.process import Postprocess, Preprocess -from flash.text.seq2seq.core.data import Seq2SeqData +from transformers import AutoTokenizer +from flash.data.process import Postprocess +from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPreprocess -class SummarizationData(Seq2SeqData): - - @classmethod - def from_files( - cls, - train_file: Optional[str] = None, - input: str = 'input', - target: Optional[str] = None, - filetype: str = "csv", - backbone: str = "t5-small", - use_fast: bool = True, - val_file: str = None, - test_file: str = None, - predict_file: str = None, - max_source_length: int = 512, - max_target_length: int = 128, - padding: Union[str, bool] = 'max_length', - batch_size: int = 16, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, - ): - """Creates a SummarizationData object from files. - - Args: - train_file: Path to training data. - input: The field storing the source translation text. - target: The field storing the target translation text. - filetype: .csv or .json - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - val_file: Path to validation data. - test_file: Path to test data. - max_source_length: Maximum length of the source text. Any text longer will be truncated. - max_target_length: Maximum length of the target text. Any text longer will be truncated. - padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: The batchsize to use for parallel loading. Defaults to 16. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - - Returns: - SummarizationData: The constructed data module. - Examples:: +class SummarizationPostprocess(Postprocess): - train_df = pd.read_csv("train_data.csv") - tab_data = TabularData.from_df(train_df, target="fraud", - num_cols=["account_value"], - cat_cols=["account_type"]) - - """ + def __init__( + self, + backbone: str = "sshleifer/tiny-mbart", + ): + super().__init__() - return super().from_files( - train_file=train_file, - input=input, - target=target, - filetype=filetype, - backbone=backbone, - use_fast=use_fast, - val_file=val_file, - test_file=test_file, - predict_file=predict_file, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, - ) + # TODO: Should share the backbone or tokenizer over state + self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) - @classmethod - def from_file( - cls, - predict_file: str, - input: str = 'src_text', - target: Optional[str] = None, - backbone: str = "t5-small", - filetype: str = "csv", - max_source_length: int = 512, - max_target_length: int = 128, - padding: Union[str, bool] = 'longest', - batch_size: int = 16, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, - ): - """Creates a SummarizationData object from files. + def uncollate(self, generated_tokens: Any) -> Any: + pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + pred_str = [str.strip(s) for s in pred_str] + return pred_str - Args: - predict_file: Path to prediction input file. - input: The field storing the source translation text. - target: The field storing the target translation text. - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - filetype: csv or json. - max_source_length: Maximum length of the source text. Any text longer will be truncated. - max_target_length: Maximum length of the target text. Any text longer will be truncated. - padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: The batchsize to use for parallel loading. Defaults to 16. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - Returns: - SummarizationData: The constructed data module. +class SummarizationData(Seq2SeqData): - """ - return super().from_file( - predict_file=predict_file, - input=input, - target=target, - backbone=backbone, - filetype=filetype, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, - ) + preprocess_cls = Seq2SeqPreprocess + postprocess_cls = SummarizationPostprocess diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index a3c9142bb5..04e763780b 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union +from typing import Callable, Dict, Mapping, Optional, Sequence, Type, Union import pytorch_lightning as pl import torch @@ -37,7 +37,7 @@ class SummarizationTask(Seq2SeqTask): def __init__( self, - backbone: str = "t5-small", + backbone: str = "sshleifer/tiny-mbart", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None, diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index ba724c0387..057ce41869 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -11,135 +11,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Callable, Dict, Optional, Union -from flash.data.process import Postprocess, Preprocess -from flash.text.seq2seq.core.data import Seq2SeqData +from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPreprocess -class TranslationData(Seq2SeqData): - """Data module for Translation tasks.""" +class TranslationPreprocess(Seq2SeqPreprocess): - @classmethod - def from_files( - cls, - train_file, - input: str = 'input', - target: Optional[str] = None, - filetype="csv", - backbone="Helsinki-NLP/opus-mt-en-ro", - use_fast: bool = True, - val_file=None, - test_file=None, - predict_file=None, + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + backbone: str = "t5-small", max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length', - batch_size: int = 8, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, + padding: Union[str, bool] = 'max_length' ): - """Creates a TranslateData object from files. - - Args: - train_file: Path to training data. - input: The field storing the source translation text. - target: The field storing the target translation text. - filetype: .csv or .json - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - val_file: Path to validation data. - test_file: Path to test data. - predict_file: Path to predict data. - max_source_length: Maximum length of the source text. Any text longer will be truncated. - max_target_length: Maximum length of the target text. Any text longer will be truncated. - padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: The batchsize to use for parallel loading. Defaults to 8. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - or 0 for Darwin platform. - - Returns: - TranslateData: The constructed data module. - - Examples:: - - datamodule = TranslationData.from_files( - train_file="data/wmt_en_ro/train.csv", - val_file="data/wmt_en_ro/valid.csv", - test_file="data/wmt_en_ro/test.csv", - input="input", - target="target", - batch_size=1, - ) - - """ - return super().from_files( - train_file=train_file, - val_file=val_file, - test_file=test_file, - predict_file=predict_file, - input=input, - target=target, + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, backbone=backbone, - use_fast=use_fast, - filetype=filetype, max_source_length=max_source_length, max_target_length=max_target_length, padding=padding, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, ) - @classmethod - def from_file( - cls, - predict_file: str, - input: str = 'input', - target: Optional[str] = None, - backbone="facebook/mbart-large-en-ro", - filetype="csv", - max_source_length: int = 128, - max_target_length: int = 128, - padding: Union[str, bool] = 'longest', - batch_size: int = 8, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - postprocess: Optional[Postprocess] = None, - ): - """Creates a TranslationData object from files. - Args: - predict_file: Path to prediction input file. - input: The field storing the source translation text. - target: The field storing the target translation text. - backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer. - filetype: csv or json. - max_source_length: Maximum length of the source text. Any text longer will be truncated. - max_target_length: Maximum length of the target text. Any text longer will be truncated. - padding: Padding strategy for batches. Default is pad to maximum length. - batch_size: The batchsize to use for parallel loading. Defaults to 8. - num_workers: The number of workers to use for parallelized loading. - Defaults to None which equals the number of available CPU threads, - - - Returns: - Seq2SeqData: The constructed data module. +class TranslationData(Seq2SeqData): + """Data module for Translation tasks.""" - """ - return super().from_file( - predict_file=predict_file, - input=input, - target=target, - backbone=backbone, - filetype=filetype, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - postprocess=postprocess, - ) + preprocess_cls = TranslationPreprocess diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index 1ae64d3e11..9eba02d753 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -37,7 +37,7 @@ class TranslationTask(Seq2SeqTask): def __init__( self, - backbone: str = "Helsinki-NLP/opus-mt-en-ro", + backbone: str = "t5-small", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None, diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index 3bac7e92ed..5aefd5d14a 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import pathlib from typing import Any, Callable, Dict, List, Optional, Type, Union @@ -19,18 +18,15 @@ import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import RandomSampler, Sampler -from torch.utils.data.dataset import IterableDataset -from flash.core.classification import ClassificationState from flash.data.data_module import DataModule +from flash.data.data_source import DefaultDataKeys, DefaultDataSources, LabelsState, PathsDataSource from flash.data.process import Preprocess from flash.utils.imports import _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE if _KORNIA_AVAILABLE: import kornia.augmentation as K - import kornia.geometry.transform as T -else: - from torchvision import transforms as T + if _PYTORCHVIDEO_AVAILABLE: from pytorchvideo.data.clip_sampling import ClipSampler, make_clip_sampler from pytorchvideo.data.encoded_video import EncodedVideo @@ -43,75 +39,24 @@ _PYTORCHVIDEO_DATA = Dict[str, Union[str, torch.Tensor, int, float, List]] -class VideoClassificationPreprocess(Preprocess): - - EXTENSIONS = ("mp4", "avi") - - @staticmethod - def default_predict_transform() -> Dict[str, 'Compose']: - return { - "post_tensor_transform": Compose([ - ApplyTransformToKey( - key="video", - transform=Compose([ - UniformTemporalSubsample(8), - RandomShortSideScale(min_size=256, max_size=320), - RandomCrop(244), - RandomHorizontalFlip(p=0.5), - ]), - ), - ]), - "per_batch_transform_on_device": Compose([ - ApplyTransformToKey( - key="video", - transform=K.VideoSequential( - K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), - data_format="BCTHW", - same_on_frame=False - ) - ), - ]), - } +class VideoClassificationPathsDataSource(PathsDataSource): def __init__( self, clip_sampler: 'ClipSampler', - video_sampler: Type[Sampler], - decode_audio: bool, - decoder: str, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, + decode_audio: bool = True, + decoder: str = "pyav", ): - # Make sure to provide your transform to the Preprocess Class - super().__init__( - train_transform, val_transform, test_transform, predict_transform or self.default_predict_transform() - ) + super().__init__(extensions=("mp4", "avi")) self.clip_sampler = clip_sampler self.video_sampler = video_sampler self.decode_audio = decode_audio self.decoder = decoder - def get_state_dict(self) -> Dict[str, Any]: - return { - 'clip_sampler': self.clip_sampler, - 'video_sampler': self.video_sampler, - 'decode_audio': self.decode_audio, - 'decoder': self.decoder, - 'train_transform': self._train_transform, - 'val_transform': self._val_transform, - 'test_transform': self._test_transform, - 'predict_transform': self._predict_transform, - } - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool) -> 'VideoClassificationPreprocess': - return cls(**state_dict) - - def load_data(self, data: Any, dataset: IterableDataset) -> 'EncodedVideoDataset': + def load_data(self, data: str, dataset: Optional[Any] = None) -> 'EncodedVideoDataset': ds: EncodedVideoDataset = labeled_encoded_video_dataset( - data, + pathlib.Path(data), self.clip_sampler, video_sampler=self.video_sampler, decode_audio=self.decode_audio, @@ -119,21 +64,10 @@ def load_data(self, data: Any, dataset: IterableDataset) -> 'EncodedVideoDataset ) if self.training: label_to_class_mapping = {p[1]: p[0].split("/")[-2] for p in ds._labeled_videos._paths_and_labels} - self.set_state(ClassificationState(label_to_class_mapping)) + self.set_state(LabelsState(label_to_class_mapping)) dataset.num_classes = len(np.unique([s[1]['label'] for s in ds._labeled_videos])) return ds - def predict_load_data(self, folder_or_file: Union[str, List[str]]) -> List[str]: - if isinstance(folder_or_file, list) and all(os.path.exists(p) for p in folder_or_file): - return folder_or_file - elif os.path.isdir(folder_or_file): - return [f for f in os.listdir(folder_or_file) if f.lower().endswith(self.EXTENSIONS)] - elif os.path.exists(folder_or_file) and folder_or_file.lower().endswith(self.EXTENSIONS): - return [folder_or_file] - raise MisconfigurationException( - f"The provided predict output should be a folder or a path. Found: {folder_or_file}" - ) - def _encoded_video_to_dict(self, video) -> Dict[str, Any]: ( clip_start, @@ -167,91 +101,32 @@ def _encoded_video_to_dict(self, video) -> Dict[str, Any]: } if audio_samples is not None else {}), } - def predict_load_sample(self, video_path: str) -> "EncodedVideo": - return self._encoded_video_to_dict(EncodedVideo.from_path(video_path)) - - def pre_tensor_transform(self, sample: _PYTORCHVIDEO_DATA) -> _PYTORCHVIDEO_DATA: - return self.current_transform(sample) - - def to_tensor_transform(self, sample: _PYTORCHVIDEO_DATA) -> _PYTORCHVIDEO_DATA: - return self.current_transform(sample) - - def post_tensor_transform(self, sample: _PYTORCHVIDEO_DATA) -> _PYTORCHVIDEO_DATA: - return self.current_transform(sample) - - def per_batch_transform(self, sample: _PYTORCHVIDEO_DATA) -> _PYTORCHVIDEO_DATA: - return self.current_transform(sample) - - def per_batch_transform_on_device(self, sample: _PYTORCHVIDEO_DATA) -> _PYTORCHVIDEO_DATA: - return self.current_transform(sample) + def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + return self._encoded_video_to_dict(EncodedVideo.from_path(sample[DefaultDataKeys.INPUT])) -class VideoClassificationData(DataModule): - """Data module for Video classification tasks.""" - - preprocess_cls = VideoClassificationPreprocess +class VideoClassificationPreprocess(Preprocess): - @classmethod - def from_paths( - cls, - train_data_path: Optional[Union[str, pathlib.Path]] = None, - val_data_path: Optional[Union[str, pathlib.Path]] = None, - test_data_path: Optional[Union[str, pathlib.Path]] = None, - predict_data_path: Union[str, pathlib.Path] = None, + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, clip_sampler: Union[str, 'ClipSampler'] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, - video_sampler: Type[Sampler] = RandomSampler, + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = True, decoder: str = "pyav", - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - batch_size: int = 4, - num_workers: Optional[int] = None, - preprocess: Optional[Preprocess] = None, - **kwargs, - ) -> 'DataModule': - """ - - Creates a VideoClassificationData object from folders of videos arranged in this way: :: - - train/class_x/xxx.ext - train/class_x/xxy.ext - train/class_x/xxz.ext - train/class_y/123.ext - train/class_y/nsdf3.ext - train/class_y/asd932_.ext - - Args: - train_data_path: Path to training folder. Default: None. - val_data_path: Path to validation folder. Default: None. - test_data_path: Path to test folder. Default: None. - predict_data_path: Path to predict folder. Default: None. - clip_sampler: ClipSampler to be used on videos. - clip_duration: Clip duration for the clip sampler. - clip_sampler_kwargs: Extra ClipSampler keyword arguments. - video_sampler: Sampler for the internal video container. - This defines the order videos are decoded and, if necessary, the distributed split. - decode_audio: Whether to decode the audio with the video clip. - decoder: Defines what type of decoder used to decode a video. - train_transform: Video clip dictionary transform to use for training set. - val_transform: Video clip dictionary transform to use for validation set. - test_transform: Video clip dictionary transform to use for test set. - predict_transform: Video clip dictionary transform to use for predict set. - batch_size: Batch size for data loading. - num_workers: The number of workers to use for parallelized loading. - Defaults to ``None`` which equals the number of available CPU threads. - preprocess: VideoClassifierPreprocess to handle the data processing. - - Returns: - VideoClassificationData: the constructed data module - - Examples: - >>> videos = VideoClassificationData.from_paths("train/") # doctest: +SKIP + ): + self.clip_sampler = clip_sampler + self.clip_duration = clip_duration + self.clip_sampler_kwargs = clip_sampler_kwargs + self.video_sampler = video_sampler + self.decode_audio = decode_audio + self.decoder = decoder - """ if not _PYTORCHVIDEO_AVAILABLE: raise ModuleNotFoundError("Please, run `pip install pytorchvideo`.") @@ -265,19 +140,65 @@ def from_paths( clip_sampler = make_clip_sampler(clip_sampler, clip_duration, **clip_sampler_kwargs) - preprocess: Preprocess = preprocess or cls.preprocess_cls( - clip_sampler, video_sampler, decode_audio, decoder, train_transform, val_transform, test_transform, - predict_transform + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.PATHS: VideoClassificationPathsDataSource( + clip_sampler, + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + ) + }, + default_data_source=DefaultDataSources.PATHS, ) - return cls.from_load_data_inputs( - train_load_data_input=train_data_path, - val_load_data_input=val_data_path, - test_load_data_input=test_data_path, - predict_load_data_input=predict_data_path, - batch_size=batch_size, - num_workers=num_workers, - preprocess=preprocess, - use_iterable_auto_dataset=True, - **kwargs, - ) + def get_state_dict(self) -> Dict[str, Any]: + return { + **self.transforms, + "clip_sampler": self.clip_sampler, + "clip_duration": self.clip_duration, + "clip_sampler_kwargs": self.clip_sampler_kwargs, + "video_sampler": self.video_sampler, + "decode_audio": self.decode_audio, + "decoder": self.decoder, + } + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool) -> 'VideoClassificationPreprocess': + return cls(**state_dict) + + @staticmethod + def default_predict_transform() -> Dict[str, 'Compose']: + return { + "post_tensor_transform": Compose([ + ApplyTransformToKey( + key="video", + transform=Compose([ + UniformTemporalSubsample(8), + RandomShortSideScale(min_size=256, max_size=320), + RandomCrop(244), + RandomHorizontalFlip(p=0.5), + ]), + ), + ]), + "per_batch_transform_on_device": Compose([ + ApplyTransformToKey( + key="video", + transform=K.VideoSequential( + K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), + data_format="BCTHW", + same_on_frame=False + ) + ), + ]), + } + + +class VideoClassificationData(DataModule): + """Data module for Video classification tasks.""" + + preprocess_cls = VideoClassificationPreprocess diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index ac492461de..928605b244 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -11,32 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import pathlib -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch -import torchvision from PIL import Image from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch import nn -from torch.utils.data import Dataset from torch.utils.data._utils.collate import default_collate -from torchvision import transforms as T -from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset -from flash.core.classification import ClassificationState -from flash.data.auto_dataset import AutoDataset from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule +from flash.data.data_source import DefaultDataKeys, DefaultDataSources from flash.data.process import Preprocess -from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE - -if _KORNIA_AVAILABLE: - import kornia as K +from flash.utils.imports import _MATPLOTLIB_AVAILABLE +from flash.vision.classification.transforms import default_train_transforms, default_val_transforms +from flash.vision.data import ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt @@ -46,285 +37,66 @@ class ImageClassificationPreprocess(Preprocess): - to_tensor = T.ToTensor() - def __init__( self, - train_transform: Optional[Union[Dict[str, Callable]]] = None, - val_transform: Optional[Union[Dict[str, Callable]]] = None, - test_transform: Optional[Union[Dict[str, Callable]]] = None, - predict_transform: Optional[Union[Dict[str, Callable]]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, image_size: Tuple[int, int] = (196, 196), ): - """ - Preprocess pipeline for image classification tasks. + self.image_size = image_size - Args: - train_transform: Dictionary with the set of transforms to apply during training. - val_transform: Dictionary with the set of transforms to apply during validation. - test_transform: Dictionary with the set of transforms to apply during testing. - predict_transform: Dictionary with the set of transforms to apply during prediction. - image_size: A tuple with the expected output image size. - """ - train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( - train_transform, val_transform, test_transform, predict_transform, image_size + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.PATHS: ImagePathsDataSource(), + DefaultDataSources.NUMPY: ImageNumpyDataSource(), + DefaultDataSources.TENSOR: ImageTensorDataSource(), + }, + default_data_source=DefaultDataSources.PATHS, ) - self.image_size = image_size - super().__init__(train_transform, val_transform, test_transform, predict_transform) def get_state_dict(self) -> Dict[str, Any]: - return { - "train_transform": self._train_transform, - "val_transform": self._val_transform, - "test_transform": self._test_transform, - "predict_transform": self._predict_transform, - "image_size": self.image_size - } + return {**self.transforms, "image_size": self.image_size} @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): return cls(**state_dict) - @staticmethod - def _find_classes(dir: str) -> Tuple: - """ - Finds the class folders in a dataset. - Args: - dir: Root directory path. - Returns: - tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. - Ensures: - No class is a subdirectory of another. - """ - classes = [d.name for d in os.scandir(dir) if d.is_dir()] - classes.sort() - class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} - return classes, class_to_idx - - @staticmethod - def _get_predicting_files(samples: Union[Sequence, str]) -> List[str]: - files = [] - if isinstance(samples, str): - samples = [samples] - - if isinstance(samples, (list, tuple)) and all(os.path.isdir(s) for s in samples): - files = [os.path.join(sp, f) for sp in samples for f in os.listdir(sp)] - - elif isinstance(samples, (list, tuple)) and all(os.path.isfile(s) for s in samples): - files = samples - - files = list(filter(lambda p: has_file_allowed_extension(p, IMG_EXTENSIONS), files)) - - return files - - def default_train_transforms(self, image_size: Tuple[int, int]) -> Dict[str, Callable]: - if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": - # Better approach as all transforms are applied on tensor directly - return { - "to_tensor_transform": torchvision.transforms.ToTensor(), - "post_tensor_transform": nn.Sequential( - # TODO (Edgar): replace with resize once kornia is fixed - K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)), - K.augmentation.RandomHorizontalFlip(), - ), - "per_batch_transform_on_device": nn.Sequential( - K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), - ) - } - else: - from torchvision import transforms as T # noqa F811 - return { - "pre_tensor_transform": nn.Sequential(T.Resize(image_size), T.RandomHorizontalFlip()), - "to_tensor_transform": torchvision.transforms.ToTensor(), - "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - } - - def default_val_transforms(self, image_size: Tuple[int, int]) -> Dict[str, Callable]: - if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": - # Better approach as all transforms are applied on tensor directly - return { - "to_tensor_transform": torchvision.transforms.ToTensor(), - "post_tensor_transform": nn.Sequential( - # TODO (Edgar): replace with resize once kornia is fixed - K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)), - ), - "per_batch_transform_on_device": nn.Sequential( - K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), - ) - } - else: - from torchvision import transforms as T # noqa F811 - return { - "pre_tensor_transform": T.Compose([T.Resize(image_size)]), - "to_tensor_transform": torchvision.transforms.ToTensor(), - "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - } - - def _resolve_transforms( - self, - train_transform: Optional[Union[str, Dict]] = 'default', - val_transform: Optional[Union[str, Dict]] = 'default', - test_transform: Optional[Union[str, Dict]] = 'default', - predict_transform: Optional[Union[str, Dict]] = 'default', - image_size: Tuple[int, int] = (196, 196), - ): - - if not train_transform or train_transform == 'default': - train_transform = self.default_train_transforms(image_size) - - if not val_transform or val_transform == 'default': - val_transform = self.default_val_transforms(image_size) - - if not test_transform or test_transform == 'default': - test_transform = self.default_val_transforms(image_size) - - if not predict_transform or predict_transform == 'default': - predict_transform = self.default_val_transforms(image_size) - - return ( - train_transform, - val_transform, - test_transform, - predict_transform, - ) - - @classmethod - def _load_data_dir( - cls, - data: Any, - dataset: Optional[AutoDataset] = None, - ) -> Tuple[Optional[List[str]], List[Tuple[str, int]]]: - if isinstance(data, list): - # TODO: define num_classes elsewhere. This is a bad assumption since the list of - # labels might not contain the complete set of ids so that you can infer the total - # number of classes to train in your dataset. - dataset.num_classes = len(data) - out: List[Tuple[str, int]] = [] - for p, label in data: - if os.path.isdir(p): - # TODO: there is an issue here when a path is provided along with labels. - # os.listdir cannot assure the same file order as the passed labels list. - files_list: List[str] = os.listdir(p) - if len(files_list) > 1: - raise ValueError( - f"The provided directory contains more than one file." - f"Directory: {p} -> Contains: {files_list}" - ) - for f in files_list: - if has_file_allowed_extension(f, IMG_EXTENSIONS): - out.append([os.path.join(p, f), label]) - elif os.path.isfile(p) and has_file_allowed_extension(str(p), IMG_EXTENSIONS): - out.append([p, label]) - else: - raise TypeError(f"Unexpected file path type: {p}.") - return None, out - else: - classes, class_to_idx = cls._find_classes(data) - # TODO: define num_classes elsewhere. This is a bad assumption since the list of - # labels might not contain the complete set of ids so that you can infer the total - # number of classes to train in your dataset. - dataset.num_classes = len(classes) - return classes, make_dataset(data, class_to_idx, IMG_EXTENSIONS, None) - - @classmethod - def _load_data_files_labels(cls, data: Any, dataset: Optional[AutoDataset] = None) -> Any: - _classes = [tmp[1] for tmp in data] - - _classes = torch.stack([ - torch.tensor(int(_cls)) if not isinstance(_cls, torch.Tensor) else _cls.view(-1) for _cls in _classes - ]).unique() - - dataset.num_classes = len(_classes) - - return data - - def load_data(self, data: Any, dataset: Optional[AutoDataset] = None) -> Iterable: - if isinstance(data, (str, pathlib.Path, list)): - classes, data = self._load_data_dir(data=data, dataset=dataset) - state = ClassificationState(classes) - self.set_state(state) - return data - return self._load_data_files_labels(data=data, dataset=dataset) - - @staticmethod - def load_sample(sample) -> Union[Image.Image, torch.Tensor, Tuple[Image.Image, torch.Tensor]]: - # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) - if isinstance(sample, torch.Tensor): - out: torch.Tensor = sample - return out - - path: str = "" - if isinstance(sample, (tuple, list)): - path = sample[0] - sample = list(sample) - else: - path = sample - - with open(path, "rb") as f, Image.open(f) as img: - img_out: Image.Image = img.convert("RGB") - - if isinstance(sample, list): - # return a tuple with the PIL image and tensor with the labels. - # returning the tensor helps later to easily collate the batch - # for single/multi label at the same time. - out: Tuple[Image.Image, torch.Tensor] = (img_out, torch.as_tensor(sample[1])) - return out - - return img_out - - @classmethod - def predict_load_data(cls, samples: Any) -> Iterable: - if isinstance(samples, torch.Tensor): - return samples - return cls._get_predicting_files(samples) - - def collate(self, samples: Sequence) -> Any: - _samples = [] + def collate(self, samples: Sequence[Dict[str, Any]]) -> Any: # todo: Kornia transforms add batch dimension which need to be removed for sample in samples: - if isinstance(sample, tuple): - sample = (sample[0].squeeze(0), ) + sample[1:] - else: - sample = sample.squeeze(0) - _samples.append(sample) - return default_collate(_samples) - - def common_step(self, sample: Any) -> Any: - if isinstance(sample, (list, tuple)): - source, target = sample - return self.current_transform(source), target - return self.current_transform(sample) - - def pre_tensor_transform(self, sample: Any) -> Any: - return self.common_step(sample) + for key in sample.keys(): + if torch.is_tensor(sample[key]): + sample[key] = sample[key].squeeze(0) + return default_collate(samples) - def to_tensor_transform(self, sample: Any) -> Any: - if self.current_transform == self._identity: - if isinstance(sample, (list, tuple)): - source, target = sample - if isinstance(source, torch.Tensor): - return source, target - return self.to_tensor(source), target - elif isinstance(sample, torch.Tensor): - return sample - return self.to_tensor(sample) - if isinstance(sample, torch.Tensor): - return sample - return self.common_step(sample) + @property + def default_train_transforms(self) -> Optional[Dict[str, Callable]]: + return default_train_transforms(self.image_size) - def post_tensor_transform(self, sample: Any) -> Any: - return self.common_step(sample) + @property + def default_val_transforms(self) -> Optional[Dict[str, Callable]]: + return default_val_transforms(self.image_size) - def per_batch_transform(self, sample: Any) -> Any: - return self.common_step(sample) + @property + def default_test_transforms(self) -> Optional[Dict[str, Callable]]: + return default_val_transforms(self.image_size) - def per_batch_transform_on_device(self, sample: Any) -> Any: - return self.common_step(sample) + @property + def default_predict_transforms(self) -> Optional[Dict[str, Callable]]: + return default_val_transforms(self.image_size) class ImageClassificationData(DataModule): """Data module for image classification tasks.""" + preprocess_cls = ImageClassificationPreprocess + def set_block_viz_window(self, value: bool) -> None: """Setter method to switch on/off matplotlib to pop up windows.""" self.data_fetcher.block_viz_window = value @@ -333,179 +105,6 @@ def set_block_viz_window(self, value: bool) -> None: def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: return MatplotlibVisualization(*args, **kwargs) - def _get_num_classes(self, dataset: torch.utils.data.Dataset): - num_classes = self.get_dataset_attribute(dataset, "num_classes", None) - if num_classes is None: - num_classes = torch.tensor([dataset[idx][1] for idx in range(len(dataset))]).unique().numel() - - return num_classes - - @classmethod - def from_folders( - cls, - train_folder: Optional[Union[str, pathlib.Path]] = None, - val_folder: Optional[Union[str, pathlib.Path]] = None, - test_folder: Optional[Union[str, pathlib.Path]] = None, - predict_folder: Union[str, pathlib.Path] = None, - train_transform: Optional[Union[str, Dict]] = 'default', - val_transform: Optional[Union[str, Dict]] = 'default', - test_transform: Optional[Union[str, Dict]] = 'default', - predict_transform: Optional[Union[str, Dict]] = 'default', - batch_size: int = 4, - num_workers: Optional[int] = None, - data_fetcher: BaseDataFetcher = None, - preprocess: Optional[Preprocess] = None, - **kwargs, - ) -> 'DataModule': - """ - Creates a ImageClassificationData object from folders of images arranged in this way: :: - - train/dog/xxx.png - train/dog/xxy.png - train/dog/xxz.png - train/cat/123.png - train/cat/nsdf3.png - train/cat/asd932.png - - Args: - train_folder: Path to training folder. Default: None. - val_folder: Path to validation folder. Default: None. - test_folder: Path to test folder. Default: None. - predict_folder: Path to predict folder. Default: None. - val_transform: Image transform to use for validation and test set. - train_transform: Image transform to use for training set. - val_transform: Image transform to use for validation set. - test_transform: Image transform to use for test set. - predict_transform: Image transform to use for predict set. - batch_size: Batch size for data loading. - num_workers: The number of workers to use for parallelized loading. - Defaults to ``None`` which equals the number of available CPU threads. - - Returns: - ImageClassificationData: the constructed data module - - Examples: - >>> img_data = ImageClassificationData.from_folders("train/") # doctest: +SKIP - - """ - preprocess = preprocess or ImageClassificationPreprocess( - train_transform, - val_transform, - test_transform, - predict_transform, - ) - - return cls.from_load_data_inputs( - train_load_data_input=train_folder, - val_load_data_input=val_folder, - test_load_data_input=test_folder, - predict_load_data_input=predict_folder, - batch_size=batch_size, - num_workers=num_workers, - data_fetcher=data_fetcher, - preprocess=preprocess, - **kwargs, - ) - - @classmethod - def from_filepaths( - cls, - train_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, - train_labels: Optional[Sequence] = None, - val_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, - val_labels: Optional[Sequence] = None, - test_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, - test_labels: Optional[Sequence] = None, - predict_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, - train_transform: Union[str, Dict] = 'default', - val_transform: Union[str, Dict] = 'default', - test_transform: Union[str, Dict] = 'default', - predict_transform: Union[str, Dict] = 'default', - image_size: Tuple[int, int] = (196, 196), - batch_size: int = 64, - num_workers: Optional[int] = None, - seed: Optional[int] = 42, - data_fetcher: BaseDataFetcher = None, - preprocess: Optional[Preprocess] = None, - val_split: Optional[float] = None, - **kwargs, - ) -> 'ImageClassificationData': - """ - Creates a ImageClassificationData object from folders of images arranged in this way: :: - - folder/dog_xxx.png - folder/dog_xxy.png - folder/dog_xxz.png - folder/cat_123.png - folder/cat_nsdf3.png - folder/cat_asd932_.png - - Args: - - train_filepaths: String or sequence of file paths for training dataset. Defaults to ``None``. - train_labels: Sequence of labels for training dataset. Defaults to ``None``. - val_filepaths: String or sequence of file paths for validation dataset. Defaults to ``None``. - val_labels: Sequence of labels for validation dataset. Defaults to ``None``. - test_filepaths: String or sequence of file paths for test dataset. Defaults to ``None``. - test_labels: Sequence of labels for test dataset. Defaults to ``None``. - train_transform: Image transform to use for the train set. Defaults to ``default``, which loads imagenet - transforms. - val_transform: Image transform to use for the validation set. Defaults to ``default``, which loads - imagenet transforms. - test_transform: Image transform to use for the test set. Defaults to ``default``, which loads imagenet - transforms. - predict_transform: Image transform to use for the predict set. Defaults to ``default``, which loads imagenet - transforms. - batch_size: The batchsize to use for parallel loading. Defaults to ``64``. - num_workers: The number of workers to use for parallelized loading. - Defaults to ``None`` which equals the number of available CPU threads. - seed: Used for the train/val splits. - - Returns: - - ImageClassificationData: The constructed data module. - """ - # enable passing in a string which loads all files in that folder as a list - if isinstance(train_filepaths, str): - if os.path.isdir(train_filepaths): - train_filepaths = [os.path.join(train_filepaths, x) for x in os.listdir(train_filepaths)] - else: - train_filepaths = [train_filepaths] - - if isinstance(val_filepaths, str): - if os.path.isdir(val_filepaths): - val_filepaths = [os.path.join(val_filepaths, x) for x in os.listdir(val_filepaths)] - else: - val_filepaths = [val_filepaths] - - if isinstance(test_filepaths, str): - if os.path.isdir(test_filepaths): - test_filepaths = [os.path.join(test_filepaths, x) for x in os.listdir(test_filepaths)] - else: - test_filepaths = [test_filepaths] - - preprocess = preprocess or ImageClassificationPreprocess( - train_transform, - val_transform, - test_transform, - predict_transform, - image_size=image_size, - ) - - return cls.from_load_data_inputs( - train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None, - val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None, - test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None, - predict_load_data_input=predict_filepaths, - batch_size=batch_size, - num_workers=num_workers, - data_fetcher=data_fetcher, - preprocess=preprocess, - seed=seed, - val_split=val_split, - **kwargs - ) - class MatplotlibVisualization(BaseVisualization): """Process and show the image batch and its associated label using matplotlib. @@ -539,10 +138,9 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) for i, ax in enumerate(axs.ravel()): # unpack images and labels if isinstance(data, list): - _img, _label = data[i] - elif isinstance(data, tuple): - imgs, labels = data - _img, _label = imgs[i], labels[i] + _img, _label = data[i][DefaultDataKeys.INPUT], data[i][DefaultDataKeys.TARGET] + elif isinstance(data, dict): + _img, _label = data[DefaultDataKeys.INPUT][i], data[DefaultDataKeys.TARGET][i] else: raise TypeError(f"Unknown data type. Got: {type(data)}.") # convert images to numpy @@ -573,4 +171,4 @@ def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningS def show_per_batch_transform(self, batch: List[Any], running_stage): win_title: str = f"{running_stage} - show_per_batch_transform" - self._show_images_and_labels(batch[0], batch[0][0].shape[0], win_title) + self._show_images_and_labels(batch[0], batch[0][DefaultDataKeys.INPUT].shape[0], win_title) diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 916c7c2d90..36fb1808fa 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -21,6 +21,7 @@ from flash.core.classification import ClassificationTask from flash.core.registry import FlashRegistry +from flash.data.data_source import DefaultDataKeys from flash.data.process import Serializer from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES @@ -106,6 +107,22 @@ def __init__( head = head(num_features, num_classes) if isinstance(head, FunctionType) else head self.head = head or nn.Sequential(nn.Linear(num_features, num_classes), ) + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch = (batch[DefaultDataKeys.INPUT]) + return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + def forward(self, x) -> torch.Tensor: x = self.backbone(x) if x.dim() == 4: diff --git a/flash/vision/classification/transforms.py b/flash/vision/classification/transforms.py new file mode 100644 index 0000000000..3eff2f4c2c --- /dev/null +++ b/flash/vision/classification/transforms.py @@ -0,0 +1,92 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Callable, Dict, Tuple + +import torch +import torchvision +from torch import nn +from torchvision import transforms as T + +from flash.data.data_source import DefaultDataKeys +from flash.data.transforms import ApplyToKeys +from flash.utils.imports import _KORNIA_AVAILABLE + +if _KORNIA_AVAILABLE: + import kornia as K + + +def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: + if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": + # Better approach as all transforms are applied on tensor directly + return { + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys( + DefaultDataKeys.INPUT, + # TODO (Edgar): replace with resize once kornia is fixed + K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)), + K.augmentation.RandomHorizontalFlip(), + ), + "per_batch_transform_on_device": ApplyToKeys( + DefaultDataKeys.INPUT, + K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), + ) + } + else: + return { + "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(image_size), T.RandomHorizontalFlip()), + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys( + DefaultDataKeys.INPUT, + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ), + } + + +def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: + if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": + # Better approach as all transforms are applied on tensor directly + return { + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys( + DefaultDataKeys.INPUT, + # TODO (Edgar): replace with resize once kornia is fixed + K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)), + ), + "per_batch_transform_on_device": ApplyToKeys( + DefaultDataKeys.INPUT, + K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), + ) + } + else: + return { + "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(image_size)), + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys( + DefaultDataKeys.INPUT, + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ), + } diff --git a/flash/vision/data.py b/flash/vision/data.py new file mode 100644 index 0000000000..056e856468 --- /dev/null +++ b/flash/vision/data.py @@ -0,0 +1,44 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import torch +from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS +from torchvision.transforms.functional import to_pil_image + +from flash.data.data_source import DefaultDataKeys, NumpyDataSource, PathsDataSource, TensorDataSource + + +class ImagePathsDataSource(PathsDataSource): + + def __init__(self): + super().__init__(extensions=IMG_EXTENSIONS) + + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: + sample[DefaultDataKeys.INPUT] = default_loader(sample[DefaultDataKeys.INPUT]) + return sample + + +class ImageTensorDataSource(TensorDataSource): + + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: + sample[DefaultDataKeys.INPUT] = to_pil_image(sample[DefaultDataKeys.INPUT]) + return sample + + +class ImageNumpyDataSource(NumpyDataSource): + + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: + sample[DefaultDataKeys.INPUT] = to_pil_image(torch.from_numpy(sample[DefaultDataKeys.INPUT])) + return sample diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py index 35905d683b..528a74a99d 100644 --- a/flash/vision/detection/data.py +++ b/flash/vision/detection/data.py @@ -12,168 +12,129 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union -import torch -from PIL import Image -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch import Tensor, tensor -from torch._six import container_abcs from torch.nn import Module -from torch.utils.data._utils.collate import default_collate -from torchvision import transforms as T +from torchvision.datasets.folder import default_loader -from flash.data.auto_dataset import AutoDataset from flash.data.data_module import DataModule -from flash.data.process import DefaultPreprocess, Preprocess -from flash.data.utils import _contains_any_tensor +from flash.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources +from flash.data.process import Preprocess from flash.utils.imports import _COCO_AVAILABLE -from flash.vision.utils import pil_loader +from flash.vision.data import ImagePathsDataSource +from flash.vision.detection.transforms import default_transforms if _COCO_AVAILABLE: from pycocotools.coco import COCO -class CustomCOCODataset(torch.utils.data.Dataset): +class COCODataSource(DataSource[Tuple[str, str]]): - def __init__( - self, - root: str, - ann_file: str, - transforms: Optional[Callable] = None, - loader: Optional[Callable] = pil_loader, - ): - if not _COCO_AVAILABLE: - raise ImportError("Kindly install the COCO API `pycocotools` to use the Dataset") + def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + root, ann_file = data - self.root = root - self.transforms = transforms - self.coco = COCO(ann_file) - self.ids = list(sorted(self.coco.imgs.keys())) - self.loader = loader + coco = COCO(ann_file) - @property - def num_classes(self) -> int: - categories = self.coco.loadCats(self.coco.getCatIds()) - if not categories: - raise ValueError("No Categories found") - return categories[-1]["id"] + 1 - - def __getitem__(self, index: int) -> Tuple[Any, Any]: - coco = self.coco - img_idx = self.ids[index] - - ann_ids = coco.getAnnIds(imgIds=img_idx) - annotations = coco.loadAnns(ann_ids) - - image_path = coco.loadImgs(img_idx)[0]["file_name"] - img = Image.open(os.path.join(self.root, image_path)) - - boxes = [] - labels = [] - areas = [] - iscrowd = [] - - for obj in annotations: - xmin = obj["bbox"][0] - ymin = obj["bbox"][1] - xmax = xmin + obj["bbox"][2] - ymax = ymin + obj["bbox"][3] - - bbox = [xmin, ymin, xmax, ymax] - keep = (bbox[3] > bbox[1]) & (bbox[2] > bbox[0]) - if keep: - boxes.append(bbox) - labels.append(obj["category_id"]) - areas.append(obj["area"]) - iscrowd.append(obj["iscrowd"]) - - target = dict( - boxes=torch.as_tensor(boxes, dtype=torch.float32), - labels=torch.as_tensor(labels, dtype=torch.int64), - image_id=tensor([img_idx]), - area=torch.as_tensor(areas, dtype=torch.float32), - iscrowd=torch.as_tensor(iscrowd, dtype=torch.int64) - ) + categories = coco.loadCats(coco.getCatIds()) + if categories: + dataset.num_classes = categories[-1]["id"] + 1 - if self.transforms: - img = self.transforms(img) + img_ids = list(sorted(coco.imgs.keys())) + paths = coco.loadImgs(img_ids) - return img, target + data = [] - def __len__(self) -> int: - return len(self.ids) + for img_id, path in zip(img_ids, paths): + path = path["file_name"] + ann_ids = coco.getAnnIds(imgIds=img_id) + annotations = coco.loadAnns(ann_ids) -def _coco_remove_images_without_annotations(dataset): - # Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py + boxes, labels, areas, iscrowd = [], [], [], [] - def _has_only_empty_bbox(annot: List): - return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in annot) + # Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py + if self.training and all(any(o <= 1 for o in obj["bbox"][2:]) for obj in annotations): + continue - def _has_valid_annotation(annot: List): - # if it's empty, there is no annotation - if not annot: - return False - # if all boxes have close to zero area, there is no annotation - if _has_only_empty_bbox(annot): - return False - return True + for obj in annotations: + xmin = obj["bbox"][0] + ymin = obj["bbox"][1] + xmax = xmin + obj["bbox"][2] + ymax = ymin + obj["bbox"][3] - ids = [] - for ds_idx, img_id in enumerate(dataset.ids): - ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) - anno = dataset.coco.loadAnns(ann_ids) - if _has_valid_annotation(anno): - ids.append(ds_idx) + bbox = [xmin, ymin, xmax, ymax] + keep = (bbox[3] > bbox[1]) & (bbox[2] > bbox[0]) + if keep: + boxes.append(bbox) + labels.append(obj["category_id"]) + areas.append(obj["area"]) + iscrowd.append(obj["iscrowd"]) - dataset = torch.utils.data.Subset(dataset, ids) - return dataset + data.append( + dict( + input=os.path.join(root, path), + target=dict( + boxes=boxes, + labels=labels, + image_id=img_id, + area=areas, + iscrowd=iscrowd, + ) + ) + ) + return data + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + sample[DefaultDataKeys.INPUT] = default_loader(sample[DefaultDataKeys.INPUT]) + return sample -class ObjectDetectionPreprocess(DefaultPreprocess): - to_tensor = T.ToTensor() +class ObjectDetectionPreprocess(Preprocess): - def load_data(self, metadata: Any, dataset: AutoDataset) -> CustomCOCODataset: - # Extract folder, coco annotation file and the transform to be applied on the images - folder, ann_file, transform = metadata - ds = CustomCOCODataset(folder, ann_file, transform) - if self.training: - dataset.num_classes = ds.num_classes - ds = _coco_remove_images_without_annotations(ds) - return ds + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + ): + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.PATHS: ImagePathsDataSource(), + "coco": COCODataSource(), + }, + default_data_source=DefaultDataSources.PATHS, + ) - def predict_load_data(self, samples): - return samples + def collate(self, samples: Any) -> Any: + return {key: [sample[key] for sample in samples] for key in samples[0]} - def pre_tensor_transform(self, samples: Any) -> Any: - if _contains_any_tensor(samples): - return samples + def get_state_dict(self) -> Dict[str, Any]: + return {**self.transforms} - if isinstance(samples, str): - samples = [samples] + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) - if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples): - outputs = [] - for sample in samples: - outputs.append(pil_loader(sample)) - return outputs - raise MisconfigurationException("The samples should either be a tensor, a list of paths or a path.") + @property + def default_train_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms() - def to_tensor_transform(self, sample) -> Any: - return self.to_tensor(sample[0]), sample[1] + @property + def default_val_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms() - def predict_to_tensor_transform(self, sample) -> Any: - return self.to_tensor(sample[0]) + @property + def default_test_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms() - def collate(self, samples: Any) -> Any: - if not isinstance(samples, Tensor): - elem = samples[0] - if isinstance(elem, container_abcs.Sequence): - return tuple(zip(*samples)) - return default_collate(samples) - return samples.unsqueeze(dim=0) + @property + def default_predict_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms() class ObjectDetectionData(DataModule): @@ -185,32 +146,28 @@ def from_coco( cls, train_folder: Optional[str] = None, train_ann_file: Optional[str] = None, - train_transform: Optional[Dict[str, Module]] = None, + train_transform: Optional[Dict[str, Callable]] = None, val_folder: Optional[str] = None, val_ann_file: Optional[str] = None, - val_transform: Optional[Dict[str, Module]] = None, + val_transform: Optional[Dict[str, Callable]] = None, test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, - test_transform: Optional[Dict[str, Module]] = None, - predict_transform: Optional[Dict[str, Module]] = None, + test_transform: Optional[Dict[str, Callable]] = None, batch_size: int = 4, num_workers: Optional[int] = None, preprocess: Preprocess = None, - **kwargs + val_split: Optional[float] = None, ): - preprocess = preprocess or cls.preprocess_cls( - train_transform, - val_transform, - test_transform, - predict_transform, - ) - - return cls.from_load_data_inputs( - train_load_data_input=(train_folder, train_ann_file, train_transform), - val_load_data_input=(val_folder, val_ann_file, val_transform) if val_folder else None, - test_load_data_input=(test_folder, test_ann_file, test_transform) if test_folder else None, + return cls.from_data_source( + data_source="coco", + train_data=(train_folder, train_ann_file) if train_folder else None, + val_data=(val_folder, val_ann_file) if val_folder else None, + test_data=(test_folder, test_ann_file) if test_folder else None, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + preprocess=preprocess, + val_split=val_split, batch_size=batch_size, num_workers=num_workers, - preprocess=preprocess, - **kwargs ) diff --git a/flash/vision/detection/model.py b/flash/vision/detection/model.py index a7eed0e105..dba922b8f9 100644 --- a/flash/vision/detection/model.py +++ b/flash/vision/detection/model.py @@ -24,6 +24,7 @@ from flash.core import Task from flash.core.registry import FlashRegistry +from flash.data.data_source import DefaultDataKeys from flash.vision.backbones import OBJ_DETECTION_BACKBONES from flash.vision.detection.finetuning import ObjectDetectionFineTuning @@ -156,7 +157,7 @@ def get_model( def training_step(self, batch, batch_idx) -> Any: """The training step. Overrides ``Task.training_step`` """ - images, targets = batch + images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] targets = [{k: v for k, v in t.items()} for t in targets] # fasterrcnn takes both images and targets for training, returns loss_dict @@ -166,7 +167,7 @@ def training_step(self, batch, batch_idx) -> Any: return loss def validation_step(self, batch, batch_idx): - images, targets = batch + images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] # fasterrcnn takes only images for eval() mode outs = self.model(images) iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() @@ -178,7 +179,7 @@ def validation_epoch_end(self, outs): return {"avg_val_iou": avg_iou, "log": logs} def test_step(self, batch, batch_idx): - images, targets = batch + images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] # fasterrcnn takes only images for eval() mode outs = self.model(images) iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() @@ -189,5 +190,9 @@ def test_epoch_end(self, outs): logs = {"test_iou": avg_iou} return {"avg_test_iou": avg_iou, "log": logs} + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + images = batch[DefaultDataKeys.INPUT] + return self.model(images) + def configure_finetune_callback(self): return [ObjectDetectionFineTuning(train_bn=True)] diff --git a/flash/vision/detection/transforms.py b/flash/vision/detection/transforms.py new file mode 100644 index 0000000000..735f9db305 --- /dev/null +++ b/flash/vision/detection/transforms.py @@ -0,0 +1,38 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict + +import torch +import torchvision +from torch import nn + +from flash.data.transforms import ApplyToKeys + + +def default_transforms() -> Dict[str, Callable]: + return { + "to_tensor_transform": nn.Sequential( + ApplyToKeys('input', torchvision.transforms.ToTensor()), + ApplyToKeys( + 'target', + nn.Sequential( + ApplyToKeys('boxes', torch.as_tensor), + ApplyToKeys('labels', torch.as_tensor), + ApplyToKeys('image_id', torch.as_tensor), + ApplyToKeys('area', torch.as_tensor), + ApplyToKeys('iscrowd', torch.as_tensor), + ) + ), + ), + } diff --git a/flash/vision/embedding/model.py b/flash/vision/embedding/model.py index f43dabcfaa..1392228e37 100644 --- a/flash/vision/embedding/model.py +++ b/flash/vision/embedding/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union import torch from pytorch_lightning.utilities.distributed import rank_zero_warn @@ -21,6 +21,7 @@ from flash.core import Task from flash.core.registry import FlashRegistry +from flash.data.data_source import DefaultDataKeys from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES from flash.vision.classification.data import ImageClassificationData, ImageClassificationPreprocess @@ -108,3 +109,19 @@ def forward(self, x) -> torch.Tensor: x = self.head(x) return x + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch = (batch[DefaultDataKeys.INPUT]) + return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) diff --git a/flash_examples/custom_task.py b/flash_examples/custom_task.py index 4c8e5b9c04..8fc9c3de88 100644 --- a/flash_examples/custom_task.py +++ b/flash_examples/custom_task.py @@ -9,7 +9,8 @@ import flash from flash.data.auto_dataset import AutoDataset -from flash.data.process import Postprocess, Preprocess +from flash.data.data_source import DataSource +from flash.data.process import Preprocess seed_everything(42) @@ -41,22 +42,28 @@ def forward(self, x): return self.model(x) -class NumpyPreprocess(Preprocess): +class NumpyDataSource(DataSource): def load_data(self, data: Tuple[ND, ND], dataset: AutoDataset) -> List[Tuple[ND, float]]: if self.training: dataset.num_inputs = data[0].shape[1] return [(x, y) for x, y in zip(*data)] + def predict_load_data(self, data: ND) -> ND: + return data + + +class NumpyPreprocess(Preprocess): + + def __init__(self): + super().__init__(data_sources={"numpy": NumpyDataSource()}, default_data_source="numpy") + def to_tensor_transform(self, sample: Any) -> Tuple[Tensor, Tensor]: x, y = sample x = torch.from_numpy(x).float() y = torch.tensor(y, dtype=torch.float) return x, y - def predict_load_data(self, data: ND) -> ND: - return data - def predict_to_tensor_transform(self, sample: ND) -> ND: return torch.from_numpy(sample).float() @@ -77,12 +84,13 @@ def from_dataset(cls, x: ND, y: ND, preprocess: Preprocess, batch_size: int = 64 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0) - dm = cls.from_load_data_inputs( - train_load_data_input=(x_train, y_train), - test_load_data_input=(x_test, y_test), + dm = cls.from_data_source( + "numpy", + train_data=(x_train, y_train), + test_data=(x_test, y_test), preprocess=preprocess, batch_size=batch_size, - num_workers=num_workers + num_workers=num_workers, ) dm.num_inputs = dm.train_dataset.num_inputs return dm diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index eb863f4e9f..2ebc668f95 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -48,7 +48,7 @@ def fn_resnet(pretrained: bool = True): print(ImageClassifier.available_backbones()) # 4. Build the model -model = ImageClassifier(backbone="dino_vitb16", num_classes=datamodule.num_classes) +model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) # 5. Create the trainer. trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) @@ -66,13 +66,9 @@ def fn_resnet(pretrained: bool = True): "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", ]) - print(predictions) -datamodule = ImageClassificationData.from_folders( - predict_folder="data/hymenoptera_data/predict/", - preprocess=model.preprocess, -) +datamodule = ImageClassificationData.from_folders(predict_folder="data/hymenoptera_data/predict/") # 7b. Or generate predictions with a whole folder! predictions = Trainer().predict(model, datamodule=datamodule) diff --git a/flash_examples/finetuning/image_classification_multi_label.py b/flash_examples/finetuning/image_classification_multi_label.py index ca2360c519..5ae824d74f 100644 --- a/flash_examples/finetuning/image_classification_multi_label.py +++ b/flash_examples/finetuning/image_classification_multi_label.py @@ -39,16 +39,16 @@ def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], L [[int(row[genre]) for genre in genres] for _, row in metadata.iterrows()]) -train_filepaths, train_labels = load_data('train') -test_filepaths, test_labels = load_data('test') - -datamodule = ImageClassificationData.from_filepaths( - train_filepaths=train_filepaths, - train_labels=train_labels, - test_filepaths=test_filepaths, - test_labels=test_labels, - preprocess=ImageClassificationPreprocess(image_size=(128, 128)), +train_files, train_targets = load_data('train') +test_files, test_targets = load_data('test') + +datamodule = ImageClassificationData.from_files( + train_files=train_files, + train_targets=train_targets, + test_files=test_files, + test_targets=test_targets, val_split=0.1, # Use 10 % of the train dataset to generate validation one. + image_size=(128, 128), ) # 3. Build the model diff --git a/flash_examples/finetuning/object_detection.py b/flash_examples/finetuning/object_detection.py index 4d013c37ac..eee289cd0d 100644 --- a/flash_examples/finetuning/object_detection.py +++ b/flash_examples/finetuning/object_detection.py @@ -23,14 +23,14 @@ datamodule = ObjectDetectionData.from_coco( train_folder="data/coco128/images/train2017/", train_ann_file="data/coco128/annotations/instances_train2017.json", - batch_size=2 + batch_size=2, ) # 3. Build the model model = ObjectDetector(num_classes=datamodule.num_classes) # 4. Create the trainer -trainer = flash.Trainer(max_epochs=3) +trainer = flash.Trainer(max_epochs=3, limit_train_batches=1) # 5. Finetune the model trainer.finetune(model, datamodule) diff --git a/flash_examples/finetuning/summarization.py b/flash_examples/finetuning/summarization.py index d2ecc726f3..78757ebacc 100644 --- a/flash_examples/finetuning/summarization.py +++ b/flash_examples/finetuning/summarization.py @@ -21,12 +21,12 @@ download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/") # 2. Load the data -datamodule = SummarizationData.from_files( +datamodule = SummarizationData.from_csv( + "input", + "target", train_file="data/xsum/train.csv", val_file="data/xsum/valid.csv", test_file="data/xsum/test.csv", - input="input", - target="target" ) # 3. Build the model diff --git a/flash_examples/finetuning/tabular_classification.py b/flash_examples/finetuning/tabular_classification.py index 9d5b8ad256..ad8a949455 100644 --- a/flash_examples/finetuning/tabular_classification.py +++ b/flash_examples/finetuning/tabular_classification.py @@ -22,12 +22,12 @@ # 2. Load the data datamodule = TabularData.from_csv( - target_col="Survived", - train_csv="./data/titanic/titanic.csv", - test_csv="./data/titanic/test.csv", - categorical_cols=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], - numerical_cols=["Fare"], - val_size=0.25, + ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], + ["Fare"], + target_field="Survived", + train_file="./data/titanic/titanic.csv", + test_file="./data/titanic/test.csv", + val_split=0.25, ) # 3. Build the model diff --git a/flash_examples/finetuning/text_classification.py b/flash_examples/finetuning/text_classification.py index efbcac71ea..5f0000cf40 100644 --- a/flash_examples/finetuning/text_classification.py +++ b/flash_examples/finetuning/text_classification.py @@ -19,12 +19,12 @@ download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/") # 2. Load the data -datamodule = TextClassificationData.from_files( +datamodule = TextClassificationData.from_csv( train_file="data/imdb/train.csv", val_file="data/imdb/valid.csv", test_file="data/imdb/test.csv", - input="review", - target="sentiment", + input_fields="review", + target_fields="sentiment", batch_size=16, ) diff --git a/flash_examples/finetuning/translation.py b/flash_examples/finetuning/translation.py index be91ea057d..2a3b1bebf9 100644 --- a/flash_examples/finetuning/translation.py +++ b/flash_examples/finetuning/translation.py @@ -21,20 +21,24 @@ download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/") # 2. Load the data -datamodule = TranslationData.from_files( +datamodule = TranslationData.from_csv( + "input", + "target", train_file="data/wmt_en_ro/train.csv", val_file="data/wmt_en_ro/valid.csv", test_file="data/wmt_en_ro/test.csv", - input="input", - target="target", - batch_size=1 + batch_size=1, ) # 3. Build the model model = TranslationTask() # 4. Create the trainer -trainer = flash.Trainer(precision=32, gpus=int(torch.cuda.is_available()), fast_dev_run=True) +trainer = flash.Trainer( + precision=16 if torch.cuda.is_available() else 32, + gpus=int(torch.cuda.is_available()), + fast_dev_run=True, +) # 5. Fine-tune the model trainer.finetune(model, datamodule=datamodule) diff --git a/flash_examples/finetuning/video_classification.py b/flash_examples/finetuning/video_classification.py index 0e30141a61..c9ede4f043 100644 --- a/flash_examples/finetuning/video_classification.py +++ b/flash_examples/finetuning/video_classification.py @@ -35,7 +35,7 @@ if __name__ == '__main__': - _PATH_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + _PATH_ROOT = os.path.dirname(os.path.abspath(__file__)) # 1. Download a video clip dataset. Find more dataset at https://pytorchvideo.readthedocs.io/en/latest/data.html download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip") @@ -72,19 +72,18 @@ def make_transform( } # 3. Load the data from directories. - datamodule = VideoClassificationData.from_paths( - train_data_path=os.path.join(_PATH_ROOT, "data/kinetics/train"), - val_data_path=os.path.join(_PATH_ROOT, "data/kinetics/val"), - predict_data_path=os.path.join(_PATH_ROOT, "data/kinetics/predict"), - clip_sampler="uniform", - clip_duration=2, - video_sampler=RandomSampler, - decode_audio=False, + datamodule = VideoClassificationData.from_folders( + train_folder=os.path.join(_PATH_ROOT, "data/kinetics/train"), + val_folder=os.path.join(_PATH_ROOT, "data/kinetics/val"), + predict_folder=os.path.join(_PATH_ROOT, "data/kinetics/predict"), train_transform=make_transform(train_post_tensor_transform), val_transform=make_transform(val_post_tensor_transform), predict_transform=make_transform(val_post_tensor_transform), - num_workers=8, batch_size=8, + clip_sampler="uniform", + clip_duration=2, + video_sampler=RandomSampler, + decode_audio=False, ) # 4. List the available models @@ -97,12 +96,11 @@ def make_transform( model.serializer = Labels() # 6. Finetune the model - trainer = flash.Trainer(max_epochs=3, gpus=1) + trainer = flash.Trainer(max_epochs=3) trainer.finetune(model, datamodule=datamodule, strategy=NoFreeze()) trainer.save_checkpoint("video_classification.pt") # 7. Make a prediction - val_folder = os.path.join(_PATH_ROOT, os.path.join(_PATH_ROOT, "data/kinetics/predict")) - predictions = model.predict([os.path.join(val_folder, f) for f in os.listdir(val_folder)]) + predictions = model.predict(os.path.join(_PATH_ROOT, "data/kinetics/predict")) print(predictions) diff --git a/flash_examples/predict/image_classification_multi_label.py b/flash_examples/predict/image_classification_multi_label.py index c20f78172b..59e4c7da9e 100644 --- a/flash_examples/predict/image_classification_multi_label.py +++ b/flash_examples/predict/image_classification_multi_label.py @@ -33,7 +33,7 @@ class CustomViz(BaseVisualization): def show_per_batch_transform(self, batch: Any, _) -> None: - images = batch[0] + images = batch[0]["input"] image = make_grid(images, nrow=2) image = T.to_pil_image(image, 'RGB') image.show() @@ -56,7 +56,7 @@ def show_per_batch_transform(self, batch: Any, _) -> None: datamodule = ImageClassificationData.from_folders( predict_folder="data/movie_posters/predict/", data_fetcher=CustomViz(), - preprocess=model.preprocess, + image_size=(128, 128), ) predictions = Trainer().predict(model, datamodule=datamodule) diff --git a/flash_examples/predict/image_embedder.py b/flash_examples/predict/image_embedder.py index 04bb155361..54df44a736 100644 --- a/flash_examples/predict/image_embedder.py +++ b/flash_examples/predict/image_embedder.py @@ -33,7 +33,7 @@ random_image = torch.randn(1, 3, 244, 244) # 6. Generate an embedding from this random image. -embeddings = embedder.predict(random_image) +embeddings = embedder.predict(random_image, data_source="tensor") # 7. Print embeddings shape print(embeddings[0].shape) diff --git a/flash_examples/predict/summarization.py b/flash_examples/predict/summarization.py index 6d16ebfcaf..ff59c6cfa3 100644 --- a/flash_examples/predict/summarization.py +++ b/flash_examples/predict/summarization.py @@ -48,9 +48,9 @@ print(predictions) # 2b. Or generate summaries from a sheet file! -datamodule = SummarizationData.from_files( +datamodule = SummarizationData.from_csv( + "input", predict_file="data/xsum/predict.csv", - input="input", ) predictions = Trainer().predict(model, datamodule=datamodule) print(predictions) diff --git a/flash_examples/predict/text_classification.py b/flash_examples/predict/text_classification.py index e81fd17c52..372250b21f 100644 --- a/flash_examples/predict/text_classification.py +++ b/flash_examples/predict/text_classification.py @@ -36,11 +36,9 @@ print(predictions) # 2b. Or generate predictions from a sheet file! -datamodule = TextClassificationData.from_file( +datamodule = TextClassificationData.from_csv( + "review", predict_file="data/imdb/predict.csv", - input="review", - # use the same data pre-processing values we used to predict in 2a - preprocess=model.preprocess, ) predictions = Trainer().predict(model, datamodule=datamodule) print(predictions) diff --git a/flash_examples/predict/translation.py b/flash_examples/predict/translation.py index a210f267ae..cd6009f4db 100644 --- a/flash_examples/predict/translation.py +++ b/flash_examples/predict/translation.py @@ -11,10 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning import Trainer - from flash.data.utils import download_data -from flash.text import TranslationData, TranslationTask +from flash.text import TranslationTask # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/") @@ -22,7 +20,7 @@ # 2. Load the model from a checkpoint model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") -# 2. Translate a few sentences! +# 3. Translate a few sentences! predictions = model.predict([ "BBC News went to meet one of the project's first graduates.", "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", diff --git a/flash_examples/predict/video_classification.py b/flash_examples/predict/video_classification.py index 0fd790b492..2bf8bff520 100644 --- a/flash_examples/predict/video_classification.py +++ b/flash_examples/predict/video_classification.py @@ -11,25 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import sys -from typing import Callable, List -import torch -from torch.utils.data.sampler import RandomSampler - -import flash -from flash.core.classification import Labels -from flash.core.finetuning import NoFreeze from flash.data.utils import download_data from flash.utils.imports import _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE -from flash.video import VideoClassificationData, VideoClassifier +from flash.video import VideoClassifier -if _PYTORCHVIDEO_AVAILABLE and _KORNIA_AVAILABLE: - import kornia.augmentation as K - from pytorchvideo.transforms import ApplyTransformToKey, RandomShortSideScale, UniformTemporalSubsample - from torchvision.transforms import CenterCrop, Compose, RandomCrop, RandomHorizontalFlip -else: +if not (_PYTORCHVIDEO_AVAILABLE and _KORNIA_AVAILABLE): print("Please, run `pip install torchvideo kornia`") sys.exit(0) @@ -41,6 +29,5 @@ ) # 2. Make a prediction -predict_folder = "data/kinetics/predict/" -predictions = model.predict([os.path.join(predict_folder, f) for f in os.listdir(predict_folder)]) +predictions = model.predict("data/kinetics/predict/") print(predictions) diff --git a/requirements.txt b/requirements.txt index 8aaa1ec97d..bce2efd675 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ pytorch-lightning>=1.3.0rc1 lightning-bolts>=0.3.3 PyYAML>=5.1 Pillow>=7.2 -transformers>=4.0 +transformers>=4.5 pytorch-tabnet==3.1 datasets>=1.2, <1.3 pandas>=1.1 diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index df12d85ca2..e86838b558 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -25,7 +25,7 @@ class DummyDataset(torch.utils.data.Dataset): def __getitem__(self, index: int) -> Any: - return torch.rand(3, 64, 64), torch.randint(10, size=(1, )).item() + return {"input": torch.rand(3, 64, 64), "target": torch.randint(10, size=(1, )).item()} def __len__(self) -> int: return 100 diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 6a60071f74..df7dc89e33 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -29,7 +29,7 @@ import flash from flash.core.classification import ClassificationTask -from flash.data.process import Postprocess +from flash.data.process import DefaultPreprocess, Postprocess from flash.tabular import TabularClassifier from flash.text import SummarizationTask, TextClassifier from flash.utils.imports import _TRANSFORMERS_AVAILABLE @@ -75,7 +75,7 @@ def test_classificationtask_train(tmpdir: str, metrics: Any): def test_classificationtask_task_predict(): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) - task = ClassificationTask(model) + task = ClassificationTask(model, preprocess=DefaultPreprocess()) ds = DummyDataset() expected = list(range(10)) # single item diff --git a/tests/data/test_auto_dataset.py b/tests/data/test_auto_dataset.py index 2d50e671e4..b235c3cf38 100644 --- a/tests/data/test_auto_dataset.py +++ b/tests/data/test_auto_dataset.py @@ -16,13 +16,14 @@ import pytest from pytorch_lightning.trainer.states import RunningStage -from flash.data.auto_dataset import AutoDataset +from flash.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset from flash.data.callback import FlashCallback from flash.data.data_pipeline import DataPipeline +from flash.data.data_source import DataSource from flash.data.process import Preprocess -class _AutoDatasetTestPreprocess(Preprocess): +class _AutoDatasetTestDataSource(DataSource): def __init__(self, with_dset: bool): self._callbacks: List[FlashCallback] = [] @@ -48,13 +49,6 @@ def __init__(self, with_dset: bool): self.train_load_data = self.train_load_data_no_dset self.train_load_sample = self.train_load_sample_no_dset - def get_state_dict(self) -> Dict[str, Any]: - return {"with_dset": self.with_dset} - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): - return _AutoDatasetTestPreprocess(state_dict["with_dset"]) - def load_data_no_dset(self, data): self.load_data_count += 1 return data @@ -92,85 +86,82 @@ def train_load_data_with_dataset(self, data, dataset): return data -@pytest.mark.parametrize( - "with_dataset,with_running_stage", - [ - (True, False), - (True, True), - (False, False), - (False, True), - ], -) -def test_autodataset_with_functions( - with_dataset: bool, - with_running_stage: bool, -): +# TODO: we should test the different data types +@pytest.mark.parametrize("running_stage", [RunningStage.TRAINING, RunningStage.TESTING, RunningStage.VALIDATING]) +def test_base_autodataset_smoke(running_stage): + dt = range(10) + ds = DataSource() + dset = BaseAutoDataset(data=dt, data_source=ds, running_stage=running_stage) + assert dset is not None + assert dset.running_stage == running_stage - functions = _AutoDatasetTestPreprocess(with_dataset) + # check on members + assert dset.data == dt + assert dset.data_source == ds - load_sample_func = functions.load_sample - load_data_func = functions.load_data + # test set the running stage + dset.running_stage = RunningStage.PREDICTING + assert dset.running_stage == RunningStage.PREDICTING - if with_running_stage: - running_stage = RunningStage.TRAINING - else: - running_stage = None - dset = AutoDataset( - range(10), - load_data=load_data_func, - load_sample=load_sample_func, - running_stage=running_stage, - ) + # check on methods + assert dset.load_sample is not None + assert dset.load_sample == ds.load_sample - assert len(dset) == 10 - for idx in range(len(dset)): - dset[idx] +def test_autodataset_smoke(): + num_samples = 20 + dt = range(num_samples) + ds = DataSource() - if with_dataset: - assert dset.load_sample_was_called - assert dset.load_data_was_called - assert functions.load_sample_with_dataset_count == len(dset) - assert functions.load_data_with_dataset_count == 1 - else: - assert functions.load_data_count == 1 - assert functions.load_sample_count == len(dset) + dset = AutoDataset(data=dt, data_source=ds, running_stage=RunningStage.TRAINING) + assert dset is not None + assert dset.running_stage == RunningStage.TRAINING + # check on members + assert dset.data == dt + assert dset.data_source == ds -def test_autodataset_warning(): - with pytest.warns( - UserWarning, match="``datapipeline`` is specified but load_sample and/or load_data are also specified" - ): - AutoDataset(range(10), load_data=lambda x: x, load_sample=lambda x: x, data_pipeline=DataPipeline()) + # test set the running stage + dset.running_stage = RunningStage.PREDICTING + assert dset.running_stage == RunningStage.PREDICTING + # check on methods + assert dset.load_sample is not None + assert dset.load_sample == ds.load_sample -@pytest.mark.parametrize( - "with_dataset", - [ - True, - False, - ], -) -def test_preprocessing_data_pipeline_with_running_stage(with_dataset): - pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset)) + # check getters + assert len(dset) == num_samples + assert dset[0] == 0 + assert dset[9] == 9 + assert dset[11] == 11 - running_stage = RunningStage.TRAINING - dataset = pipe._generate_auto_dataset(range(10), running_stage=running_stage) +def test_iterable_autodataset_smoke(): + num_samples = 20 + dt = range(num_samples) + ds = DataSource() - assert len(dataset) == 10 + dset = IterableAutoDataset(data=dt, data_source=ds, running_stage=RunningStage.TRAINING) + assert dset is not None + assert dset.running_stage == RunningStage.TRAINING - for idx in range(len(dataset)): - dataset[idx] + # check on members + assert dset.data == dt + assert dset.data_source == ds - if with_dataset: - assert dataset.train_load_sample_was_called - assert dataset.train_load_data_was_called - assert pipe._preprocess_pipeline.train_load_sample_with_dataset_count == len(dataset) - assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1 - else: - assert pipe._preprocess_pipeline.train_load_sample_count == len(dataset) - assert pipe._preprocess_pipeline.train_load_data_count == 1 + # test set the running stage + dset.running_stage = RunningStage.PREDICTING + assert dset.running_stage == RunningStage.PREDICTING + + # check on methods + assert dset.load_sample is not None + assert dset.load_sample == ds.load_sample + + # check getters + itr = iter(dset) + assert next(itr) == 0 + assert next(itr) == 1 + assert next(itr) == 2 @pytest.mark.parametrize( @@ -180,29 +171,22 @@ def test_preprocessing_data_pipeline_with_running_stage(with_dataset): False, ], ) -def test_preprocessing_data_pipeline_no_running_stage(with_dataset): - pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset)) - - dataset = pipe._generate_auto_dataset(range(10), running_stage=None) +def test_preprocessing_data_source_with_running_stage(with_dataset): + data_source = _AutoDatasetTestDataSource(with_dataset) + running_stage = RunningStage.TRAINING - with pytest.raises(RuntimeError, match='`__len__` for `load_sample`'): - for idx in range(len(dataset)): - dataset[idx] + dataset = data_source.generate_dataset(range(10), running_stage=running_stage) - # will be triggered when running stage is set - if with_dataset: - assert not hasattr(dataset, 'load_sample_was_called') - assert not hasattr(dataset, 'load_data_was_called') - assert pipe._preprocess_pipeline.load_sample_with_dataset_count == 0 - assert pipe._preprocess_pipeline.load_data_with_dataset_count == 0 - else: - assert pipe._preprocess_pipeline.load_sample_count == 0 - assert pipe._preprocess_pipeline.load_data_count == 0 + assert len(dataset) == 10 - dataset.running_stage = RunningStage.TRAINING + for idx in range(len(dataset)): + dataset[idx] if with_dataset: - assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1 + assert dataset.train_load_sample_was_called assert dataset.train_load_data_was_called + assert data_source.train_load_sample_with_dataset_count == len(dataset) + assert data_source.train_load_data_with_dataset_count == 1 else: - assert pipe._preprocess_pipeline.train_load_data_count == 1 + assert data_source.train_load_sample_count == len(dataset) + assert data_source.train_load_data_count == 1 diff --git a/tests/data/test_callback.py b/tests/data/test_callback.py index 0bc47a91cd..26b1a941a0 100644 --- a/tests/data/test_callback.py +++ b/tests/data/test_callback.py @@ -11,19 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple from unittest import mock -from unittest.mock import ANY, call, MagicMock, Mock +from unittest.mock import ANY, call, MagicMock import torch -from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer.states import RunningStage -from torch import Tensor from flash.core.model import Task from flash.core.trainer import Trainer from flash.data.data_module import DataModule -from flash.data.process import Preprocess +from flash.data.process import DefaultPreprocess @mock.patch("torch.save") # need to mock torch.save or we get pickle error @@ -33,7 +30,9 @@ def test_flash_callback(_, tmpdir): callback_mock = MagicMock() inputs = [[torch.rand(1), torch.rand(1)]] - dm = DataModule.from_load_data_inputs(inputs, inputs, inputs, None, num_workers=0) + dm = DataModule.from_data_source( + "default", inputs, inputs, inputs, None, preprocess=DefaultPreprocess(), batch_size=1, num_workers=0 + ) dm.preprocess.callbacks += [callback_mock] _ = next(iter(dm.train_dataloader())) @@ -59,7 +58,9 @@ def __init__(self): limit_train_batches=1, progress_bar_refresh_rate=0, ) - dm = DataModule.from_load_data_inputs(inputs, inputs, inputs, None, num_workers=0) + dm = DataModule.from_data_source( + "default", inputs, inputs, inputs, None, preprocess=DefaultPreprocess(), batch_size=1, num_workers=0 + ) dm.preprocess.callbacks += [callback_mock] trainer.fit(CustomModel(), datamodule=dm) diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index 46a9347cfa..f4748a5149 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -25,8 +25,9 @@ from flash.data.base_viz import BaseVisualization from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule -from flash.data.process import DefaultPreprocess, Preprocess -from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX +from flash.data.data_source import DefaultDataKeys +from flash.data.process import DefaultPreprocess +from flash.data.utils import _CALLBACK_FUNCS, _STAGES_PREFIX from flash.vision import ImageClassificationData @@ -60,20 +61,25 @@ def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_dat preprocess = DefaultPreprocess() - return cls.from_load_data_inputs( - train_load_data_input=train_data, - val_load_data_input=val_data, - test_load_data_input=test_data, - predict_load_data_input=predict_data, + return cls.from_data_source( + "default", + train_data=train_data, + val_data=val_data, + test_data=test_data, + predict_data=predict_data, preprocess=preprocess, - batch_size=5 + batch_size=5, ) dm = CustomDataModule.from_inputs(range(5), range(5), range(5), range(5)) data_fetcher: CheckData = dm.data_fetcher + if not hasattr(dm, "_val_iter"): + dm._reset_iterator("val") + with data_fetcher.enable(): - _ = next(iter(dm.val_dataloader())) + assert data_fetcher.enabled + _ = next(dm._val_iter) data_fetcher.check() data_fetcher.reset() @@ -133,14 +139,14 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: B: int = 2 # batch_size - dm = CustomImageClassificationData.from_filepaths( - train_filepaths=train_images, - train_labels=[0, 1], - val_filepaths=train_images, - val_labels=[2, 3], - test_filepaths=train_images, - test_labels=[4, 5], - predict_filepaths=train_images, + dm = CustomImageClassificationData.from_files( + train_files=train_images, + train_targets=[0, 1], + val_files=train_images, + val_targets=[2, 3], + test_files=train_images, + test_targets=[4, 5], + predict_files=train_images, batch_size=B, num_workers=0, ) @@ -150,16 +156,14 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: for stage in _STAGES_PREFIX.values(): for _ in range(num_tests): - for fcn_name in _PREPROCESS_FUNCS: + for fcn_name in _CALLBACK_FUNCS: fcn = getattr(dm, f"show_{stage}_batch") fcn(fcn_name, reset=True) is_predict = stage == "predict" def _extract_data(data): - if not is_predict: - return data[0][0] - return data[0] + return data[0][DefaultDataKeys.INPUT] def _get_result(function_name: str): return dm.data_fetcher.batches[stage][function_name] @@ -170,7 +174,7 @@ def _get_result(function_name: str): if not is_predict: res = _get_result("load_sample") - assert isinstance(res[0][1], torch.Tensor) + assert isinstance(res[0][DefaultDataKeys.TARGET], int) res = _get_result("to_tensor_transform") assert len(res) == B @@ -178,21 +182,21 @@ def _get_result(function_name: str): if not is_predict: res = _get_result("to_tensor_transform") - assert isinstance(res[0][1], torch.Tensor) + assert isinstance(res[0][DefaultDataKeys.TARGET], torch.Tensor) res = _get_result("collate") assert _extract_data(res).shape == (B, 3, 196, 196) if not is_predict: res = _get_result("collate") - assert res[0][1].shape == torch.Size([2]) + assert res[0][DefaultDataKeys.TARGET].shape == torch.Size([2]) res = _get_result("per_batch_transform") assert _extract_data(res).shape == (B, 3, 196, 196) if not is_predict: res = _get_result("per_batch_transform") - assert res[0][1].shape == (B, ) + assert res[0][DefaultDataKeys.TARGET].shape == (B, ) assert dm.data_fetcher.show_load_sample_called assert dm.data_fetcher.show_pre_tensor_transform_called diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index 6b7ae78def..721213a816 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -32,6 +32,7 @@ from flash.data.batch import _PostProcessor, _PreProcessor from flash.data.data_module import DataModule from flash.data.data_pipeline import _StageOrchestrator, DataPipeline +from flash.data.data_source import DataSource from flash.data.process import DefaultPreprocess, Postprocess, Preprocess @@ -64,14 +65,16 @@ class SubPostprocess(Postprocess): pass data_pipeline = DataPipeline( - SubPreprocess() if use_preprocess else None, - SubPostprocess() if use_postprocess else None, + preprocess=SubPreprocess() if use_preprocess else None, + postprocess=SubPostprocess() if use_postprocess else None, ) - assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else Preprocess) + assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else DefaultPreprocess) assert isinstance(data_pipeline._postprocess_pipeline, SubPostprocess if use_postprocess else Postprocess) model = CustomModel(postprocess=Postprocess()) model.data_pipeline = data_pipeline + # TODO: the line below should make the same effect but it's not + # data_pipeline._attach_to_model(model) if use_preprocess: assert isinstance(model._preprocess, SubPreprocess) @@ -88,21 +91,6 @@ def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir): class CustomPreprocess(DefaultPreprocess): - def load_data(self, *_, **__): - pass - - def test_load_data(self, *_, **__): - pass - - def predict_load_data(self, *_, **__): - pass - - def predict_load_sample(self, *_, **__): - pass - - def val_load_sample(self, *_, **__): - pass - def val_pre_tensor_transform(self, *_, **__): pass @@ -125,7 +113,8 @@ def test_per_batch_transform_on_device(self, *_, **__): pass preprocess = CustomPreprocess() - data_pipeline = DataPipeline(preprocess) + data_pipeline = DataPipeline(preprocess=preprocess) + train_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING, Preprocess @@ -150,17 +139,6 @@ def test_per_batch_transform_on_device(self, *_, **__): ) for k in data_pipeline.PREPROCESS_FUNCS } - # load_data - assert train_func_names["load_data"] == "load_data" - assert val_func_names["load_data"] == "load_data" - assert test_func_names["load_data"] == "test_load_data" - assert predict_func_names["load_data"] == "predict_load_data" - - # load_sample - assert train_func_names["load_sample"] == "load_sample" - assert val_func_names["load_sample"] == "val_load_sample" - assert test_func_names["load_sample"] == "load_sample" - assert predict_func_names["load_sample"] == "predict_load_sample" # pre_tensor_transform assert train_func_names["pre_tensor_transform"] == "pre_tensor_transform" @@ -271,7 +249,7 @@ def predict_per_batch_transform_on_device(self, *_, **__): def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor(): preprocess = CustomPreprocess() - data_pipeline = DataPipeline(preprocess) + data_pipeline = DataPipeline(preprocess=preprocess) data_pipeline.worker_preprocessor(RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="are mutual exclusive"): @@ -293,7 +271,7 @@ def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) preprocess = CustomPreprocess() - data_pipeline = DataPipeline(preprocess) + data_pipeline = DataPipeline(preprocess=preprocess) model = CustomModel() model.data_pipeline = data_pipeline @@ -343,7 +321,7 @@ class SubPreprocess(DefaultPreprocess): pass preprocess = SubPreprocess() - data_pipeline = DataPipeline(preprocess) + data_pipeline = DataPipeline(preprocess=preprocess) class CustomModel(Task): @@ -491,7 +469,7 @@ def _attach_postprocess_to_model(self, model: 'Task', _postprocesssor: _PostProc model.predict_step = self._model_predict_step_wrapper(model.predict_step, _postprocesssor, model) return model - data_pipeline = CustomDataPipeline(preprocess) + data_pipeline = CustomDataPipeline(preprocess=preprocess) _postprocesssor = data_pipeline._create_uncollate_postprocessors(RunningStage.PREDICTING) data_pipeline._attach_postprocess_to_model(model, _postprocesssor) assert model.predict_step._original == _original_predict_step @@ -512,23 +490,15 @@ def __len__(self) -> int: return 5 -class TestPreprocessTransformations(DefaultPreprocess): +class TestPreprocessTransformationsDataSource(DataSource): def __init__(self): super().__init__() self.train_load_data_called = False - self.train_pre_tensor_transform_called = False - self.train_collate_called = False - self.train_per_batch_transform_on_device_called = False self.val_load_data_called = False self.val_load_sample_called = False - self.val_to_tensor_transform_called = False - self.val_collate_called = False - self.val_per_batch_transform_on_device_called = False self.test_load_data_called = False - self.test_to_tensor_transform_called = False - self.test_post_tensor_transform_called = False self.predict_load_data_called = False @staticmethod @@ -546,6 +516,53 @@ def train_load_data(self, sample) -> LamdaDummyDataset: self.train_load_data_called = True return LamdaDummyDataset(self.fn_train_load_data) + def val_load_data(self, sample, dataset) -> List[int]: + assert self.validating + assert self.current_fn == "load_data" + self.val_load_data_called = True + return list(range(5)) + + def val_load_sample(self, sample) -> Dict[str, Tensor]: + assert self.validating + assert self.current_fn == "load_sample" + self.val_load_sample_called = True + return {"a": sample, "b": sample + 1} + + @staticmethod + def fn_test_load_data() -> List[torch.Tensor]: + return [torch.rand(1), torch.rand(1)] + + def test_load_data(self, sample) -> LamdaDummyDataset: + assert self.testing + assert self.current_fn == "load_data" + self.test_load_data_called = True + return LamdaDummyDataset(self.fn_test_load_data) + + @staticmethod + def fn_predict_load_data() -> List[str]: + return (["a", "b"]) + + def predict_load_data(self, sample) -> LamdaDummyDataset: + assert self.predicting + assert self.current_fn == "load_data" + self.predict_load_data_called = True + return LamdaDummyDataset(self.fn_predict_load_data) + + +class TestPreprocessTransformations(DefaultPreprocess): + + def __init__(self): + super().__init__(data_sources={"default": TestPreprocessTransformationsDataSource()}) + + self.train_pre_tensor_transform_called = False + self.train_collate_called = False + self.train_per_batch_transform_on_device_called = False + self.val_to_tensor_transform_called = False + self.val_collate_called = False + self.val_per_batch_transform_on_device_called = False + self.test_to_tensor_transform_called = False + self.test_post_tensor_transform_called = False + def train_pre_tensor_transform(self, sample: Any) -> Any: assert self.training assert self.current_fn == "pre_tensor_transform" @@ -564,19 +581,6 @@ def train_per_batch_transform_on_device(self, batch: Any) -> Any: self.train_per_batch_transform_on_device_called = True assert torch.equal(batch, tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]])) - def val_load_data(self, sample, dataset) -> List[int]: - assert self.validating - assert self.current_fn == "load_data" - self.val_load_data_called = True - assert isinstance(dataset, AutoDataset) - return list(range(5)) - - def val_load_sample(self, sample) -> Dict[str, Tensor]: - assert self.validating - assert self.current_fn == "load_sample" - self.val_load_sample_called = True - return {"a": sample, "b": sample + 1} - def val_to_tensor_transform(self, sample: Any) -> Tensor: assert self.validating assert self.current_fn == "to_tensor_transform" @@ -601,16 +605,6 @@ def val_per_batch_transform_on_device(self, batch: Any) -> Any: assert torch.equal(batch["b"], tensor([1, 2])) return [False] - @staticmethod - def fn_test_load_data() -> List[torch.Tensor]: - return [torch.rand(1), torch.rand(1)] - - def test_load_data(self, sample) -> LamdaDummyDataset: - assert self.testing - assert self.current_fn == "load_data" - self.test_load_data_called = True - return LamdaDummyDataset(self.fn_test_load_data) - def test_to_tensor_transform(self, sample: Any) -> Tensor: assert self.testing assert self.current_fn == "to_tensor_transform" @@ -623,16 +617,6 @@ def test_post_tensor_transform(self, sample: Tensor) -> Tensor: self.test_post_tensor_transform_called = True return sample - @staticmethod - def fn_predict_load_data() -> List[str]: - return (["a", "b"]) - - def predict_load_data(self, sample) -> LamdaDummyDataset: - assert self.predicting - assert self.current_fn == "load_data" - self.predict_load_data_called = True - return LamdaDummyDataset(self.fn_predict_load_data) - class TestPreprocessTransformations2(TestPreprocessTransformations): @@ -668,8 +652,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx): def test_datapipeline_transformations(tmpdir): - datamodule = DataModule.from_load_data_inputs( - 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations() + datamodule = DataModule.from_data_source( + "default", 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations() ) assert datamodule.train_dataloader().dataset[0] == (0, 1, 2, 3) @@ -681,8 +665,8 @@ def test_datapipeline_transformations(tmpdir): with pytest.raises(MisconfigurationException, match="When ``to_tensor_transform``"): batch = next(iter(datamodule.val_dataloader())) - datamodule = DataModule.from_load_data_inputs( - 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations2() + datamodule = DataModule.from_data_source( + "default", 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations2() ) batch = next(iter(datamodule.val_dataloader())) assert torch.equal(batch["a"], tensor([0, 1])) @@ -702,19 +686,20 @@ def test_datapipeline_transformations(tmpdir): trainer.predict(model) preprocess = model._preprocess - assert preprocess.train_load_data_called + data_source = preprocess.data_source_of_name("default") + assert data_source.train_load_data_called assert preprocess.train_pre_tensor_transform_called assert preprocess.train_collate_called assert preprocess.train_per_batch_transform_on_device_called - assert preprocess.val_load_data_called - assert preprocess.val_load_sample_called + assert data_source.val_load_data_called + assert data_source.val_load_sample_called assert preprocess.val_to_tensor_transform_called assert preprocess.val_collate_called assert preprocess.val_per_batch_transform_on_device_called - assert preprocess.test_load_data_called + assert data_source.test_load_data_called assert preprocess.test_to_tensor_transform_called assert preprocess.test_post_tensor_transform_called - assert preprocess.predict_load_data_called + assert data_source.predict_load_data_called def test_is_overriden_recursive(tmpdir): @@ -741,12 +726,7 @@ def val_collate(self, *_): @mock.patch("torch.save") # need to mock torch.save or we get pickle error def test_dummy_example(tmpdir): - class ImageClassificationPreprocess(DefaultPreprocess): - - def __init__(self, to_tensor_transform, train_per_sample_transform_on_device): - super().__init__() - self._to_tensor = to_tensor_transform - self._train_per_sample_transform_on_device = train_per_sample_transform_on_device + class ImageDataSource(DataSource): def load_data(self, folder: str): # from folder -> return files paths @@ -757,6 +737,27 @@ def load_sample(self, path: str) -> Image.Image: img8Bit = np.uint8(np.random.uniform(0, 1, (64, 64, 3)) * 255.0) return Image.fromarray(img8Bit) + class ImageClassificationPreprocess(DefaultPreprocess): + + def __init__( + self, + train_transform=None, + val_transform=None, + test_transform=None, + predict_transform=None, + to_tensor_transform=None, + train_per_sample_transform_on_device=None, + ): + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={"default": ImageDataSource()}, + ) + self._to_tensor = to_tensor_transform + self._train_per_sample_transform_on_device = train_per_sample_transform_on_device + def to_tensor_transform(self, pil_image: Image.Image) -> Tensor: # convert pil image into a tensor return self._to_tensor(pil_image) @@ -783,32 +784,15 @@ class CustomDataModule(DataModule): preprocess_cls = ImageClassificationPreprocess - @property - def preprocess(self): - return self.preprocess_cls(self.to_tensor_transform, self.train_per_sample_transform_on_device) - - @classmethod - def from_folders( - cls, train_folder: Optional[str], val_folder: Optional[str], test_folder: Optional[str], - predict_folder: Optional[str], to_tensor_transform: torch.nn.Module, - train_per_sample_transform_on_device: torch.nn.Module, batch_size: int - ): - - # attach the arguments for the preprocess onto the cls - cls.to_tensor_transform = to_tensor_transform - cls.train_per_sample_transform_on_device = train_per_sample_transform_on_device - - # call ``from_load_data_inputs`` - return cls.from_load_data_inputs( - train_load_data_input=train_folder, - val_load_data_input=val_folder, - test_load_data_input=test_folder, - predict_load_data_input=predict_folder, - batch_size=batch_size - ) - - datamodule = CustomDataModule.from_folders( - "train_folder", "val_folder", "test_folder", None, T.ToTensor(), T.RandomHorizontalFlip(), batch_size=2 + datamodule = CustomDataModule.from_data_source( + "default", + "train_folder", + "val_folder", + "test_folder", + None, + batch_size=2, + to_tensor_transform=T.ToTensor(), + train_per_sample_transform_on_device=T.RandomHorizontalFlip(), ) assert isinstance(datamodule.train_dataloader().dataset[0], Image.Image) @@ -865,10 +849,10 @@ def test_preprocess_transforms(tmpdir): assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is False - train_preprocessor = DataPipeline(preprocess).worker_preprocessor(RunningStage.TRAINING) - val_preprocessor = DataPipeline(preprocess).worker_preprocessor(RunningStage.VALIDATING) - test_preprocessor = DataPipeline(preprocess).worker_preprocessor(RunningStage.TESTING) - predict_preprocessor = DataPipeline(preprocess).worker_preprocessor(RunningStage.PREDICTING) + train_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(RunningStage.TRAINING) + val_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(RunningStage.VALIDATING) + test_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(RunningStage.TESTING) + predict_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(RunningStage.PREDICTING) assert train_preprocessor.collate_fn.func == default_collate assert val_preprocessor.collate_fn.func == default_collate @@ -893,7 +877,7 @@ def per_batch_transform(self, batch: Any) -> Any: assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is False - data_pipeline = DataPipeline(preprocess) + data_pipeline = DataPipeline(preprocess=preprocess) train_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`"): @@ -908,14 +892,12 @@ def per_batch_transform(self, batch: Any) -> Any: def test_iterable_auto_dataset(tmpdir): - class CustomPreprocess(DefaultPreprocess): + class CustomDataSource(DataSource): def load_sample(self, index: int) -> Dict[str, int]: return {"index": index} - data_pipeline = DataPipeline(CustomPreprocess()) - - ds = IterableAutoDataset(range(10), running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline) + ds = IterableAutoDataset(range(10), data_source=CustomDataSource(), running_stage=RunningStage.TRAINING) for index, v in enumerate(ds): assert v == {"index": index} diff --git a/tests/data/test_process.py b/tests/data/test_process.py index 8e4544081f..66df027b5b 100644 --- a/tests/data/test_process.py +++ b/tests/data/test_process.py @@ -19,9 +19,10 @@ from torch.utils.data import DataLoader from flash import Task, Trainer -from flash.core.classification import ClassificationState, Labels +from flash.core.classification import Labels, LabelsState from flash.data.data_pipeline import DataPipeline, DataPipelineState, DefaultPreprocess -from flash.data.process import ProcessState, Properties, Serializer, SerializerMapping +from flash.data.process import Serializer, SerializerMapping +from flash.data.properties import ProcessState, Properties def test_properties_data_pipeline_state(): @@ -120,7 +121,7 @@ def __init__(self): serializer = Labels(["a", "b"]) model = CustomModel() trainer = Trainer(fast_dev_run=True) - data_pipeline = DataPipeline(DefaultPreprocess(), serializer=serializer) + data_pipeline = DataPipeline(preprocess=DefaultPreprocess(), serializer=serializer) data_pipeline.initialize() model.data_pipeline = data_pipeline assert isinstance(model.preprocess, DefaultPreprocess) @@ -128,5 +129,5 @@ def __init__(self): trainer.fit(model, train_dataloader=dummy_data) trainer.save_checkpoint(checkpoint_file) model = CustomModel.load_from_checkpoint(checkpoint_file) - assert isinstance(model.preprocess._data_pipeline_state, DataPipelineState) - # assert model.preprocess._data_pipeline_state._state[ClassificationState] == ClassificationState(['a', 'b']) + assert isinstance(model._data_pipeline_state, DataPipelineState) + assert model._data_pipeline_state._state[LabelsState] == LabelsState(["a", "b"]) diff --git a/tests/data/test_serialization.py b/tests/data/test_serialization.py index 54d5ae40e6..bb166eeec8 100644 --- a/tests/data/test_serialization.py +++ b/tests/data/test_serialization.py @@ -53,7 +53,7 @@ def test_serialization_data_pipeline(tmpdir): loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) assert loaded_model.data_pipeline - model.data_pipeline = DataPipeline(CustomPreprocess()) + model.data_pipeline = DataPipeline(preprocess=CustomPreprocess()) assert isinstance(model.preprocess, CustomPreprocess) trainer.fit(model, dummy_data) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index ba5dd7d82b..2fc4ee18f3 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -62,7 +62,7 @@ def run_test(filepath): ("finetuning", "tabular_classification.py"), # ("finetuning", "video_classification.py"), # ("finetuning", "text_classification.py"), # TODO: takes too long - # ("finetuning", "translation.py"), # TODO: takes too long. + ("finetuning", "translation.py"), ("predict", "image_classification.py"), ("predict", "image_classification_multi_label.py"), ("predict", "tabular_classification.py"), @@ -70,7 +70,7 @@ def run_test(filepath): ("predict", "image_embedder.py"), ("predict", "video_classification.py"), # ("predict", "summarization.py"), # TODO: takes too long - # ("predict", "translate.py"), # TODO: takes too long + ("predict", "translation.py"), ] ) def test_example(tmpdir, folder, file): diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index a1055f2711..393597118f 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -14,6 +14,7 @@ import torch from pytorch_lightning import Trainer +from flash.data.data_source import DefaultDataKeys from flash.tabular import TabularClassifier # ======== Mock functions ======== @@ -30,7 +31,7 @@ def __getitem__(self, index): target = torch.randint(0, 10, size=(1, )).item() cat_vars = torch.randint(0, 10, size=(self.num_cat, )) num_vars = torch.rand(self.num_num) - return (cat_vars, num_vars), target + return {DefaultDataKeys.INPUT: (cat_vars, num_vars), DefaultDataKeys.TARGET: target} def __len__(self) -> int: return 100 diff --git a/tests/tabular/data/test_data.py b/tests/tabular/data/test_data.py index 1a181a5487..1a0d1e1574 100644 --- a/tests/tabular/data/test_data.py +++ b/tests/tabular/data/test_data.py @@ -18,6 +18,7 @@ import pandas as pd import pytest +from flash.data.data_source import DefaultDataKeys from flash.tabular import TabularData from flash.tabular.classification.data.dataset import _categorize, _normalize @@ -82,67 +83,73 @@ def test_emb_sizes(): def test_tabular_data(tmpdir): - train_df = TEST_DF_1.copy() - val_df = TEST_DF_2.copy() - test_df = TEST_DF_2.copy() - dm = TabularData.from_df( - train_df, + train_data_frame = TEST_DF_1.copy() + val_data_frame = TEST_DF_2.copy() + test_data_frame = TEST_DF_2.copy() + dm = TabularData.from_data_frame( categorical_cols=["category"], numerical_cols=["scalar_b", "scalar_b"], target_col="label", - val_df=val_df, - test_df=test_df, + train_data_frame=train_data_frame, + val_data_frame=val_data_frame, + test_data_frame=test_data_frame, num_workers=0, batch_size=1, ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: - (cat, num), target = next(iter(dl)) + data = next(iter(dl)) + (cat, num) = data[DefaultDataKeys.INPUT] + target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) assert target.shape == (1, ) def test_categorical_target(tmpdir): - train_df = TEST_DF_1.copy() - val_df = TEST_DF_2.copy() - test_df = TEST_DF_2.copy() - for df in [train_df, val_df, test_df]: + train_data_frame = TEST_DF_1.copy() + val_data_frame = TEST_DF_2.copy() + test_data_frame = TEST_DF_2.copy() + for df in [train_data_frame, val_data_frame, test_data_frame]: # change int label to string df["label"] = df["label"].astype(str) - dm = TabularData.from_df( - train_df, + dm = TabularData.from_data_frame( categorical_cols=["category"], numerical_cols=["scalar_b", "scalar_b"], target_col="label", - val_df=val_df, - test_df=test_df, + train_data_frame=train_data_frame, + val_data_frame=val_data_frame, + test_data_frame=test_data_frame, num_workers=0, batch_size=1, ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: - (cat, num), target = next(iter(dl)) + data = next(iter(dl)) + (cat, num) = data[DefaultDataKeys.INPUT] + target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) assert target.shape == (1, ) -def test_from_df(tmpdir): - train_df = TEST_DF_1.copy() - val_df = TEST_DF_2.copy() - test_df = TEST_DF_2.copy() - dm = TabularData.from_df( - train_df, +def test_from_data_frame(tmpdir): + train_data_frame = TEST_DF_1.copy() + val_data_frame = TEST_DF_2.copy() + test_data_frame = TEST_DF_2.copy() + dm = TabularData.from_data_frame( categorical_cols=["category"], numerical_cols=["scalar_b", "scalar_b"], target_col="label", - val_df=val_df, - test_df=test_df, + train_data_frame=train_data_frame, + val_data_frame=val_data_frame, + test_data_frame=test_data_frame, num_workers=0, batch_size=1 ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: - (cat, num), target = next(iter(dl)) + data = next(iter(dl)) + (cat, num) = data[DefaultDataKeys.INPUT] + target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) assert target.shape == (1, ) @@ -156,25 +163,32 @@ def test_from_csv(tmpdir): TEST_DF_2.to_csv(test_csv) dm = TabularData.from_csv( - train_csv=train_csv, - categorical_cols=["category"], - numerical_cols=["scalar_b", "scalar_b"], - target_col="label", - val_csv=val_csv, - test_csv=test_csv, + categorical_fields=["category"], + numerical_fields=["scalar_b", "scalar_b"], + target_field="label", + train_file=str(train_csv), + val_file=str(val_csv), + test_file=str(test_csv), num_workers=0, batch_size=1 ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: - (cat, num), target = next(iter(dl)) + data = next(iter(dl)) + (cat, num) = data[DefaultDataKeys.INPUT] + target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) assert target.shape == (1, ) def test_empty_inputs(): - train_df = TEST_DF_1.copy() + train_data_frame = TEST_DF_1.copy() with pytest.raises(RuntimeError): - TabularData.from_df( - train_df, numerical_cols=None, categorical_cols=None, target_col="label", num_workers=0, batch_size=1 + TabularData.from_data_frame( + numerical_cols=None, + categorical_cols=None, + target_col="label", + train_data_frame=train_data_frame, + num_workers=0, + batch_size=1, ) diff --git a/tests/tabular/test_data_model_integration.py b/tests/tabular/test_data_model_integration.py index 6c022eba0f..6dcec9b6a8 100644 --- a/tests/tabular/test_data_model_integration.py +++ b/tests/tabular/test_data_model_integration.py @@ -28,16 +28,16 @@ def test_classification(tmpdir): - train_df = TEST_DF_1.copy() - val_df = TEST_DF_1.copy() - test_df = TEST_DF_1.copy() - data = TabularData.from_df( - train_df, + train_data_frame = TEST_DF_1.copy() + val_data_frame = TEST_DF_1.copy() + test_data_frame = TEST_DF_1.copy() + data = TabularData.from_data_frame( categorical_cols=["category"], numerical_cols=["scalar_a", "scalar_b"], target_col="label", - val_df=val_df, - test_df=test_df, + train_data_frame=train_data_frame, + val_data_frame=val_data_frame, + test_data_frame=test_data_frame, num_workers=0, batch_size=2, ) diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index 3df3360030..866b9d9328 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -48,9 +48,7 @@ def json_data(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) - dm = TextClassificationData.from_files( - backbone=TEST_BACKBONE, train_file=csv_path, input="sentence", target="label", batch_size=1 - ) + dm = TextClassificationData.from_csv("sentence", "label", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1) batch = next(iter(dm.train_dataloader())) assert batch["labels"].item() in [0, 1] assert "input_ids" in batch @@ -59,13 +57,13 @@ def test_from_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_test_valid(tmpdir): csv_path = csv_data(tmpdir) - dm = TextClassificationData.from_files( + dm = TextClassificationData.from_csv( + "sentence", + "label", backbone=TEST_BACKBONE, train_file=csv_path, val_file=csv_path, test_file=csv_path, - input="sentence", - target="label", batch_size=1 ) batch = next(iter(dm.val_dataloader())) @@ -80,9 +78,7 @@ def test_test_valid(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_json(tmpdir): json_path = json_data(tmpdir) - dm = TextClassificationData.from_files( - backbone=TEST_BACKBONE, train_file=json_path, input="sentence", target="lab", filetype="json", batch_size=1 - ) + dm = TextClassificationData.from_json("sentence", "lab", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1) batch = next(iter(dm.train_dataloader())) assert batch["labels"].item() in [0, 1] assert "input_ids" in batch diff --git a/tests/text/summarization/test_data.py b/tests/text/summarization/test_data.py index 616a9d6f53..67b88bc937 100644 --- a/tests/text/summarization/test_data.py +++ b/tests/text/summarization/test_data.py @@ -48,9 +48,7 @@ def json_data(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) - dm = SummarizationData.from_files( - backbone=TEST_BACKBONE, train_file=csv_path, input="input", target="target", batch_size=1 - ) + dm = SummarizationData.from_csv("input", "target", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1) batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch @@ -59,13 +57,13 @@ def test_from_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_files(tmpdir): csv_path = csv_data(tmpdir) - dm = SummarizationData.from_files( + dm = SummarizationData.from_csv( + "input", + "target", backbone=TEST_BACKBONE, train_file=csv_path, val_file=csv_path, test_file=csv_path, - input="input", - target="target", batch_size=1 ) batch = next(iter(dm.val_dataloader())) @@ -80,9 +78,7 @@ def test_from_files(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_json(tmpdir): json_path = json_data(tmpdir) - dm = SummarizationData.from_files( - backbone=TEST_BACKBONE, train_file=json_path, input="input", target="target", filetype="json", batch_size=1 - ) + dm = SummarizationData.from_json("input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1) batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch diff --git a/tests/text/test_data_model_integration.py b/tests/text/test_data_model_integration.py index 7aeadba7de..91c10bb049 100644 --- a/tests/text/test_data_model_integration.py +++ b/tests/text/test_data_model_integration.py @@ -39,11 +39,11 @@ def test_classification(tmpdir): csv_path = csv_data(tmpdir) - data = TextClassificationData.from_files( + data = TextClassificationData.from_csv( + "sentence", + "label", backbone=TEST_BACKBONE, train_file=csv_path, - input="sentence", - target="label", num_workers=0, batch_size=2, ) diff --git a/tests/text/translation/test_data.py b/tests/text/translation/test_data.py index d9e17105ce..859bd1fe7a 100644 --- a/tests/text/translation/test_data.py +++ b/tests/text/translation/test_data.py @@ -48,9 +48,7 @@ def json_data(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) - dm = TranslationData.from_files( - backbone=TEST_BACKBONE, train_file=csv_path, input="input", target="target", batch_size=1 - ) + dm = TranslationData.from_csv("input", "target", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1) batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch @@ -59,13 +57,13 @@ def test_from_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_files(tmpdir): csv_path = csv_data(tmpdir) - dm = TranslationData.from_files( + dm = TranslationData.from_csv( + "input", + "target", backbone=TEST_BACKBONE, train_file=csv_path, val_file=csv_path, test_file=csv_path, - input="input", - target="target", batch_size=1 ) batch = next(iter(dm.val_dataloader())) @@ -80,9 +78,7 @@ def test_from_files(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_json(tmpdir): json_path = json_data(tmpdir) - dm = TranslationData.from_files( - backbone=TEST_BACKBONE, train_file=json_path, input="input", target="target", filetype="json", batch_size=1 - ) + dm = TranslationData.from_json("input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1) batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch diff --git a/tests/video/test_video_classifier.py b/tests/video/test_video_classifier.py index a5c3db023f..e4ed5cae88 100644 --- a/tests/video/test_video_classifier.py +++ b/tests/video/test_video_classifier.py @@ -105,15 +105,15 @@ def test_image_classifier_finetune(tmpdir): half_duration = total_duration / 2 - 1e-9 - datamodule = VideoClassificationData.from_paths( - train_data_path=mock_csv, + datamodule = VideoClassificationData.from_folders( + train_folder=mock_csv, clip_sampler="uniform", clip_duration=half_duration, video_sampler=SequentialSampler, decode_audio=False, ) - for sample in datamodule.train_dataset.dataset: + for sample in datamodule.train_dataset.data: expected_t_shape = 5 assert sample["video"].shape[1] == expected_t_shape @@ -144,8 +144,8 @@ def test_image_classifier_finetune(tmpdir): ]), } - datamodule = VideoClassificationData.from_paths( - train_data_path=mock_csv, + datamodule = VideoClassificationData.from_folders( + train_folder=mock_csv, clip_sampler="uniform", clip_duration=half_duration, video_sampler=SequentialSampler, diff --git a/tests/vision/classification/test_data.py b/tests/vision/classification/test_data.py index ad21f53aca..19f49b672a 100644 --- a/tests/vision/classification/test_data.py +++ b/tests/vision/classification/test_data.py @@ -18,9 +18,13 @@ import kornia as K import numpy as np import torch +import torch.nn as nn +import torchvision from PIL import Image +from flash.data.data_source import DefaultDataKeys from flash.data.data_utils import labels_from_categorical_csv +from flash.data.transforms import ApplyToKeys from flash.vision import ImageClassificationData @@ -43,9 +47,14 @@ def test_from_filepaths_smoke(tmpdir): _rand_image().save(tmpdir / "a_1.png") _rand_image().save(tmpdir / "b_1.png") - img_data = ImageClassificationData.from_filepaths( - train_filepaths=[tmpdir / "a_1.png", tmpdir / "b_1.png"], - train_labels=[1, 2], + train_images = [ + str(tmpdir / "a_1.png"), + str(tmpdir / "b_1.png"), + ] + + img_data = ImageClassificationData.from_files( + train_files=train_images, + train_targets=[1, 2], batch_size=2, num_workers=0, ) @@ -54,7 +63,7 @@ def test_from_filepaths_smoke(tmpdir): assert img_data.test_dataloader() is None data = next(iter(img_data.train_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert sorted(list(labels.numpy())) == [1, 2] @@ -72,20 +81,20 @@ def test_from_filepaths_list_image_paths(tmpdir): str(tmpdir / "e_1.png"), ] - img_data = ImageClassificationData.from_filepaths( - train_filepaths=train_images, - train_labels=[0, 3, 6], - val_filepaths=train_images, - val_labels=[1, 4, 7], - test_filepaths=train_images, - test_labels=[2, 5, 8], + img_data = ImageClassificationData.from_files( + train_files=train_images, + train_targets=[0, 3, 6], + val_files=train_images, + val_targets=[1, 4, 7], + test_files=train_images, + test_targets=[2, 5, 8], batch_size=2, num_workers=0, ) # check training data data = next(iter(img_data.train_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here @@ -93,14 +102,14 @@ def test_from_filepaths_list_image_paths(tmpdir): # check validation data data = next(iter(img_data.val_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert list(labels.numpy()) == [1, 4] # check test data data = next(iter(img_data.test_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert list(labels.numpy()) == [2, 5] @@ -109,27 +118,33 @@ def test_from_filepaths_list_image_paths(tmpdir): def test_from_filepaths_visualise(tmpdir): tmpdir = Path(tmpdir) - (tmpdir / "a").mkdir() - (tmpdir / "b").mkdir() - _rand_image().save(tmpdir / "a" / "a_1.png") - _rand_image().save(tmpdir / "b" / "b_1.png") - - dm = ImageClassificationData.from_filepaths( - train_filepaths=[tmpdir / "a", tmpdir / "b"], - train_labels=[0, 1], - val_filepaths=[tmpdir / "b", tmpdir / "a"], - val_labels=[0, 2], - test_filepaths=[tmpdir / "b", tmpdir / "b"], - test_labels=[2, 1], + (tmpdir / "e").mkdir() + _rand_image().save(tmpdir / "e_1.png") + + train_images = [ + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + ] + + dm = ImageClassificationData.from_files( + train_files=train_images, + train_targets=[0, 3, 6], + val_files=train_images, + val_targets=[1, 4, 7], + test_files=train_images, + test_targets=[2, 5, 8], batch_size=2, + num_workers=0, ) + # disable visualisation for testing assert dm.data_fetcher.block_viz_window is True dm.set_block_viz_window(False) assert dm.data_fetcher.block_viz_window is False # call show functions - dm.show_train_batch() + # dm.show_train_batch() dm.show_train_batch("pre_tensor_transform") dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) @@ -139,17 +154,22 @@ def test_from_filepaths_visualise_multilabel(tmpdir): (tmpdir / "a").mkdir() (tmpdir / "b").mkdir() - _rand_image().save(tmpdir / "a" / "a_1.png") - _rand_image().save(tmpdir / "b" / "b_1.png") - - dm = ImageClassificationData.from_filepaths( - train_filepaths=[tmpdir / "a", tmpdir / "b"], - train_labels=[[0, 1, 0], [0, 1, 1]], - val_filepaths=[tmpdir / "b", tmpdir / "a"], - val_labels=[[1, 1, 0], [0, 0, 1]], - test_filepaths=[tmpdir / "b", tmpdir / "b"], - test_labels=[[0, 0, 1], [1, 1, 0]], + + image_a = str(tmpdir / "a" / "a_1.png") + image_b = str(tmpdir / "b" / "b_1.png") + + _rand_image().save(image_a) + _rand_image().save(image_b) + + dm = ImageClassificationData.from_files( + train_files=[image_a, image_b], + train_targets=[[0, 1, 0], [0, 1, 1]], + val_files=[image_b, image_a], + val_targets=[[1, 1, 0], [0, 0, 1]], + test_files=[image_b, image_b], + test_targets=[[0, 0, 1], [1, 1, 0]], batch_size=2, + image_size=(64, 64), ) # disable visualisation for testing assert dm.data_fetcher.block_viz_window is True @@ -181,18 +201,17 @@ def test_from_filepaths_splits(tmpdir): assert len(train_filepaths) == len(train_labels) - def preprocess(x): - out = K.image_to_tensor(np.array(x)) - return out - _to_tensor = { - "to_tensor_transform": lambda x: preprocess(x), + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor) + ), } def run(transform: Any = None): - img_data = ImageClassificationData.from_filepaths( - train_filepaths=train_filepaths, - train_labels=train_labels, + dm = ImageClassificationData.from_files( + train_files=train_filepaths, + train_targets=train_labels, train_transform=transform, val_transform=transform, batch_size=B, @@ -200,86 +219,14 @@ def run(transform: Any = None): val_split=val_split, image_size=img_size, ) - data = next(iter(img_data.train_dataloader())) - imgs, labels = data + data = next(iter(dm.train_dataloader())) + imgs, labels = data['input'], data['target'] assert imgs.shape == (B, 3, H, W) assert labels.shape == (B, ) - run() run(_to_tensor) -def test_categorical_csv_labels(tmpdir): - train_dir = Path(tmpdir / "some_dataset") - train_dir.mkdir() - - (train_dir / "train").mkdir() - _rand_image().save(train_dir / "train" / "train_1.png") - _rand_image().save(train_dir / "train" / "train_2.png") - - (train_dir / "valid").mkdir() - _rand_image().save(train_dir / "valid" / "val_1.png") - _rand_image().save(train_dir / "valid" / "val_2.png") - - (train_dir / "test").mkdir() - _rand_image().save(train_dir / "test" / "test_1.png") - _rand_image().save(train_dir / "test" / "test_2.png") - - train_csv = os.path.join(tmpdir, 'some_dataset', 'train.csv') - text_file = open(train_csv, 'w') - text_file.write( - 'my_id,label_a,label_b,label_c\n"train_1.png", 0, 1, 0\n"train_2.png", 0, 0, 1\n"train_2.png", 1, 0, 0\n' - ) - text_file.close() - - val_csv = os.path.join(tmpdir, 'some_dataset', 'valid.csv') - text_file = open(val_csv, 'w') - text_file.write('my_id,label_a,label_b,label_c\n"val_1.png", 0, 1, 0\n"val_2.png", 0, 0, 1\n"val_3.png", 1, 0, 0\n') - text_file.close() - - test_csv = os.path.join(tmpdir, 'some_dataset', 'test.csv') - text_file = open(test_csv, 'w') - text_file.write( - 'my_id,label_a,label_b,label_c\n"test_1.png", 0, 1, 0\n"test_2.png", 0, 0, 1\n"test_3.png", 1, 0, 0\n' - ) - text_file.close() - - def index_col_collate_fn(x): - return os.path.splitext(x)[0] - - train_labels = labels_from_categorical_csv( - train_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn - ) - val_labels = labels_from_categorical_csv( - val_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn - ) - test_labels = labels_from_categorical_csv( - test_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn - ) - B: int = 2 # batch_size - data = ImageClassificationData.from_filepaths( - batch_size=B, - train_filepaths=os.path.join(tmpdir, 'some_dataset', 'train'), - train_labels=train_labels.values(), - val_filepaths=os.path.join(tmpdir, 'some_dataset', 'valid'), - val_labels=val_labels.values(), - test_filepaths=os.path.join(tmpdir, 'some_dataset', 'test'), - test_labels=test_labels.values(), - ) - - for (x, y) in data.train_dataloader(): - assert len(x) == 2 - assert sorted(list(y.numpy())) == sorted(list(train_labels.values())[:B]) - - for (x, y) in data.val_dataloader(): - assert len(x) == 2 - assert sorted(list(y.numpy())) == sorted(list(val_labels.values())[:B]) - - for (x, y) in data.test_dataloader(): - assert len(x) == 2 - assert sorted(list(y.numpy())) == sorted(list(test_labels.values())[:B]) - - def test_from_folders_only_train(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() @@ -295,7 +242,7 @@ def test_from_folders_only_train(tmpdir): img_data = ImageClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) data = next(iter(img_data.train_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1, ) @@ -324,18 +271,18 @@ def test_from_folders_train_val(tmpdir): ) data = next(iter(img_data.train_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) data = next(iter(img_data.val_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert list(labels.numpy()) == [0, 0] data = next(iter(img_data.test_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert list(labels.numpy()) == [0, 0] @@ -353,30 +300,30 @@ def test_from_filepaths_multilabel(tmpdir): valid_labels = [[1, 1, 1, 0], [1, 0, 0, 1]] test_labels = [[1, 0, 1, 0], [1, 1, 0, 1]] - dm = ImageClassificationData.from_filepaths( - train_filepaths=train_images, - train_labels=train_labels, - val_filepaths=train_images, - val_labels=valid_labels, - test_filepaths=train_images, - test_labels=test_labels, + dm = ImageClassificationData.from_files( + train_files=train_images, + train_targets=train_labels, + val_files=train_images, + val_targets=valid_labels, + test_files=train_images, + test_targets=test_labels, batch_size=2, num_workers=0, ) data = next(iter(dm.train_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) data = next(iter(dm.val_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) torch.testing.assert_allclose(labels, torch.tensor(valid_labels)) data = next(iter(dm.test_dataloader())) - imgs, labels = data + imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) torch.testing.assert_allclose(labels, torch.tensor(test_labels)) diff --git a/tests/vision/classification/test_data_model_integration.py b/tests/vision/classification/test_data_model_integration.py index 4bd70455ec..2425e3f760 100644 --- a/tests/vision/classification/test_data_model_integration.py +++ b/tests/vision/classification/test_data_model_integration.py @@ -34,15 +34,19 @@ def test_classification(tmpdir): (tmpdir / "a").mkdir() (tmpdir / "b").mkdir() - _rand_image().save(tmpdir / "a" / "a_1.png") - _rand_image().save(tmpdir / "b" / "a_1.png") - data = ImageClassificationData.from_filepaths( - train_filepaths=[tmpdir / "a", tmpdir / "b"], - train_labels=[0, 1], - train_transform={"per_batch_transform": lambda x: x}, + image_a = str(tmpdir / "a" / "a_1.png") + image_b = str(tmpdir / "b" / "b_1.png") + + _rand_image().save(image_a) + _rand_image().save(image_b) + + data = ImageClassificationData.from_files( + train_files=[image_a, image_b], + train_targets=[0, 1], num_workers=0, batch_size=2, + image_size=(64, 64), ) model = ImageClassifier(num_classes=2, backbone="resnet18") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) diff --git a/tests/vision/classification/test_model.py b/tests/vision/classification/test_model.py index 0aa3ab1835..94e26d889d 100644 --- a/tests/vision/classification/test_model.py +++ b/tests/vision/classification/test_model.py @@ -17,6 +17,7 @@ from flash import Trainer from flash.core.classification import Probabilities +from flash.data.data_source import DefaultDataKeys from flash.vision import ImageClassifier # ======== Mock functions ======== @@ -25,7 +26,10 @@ class DummyDataset(torch.utils.data.Dataset): def __getitem__(self, index): - return torch.rand(3, 224, 224), torch.randint(10, size=(1, )).item() + return { + DefaultDataKeys.INPUT: torch.rand(3, 224, 224), + DefaultDataKeys.TARGET: torch.randint(10, size=(1, )).item(), + } def __len__(self) -> int: return 100 @@ -37,7 +41,10 @@ def __init__(self, num_classes: int): self.num_classes = num_classes def __getitem__(self, index): - return torch.rand(3, 224, 224), torch.randint(0, 2, (self.num_classes, )) + return { + DefaultDataKeys.INPUT: torch.rand(3, 224, 224), + DefaultDataKeys.TARGET: torch.randint(0, 2, (self.num_classes, )), + } def __len__(self) -> int: return 100 @@ -90,8 +97,8 @@ def test_multilabel(tmpdir): train_dl = torch.utils.data.DataLoader(ds, batch_size=2) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, train_dl, strategy="freeze_unfreeze") - image, label = ds[0] - predictions = model.predict(image.unsqueeze(0)) + image, label = ds[0][DefaultDataKeys.INPUT], ds[0][DefaultDataKeys.TARGET] + predictions = model.predict([{DefaultDataKeys.INPUT: image}]) assert (torch.tensor(predictions) > 1).sum() == 0 assert (torch.tensor(predictions) < 0).sum() == 0 assert len(predictions[0]) == num_classes == len(label) diff --git a/tests/vision/detection/test_data.py b/tests/vision/detection/test_data.py index fec4b9a5e8..39f8a191eb 100644 --- a/tests/vision/detection/test_data.py +++ b/tests/vision/detection/test_data.py @@ -6,6 +6,7 @@ from PIL import Image from pytorch_lightning.utilities import _module_available +from flash.data.data_source import DefaultDataKeys from flash.utils.imports import _COCO_AVAILABLE from flash.vision.detection.data import ObjectDetectionData @@ -83,7 +84,7 @@ def test_image_detector_data_from_coco(tmpdir): datamodule = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1) data = next(iter(datamodule.train_dataloader())) - imgs, labels = data + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert len(imgs) == 1 assert imgs[0].shape == (3, 1080, 1920) @@ -101,11 +102,11 @@ def test_image_detector_data_from_coco(tmpdir): test_folder=train_folder, test_ann_file=coco_ann_path, batch_size=1, - num_workers=0 + num_workers=0, ) data = next(iter(datamodule.val_dataloader())) - imgs, labels = data + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert len(imgs) == 1 assert imgs[0].shape == (3, 1080, 1920) @@ -113,7 +114,7 @@ def test_image_detector_data_from_coco(tmpdir): assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] data = next(iter(datamodule.test_dataloader())) - imgs, labels = data + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert len(imgs) == 1 assert imgs[0].shape == (3, 1080, 1920) diff --git a/tests/vision/detection/test_data_model_integration.py b/tests/vision/detection/test_data_model_integration.py index 8f90279959..8c71115671 100644 --- a/tests/vision/detection/test_data_model_integration.py +++ b/tests/vision/detection/test_data_model_integration.py @@ -43,5 +43,5 @@ def test_detection(tmpdir, model, backbone): Image.new('RGB', (512, 512)).save(test_image_one) Image.new('RGB', (512, 512)).save(test_image_two) - test_images = [test_image_one, test_image_two] + test_images = [str(test_image_one), str(test_image_two)] model.predict(test_images) diff --git a/tests/vision/detection/test_model.py b/tests/vision/detection/test_model.py index 110b55d43c..9d3f0a5dc6 100644 --- a/tests/vision/detection/test_model.py +++ b/tests/vision/detection/test_model.py @@ -16,11 +16,12 @@ from pytorch_lightning import Trainer from torch.utils.data import DataLoader, Dataset +from flash.data.data_source import DefaultDataKeys from flash.vision import ObjectDetector -def collate_fn(batch): - return tuple(zip(*batch)) +def collate_fn(samples): + return {key: [sample[key] for sample in samples] for key in samples[0]} class DummyDetectionDataset(Dataset): @@ -45,7 +46,7 @@ def __getitem__(self, idx): img = torch.rand(self.img_shape) boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)]) labels = torch.randint(self.num_classes, (self.num_boxes, )) - return img, {"boxes": boxes, "labels": labels} + return {DefaultDataKeys.INPUT: img, DefaultDataKeys.TARGET: {"boxes": boxes, "labels": labels}} def test_init(): @@ -55,7 +56,8 @@ def test_init(): batch_size = 2 ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10) dl = DataLoader(ds, collate_fn=collate_fn, batch_size=batch_size) - img, target = next(iter(dl)) + data = next(iter(dl)) + img = data[DefaultDataKeys.INPUT] out = model(img)