From 168b2319a478f981a4ae3ae9418c7db018d8c33a Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 30 Mar 2021 18:30:00 +0100 Subject: [PATCH 01/14] wip --- flash/data/base_viz.py | 54 ++++++++++++++++++++++++++ flash/data/data_pipeline.py | 4 +- tests/data/test_base_viz.py | 77 +++++++++++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 flash/data/base_viz.py create mode 100644 tests/data/test_base_viz.py diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py new file mode 100644 index 0000000000..3e271afd6d --- /dev/null +++ b/flash/data/base_viz.py @@ -0,0 +1,54 @@ +import functools +from typing import Any, Callable + +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.trainer.states import RunningStage + +from flash.data.data_module import DataModule +from flash.data.data_pipeline import DataPipeline +from flash.data.process import Postprocess, Preprocess + + +class BaseViz(Callback): + + def __init__(self, datamodule: DataModule): + self._datamodule = datamodule + self._wrap_preprocess() + + self.batches = {"train": {}, "val": {}, "test": {}, "predict": {}} + + def _wrap_fn( + self, + fn: Callable, + running_stage: RunningStage, + ) -> Callable: + """ + """ + + @functools.wraps(fn) + def wrapper(data) -> Any: + print(data) + data = fn(data) + print(data) + batches = self.batches[running_stage.value] + if fn.__name__ not in batches: + batches[fn.__name__] = [] + batches[fn.__name__].append(data) + return data + + return wrapper + + def _wrap_functions_per_stage(self, running_stage: RunningStage): + preprocess = self._datamodule.data_pipeline._preprocess_pipeline + fn_names = { + k: DataPipeline._resolve_function_hierarchy(k, preprocess, running_stage, Preprocess) + for k in DataPipeline.PREPROCESS_FUNCS + } + for fn_name in fn_names: + fn = getattr(preprocess, fn_name) + setattr(preprocess, fn_name, self._wrap_fn(fn, running_stage)) + + self._datamodule._train_ds.load_sample = preprocess.load_sample + + def _wrap_preprocess(self): + self._wrap_functions_per_stage(RunningStage.TRAINING) diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index b50e468c50..226e6bf1ab 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -108,11 +108,11 @@ def forward(self, samples: Sequence[Any]): post_tensor_transform │ ┌────────────────┴───────────────────┐ -(move Data to main worker) --> │ │ +(move list to main worker) --> │ │ per_sample_transform_on_device collate │ │ collate per_batch_transform - │ │ <-- (move Data to main worker) + │ │ <-- (move batch to main worker) per_batch_transform_on_device per_batch_transform_on_device │ │ └─────────────────┬──────────────────┘ diff --git a/tests/data/test_base_viz.py b/tests/data/test_base_viz.py new file mode 100644 index 0000000000..20fc836f51 --- /dev/null +++ b/tests/data/test_base_viz.py @@ -0,0 +1,77 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple +from unittest import mock + +import numpy as np +import pytest +import torch +import torchvision.transforms as T +from PIL import Image +from pytorch_lightning import Trainer +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import Tensor, tensor +from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import default_collate + +from flash.core import Task +from flash.data.auto_dataset import AutoDataset +from flash.data.base_viz import BaseViz +from flash.data.batch import _PostProcessor, _PreProcessor +from flash.data.data_module import DataModule +from flash.data.data_pipeline import _StageOrchestrator, DataPipeline +from flash.data.process import Postprocess, Preprocess +from flash.vision import ImageClassificationData + + +def _rand_image(): + return Image.fromarray(np.random.randint(0, 255, (196, 196, 3), dtype="uint8")) + + +class ImageClassificationDataViz(ImageClassificationData): + + def configure_vis(self): + if not hasattr(self, "viz"): + return BaseViz(self) + return self.viz + + def show_train_batch(self): + self.viz = self.configure_vis() + _ = next(iter(self.train_dataloader())) + + +def test_base_viz(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a" / "a_1.png") + _rand_image().save(tmpdir / "a" / "a_2.png") + + _rand_image().save(tmpdir / "b" / "a_1.png") + _rand_image().save(tmpdir / "b" / "a_2.png") + + img_data = ImageClassificationDataViz.from_filepaths( + train_filepaths=[tmpdir / "a", tmpdir / "b"], + train_transform=None, + train_labels=[0, 1], + batch_size=1, + num_workers=0, + ) + + img_data.show_train_batch() + assert img_data.viz.batches["train"]["load_sample"] is not None From cda64d3927412b870e500e6dcfb1c101b4b34687 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 10:23:14 +0100 Subject: [PATCH 02/14] add base_viz + new features for DataPipeline --- flash/data/auto_dataset.py | 36 +++---- flash/data/base_viz.py | 1 + flash/data/batch.py | 65 ++++++++---- flash/data/data_module.py | 4 + flash/data/data_pipeline.py | 13 ++- flash/data/process.py | 40 +++++++- flash/data/utils.py | 36 +++++++ flash/vision/classification/data.py | 154 ++++++++-------------------- tests/data/test_base_viz.py | 3 + tests/data/test_data_pipeline.py | 50 ++++++--- 10 files changed, 230 insertions(+), 172 deletions(-) diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index be6e32038e..e42a4cf680 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -20,7 +20,7 @@ from torch.utils.data import Dataset from flash.data.process import Preprocess -from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES +from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, set_current_stage_and_fn if TYPE_CHECKING: from flash.data.data_pipeline import DataPipeline @@ -73,12 +73,18 @@ def running_stage(self, running_stage: str) -> None: self._running_stage = running_stage self._setup(running_stage) + @property + def _preprocess(self): + if self.data_pipeline is not None: + return self.data_pipeline._preprocess_pipeline + def _call_load_data(self, data: Any) -> Iterable: parameters = signature(self.load_data).parameters - if len(parameters) > 1 and self.DATASET_KEY in parameters: - return self.load_data(data, self) - else: - return self.load_data(data) + with set_current_stage_and_fn(self._preprocess, self._running_stage, "load_data"): + if len(parameters) > 1 and self.DATASET_KEY in parameters: + return self.load_data(data, self) + else: + return self.load_data(data) def _call_load_sample(self, sample: Any) -> Any: parameters = signature(self.load_sample).parameters @@ -110,26 +116,16 @@ def _setup(self, stage: Optional[RunningStage]) -> None: "The load_data function of the Autogenerated Dataset changed. " "This is not expected! Preloading Data again to ensure compatibility. This may take some time." ) - with self._set_running_stage(stage): - self._preprocessed_data = self._call_load_data(self.data) + self._preprocessed_data = self._call_load_data(self.data) self._load_data_called = True - @contextmanager - def _set_running_stage(self, stage: RunningStage) -> None: - if self.load_data: - if self.data_pipeline and self.data_pipeline._preprocess_pipeline: - self.data_pipeline._preprocess_pipeline._running_stage = stage - yield - if self.load_data: - if self.data_pipeline and self.data_pipeline._preprocess_pipeline: - self.data_pipeline._preprocess_pipeline._running_stage = None - def __getitem__(self, index: int) -> Any: if not self.load_sample and not self.load_data: raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") - if self.load_sample: - return self._call_load_sample(self._preprocessed_data[index]) - return self._preprocessed_data[index] + with set_current_stage_and_fn(self._preprocess, self._running_stage, "load_sample"): + if self.load_sample: + return self._call_load_sample(self._preprocessed_data[index]) + return self._preprocessed_data[index] def __len__(self) -> int: if not self.load_sample and not self.load_data: diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py index 3e271afd6d..fb50168d9f 100644 --- a/flash/data/base_viz.py +++ b/flash/data/base_viz.py @@ -48,6 +48,7 @@ def _wrap_functions_per_stage(self, running_stage: RunningStage): fn = getattr(preprocess, fn_name) setattr(preprocess, fn_name, self._wrap_fn(fn, running_stage)) + # hack until solved self._datamodule._train_ds.load_sample = preprocess.load_sample def _wrap_preprocess(self): diff --git a/flash/data/batch.py b/flash/data/batch.py index d6262b1e49..1047e85a44 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Mapping, Optional, Sequence, TYPE_CHECKING, Union import torch from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor -from flash.data.utils import _contains_any_tensor, convert_to_modules +from flash.data.utils import _contains_any_tensor, convert_to_modules, set_current_fn, set_current_stage + +if TYPE_CHECKING: + from flash.data.process import Preprocess class _Sequential(torch.nn.Module): @@ -31,29 +34,40 @@ class _Sequential(torch.nn.Module): def __init__( self, + preprocess: 'Preprocess', pre_tensor_transform: Callable, to_tensor_transform: Callable, post_tensor_transform: Callable, - assert_contains_tensor: bool = False + stage: RunningStage, + assert_contains_tensor: bool = False, ): super().__init__() - + self.preprocess = preprocess self.pre_tensor_transform = convert_to_modules(pre_tensor_transform) self.to_tensor_transform = convert_to_modules(to_tensor_transform) self.post_tensor_transform = convert_to_modules(post_tensor_transform) + self.stage = stage self.assert_contains_tensor = assert_contains_tensor def forward(self, sample: Any): - sample = self.pre_tensor_transform(sample) - sample = self.to_tensor_transform(sample) - if self.assert_contains_tensor: - if not _contains_any_tensor(sample): - raise MisconfigurationException( - "When ``to_tensor_transform`` is overriden, " - "``DataPipeline`` expects the outputs to be ``tensors``" - ) - sample = self.post_tensor_transform(sample) - return sample + with set_current_stage(self.preprocess, self.stage): + with set_current_fn(self.preprocess, "pre_tensor_transform"): + sample = self.pre_tensor_transform(sample) + + with set_current_fn(self.preprocess, "to_tensor_transform"): + sample = self.to_tensor_transform(sample) + + if self.assert_contains_tensor: + if not _contains_any_tensor(sample): + raise MisconfigurationException( + "When ``to_tensor_transform`` is overriden, " + "``DataPipeline`` expects the outputs to be ``tensors``" + ) + + with set_current_fn(self.preprocess, "post_tensor_transform"): + sample = self.post_tensor_transform(sample) + + return sample def __str__(self) -> str: repr_str = f'{self.__class__.__name__}:' @@ -87,26 +101,37 @@ class _PreProcessor(torch.nn.Module): def __init__( self, + preprocess: 'Preprocess', collate_fn: Callable, per_sample_transform: Union[Callable, _Sequential], per_batch_transform: Callable, stage: Optional[RunningStage] = None, apply_per_sample_transform: bool = True, + on_device: bool = False ): super().__init__() + self.preprocess = preprocess self.collate_fn = convert_to_modules(collate_fn) self.per_sample_transform = convert_to_modules(per_sample_transform) self.per_batch_transform = convert_to_modules(per_batch_transform) self.apply_per_sample_transform = apply_per_sample_transform self.stage = stage + self.on_device = on_device def forward(self, samples: Sequence[Any]): - if self.apply_per_sample_transform: - samples = [self.per_sample_transform(sample) for sample in samples] - samples = type(samples)(samples) - samples = self.collate_fn(samples) - samples = self.per_batch_transform(samples) - return samples + with set_current_stage(self.preprocess, self.stage): + + if self.apply_per_sample_transform: + with set_current_fn(self.preprocess, f"per_sample_transform_{'on_device' if self.on_device else ''}"): + samples = [self.per_sample_transform(sample) for sample in samples] + samples = type(samples)(samples) + + with set_current_fn(self.preprocess, "collate"): + samples = self.collate_fn(samples) + + with set_current_fn(self.preprocess, f"per_batch_transform_{'on_device' if self.on_device else ''}"): + samples = self.per_batch_transform(samples) + return samples def __str__(self) -> str: # todo: define repr function which would take object and string attributes to be shown diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 641eff21d7..f998c62ad1 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -87,6 +87,9 @@ def __init__( # this may also trigger data preloading self.set_running_stages() + def configure_vis(self): + return self + @staticmethod def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any: if isinstance(dataset, Subset): @@ -340,4 +343,5 @@ def from_load_data_inputs( ) datamodule._preprocess = data_pipeline._preprocess_pipeline datamodule._postprocess = data_pipeline._postprocess_pipeline + datamodule.configure_vis() return datamodule diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 226e6bf1ab..2aca209f3d 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import inspect import weakref from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union @@ -181,7 +182,7 @@ def _is_overriden_recursive( if not hasattr(process_obj, current_method_name): return DataPipeline._is_overriden_recursive(method_name, process_obj, super_obj) - current_code = getattr(process_obj, current_method_name).__code__ + current_code = inspect.unwrap(getattr(process_obj, current_method_name)).__code__ has_different_code = current_code != getattr(super_obj, method_name).__code__ if not prefix: @@ -257,7 +258,7 @@ def _create_collate_preprocessors( if per_batch_transform_overriden and per_sample_transform_on_device_overriden: raise MisconfigurationException( - f'{self.__class__.__name__}: `per_batch_transform` and `gpu_per_sample_transform` ' + f'{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` ' f'are mutual exclusive for stage {stage}' ) @@ -282,21 +283,25 @@ def _create_collate_preprocessors( ) worker_preprocessor = _PreProcessor( - worker_collate_fn, + self._preprocess_pipeline, worker_collate_fn, _Sequential( + self._preprocess_pipeline, getattr(self._preprocess_pipeline, func_names['pre_tensor_transform']), getattr(self._preprocess_pipeline, func_names['to_tensor_transform']), getattr(self._preprocess_pipeline, func_names['post_tensor_transform']), + stage, assert_contains_tensor=assert_contains_tensor, ), getattr(self._preprocess_pipeline, func_names['per_batch_transform']), stage ) worker_preprocessor._original_collate_fn = original_collate_fn device_preprocessor = _PreProcessor( + self._preprocess_pipeline, device_collate_fn, getattr(self._preprocess_pipeline, func_names['per_sample_transform_on_device']), getattr(self._preprocess_pipeline, func_names['per_batch_transform_on_device']), stage, - apply_per_sample_transform=device_collate_fn != self._identity + apply_per_sample_transform=device_collate_fn != self._identity, + on_device=True, ) return worker_preprocessor, device_preprocessor diff --git a/flash/data/process.py b/flash/data/process.py index f61220dc11..35ecda3993 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -27,7 +27,24 @@ class Properties: - _running_stage = None + _running_stage: RunningStage = None + _current_fn: str = None + + @property + def current_fn(self) -> str: + return self._current_fn + + @current_fn.setter + def current_fn(self, current_fn: str): + self._current_fn = current_fn + + @property + def running_stage(self) -> RunningStage: + return self._running_stage + + @running_stage.setter + def running_stage(self, running_stage: RunningStage): + self._running_stage = running_stage @property def training(self) -> bool: @@ -97,6 +114,27 @@ def __init__( self.test_transform = convert_to_modules(test_transform) self.predict_transform = convert_to_modules(predict_transform) + def _identify(self, x): + return x + + def _get_transform(self, transform: Dict[str, Callable]): + if self.current_fn in transform: + return transform[self.current_fn] + return self._identify + + @property + def current_transform(self): + if self.training and self.train_transform: + return self._get_transform(self.train_transform) + elif self.validating and self.val_transform: + return self._get_transform(self.val_transform) + elif self.testing and self.test_transform: + return self._get_transform(self.test_transform) + elif self.predicting and self.predict_transform: + return self._get_transform(self.predict_transform) + else: + return self._identify + @classmethod def from_state(cls, state: PreprocessState) -> 'Preprocess': return cls(**vars(state)) diff --git a/flash/data/utils.py b/flash/data/utils.py index 4be6d177ba..cad73d3258 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -14,6 +14,7 @@ import os.path import zipfile +from contextlib import contextmanager from typing import Any, Callable, Dict, Iterable, Mapping, Type import requests @@ -32,6 +33,41 @@ _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} +@contextmanager +def set_current_stage(obj: Any, stage: RunningStage) -> None: + if obj is not None: + if getattr(obj, "_running_stage", None) == stage: + yield + else: + obj.running_stage = stage + yield + obj.running_stage = None + else: + yield + + +@contextmanager +def set_current_fn(obj: Any, current_fn: str) -> None: + if obj is not None: + obj.current_fn = current_fn + yield + obj.current_fn = None + else: + yield + + +@contextmanager +def set_current_stage_and_fn(obj: Any, stage: RunningStage, current_fn: str) -> None: + if obj is not None: + obj.running_stage = stage + obj.current_fn = current_fn + yield + obj.running_stage = None + obj.current_fn = None + else: + yield + + def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: """ Download file with progressbar diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 37baff9440..6f9cb8bb36 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -41,7 +41,6 @@ class ImageClassificationPreprocess(Preprocess): - to_tensor = torchvision_T.ToTensor() @staticmethod def _find_classes(dir: str) -> Tuple: @@ -112,7 +111,7 @@ def load_data(cls, data: Any, dataset: Optional[AutoDataset] = None) -> Iterable return cls._load_data_files_labels(data=data, dataset=dataset) @staticmethod - def load_sample(sample) -> Union[Image.Image, list]: + def load_sample(sample) -> Union[Image.Image]: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) if isinstance(sample, torch.Tensor): return sample @@ -138,25 +137,6 @@ def predict_load_data(cls, samples: Any) -> Iterable: return samples return cls._get_predicting_files(samples) - def _convert_tensor_to_pil(self, sample): - # some datasets provide their data as tensors. - # however, it would be better to convert those data once in load_data - if isinstance(sample, torch.Tensor): - sample = to_pil_image(sample) - return sample - - def _apply_transform( - self, sample: Any, transform: Union[Callable, Dict[str, Callable]], func_name: str - ) -> torch.Tensor: - if transform is not None: - if isinstance(transform, (Dict, ModuleDict)): - if func_name not in transform: - return sample - else: - transform = transform[func_name] - sample = transform(sample) - return sample - def collate(self, samples: Sequence) -> Any: _samples = [] # todo: Kornia transforms add batch dimension which need to be removed @@ -168,56 +148,28 @@ def collate(self, samples: Sequence) -> Any: _samples.append(sample) return default_collate(_samples) - def common_pre_tensor_transform(self, sample: Any, transform) -> Any: - return self._apply_transform(sample, transform, "pre_tensor_transform") - - def train_pre_tensor_transform(self, sample: Any) -> Any: - source, target = sample - return self.common_pre_tensor_transform(source, self.train_transform), target - - def val_pre_tensor_transform(self, sample: Any) -> Any: - source, target = sample - return self.common_pre_tensor_transform(source, self.val_transform), target - - def test_pre_tensor_transform(self, sample: Any) -> Any: - source, target = sample - return self.common_pre_tensor_transform(source, self.test_transform), target - - def predict_pre_tensor_transform(self, sample: Any) -> Any: + def common_step(self, sample: Any) -> Any: + if isinstance(sample, (list, tuple)): + source, target = sample + return self.current_transform(source), target if isinstance(sample, torch.Tensor): return sample - return self.common_pre_tensor_transform(sample, self.predict_transform) + return self.current_transform(sample) - def to_tensor_transform(self, sample) -> Any: - source, target = sample - return source if isinstance(source, torch.Tensor) else self.to_tensor(source), target + def per_tensor_transform(self, sample: Any) -> Any: + return self.common_step(sample) - def predict_to_tensor_transform(self, sample) -> Any: - if isinstance(sample, torch.Tensor): - return sample - return self.to_tensor(sample) - - def common_post_tensor_transform(self, sample: Any, transform) -> Any: - return self._apply_transform(sample, transform, "post_tensor_transform") + def to_tensor_transform(self, sample: Any) -> Any: + return self.common_step(sample) - def train_post_tensor_transform(self, sample: Any) -> Any: - source, target = sample - return self.common_post_tensor_transform(source, self.train_transform), target + def post_tensor_transform(self, sample: Any) -> Any: + return self.common_step(sample) - def val_post_tensor_transform(self, sample: Any) -> Any: - source, target = sample - return self.common_post_tensor_transform(source, self.val_transform), target + def per_batch_transform(self, sample: Any) -> Any: + return self.common_step(sample) - def test_post_tensor_transform(self, sample: Any) -> Any: - source, target = sample - return self.common_post_tensor_transform(source, self.test_transform), target - - def predict_post_tensor_transform(self, sample: Any) -> Any: - return self.common_post_tensor_transform(sample, self.predict_transform) - - def train_per_batch_transform_on_device(self, batch: Tuple) -> Tuple: - batch, target = batch - return self._apply_transform(batch, self.train_transform, "per_batch_transform_on_device"), target + def per_batch_transform_on_device(self, sample: Any) -> Any: + return self.common_step(sample) class ImageClassificationData(DataModule): @@ -285,6 +237,7 @@ def default_train_transforms(): if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": # Better approach as all transforms are applied on tensor directly return { + "to_tensor_transform": torchvision_T.ToTensor(), "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size), K.RandomHorizontalFlip()), "per_batch_transform_on_device": nn.Sequential( K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), @@ -294,6 +247,7 @@ def default_train_transforms(): from torchvision import transforms as T # noqa F811 return { "pre_tensor_transform": nn.Sequential(T.RandomResizedCrop(image_size), T.RandomHorizontalFlip()), + "to_tensor_transform": torchvision_T.ToTensor(), "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), } @@ -303,6 +257,7 @@ def default_val_transforms(): if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": # Better approach as all transforms are applied on tensor directly return { + "to_tensor_transform": torchvision_T.ToTensor(), "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size)), "per_batch_transform_on_device": nn.Sequential( K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), @@ -312,6 +267,7 @@ def default_val_transforms(): from torchvision import transforms as T # noqa F811 return { "pre_tensor_transform": T.Compose([T.RandomResizedCrop(image_size)]), + "to_tensor_transform": torchvision_T.ToTensor(), "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), } @@ -471,23 +427,24 @@ def from_filepaths( test_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, test_labels: Optional[Sequence] = None, predict_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, - train_transform: Optional[Callable] = 'default', - val_transform: Optional[Callable] = 'default', + train_transform: Union[str, Dict] = 'default', + val_transform: Union[str, Dict] = 'default', + test_transform: Union[str, Dict] = 'default', + predict_transform: Union[str, Dict] = 'default', batch_size: int = 64, num_workers: Optional[int] = None, seed: Optional[int] = 42, + preprocess_cls: Optional[Type[Preprocess]] = None, **kwargs, ) -> 'ImageClassificationData': """ Creates a ImageClassificationData object from folders of images arranged in this way: :: - folder/dog_xxx.png folder/dog_xxy.png folder/dog_xxz.png folder/cat_123.png folder/cat_nsdf3.png folder/cat_asd932_.png - Args: train_filepaths: String or sequence of file paths for training dataset. Defaults to ``None``. train_labels: Sequence of labels for training dataset. Defaults to ``None``. @@ -502,19 +459,14 @@ def from_filepaths( num_workers: The number of workers to use for parallelized loading. Defaults to ``None`` which equals the number of available CPU threads. seed: Used for the train/val splits. - Returns: ImageClassificationData: The constructed data module. - Examples: >>> img_data = ImageClassificationData.from_filepaths(["a.png", "b.png"], [0, 1]) # doctest: +SKIP - Example when labels are in .csv file:: - train_labels = labels_from_categorical_csv('path/to/train.csv', 'my_id') val_labels = labels_from_categorical_csv(path/to/val.csv', 'my_id') test_labels = labels_from_categorical_csv(path/to/tests.csv', 'my_id') - data = ImageClassificationData.from_filepaths( batch_size=2, train_filepaths='path/to/train', @@ -524,7 +476,6 @@ def from_filepaths( test_filepaths='path/to/test', test_labels=test_labels, ) - """ # enable passing in a string which loads all files in that folder as a list if isinstance(train_filepaths, str): @@ -532,59 +483,34 @@ def from_filepaths( train_filepaths = [os.path.join(train_filepaths, x) for x in os.listdir(train_filepaths)] else: train_filepaths = [train_filepaths] + if isinstance(val_filepaths, str): if os.path.isdir(val_filepaths): val_filepaths = [os.path.join(val_filepaths, x) for x in os.listdir(val_filepaths)] else: val_filepaths = [val_filepaths] + if isinstance(test_filepaths, str): if os.path.isdir(test_filepaths): test_filepaths = [os.path.join(test_filepaths, x) for x in os.listdir(test_filepaths)] else: test_filepaths = [test_filepaths] - if isinstance(predict_filepaths, str): - if os.path.isdir(predict_filepaths): - predict_filepaths = [os.path.join(predict_filepaths, x) for x in os.listdir(predict_filepaths)] - else: - predict_filepaths = [predict_filepaths] - if train_filepaths is not None and train_labels is not None: - train_dataset = cls._generate_dataset_if_possible( - list(zip(train_filepaths, train_labels)), running_stage=RunningStage.TRAINING - ) - else: - train_dataset = None - - if val_filepaths is not None and val_labels is not None: - val_dataset = cls._generate_dataset_if_possible( - list(zip(val_filepaths, val_labels)), running_stage=RunningStage.VALIDATING - ) - else: - val_dataset = None - - if test_filepaths is not None and test_labels is not None: - test_dataset = cls._generate_dataset_if_possible( - list(zip(test_filepaths, test_labels)), running_stage=RunningStage.TESTING - ) - else: - test_dataset = None - - if predict_filepaths is not None: - predict_dataset = cls._generate_dataset_if_possible( - predict_filepaths, running_stage=RunningStage.PREDICTING - ) - else: - predict_dataset = None + preprocess = cls.instantiate_preprocess( + train_transform, + val_transform, + test_transform, + predict_transform, + preprocess_cls=preprocess_cls, + ) - return cls( - train_dataset=train_dataset, - val_dataset=val_dataset, - test_dataset=test_dataset, - predict_dataset=predict_dataset, - train_transform=train_transform, - val_transform=val_transform, + return cls.from_load_data_inputs( + train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None, + val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None, + test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None, + predict_load_data_input=predict_filepaths, batch_size=batch_size, num_workers=num_workers, - seed=seed, + preprocess=preprocess, **kwargs ) diff --git a/tests/data/test_base_viz.py b/tests/data/test_base_viz.py index 20fc836f51..c153903a76 100644 --- a/tests/data/test_base_viz.py +++ b/tests/data/test_base_viz.py @@ -75,3 +75,6 @@ def test_base_viz(tmpdir): img_data.show_train_batch() assert img_data.viz.batches["train"]["load_sample"] is not None + assert img_data.viz.batches["train"]["to_tensor_transform"] is not None + assert img_data.viz.batches["train"]["collate"] is not None + assert img_data.viz.batches["train"]["per_batch_transform"] is not None diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index 7aac65b07a..c1d8ae6b62 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -498,41 +498,59 @@ def __init__(self): self.predict_load_data_called = False def train_load_data(self, sample) -> LamdaDummyDataset: + assert self.training + assert self.current_fn == "load_data" self.train_load_data_called = True return LamdaDummyDataset(lambda: (0, 1, 2, 3)) def train_pre_tensor_transform(self, sample: Any) -> Any: + assert self.training + assert self.current_fn == "pre_tensor_transform" self.train_pre_tensor_transform_called = True return sample + (5, ) def train_collate(self, samples) -> Tensor: + assert self.training + assert self.current_fn == "collate" self.train_collate_called = True return tensor([list(s) for s in samples]) def train_per_batch_transform_on_device(self, batch: Any) -> Any: + assert self.training + assert self.current_fn == "per_batch_transform_on_device" self.train_per_batch_transform_on_device_called = True assert torch.equal(batch, tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]])) def val_load_data(self, sample, dataset) -> List[int]: + assert self.validating + assert self.current_fn == "load_data" self.val_load_data_called = True assert isinstance(dataset, AutoDataset) return list(range(5)) def val_load_sample(self, sample) -> Dict[str, Tensor]: + assert self.validating + assert self.current_fn == "load_sample" self.val_load_sample_called = True return {"a": sample, "b": sample + 1} def val_to_tensor_transform(self, sample: Any) -> Tensor: + assert self.validating + assert self.current_fn == "to_tensor_transform" self.val_to_tensor_transform_called = True return sample def val_collate(self, samples) -> Dict[str, Tensor]: + assert self.validating + assert self.current_fn == "collate" self.val_collate_called = True _count = samples[0]['a'] assert samples == [{'a': _count, 'b': _count + 1}, {'a': _count + 1, 'b': _count + 2}] return {'a': tensor([0, 1]), 'b': tensor([1, 2])} def val_per_batch_transform_on_device(self, batch: Any) -> Any: + assert self.validating + assert self.current_fn == "per_batch_transform_on_device" self.val_per_batch_transform_on_device_called = True batch = batch[0] assert torch.equal(batch["a"], tensor([0, 1])) @@ -540,18 +558,26 @@ def val_per_batch_transform_on_device(self, batch: Any) -> Any: return [False] def test_load_data(self, sample) -> LamdaDummyDataset: + assert self.testing + assert self.current_fn == "load_data" self.test_load_data_called = True return LamdaDummyDataset(lambda: [torch.rand(1), torch.rand(1)]) def test_to_tensor_transform(self, sample: Any) -> Tensor: + assert self.testing + assert self.current_fn == "to_tensor_transform" self.test_to_tensor_transform_called = True return sample def test_post_tensor_transform(self, sample: Tensor) -> Tensor: + assert self.testing + assert self.current_fn == "post_tensor_transform" self.test_post_tensor_transform_called = True return sample def predict_load_data(self, sample) -> LamdaDummyDataset: + assert self.predicting + assert self.current_fn == "load_data" self.predict_load_data_called = True return LamdaDummyDataset(lambda: (["a", "b"])) @@ -563,7 +589,6 @@ def val_to_tensor_transform(self, sample: Any) -> Tensor: return {"a": tensor(sample["a"]), "b": tensor(sample["b"])} -@pytest.mark.skipif(reason="Still using DataPipeline Old API") def test_datapipeline_transformations(tmpdir): class CustomModel(Task): @@ -619,21 +644,20 @@ class CustomDataModule(DataModule): trainer.test(model) trainer.predict(model) - # todo (tchaton) resolve the lost reference. preprocess = model._preprocess - # assert preprocess.train_load_data_called - # assert preprocess.train_pre_tensor_transform_called - # assert preprocess.train_collate_called + assert preprocess.train_load_data_called + assert preprocess.train_pre_tensor_transform_called + assert preprocess.train_collate_called assert preprocess.train_per_batch_transform_on_device_called - # assert preprocess.val_load_data_called - # assert preprocess.val_load_sample_called - # assert preprocess.val_to_tensor_transform_called - # assert preprocess.val_collate_called + assert preprocess.val_load_data_called + assert preprocess.val_load_sample_called + assert preprocess.val_to_tensor_transform_called + assert preprocess.val_collate_called assert preprocess.val_per_batch_transform_on_device_called - # assert preprocess.test_load_data_called - # assert preprocess.test_to_tensor_transform_called - # assert preprocess.test_post_tensor_transform_called - # assert preprocess.predict_load_data_called + assert preprocess.test_load_data_called + assert preprocess.test_to_tensor_transform_called + assert preprocess.test_post_tensor_transform_called + assert preprocess.predict_load_data_called def test_is_overriden_recursive(tmpdir): From 2b2c49901eb5f1dc24c14e589ab894f3a62cce09 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 13:22:25 +0100 Subject: [PATCH 03/14] update --- flash/core/classification.py | 3 +- flash/core/model.py | 13 +++++++-- flash/data/base_viz.py | 43 +++++++++++++---------------- flash/data/data_module.py | 21 ++++++++++++-- flash/data/data_pipeline.py | 37 ++++++++++++++----------- flash/data/process.py | 11 ++++++++ flash/data/utils.py | 1 + flash/vision/classification/data.py | 37 ++++++++++++++++++++----- tests/data/test_base_viz.py | 9 ++---- tests/examples/test_scripts.py | 10 +++---- 10 files changed, 119 insertions(+), 66 deletions(-) diff --git a/flash/core/classification.py b/flash/core/classification.py index 86b4066410..4340f404b5 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Union +from typing import Any import torch -from torch import Tensor from flash.core.model import Task from flash.data.process import Postprocess diff --git a/flash/core/model.py b/flash/core/model.py index 6cc7bcda5f..b03a424dd2 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -import os +import inspect +from copy import deepcopy from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union import torch @@ -244,13 +245,21 @@ def on_fit_end(self) -> None: self.data_pipeline._detach_from_model(self) super().on_fit_end() + def _sanetize_funcs(self, obj: Any) -> Any: + if hasattr(obj, "__dict__"): + for k, v in obj.__dict__.items(): + if isinstance(v, Callable): + obj.__dict__[k] = inspect.unwrap(v) + return obj + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # TODO: Is this the best way to do this? or should we also use some kind of hparams here? # This may be an issue since here we create the same problems with pickle as in # https://pytorch.org/docs/stable/notes/serialization.html - if self.data_pipeline is not None and 'data_pipeline' not in checkpoint: + self._preprocess = self._sanetize_funcs(self._preprocess) checkpoint['data_pipeline'] = self.data_pipeline + # todo (tchaton) re-wrap visualization super().on_save_checkpoint(checkpoint) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py index fb50168d9f..40e341196e 100644 --- a/flash/data/base_viz.py +++ b/flash/data/base_viz.py @@ -4,42 +4,43 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.states import RunningStage -from flash.data.data_module import DataModule from flash.data.data_pipeline import DataPipeline -from flash.data.process import Postprocess, Preprocess +from flash.data.process import Preprocess class BaseViz(Callback): - def __init__(self, datamodule: DataModule): - self._datamodule = datamodule - self._wrap_preprocess() - + def __init__(self, enabled: bool = False): self.batches = {"train": {}, "val": {}, "test": {}, "predict": {}} + self.enabled = enabled + self._datamodule = None + + def attach_to_preprocess(self, preprocess: Preprocess) -> None: + self._wrap_functions_per_stage(RunningStage.TRAINING, preprocess) + + def attach_to_datamodule(self, datamodule) -> None: + self._datamodule = datamodule + datamodule.viz = self def _wrap_fn( self, fn: Callable, running_stage: RunningStage, ) -> Callable: - """ - """ @functools.wraps(fn) - def wrapper(data) -> Any: - print(data) - data = fn(data) - print(data) - batches = self.batches[running_stage.value] - if fn.__name__ not in batches: - batches[fn.__name__] = [] - batches[fn.__name__].append(data) + def wrapper(*args) -> Any: + data = fn(*args) + if self.enabled: + batches = self.batches[running_stage.value] + if fn.__name__ not in batches: + batches[fn.__name__] = [] + batches[fn.__name__].append(data) return data return wrapper - def _wrap_functions_per_stage(self, running_stage: RunningStage): - preprocess = self._datamodule.data_pipeline._preprocess_pipeline + def _wrap_functions_per_stage(self, running_stage: RunningStage, preprocess: Preprocess): fn_names = { k: DataPipeline._resolve_function_hierarchy(k, preprocess, running_stage, Preprocess) for k in DataPipeline.PREPROCESS_FUNCS @@ -47,9 +48,3 @@ def _wrap_functions_per_stage(self, running_stage: RunningStage): for fn_name in fn_names: fn = getattr(preprocess, fn_name) setattr(preprocess, fn_name, self._wrap_fn(fn, running_stage)) - - # hack until solved - self._datamodule._train_ds.load_sample = preprocess.load_sample - - def _wrap_preprocess(self): - self._wrap_functions_per_stage(RunningStage.TRAINING) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index f998c62ad1..286be2b6fa 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -24,6 +24,7 @@ from torch.utils.data.dataset import Subset from flash.data.auto_dataset import AutoDataset +from flash.data.base_viz import BaseViz from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess @@ -83,12 +84,22 @@ def __init__( self._preprocess = None self._postprocess = None + self._viz = None # this may also trigger data preloading self.set_running_stages() - def configure_vis(self): - return self + @property + def viz(self) -> BaseViz: + return self._viz or DataModule.configure_vis() + + @viz.setter + def viz(self, viz: BaseViz) -> None: + self._viz = viz + + @classmethod + def configure_vis(cls) -> BaseViz: + return BaseViz() @staticmethod def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any: @@ -322,6 +333,10 @@ def from_load_data_inputs( ) else: data_pipeline = cls(**kwargs).data_pipeline + + viz_callback = cls.configure_vis() + viz_callback.attach_to_preprocess(data_pipeline._preprocess_pipeline) + train_dataset = cls._generate_dataset_if_possible( train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline ) @@ -343,5 +358,5 @@ def from_load_data_inputs( ) datamodule._preprocess = data_pipeline._preprocess_pipeline datamodule._postprocess = data_pipeline._postprocess_pipeline - datamodule.configure_vis() + viz_callback.attach_to_datamodule(datamodule) return datamodule diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 2aca209f3d..40f9d48be8 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -18,6 +18,7 @@ from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities import imports from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data._utils.collate import default_collate, default_convert from torch.utils.data.dataloader import DataLoader @@ -240,23 +241,27 @@ def _create_collate_preprocessors( if collate_fn is None: collate_fn = default_collate + preprocess = self._preprocess_pipeline + func_names = { - k: self._resolve_function_hierarchy(k, self._preprocess_pipeline, stage, Preprocess) + k: self._resolve_function_hierarchy(k, preprocess, stage, Preprocess) for k in self.PREPROCESS_FUNCS } - if self._is_overriden_recursive("collate", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage]): - collate_fn = getattr(self._preprocess_pipeline, func_names["collate"]) + if self._is_overriden_recursive("collate", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage]): + collate_fn = getattr(preprocess, func_names["collate"]) per_batch_transform_overriden = self._is_overriden_recursive( - "per_batch_transform", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage] + "per_batch_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] ) per_sample_transform_on_device_overriden = self._is_overriden_recursive( - "per_sample_transform_on_device", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage] + "per_sample_transform_on_device", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] ) - if per_batch_transform_overriden and per_sample_transform_on_device_overriden: + skip_mutual_check = preprocess.skip_mutual_check + + if (not skip_mutual_check and per_batch_transform_overriden and per_sample_transform_on_device_overriden): raise MisconfigurationException( f'{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` ' f'are mutual exclusive for stage {stage}' @@ -279,26 +284,26 @@ def _create_collate_preprocessors( ) else worker_collate_fn assert_contains_tensor = self._is_overriden_recursive( - "to_tensor_transform", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage] + "to_tensor_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] ) worker_preprocessor = _PreProcessor( - self._preprocess_pipeline, worker_collate_fn, + preprocess, worker_collate_fn, _Sequential( - self._preprocess_pipeline, - getattr(self._preprocess_pipeline, func_names['pre_tensor_transform']), - getattr(self._preprocess_pipeline, func_names['to_tensor_transform']), - getattr(self._preprocess_pipeline, func_names['post_tensor_transform']), + preprocess, + getattr(preprocess, func_names['pre_tensor_transform']), + getattr(preprocess, func_names['to_tensor_transform']), + getattr(preprocess, func_names['post_tensor_transform']), stage, assert_contains_tensor=assert_contains_tensor, - ), getattr(self._preprocess_pipeline, func_names['per_batch_transform']), stage + ), getattr(preprocess, func_names['per_batch_transform']), stage ) worker_preprocessor._original_collate_fn = original_collate_fn device_preprocessor = _PreProcessor( - self._preprocess_pipeline, + preprocess, device_collate_fn, - getattr(self._preprocess_pipeline, func_names['per_sample_transform_on_device']), - getattr(self._preprocess_pipeline, func_names['per_batch_transform_on_device']), + getattr(preprocess, func_names['per_sample_transform_on_device']), + getattr(preprocess, func_names['per_batch_transform_on_device']), stage, apply_per_sample_transform=device_collate_fn != self._identity, on_device=True, diff --git a/flash/data/process.py b/flash/data/process.py index 35ecda3993..62b23cc4a0 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -114,6 +114,17 @@ def __init__( self.test_transform = convert_to_modules(test_transform) self.predict_transform = convert_to_modules(predict_transform) + if not hasattr(self, "_skip_mutual_check"): + self._skip_mutual_check = False + + @property + def skip_mutual_check(self) -> bool: + return self._skip_mutual_check + + @skip_mutual_check.setter + def skip_mutual_check(self, skip_mutual_check: bool) -> None: + self._skip_mutual_check = skip_mutual_check + def _identify(self, x): return x diff --git a/flash/data/utils.py b/flash/data/utils.py index cad73d3258..4b7fec9122 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -33,6 +33,7 @@ _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} +# todo (tchaton) convert to class @contextmanager def set_current_stage(obj: Any, stage: RunningStage) -> None: if obj is not None: diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 6f9cb8bb36..d09f467155 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union import torch +import torchvision from PIL import Image from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -23,7 +24,7 @@ from torch.nn.modules import ModuleDict from torch.utils.data import Dataset from torch.utils.data._utils.collate import default_collate -from torchvision import transforms as torchvision_T +from torchvision import transforms from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset from torchvision.transforms.functional import to_pil_image @@ -42,6 +43,11 @@ class ImageClassificationPreprocess(Preprocess): + # this assignement is used to skip the assert that `per_batch_transform` and `per_sample_transform_on_device` + # are mutually exclusive on the DataPipeline internals + _skip_mutual_check = True + to_tensor = torchvision.transforms.ToTensor() + @staticmethod def _find_classes(dir: str) -> Tuple: """ @@ -152,7 +158,7 @@ def common_step(self, sample: Any) -> Any: if isinstance(sample, (list, tuple)): source, target = sample return self.current_transform(source), target - if isinstance(sample, torch.Tensor): + elif isinstance(sample, torch.Tensor): return sample return self.current_transform(sample) @@ -160,6 +166,13 @@ def per_tensor_transform(self, sample: Any) -> Any: return self.common_step(sample) def to_tensor_transform(self, sample: Any) -> Any: + if self.current_transform == self._identify: + if isinstance(sample, (list, tuple)): + source, target = sample + return self.to_tensor(source), target + elif isinstance(sample, torch.Tensor): + return sample + return self.to_tensor(sample) return self.common_step(sample) def post_tensor_transform(self, sample: Any) -> Any: @@ -168,6 +181,9 @@ def post_tensor_transform(self, sample: Any) -> Any: def per_batch_transform(self, sample: Any) -> Any: return self.common_step(sample) + def per_sample_transform_on_device(self, sample: Any) -> Any: + return self.common_step(sample) + def per_batch_transform_on_device(self, sample: Any) -> Any: return self.common_step(sample) @@ -229,6 +245,11 @@ def _check_transforms(transform: Dict[str, Union[nn.Module, Callable]]) -> Dict[ "Transform should be a dict. " f"Here are the available keys for your transforms: {DataPipeline.PREPROCESS_FUNCS}." ) + if "per_batch_transform" in transform and "per_sample_transform_on_device" in transform: + raise MisconfigurationException( + f'{transform}: `per_batch_transform` and `per_sample_transform_on_device` ' + f'are mutual exclusive.' + ) return transform @staticmethod @@ -237,7 +258,7 @@ def default_train_transforms(): if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": # Better approach as all transforms are applied on tensor directly return { - "to_tensor_transform": torchvision_T.ToTensor(), + "to_tensor_transform": torchvision.transforms.ToTensor(), "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size), K.RandomHorizontalFlip()), "per_batch_transform_on_device": nn.Sequential( K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), @@ -247,7 +268,7 @@ def default_train_transforms(): from torchvision import transforms as T # noqa F811 return { "pre_tensor_transform": nn.Sequential(T.RandomResizedCrop(image_size), T.RandomHorizontalFlip()), - "to_tensor_transform": torchvision_T.ToTensor(), + "to_tensor_transform": torchvision.transforms.ToTensor(), "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), } @@ -257,7 +278,7 @@ def default_val_transforms(): if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": # Better approach as all transforms are applied on tensor directly return { - "to_tensor_transform": torchvision_T.ToTensor(), + "to_tensor_transform": torchvision.transforms.ToTensor(), "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size)), "per_batch_transform_on_device": nn.Sequential( K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), @@ -267,7 +288,7 @@ def default_val_transforms(): from torchvision import transforms as T # noqa F811 return { "pre_tensor_transform": T.Compose([T.RandomResizedCrop(image_size)]), - "to_tensor_transform": torchvision_T.ToTensor(), + "to_tensor_transform": torchvision.transforms.ToTensor(), "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), } @@ -323,7 +344,9 @@ def instantiate_preprocess( ) preprocess_cls = preprocess_cls or cls.preprocess_cls - return preprocess_cls(train_transform, val_transform, test_transform, predict_transform) + preprocess = preprocess_cls(train_transform, val_transform, test_transform, predict_transform) + # todo (tchaton) add check on mutually exclusive transforms + return preprocess @classmethod def _resolve_transforms( diff --git a/tests/data/test_base_viz.py b/tests/data/test_base_viz.py index c153903a76..a4e49a3caf 100644 --- a/tests/data/test_base_viz.py +++ b/tests/data/test_base_viz.py @@ -44,14 +44,10 @@ def _rand_image(): class ImageClassificationDataViz(ImageClassificationData): - def configure_vis(self): - if not hasattr(self, "viz"): - return BaseViz(self) - return self.viz - def show_train_batch(self): - self.viz = self.configure_vis() + self.viz.enabled = True _ = next(iter(self.train_dataloader())) + self.viz.enabled = False def test_base_viz(tmpdir): @@ -67,7 +63,6 @@ def test_base_viz(tmpdir): img_data = ImageClassificationDataViz.from_filepaths( train_filepaths=[tmpdir / "a", tmpdir / "b"], - train_transform=None, train_labels=[0, 1], batch_size=1, num_workers=0, diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 669baee5a1..3db3e9384b 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -61,11 +61,11 @@ def run_test(filepath): ("finetuning", "tabular_classification.py"), ("finetuning", "text_classification.py"), # TODO: takes too long # ("finetuning", "translation.py"), # TODO: takes too long. - ("predict", "image_classification.py"), - ("predict", "tabular_classification.py"), - ("predict", "text_classification.py"), - ("predict", "image_embedder.py"), - ("predict", "summarization.py"), # TODO: takes too long + #("predict", "image_classification.py"), + #("predict", "tabular_classification.py"), + #("predict", "text_classification.py"), + #("predict", "image_embedder.py"), + #("predict", "summarization.py"), # TODO: takes too long # ("predict", "translate.py"), # TODO: takes too long ] ) From 6db6b1cb90a2a457cabfa553a93c97ee8291f870 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 13:22:52 +0100 Subject: [PATCH 04/14] resolve flake8 --- tests/examples/test_scripts.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 3db3e9384b..e8bab588f3 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -61,11 +61,11 @@ def run_test(filepath): ("finetuning", "tabular_classification.py"), ("finetuning", "text_classification.py"), # TODO: takes too long # ("finetuning", "translation.py"), # TODO: takes too long. - #("predict", "image_classification.py"), - #("predict", "tabular_classification.py"), - #("predict", "text_classification.py"), - #("predict", "image_embedder.py"), - #("predict", "summarization.py"), # TODO: takes too long + # ("predict", "image_classification.py"), + # ("predict", "tabular_classification.py"), + # ("predict", "text_classification.py"), + # ("predict", "image_embedder.py"), + # ("predict", "summarization.py"), # TODO: takes too long # ("predict", "translate.py"), # TODO: takes too long ] ) From f61deeacb9b9d1b4766a8ee2411aedfdec8fa2e5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 17:23:28 +0100 Subject: [PATCH 05/14] update --- flash/data/auto_dataset.py | 26 +++++---- flash/data/batch.py | 31 ++++++---- flash/data/process.py | 8 +-- flash/data/utils.py | 90 ++++++++++++++++++----------- flash/vision/classification/data.py | 2 + tests/examples/test_scripts.py | 2 +- 6 files changed, 98 insertions(+), 61 deletions(-) diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index e42a4cf680..7a1bc0455b 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -20,7 +20,7 @@ from torch.utils.data import Dataset from flash.data.process import Preprocess -from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, set_current_stage_and_fn +from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, CurrentRunningStageFuncContext if TYPE_CHECKING: from flash.data.data_pipeline import DataPipeline @@ -68,9 +68,13 @@ def running_stage(self) -> Optional[RunningStage]: return self._running_stage @running_stage.setter - def running_stage(self, running_stage: str) -> None: + def running_stage(self, running_stage: RunningStage) -> None: if self._running_stage != running_stage or (not self._running_stage): self._running_stage = running_stage + self._load_data_context = CurrentRunningStageFuncContext(self._running_stage, "load_data", self._preprocess) + self._load_sample_context = CurrentRunningStageFuncContext( + self._running_stage, "load_sample", self._preprocess + ) self._setup(running_stage) @property @@ -80,11 +84,10 @@ def _preprocess(self): def _call_load_data(self, data: Any) -> Iterable: parameters = signature(self.load_data).parameters - with set_current_stage_and_fn(self._preprocess, self._running_stage, "load_data"): - if len(parameters) > 1 and self.DATASET_KEY in parameters: - return self.load_data(data, self) - else: - return self.load_data(data) + if len(parameters) > 1 and self.DATASET_KEY in parameters: + return self.load_data(data, self) + else: + return self.load_data(data) def _call_load_sample(self, sample: Any) -> Any: parameters = signature(self.load_sample).parameters @@ -116,16 +119,17 @@ def _setup(self, stage: Optional[RunningStage]) -> None: "The load_data function of the Autogenerated Dataset changed. " "This is not expected! Preloading Data again to ensure compatibility. This may take some time." ) - self._preprocessed_data = self._call_load_data(self.data) + with self._load_data_context: + self._preprocessed_data = self._call_load_data(self.data) self._load_data_called = True def __getitem__(self, index: int) -> Any: if not self.load_sample and not self.load_data: raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") - with set_current_stage_and_fn(self._preprocess, self._running_stage, "load_sample"): - if self.load_sample: + if self.load_sample: + with self._load_sample_context: return self._call_load_sample(self._preprocessed_data[index]) - return self._preprocessed_data[index] + return self._preprocessed_data[index] def __len__(self) -> int: if not self.load_sample and not self.load_data: diff --git a/flash/data/batch.py b/flash/data/batch.py index 1047e85a44..3b7bc70ab4 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -18,7 +18,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor -from flash.data.utils import _contains_any_tensor, convert_to_modules, set_current_fn, set_current_stage +from flash.data.utils import _contains_any_tensor, convert_to_modules, CurrentFuncContext, CurrentRunningStageContext if TYPE_CHECKING: from flash.data.process import Preprocess @@ -49,12 +49,17 @@ def __init__( self.stage = stage self.assert_contains_tensor = assert_contains_tensor + self._current_stage_context = CurrentRunningStageContext(stage, preprocess, reset=False) + self._pre_tensor_transform_context = CurrentFuncContext("pre_tensor_transform", preprocess) + self._to_tensor_transform_context = CurrentFuncContext("to_tensor_transform", preprocess) + self._post_tensor_transform_context = CurrentFuncContext("post_tensor_transform", preprocess) + def forward(self, sample: Any): - with set_current_stage(self.preprocess, self.stage): - with set_current_fn(self.preprocess, "pre_tensor_transform"): + with self._current_stage_context: + with self._pre_tensor_transform_context: sample = self.pre_tensor_transform(sample) - with set_current_fn(self.preprocess, "to_tensor_transform"): + with self._to_tensor_transform_context: sample = self.to_tensor_transform(sample) if self.assert_contains_tensor: @@ -64,7 +69,7 @@ def forward(self, sample: Any): "``DataPipeline`` expects the outputs to be ``tensors``" ) - with set_current_fn(self.preprocess, "post_tensor_transform"): + with self._post_tensor_transform_context: sample = self.post_tensor_transform(sample) return sample @@ -105,7 +110,7 @@ def __init__( collate_fn: Callable, per_sample_transform: Union[Callable, _Sequential], per_batch_transform: Callable, - stage: Optional[RunningStage] = None, + stage: RunningStage, apply_per_sample_transform: bool = True, on_device: bool = False ): @@ -118,18 +123,24 @@ def __init__( self.stage = stage self.on_device = on_device + extension = f"{'on_device' if self.on_device else ''}" + self._current_stage_context = CurrentRunningStageContext(stage, preprocess) + self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform_{extension}", preprocess) + self._collate_context = CurrentFuncContext("collate", preprocess) + self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform_{extension}", preprocess) + def forward(self, samples: Sequence[Any]): - with set_current_stage(self.preprocess, self.stage): + with self._current_stage_context: if self.apply_per_sample_transform: - with set_current_fn(self.preprocess, f"per_sample_transform_{'on_device' if self.on_device else ''}"): + with self._per_sample_transform_context: samples = [self.per_sample_transform(sample) for sample in samples] samples = type(samples)(samples) - with set_current_fn(self.preprocess, "collate"): + with self._collate_context: samples = self.collate_fn(samples) - with set_current_fn(self.preprocess, f"per_batch_transform_{'on_device' if self.on_device else ''}"): + with self._per_batch_transform_context: samples = self.per_batch_transform(samples) return samples diff --git a/flash/data/process.py b/flash/data/process.py index 62b23cc4a0..4dc1fea4c0 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -27,11 +27,11 @@ class Properties: - _running_stage: RunningStage = None - _current_fn: str = None + _running_stage: Optional[RunningStage] = None + _current_fn: Optional[str] = None @property - def current_fn(self) -> str: + def current_fn(self) -> Optional[str]: return self._current_fn @current_fn.setter @@ -39,7 +39,7 @@ def current_fn(self, current_fn: str): self._current_fn = current_fn @property - def running_stage(self) -> RunningStage: + def running_stage(self) -> Optional[RunningStage]: return self._running_stage @running_stage.setter diff --git a/flash/data/utils.py b/flash/data/utils.py index 4b7fec9122..f3928861ba 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -14,7 +14,7 @@ import os.path import zipfile -from contextlib import contextmanager +from contextlib import ContextDecorator, contextmanager from typing import Any, Callable, Dict, Iterable, Mapping, Type import requests @@ -33,40 +33,60 @@ _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} -# todo (tchaton) convert to class -@contextmanager -def set_current_stage(obj: Any, stage: RunningStage) -> None: - if obj is not None: - if getattr(obj, "_running_stage", None) == stage: - yield - else: - obj.running_stage = stage - yield - obj.running_stage = None - else: - yield - - -@contextmanager -def set_current_fn(obj: Any, current_fn: str) -> None: - if obj is not None: - obj.current_fn = current_fn - yield - obj.current_fn = None - else: - yield - - -@contextmanager -def set_current_stage_and_fn(obj: Any, stage: RunningStage, current_fn: str) -> None: - if obj is not None: - obj.running_stage = stage - obj.current_fn = current_fn - yield - obj.running_stage = None - obj.current_fn = None - else: - yield +class CurrentRunningStageContext: + + def __init__(self, running_stage: RunningStage, obj: Any, reset: bool = True): + self._running_stage = running_stage + self._obj = obj + self._reset = reset + + def __enter__(self): + if self._obj is not None: + if getattr(self._obj, "running_stage", None) != self._running_stage: + self._obj.running_stage = self._running_stage + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._obj is not None and self._reset: + self._obj.running_stage = None + + +class CurrentFuncContext: + + def __init__(self, current_fn: str, obj: Any): + self._current_fn = current_fn + self._obj = obj + + def __enter__(self): + if self._obj is not None: + if getattr(self._obj, "current_fn", None) != self._current_fn: + self._obj.current_fn = self._current_fn + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._obj is not None: + self._obj.current_fn = None + + +class CurrentRunningStageFuncContext: + + def __init__(self, running_stage: RunningStage, current_fn: str, obj: Any): + self._running_stage = running_stage + self._current_fn = current_fn + self._obj = obj + + def __enter__(self): + if self._obj is not None: + if getattr(self._obj, "running_stage", None) != self._running_stage: + self._obj.running_stage = self._running_stage + if getattr(self._obj, "current_fn", None) != self._current_fn: + self._obj.current_fn = self._current_fn + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._obj is not None: + self._obj.running_stage = None + self._obj.current_fn = None def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index d09f467155..97497c799b 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -169,6 +169,8 @@ def to_tensor_transform(self, sample: Any) -> Any: if self.current_transform == self._identify: if isinstance(sample, (list, tuple)): source, target = sample + if isinstance(source, torch.Tensor): + return source, target return self.to_tensor(source), target elif isinstance(sample, torch.Tensor): return sample diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index e8bab588f3..d733768d65 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -53,7 +53,7 @@ def run_test(filepath): @mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) @pytest.mark.parametrize( - "folder,file", + "folder, file", [ # ("finetuning", "image_classification.py"), # ("finetuning", "object_detection.py"), # TODO: takes too long. From ffaa7c7668ba836786e50aab3e83e0d9f52158f2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 18:16:00 +0100 Subject: [PATCH 06/14] resolve tests --- flash/data/data_pipeline.py | 2 +- flash/vision/classification/data.py | 11 ++++------- tests/examples/test_scripts.py | 10 +++++----- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 40f9d48be8..f28d882a22 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -259,7 +259,7 @@ def _create_collate_preprocessors( "per_sample_transform_on_device", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] ) - skip_mutual_check = preprocess.skip_mutual_check + skip_mutual_check = getattr(preprocess, "skip_mutual_check", False) if (not skip_mutual_check and per_batch_transform_overriden and per_sample_transform_on_device_overriden): raise MisconfigurationException( diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 97497c799b..39bd1a2023 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -18,15 +18,11 @@ import torch import torchvision from PIL import Image -from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn -from torch.nn.modules import ModuleDict from torch.utils.data import Dataset from torch.utils.data._utils.collate import default_collate -from torchvision import transforms from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset -from torchvision.transforms.functional import to_pil_image from flash.data.auto_dataset import AutoDataset from flash.data.data_module import DataModule @@ -158,11 +154,9 @@ def common_step(self, sample: Any) -> Any: if isinstance(sample, (list, tuple)): source, target = sample return self.current_transform(source), target - elif isinstance(sample, torch.Tensor): - return sample return self.current_transform(sample) - def per_tensor_transform(self, sample: Any) -> Any: + def pre_tensor_transform(self, sample: Any) -> Any: return self.common_step(sample) def to_tensor_transform(self, sample: Any) -> Any: @@ -175,6 +169,8 @@ def to_tensor_transform(self, sample: Any) -> Any: elif isinstance(sample, torch.Tensor): return sample return self.to_tensor(sample) + if isinstance(sample, torch.Tensor): + return sample return self.common_step(sample) def post_tensor_transform(self, sample: Any) -> Any: @@ -537,5 +533,6 @@ def from_filepaths( batch_size=batch_size, num_workers=num_workers, preprocess=preprocess, + seed=seed, **kwargs ) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index d733768d65..f04af08d54 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -55,16 +55,16 @@ def run_test(filepath): @pytest.mark.parametrize( "folder, file", [ - # ("finetuning", "image_classification.py"), + ("finetuning", "image_classification.py"), # ("finetuning", "object_detection.py"), # TODO: takes too long. # ("finetuning", "summarization.py"), # TODO: takes too long. ("finetuning", "tabular_classification.py"), - ("finetuning", "text_classification.py"), # TODO: takes too long + # ("finetuning", "text_classification.py"), # TODO: takes too long # ("finetuning", "translation.py"), # TODO: takes too long. - # ("predict", "image_classification.py"), - # ("predict", "tabular_classification.py"), + ("predict", "image_classification.py"), + ("predict", "tabular_classification.py"), # ("predict", "text_classification.py"), - # ("predict", "image_embedder.py"), + ("predict", "image_embedder.py"), # ("predict", "summarization.py"), # TODO: takes too long # ("predict", "translate.py"), # TODO: takes too long ] From 596a523998feab60c90bc284e2aaf115503dcfb4 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 18:20:49 +0100 Subject: [PATCH 07/14] update --- flash/data/base_viz.py | 50 ------------------------- flash/data/data_module.py | 18 --------- tests/data/test_base_viz.py | 75 ------------------------------------- 3 files changed, 143 deletions(-) delete mode 100644 flash/data/base_viz.py delete mode 100644 tests/data/test_base_viz.py diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py deleted file mode 100644 index 40e341196e..0000000000 --- a/flash/data/base_viz.py +++ /dev/null @@ -1,50 +0,0 @@ -import functools -from typing import Any, Callable - -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.trainer.states import RunningStage - -from flash.data.data_pipeline import DataPipeline -from flash.data.process import Preprocess - - -class BaseViz(Callback): - - def __init__(self, enabled: bool = False): - self.batches = {"train": {}, "val": {}, "test": {}, "predict": {}} - self.enabled = enabled - self._datamodule = None - - def attach_to_preprocess(self, preprocess: Preprocess) -> None: - self._wrap_functions_per_stage(RunningStage.TRAINING, preprocess) - - def attach_to_datamodule(self, datamodule) -> None: - self._datamodule = datamodule - datamodule.viz = self - - def _wrap_fn( - self, - fn: Callable, - running_stage: RunningStage, - ) -> Callable: - - @functools.wraps(fn) - def wrapper(*args) -> Any: - data = fn(*args) - if self.enabled: - batches = self.batches[running_stage.value] - if fn.__name__ not in batches: - batches[fn.__name__] = [] - batches[fn.__name__].append(data) - return data - - return wrapper - - def _wrap_functions_per_stage(self, running_stage: RunningStage, preprocess: Preprocess): - fn_names = { - k: DataPipeline._resolve_function_hierarchy(k, preprocess, running_stage, Preprocess) - for k in DataPipeline.PREPROCESS_FUNCS - } - for fn_name in fn_names: - fn = getattr(preprocess, fn_name) - setattr(preprocess, fn_name, self._wrap_fn(fn, running_stage)) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 286be2b6fa..f7c2e8f6d2 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -24,7 +24,6 @@ from torch.utils.data.dataset import Subset from flash.data.auto_dataset import AutoDataset -from flash.data.base_viz import BaseViz from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess @@ -84,23 +83,10 @@ def __init__( self._preprocess = None self._postprocess = None - self._viz = None # this may also trigger data preloading self.set_running_stages() - @property - def viz(self) -> BaseViz: - return self._viz or DataModule.configure_vis() - - @viz.setter - def viz(self, viz: BaseViz) -> None: - self._viz = viz - - @classmethod - def configure_vis(cls) -> BaseViz: - return BaseViz() - @staticmethod def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any: if isinstance(dataset, Subset): @@ -334,9 +320,6 @@ def from_load_data_inputs( else: data_pipeline = cls(**kwargs).data_pipeline - viz_callback = cls.configure_vis() - viz_callback.attach_to_preprocess(data_pipeline._preprocess_pipeline) - train_dataset = cls._generate_dataset_if_possible( train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline ) @@ -358,5 +341,4 @@ def from_load_data_inputs( ) datamodule._preprocess = data_pipeline._preprocess_pipeline datamodule._postprocess = data_pipeline._postprocess_pipeline - viz_callback.attach_to_datamodule(datamodule) return datamodule diff --git a/tests/data/test_base_viz.py b/tests/data/test_base_viz.py deleted file mode 100644 index a4e49a3caf..0000000000 --- a/tests/data/test_base_viz.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple -from unittest import mock - -import numpy as np -import pytest -import torch -import torchvision.transforms as T -from PIL import Image -from pytorch_lightning import Trainer -from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch import Tensor, tensor -from torch.utils.data import DataLoader -from torch.utils.data._utils.collate import default_collate - -from flash.core import Task -from flash.data.auto_dataset import AutoDataset -from flash.data.base_viz import BaseViz -from flash.data.batch import _PostProcessor, _PreProcessor -from flash.data.data_module import DataModule -from flash.data.data_pipeline import _StageOrchestrator, DataPipeline -from flash.data.process import Postprocess, Preprocess -from flash.vision import ImageClassificationData - - -def _rand_image(): - return Image.fromarray(np.random.randint(0, 255, (196, 196, 3), dtype="uint8")) - - -class ImageClassificationDataViz(ImageClassificationData): - - def show_train_batch(self): - self.viz.enabled = True - _ = next(iter(self.train_dataloader())) - self.viz.enabled = False - - -def test_base_viz(tmpdir): - tmpdir = Path(tmpdir) - - (tmpdir / "a").mkdir() - (tmpdir / "b").mkdir() - _rand_image().save(tmpdir / "a" / "a_1.png") - _rand_image().save(tmpdir / "a" / "a_2.png") - - _rand_image().save(tmpdir / "b" / "a_1.png") - _rand_image().save(tmpdir / "b" / "a_2.png") - - img_data = ImageClassificationDataViz.from_filepaths( - train_filepaths=[tmpdir / "a", tmpdir / "b"], - train_labels=[0, 1], - batch_size=1, - num_workers=0, - ) - - img_data.show_train_batch() - assert img_data.viz.batches["train"]["load_sample"] is not None - assert img_data.viz.batches["train"]["to_tensor_transform"] is not None - assert img_data.viz.batches["train"]["collate"] is not None - assert img_data.viz.batches["train"]["per_batch_transform"] is not None From b928fc5f7506fa3e219cc93e8dc1aeadfabf8cfa Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 19:29:34 +0100 Subject: [PATCH 08/14] resolve doc --- flash/vision/classification/data.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 39bd1a2023..d0f7e0cb63 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -459,44 +459,37 @@ def from_filepaths( **kwargs, ) -> 'ImageClassificationData': """ - Creates a ImageClassificationData object from folders of images arranged in this way: :: + Creates a ImageClassificationData object from folders of images arranged in this way: + + Examples:: + folder/dog_xxx.png folder/dog_xxy.png folder/dog_xxz.png folder/cat_123.png folder/cat_nsdf3.png folder/cat_asd932_.png + Args: + train_filepaths: String or sequence of file paths for training dataset. Defaults to ``None``. train_labels: Sequence of labels for training dataset. Defaults to ``None``. val_filepaths: String or sequence of file paths for validation dataset. Defaults to ``None``. val_labels: Sequence of labels for validation dataset. Defaults to ``None``. test_filepaths: String or sequence of file paths for test dataset. Defaults to ``None``. test_labels: Sequence of labels for test dataset. Defaults to ``None``. - train_transform: Transforms for training dataset. Defaults to ``default``, which loads imagenet transforms. + train_transform: Transforms for training dataset. Defaults to ``default``, + which loads imagenet transforms. val_transform: Transforms for validation and testing dataset. Defaults to ``default``, which loads imagenet transforms. batch_size: The batchsize to use for parallel loading. Defaults to ``64``. num_workers: The number of workers to use for parallelized loading. Defaults to ``None`` which equals the number of available CPU threads. seed: Used for the train/val splits. + Returns: + ImageClassificationData: The constructed data module. - Examples: - >>> img_data = ImageClassificationData.from_filepaths(["a.png", "b.png"], [0, 1]) # doctest: +SKIP - Example when labels are in .csv file:: - train_labels = labels_from_categorical_csv('path/to/train.csv', 'my_id') - val_labels = labels_from_categorical_csv(path/to/val.csv', 'my_id') - test_labels = labels_from_categorical_csv(path/to/tests.csv', 'my_id') - data = ImageClassificationData.from_filepaths( - batch_size=2, - train_filepaths='path/to/train', - train_labels=train_labels, - val_filepaths='path/to/val', - val_labels=val_labels, - test_filepaths='path/to/test', - test_labels=test_labels, - ) """ # enable passing in a string which loads all files in that folder as a list if isinstance(train_filepaths, str): From 9381d412de7a3ddeac2905c24bebbf8160783f65 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 19:36:14 +0100 Subject: [PATCH 09/14] update doc --- flash/vision/classification/data.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index d0f7e0cb63..6e7f99ba16 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -459,9 +459,7 @@ def from_filepaths( **kwargs, ) -> 'ImageClassificationData': """ - Creates a ImageClassificationData object from folders of images arranged in this way: - - Examples:: + Creates a ImageClassificationData object from folders of images arranged in this way: :: folder/dog_xxx.png folder/dog_xxy.png @@ -471,7 +469,6 @@ def from_filepaths( folder/cat_asd932_.png Args: - train_filepaths: String or sequence of file paths for training dataset. Defaults to ``None``. train_labels: Sequence of labels for training dataset. Defaults to ``None``. val_filepaths: String or sequence of file paths for validation dataset. Defaults to ``None``. @@ -488,7 +485,6 @@ def from_filepaths( seed: Used for the train/val splits. Returns: - ImageClassificationData: The constructed data module. """ # enable passing in a string which loads all files in that folder as a list From 108a7cca4d6c390615fd78ba8c73aa32d7687c4a Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Apr 2021 09:53:03 +0100 Subject: [PATCH 10/14] update --- tests/data/test_data_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index c1d8ae6b62..172a6793eb 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -626,7 +626,7 @@ class CustomDataModule(DataModule): batch = next(iter(datamodule.val_dataloader())) CustomDataModule.preprocess_cls = TestPreprocessTransformations2 - datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2) + datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2, num_workers=0) batch = next(iter(datamodule.val_dataloader())) assert torch.equal(batch["a"], tensor([0, 1])) assert torch.equal(batch["b"], tensor([1, 2])) From 6da92b375dee1c9f77d1ce121330ea08ce85f982 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Apr 2021 10:02:25 +0100 Subject: [PATCH 11/14] update --- tests/data/test_data_pipeline.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index 172a6793eb..d311011c68 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -607,7 +607,10 @@ def test_step(self, batch, batch_idx): assert batch[0].shape == torch.Size([2, 1]) def predict_step(self, batch, batch_idx, dataloader_idx): - assert batch == [('a', 'a'), ('b', 'b')] + assert batch[0][0] == 'a' + assert batch[0][1] == 'a' + assert batch[1][0] == 'b' + assert batch[1][1] == 'b' return tensor([0, 0, 0]) class CustomDataModule(DataModule): From 16deb7b0f787b6bcac15f4c942f6d724372464b2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Apr 2021 10:22:16 +0100 Subject: [PATCH 12/14] convert to staticmethod --- flash/core/model.py | 3 ++- tests/data/test_data_pipeline.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index b03a424dd2..78c907fc6c 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -245,7 +245,8 @@ def on_fit_end(self) -> None: self.data_pipeline._detach_from_model(self) super().on_fit_end() - def _sanetize_funcs(self, obj: Any) -> Any: + @staticmethod + def _sanetize_funcs(obj: Any) -> Any: if hasattr(obj, "__dict__"): for k, v in obj.__dict__.items(): if isinstance(v, Callable): diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index d311011c68..0154377160 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -617,7 +617,7 @@ class CustomDataModule(DataModule): preprocess_cls = TestPreprocessTransformations - datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2) + datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2, num_workers=0) assert datamodule.train_dataloader().dataset[0] == (0, 1, 2, 3) batch = next(iter(datamodule.train_dataloader())) From ff8e1adf7161c54016f2bc0f1dd2d5029435b4c9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Apr 2021 11:42:42 +0100 Subject: [PATCH 13/14] update on comments --- flash/data/auto_dataset.py | 22 ++++---- flash/data/batch.py | 4 +- flash/data/data_pipeline.py | 66 +++++++++++------------ flash/data/process.py | 6 +-- flash/tabular/classification/data/data.py | 2 +- flash/vision/classification/data.py | 3 +- tests/data/test_data_pipeline.py | 8 +-- 7 files changed, 55 insertions(+), 56 deletions(-) diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 7a1bc0455b..bc05f8c441 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -71,14 +71,14 @@ def running_stage(self) -> Optional[RunningStage]: def running_stage(self, running_stage: RunningStage) -> None: if self._running_stage != running_stage or (not self._running_stage): self._running_stage = running_stage - self._load_data_context = CurrentRunningStageFuncContext(self._running_stage, "load_data", self._preprocess) + self._load_data_context = CurrentRunningStageFuncContext(self._running_stage, "load_data", self.preprocess) self._load_sample_context = CurrentRunningStageFuncContext( - self._running_stage, "load_sample", self._preprocess + self._running_stage, "load_sample", self.preprocess ) self._setup(running_stage) @property - def _preprocess(self): + def preprocess(self) -> Optional[Preprocess]: if self.data_pipeline is not None: return self.data_pipeline._preprocess_pipeline @@ -102,15 +102,15 @@ def _setup(self, stage: Optional[RunningStage]) -> None: if self._running_stage and self.data_pipeline and (not self.load_data or not self.load_sample) and stage: self.load_data = getattr( - self.data_pipeline._preprocess_pipeline, + self.data_pipeline.preprocess_pipeline, self.data_pipeline._resolve_function_hierarchy( - 'load_data', self.data_pipeline._preprocess_pipeline, stage, Preprocess + 'load_data', self.data_pipeline.preprocess_pipeline, stage, Preprocess ) ) self.load_sample = getattr( - self.data_pipeline._preprocess_pipeline, + self.data_pipeline.preprocess_pipeline, self.data_pipeline._resolve_function_hierarchy( - 'load_sample', self.data_pipeline._preprocess_pipeline, stage, Preprocess + 'load_sample', self.data_pipeline.preprocess_pipeline, stage, Preprocess ) ) if self.load_data and (previous_load_data != self.load_data.__code__ or not self._load_data_called): @@ -120,7 +120,7 @@ def _setup(self, stage: Optional[RunningStage]) -> None: "This is not expected! Preloading Data again to ensure compatibility. This may take some time." ) with self._load_data_context: - self._preprocessed_data = self._call_load_data(self.data) + self.preprocessed_data = self._call_load_data(self.data) self._load_data_called = True def __getitem__(self, index: int) -> Any: @@ -128,10 +128,10 @@ def __getitem__(self, index: int) -> Any: raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") if self.load_sample: with self._load_sample_context: - return self._call_load_sample(self._preprocessed_data[index]) - return self._preprocessed_data[index] + return self._call_load_sample(self.preprocessed_data[index]) + return self.preprocessed_data[index] def __len__(self) -> int: if not self.load_sample and not self.load_data: raise RuntimeError("`__len__` for `load_sample` and `load_data` could not be inferred.") - return len(self._preprocessed_data) + return len(self.preprocessed_data) diff --git a/flash/data/batch.py b/flash/data/batch.py index 3b7bc70ab4..9c7cce304e 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -54,7 +54,7 @@ def __init__( self._to_tensor_transform_context = CurrentFuncContext("to_tensor_transform", preprocess) self._post_tensor_transform_context = CurrentFuncContext("post_tensor_transform", preprocess) - def forward(self, sample: Any): + def forward(self, sample: Any) -> Any: with self._current_stage_context: with self._pre_tensor_transform_context: sample = self.pre_tensor_transform(sample) @@ -129,7 +129,7 @@ def __init__( self._collate_context = CurrentFuncContext("collate", preprocess) self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform_{extension}", preprocess) - def forward(self, samples: Sequence[Any]): + def forward(self, samples: Sequence[Any]) -> Any: with self._current_stage_context: if self.apply_per_sample_transform: diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index f28d882a22..7cce7fc04a 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -14,7 +14,7 @@ import functools import inspect import weakref -from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import RunningStage @@ -101,31 +101,31 @@ def forward(self, samples: Sequence[Any]): General flow: - load_sample - │ - pre_tensor_transform - │ - to_tensor_transform - │ - post_tensor_transform - │ - ┌────────────────┴───────────────────┐ -(move list to main worker) --> │ │ - per_sample_transform_on_device collate - │ │ - collate per_batch_transform - │ │ <-- (move batch to main worker) - per_batch_transform_on_device per_batch_transform_on_device - │ │ - └─────────────────┬──────────────────┘ - │ - model.predict_step - │ - per_batch_transform - │ - uncollate - │ - per_sample_transform + load_sample + │ + pre_tensor_transform + │ + to_tensor_transform + │ + post_tensor_transform + │ + ┌────────────────┴───────────────────┐ +(move samples's sequence to main worker) --> │ │ + per_sample_transform_on_device collate + │ │ + collate per_batch_transform + │ │ <-- (move batch to main worker) + per_batch_transform_on_device per_batch_transform_on_device + │ │ + └─────────────────┬──────────────────┘ + │ + model.predict_step + │ + per_batch_transform + │ + uncollate + │ + per_sample_transform """ @@ -241,25 +241,25 @@ def _create_collate_preprocessors( if collate_fn is None: collate_fn = default_collate - preprocess = self._preprocess_pipeline + preprocess: Preprocess = self._preprocess_pipeline - func_names = { + func_names: Dict[str, str] = { k: self._resolve_function_hierarchy(k, preprocess, stage, Preprocess) for k in self.PREPROCESS_FUNCS } if self._is_overriden_recursive("collate", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage]): - collate_fn = getattr(preprocess, func_names["collate"]) + collate_fn: Callable = getattr(preprocess, func_names["collate"]) - per_batch_transform_overriden = self._is_overriden_recursive( + per_batch_transform_overriden: bool = self._is_overriden_recursive( "per_batch_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] ) - per_sample_transform_on_device_overriden = self._is_overriden_recursive( + per_sample_transform_on_device_overriden: bool = self._is_overriden_recursive( "per_sample_transform_on_device", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] ) - skip_mutual_check = getattr(preprocess, "skip_mutual_check", False) + skip_mutual_check: bool = getattr(preprocess, "skip_mutual_check", False) if (not skip_mutual_check and per_batch_transform_overriden and per_sample_transform_on_device_overriden): raise MisconfigurationException( @@ -562,7 +562,7 @@ def to_dataloader( return DataLoader(self._generate_auto_dataset(data), **loader_kwargs) def __str__(self) -> str: - preprocess = self._preprocess_pipeline + preprocess: Preprocess = self._preprocess_pipeline postprocess = self._postprocess_pipeline return f"{self.__class__.__name__}(preprocess={preprocess}, postprocess={postprocess})" diff --git a/flash/data/process.py b/flash/data/process.py index 4dc1fea4c0..8fd16e5359 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -125,16 +125,16 @@ def skip_mutual_check(self) -> bool: def skip_mutual_check(self, skip_mutual_check: bool) -> None: self._skip_mutual_check = skip_mutual_check - def _identify(self, x): + def _identify(self, x: Any) -> Any: return x - def _get_transform(self, transform: Dict[str, Callable]): + def _get_transform(self, transform: Dict[str, Callable]) -> Callable: if self.current_fn in transform: return transform[self.current_fn] return self._identify @property - def current_transform(self): + def current_transform(self) -> Callable: if self.training and self.train_transform: return self._get_transform(self.train_transform) elif self.validating and self.val_transform: diff --git a/flash/tabular/classification/data/data.py b/flash/tabular/classification/data/data.py index fa62d4e6ca..58f583e524 100644 --- a/flash/tabular/classification/data/data.py +++ b/flash/tabular/classification/data/data.py @@ -349,7 +349,7 @@ def from_df( is_regression, preprocess_state=preprocess_state ) - preprocess = preprocess_cls.from_state(preprocess_state) + preprocess: Preprocess = preprocess_cls.from_state(preprocess_state) return cls.from_load_data_inputs( train_load_data_input=train_df, diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 6e7f99ba16..d66c9bb355 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -342,8 +342,7 @@ def instantiate_preprocess( ) preprocess_cls = preprocess_cls or cls.preprocess_cls - preprocess = preprocess_cls(train_transform, val_transform, test_transform, predict_transform) - # todo (tchaton) add check on mutually exclusive transforms + preprocess: Preprocess = preprocess_cls(train_transform, val_transform, test_transform, predict_transform) return preprocess @classmethod diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index 0154377160..ed7ebe60b9 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -130,25 +130,25 @@ def test_per_batch_transform_on_device(self, *_, **__): preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess) - train_func_names = { + train_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } - val_func_names = { + val_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.VALIDATING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } - test_func_names = { + test_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TESTING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } - predict_func_names = { + predict_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.PREDICTING, Preprocess ) From 84eaa68c7a6c3d9f7c34dbb3553e1027d3507e7a Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Apr 2021 11:48:21 +0100 Subject: [PATCH 14/14] resolve bug --- flash/data/auto_dataset.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index bc05f8c441..5652496c10 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -102,16 +102,12 @@ def _setup(self, stage: Optional[RunningStage]) -> None: if self._running_stage and self.data_pipeline and (not self.load_data or not self.load_sample) and stage: self.load_data = getattr( - self.data_pipeline.preprocess_pipeline, - self.data_pipeline._resolve_function_hierarchy( - 'load_data', self.data_pipeline.preprocess_pipeline, stage, Preprocess - ) + self.preprocess, + self.data_pipeline._resolve_function_hierarchy('load_data', self.preprocess, stage, Preprocess) ) self.load_sample = getattr( - self.data_pipeline.preprocess_pipeline, - self.data_pipeline._resolve_function_hierarchy( - 'load_sample', self.data_pipeline.preprocess_pipeline, stage, Preprocess - ) + self.preprocess, + self.data_pipeline._resolve_function_hierarchy('load_sample', self.preprocess, stage, Preprocess) ) if self.load_data and (previous_load_data != self.load_data.__code__ or not self._load_data_called): if previous_load_data: