Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[feat] Add support for running_stage and current_fn in all Preprocess…
Browse files Browse the repository at this point in the history
… hook (1 / 2) (#200)

* wip

* add base_viz + new features for DataPipeline

* update

* resolve flake8

* update

* resolve tests

* update

* resolve doc

* update doc

* update

* update

* convert to staticmethod

* update on comments

* resolve bug
  • Loading branch information
tchaton authored Apr 1, 2021
1 parent 62472d7 commit 3b6a5de
Show file tree
Hide file tree
Showing 12 changed files with 360 additions and 217 deletions.
3 changes: 1 addition & 2 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Union
from typing import Any

import torch
from torch import Tensor

from flash.core.model import Task
from flash.data.process import Postprocess
Expand Down
14 changes: 12 additions & 2 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import inspect
from copy import deepcopy
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union

import torch
Expand Down Expand Up @@ -244,13 +245,22 @@ def on_fit_end(self) -> None:
self.data_pipeline._detach_from_model(self)
super().on_fit_end()

@staticmethod
def _sanetize_funcs(obj: Any) -> Any:
if hasattr(obj, "__dict__"):
for k, v in obj.__dict__.items():
if isinstance(v, Callable):
obj.__dict__[k] = inspect.unwrap(v)
return obj

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# TODO: Is this the best way to do this? or should we also use some kind of hparams here?
# This may be an issue since here we create the same problems with pickle as in
# https://pytorch.org/docs/stable/notes/serialization.html

if self.data_pipeline is not None and 'data_pipeline' not in checkpoint:
self._preprocess = self._sanetize_funcs(self._preprocess)
checkpoint['data_pipeline'] = self.data_pipeline
# todo (tchaton) re-wrap visualization
super().on_save_checkpoint(checkpoint)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
Expand Down
46 changes: 21 additions & 25 deletions flash/data/auto_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.utils.data import Dataset

from flash.data.process import Preprocess
from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES
from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, CurrentRunningStageFuncContext

if TYPE_CHECKING:
from flash.data.data_pipeline import DataPipeline
Expand Down Expand Up @@ -68,11 +68,20 @@ def running_stage(self) -> Optional[RunningStage]:
return self._running_stage

@running_stage.setter
def running_stage(self, running_stage: str) -> None:
def running_stage(self, running_stage: RunningStage) -> None:
if self._running_stage != running_stage or (not self._running_stage):
self._running_stage = running_stage
self._load_data_context = CurrentRunningStageFuncContext(self._running_stage, "load_data", self.preprocess)
self._load_sample_context = CurrentRunningStageFuncContext(
self._running_stage, "load_sample", self.preprocess
)
self._setup(running_stage)

@property
def preprocess(self) -> Optional[Preprocess]:
if self.data_pipeline is not None:
return self.data_pipeline._preprocess_pipeline

def _call_load_data(self, data: Any) -> Iterable:
parameters = signature(self.load_data).parameters
if len(parameters) > 1 and self.DATASET_KEY in parameters:
Expand All @@ -93,45 +102,32 @@ def _setup(self, stage: Optional[RunningStage]) -> None:

if self._running_stage and self.data_pipeline and (not self.load_data or not self.load_sample) and stage:
self.load_data = getattr(
self.data_pipeline._preprocess_pipeline,
self.data_pipeline._resolve_function_hierarchy(
'load_data', self.data_pipeline._preprocess_pipeline, stage, Preprocess
)
self.preprocess,
self.data_pipeline._resolve_function_hierarchy('load_data', self.preprocess, stage, Preprocess)
)
self.load_sample = getattr(
self.data_pipeline._preprocess_pipeline,
self.data_pipeline._resolve_function_hierarchy(
'load_sample', self.data_pipeline._preprocess_pipeline, stage, Preprocess
)
self.preprocess,
self.data_pipeline._resolve_function_hierarchy('load_sample', self.preprocess, stage, Preprocess)
)
if self.load_data and (previous_load_data != self.load_data.__code__ or not self._load_data_called):
if previous_load_data:
rank_zero_warn(
"The load_data function of the Autogenerated Dataset changed. "
"This is not expected! Preloading Data again to ensure compatibility. This may take some time."
)
with self._set_running_stage(stage):
self._preprocessed_data = self._call_load_data(self.data)
with self._load_data_context:
self.preprocessed_data = self._call_load_data(self.data)
self._load_data_called = True

@contextmanager
def _set_running_stage(self, stage: RunningStage) -> None:
if self.load_data:
if self.data_pipeline and self.data_pipeline._preprocess_pipeline:
self.data_pipeline._preprocess_pipeline._running_stage = stage
yield
if self.load_data:
if self.data_pipeline and self.data_pipeline._preprocess_pipeline:
self.data_pipeline._preprocess_pipeline._running_stage = None

def __getitem__(self, index: int) -> Any:
if not self.load_sample and not self.load_data:
raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.")
if self.load_sample:
return self._call_load_sample(self._preprocessed_data[index])
return self._preprocessed_data[index]
with self._load_sample_context:
return self._call_load_sample(self.preprocessed_data[index])
return self.preprocessed_data[index]

def __len__(self) -> int:
if not self.load_sample and not self.load_data:
raise RuntimeError("`__len__` for `load_sample` and `load_data` could not be inferred.")
return len(self._preprocessed_data)
return len(self.preprocessed_data)
82 changes: 59 additions & 23 deletions flash/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Optional, Sequence, Union
from typing import Any, Callable, Mapping, Optional, Sequence, TYPE_CHECKING, Union

import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor

from flash.data.utils import _contains_any_tensor, convert_to_modules
from flash.data.utils import _contains_any_tensor, convert_to_modules, CurrentFuncContext, CurrentRunningStageContext

if TYPE_CHECKING:
from flash.data.process import Preprocess


class _Sequential(torch.nn.Module):
Expand All @@ -31,29 +34,45 @@ class _Sequential(torch.nn.Module):

def __init__(
self,
preprocess: 'Preprocess',
pre_tensor_transform: Callable,
to_tensor_transform: Callable,
post_tensor_transform: Callable,
assert_contains_tensor: bool = False
stage: RunningStage,
assert_contains_tensor: bool = False,
):
super().__init__()

self.preprocess = preprocess
self.pre_tensor_transform = convert_to_modules(pre_tensor_transform)
self.to_tensor_transform = convert_to_modules(to_tensor_transform)
self.post_tensor_transform = convert_to_modules(post_tensor_transform)
self.stage = stage
self.assert_contains_tensor = assert_contains_tensor

def forward(self, sample: Any):
sample = self.pre_tensor_transform(sample)
sample = self.to_tensor_transform(sample)
if self.assert_contains_tensor:
if not _contains_any_tensor(sample):
raise MisconfigurationException(
"When ``to_tensor_transform`` is overriden, "
"``DataPipeline`` expects the outputs to be ``tensors``"
)
sample = self.post_tensor_transform(sample)
return sample
self._current_stage_context = CurrentRunningStageContext(stage, preprocess, reset=False)
self._pre_tensor_transform_context = CurrentFuncContext("pre_tensor_transform", preprocess)
self._to_tensor_transform_context = CurrentFuncContext("to_tensor_transform", preprocess)
self._post_tensor_transform_context = CurrentFuncContext("post_tensor_transform", preprocess)

def forward(self, sample: Any) -> Any:
with self._current_stage_context:
with self._pre_tensor_transform_context:
sample = self.pre_tensor_transform(sample)

with self._to_tensor_transform_context:
sample = self.to_tensor_transform(sample)

if self.assert_contains_tensor:
if not _contains_any_tensor(sample):
raise MisconfigurationException(
"When ``to_tensor_transform`` is overriden, "
"``DataPipeline`` expects the outputs to be ``tensors``"
)

with self._post_tensor_transform_context:
sample = self.post_tensor_transform(sample)

return sample

def __str__(self) -> str:
repr_str = f'{self.__class__.__name__}:'
Expand Down Expand Up @@ -87,26 +106,43 @@ class _PreProcessor(torch.nn.Module):

def __init__(
self,
preprocess: 'Preprocess',
collate_fn: Callable,
per_sample_transform: Union[Callable, _Sequential],
per_batch_transform: Callable,
stage: Optional[RunningStage] = None,
stage: RunningStage,
apply_per_sample_transform: bool = True,
on_device: bool = False
):
super().__init__()
self.preprocess = preprocess
self.collate_fn = convert_to_modules(collate_fn)
self.per_sample_transform = convert_to_modules(per_sample_transform)
self.per_batch_transform = convert_to_modules(per_batch_transform)
self.apply_per_sample_transform = apply_per_sample_transform
self.stage = stage
self.on_device = on_device

extension = f"{'on_device' if self.on_device else ''}"
self._current_stage_context = CurrentRunningStageContext(stage, preprocess)
self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform_{extension}", preprocess)
self._collate_context = CurrentFuncContext("collate", preprocess)
self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform_{extension}", preprocess)

def forward(self, samples: Sequence[Any]) -> Any:
with self._current_stage_context:

if self.apply_per_sample_transform:
with self._per_sample_transform_context:
samples = [self.per_sample_transform(sample) for sample in samples]
samples = type(samples)(samples)

with self._collate_context:
samples = self.collate_fn(samples)

def forward(self, samples: Sequence[Any]):
if self.apply_per_sample_transform:
samples = [self.per_sample_transform(sample) for sample in samples]
samples = type(samples)(samples)
samples = self.collate_fn(samples)
samples = self.per_batch_transform(samples)
return samples
with self._per_batch_transform_context:
samples = self.per_batch_transform(samples)
return samples

def __str__(self) -> str:
# todo: define repr function which would take object and string attributes to be shown
Expand Down
1 change: 1 addition & 0 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def from_load_data_inputs(
)
else:
data_pipeline = cls(**kwargs).data_pipeline

train_dataset = cls._generate_dataset_if_possible(
train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline
)
Expand Down
Loading

0 comments on commit 3b6a5de

Please sign in to comment.