Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[feat] Add support for running_stage and current_fn in all Preprocess hook (1 / 2) #200

Merged
merged 15 commits into from
Apr 1, 2021
3 changes: 1 addition & 2 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Union
from typing import Any

import torch
from torch import Tensor

from flash.core.model import Task
from flash.data.process import Postprocess
Expand Down
14 changes: 12 additions & 2 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import inspect
from copy import deepcopy
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union

import torch
Expand Down Expand Up @@ -244,13 +245,22 @@ def on_fit_end(self) -> None:
self.data_pipeline._detach_from_model(self)
super().on_fit_end()

@staticmethod
def _sanetize_funcs(obj: Any) -> Any:
if hasattr(obj, "__dict__"):
for k, v in obj.__dict__.items():
if isinstance(v, Callable):
obj.__dict__[k] = inspect.unwrap(v)
return obj

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# TODO: Is this the best way to do this? or should we also use some kind of hparams here?
# This may be an issue since here we create the same problems with pickle as in
# https://pytorch.org/docs/stable/notes/serialization.html

if self.data_pipeline is not None and 'data_pipeline' not in checkpoint:
self._preprocess = self._sanetize_funcs(self._preprocess)
checkpoint['data_pipeline'] = self.data_pipeline
# todo (tchaton) re-wrap visualization
super().on_save_checkpoint(checkpoint)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
Expand Down
28 changes: 14 additions & 14 deletions flash/data/auto_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.utils.data import Dataset

from flash.data.process import Preprocess
from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES
from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, CurrentRunningStageFuncContext

if TYPE_CHECKING:
from flash.data.data_pipeline import DataPipeline
Expand Down Expand Up @@ -68,11 +68,20 @@ def running_stage(self) -> Optional[RunningStage]:
return self._running_stage

@running_stage.setter
def running_stage(self, running_stage: str) -> None:
def running_stage(self, running_stage: RunningStage) -> None:
if self._running_stage != running_stage or (not self._running_stage):
self._running_stage = running_stage
self._load_data_context = CurrentRunningStageFuncContext(self._running_stage, "load_data", self._preprocess)
self._load_sample_context = CurrentRunningStageFuncContext(
self._running_stage, "load_sample", self._preprocess
)
self._setup(running_stage)

@property
def _preprocess(self):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if self.data_pipeline is not None:
return self.data_pipeline._preprocess_pipeline

def _call_load_data(self, data: Any) -> Iterable:
parameters = signature(self.load_data).parameters
if len(parameters) > 1 and self.DATASET_KEY in parameters:
Expand Down Expand Up @@ -110,25 +119,16 @@ def _setup(self, stage: Optional[RunningStage]) -> None:
"The load_data function of the Autogenerated Dataset changed. "
"This is not expected! Preloading Data again to ensure compatibility. This may take some time."
)
with self._set_running_stage(stage):
with self._load_data_context:
self._preprocessed_data = self._call_load_data(self.data)
self._load_data_called = True

@contextmanager
def _set_running_stage(self, stage: RunningStage) -> None:
if self.load_data:
if self.data_pipeline and self.data_pipeline._preprocess_pipeline:
self.data_pipeline._preprocess_pipeline._running_stage = stage
yield
if self.load_data:
if self.data_pipeline and self.data_pipeline._preprocess_pipeline:
self.data_pipeline._preprocess_pipeline._running_stage = None

def __getitem__(self, index: int) -> Any:
if not self.load_sample and not self.load_data:
raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.")
if self.load_sample:
return self._call_load_sample(self._preprocessed_data[index])
with self._load_sample_context:
return self._call_load_sample(self._preprocessed_data[index])
return self._preprocessed_data[index]

def __len__(self) -> int:
Expand Down
78 changes: 57 additions & 21 deletions flash/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Optional, Sequence, Union
from typing import Any, Callable, Mapping, Optional, Sequence, TYPE_CHECKING, Union

import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor

from flash.data.utils import _contains_any_tensor, convert_to_modules
from flash.data.utils import _contains_any_tensor, convert_to_modules, CurrentFuncContext, CurrentRunningStageContext

if TYPE_CHECKING:
from flash.data.process import Preprocess


class _Sequential(torch.nn.Module):
Expand All @@ -31,29 +34,45 @@ class _Sequential(torch.nn.Module):

def __init__(
self,
preprocess: 'Preprocess',
pre_tensor_transform: Callable,
to_tensor_transform: Callable,
post_tensor_transform: Callable,
assert_contains_tensor: bool = False
stage: RunningStage,
assert_contains_tensor: bool = False,
):
super().__init__()

self.preprocess = preprocess
self.pre_tensor_transform = convert_to_modules(pre_tensor_transform)
self.to_tensor_transform = convert_to_modules(to_tensor_transform)
self.post_tensor_transform = convert_to_modules(post_tensor_transform)
self.stage = stage
self.assert_contains_tensor = assert_contains_tensor

self._current_stage_context = CurrentRunningStageContext(stage, preprocess, reset=False)
self._pre_tensor_transform_context = CurrentFuncContext("pre_tensor_transform", preprocess)
self._to_tensor_transform_context = CurrentFuncContext("to_tensor_transform", preprocess)
self._post_tensor_transform_context = CurrentFuncContext("post_tensor_transform", preprocess)

def forward(self, sample: Any):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
sample = self.pre_tensor_transform(sample)
sample = self.to_tensor_transform(sample)
if self.assert_contains_tensor:
if not _contains_any_tensor(sample):
raise MisconfigurationException(
"When ``to_tensor_transform`` is overriden, "
"``DataPipeline`` expects the outputs to be ``tensors``"
)
sample = self.post_tensor_transform(sample)
return sample
with self._current_stage_context:
with self._pre_tensor_transform_context:
sample = self.pre_tensor_transform(sample)

with self._to_tensor_transform_context:
sample = self.to_tensor_transform(sample)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

if self.assert_contains_tensor:
if not _contains_any_tensor(sample):
raise MisconfigurationException(
"When ``to_tensor_transform`` is overriden, "
"``DataPipeline`` expects the outputs to be ``tensors``"
)

with self._post_tensor_transform_context:
sample = self.post_tensor_transform(sample)

return sample

def __str__(self) -> str:
repr_str = f'{self.__class__.__name__}:'
Expand Down Expand Up @@ -87,26 +106,43 @@ class _PreProcessor(torch.nn.Module):

def __init__(
self,
preprocess: 'Preprocess',
collate_fn: Callable,
per_sample_transform: Union[Callable, _Sequential],
per_batch_transform: Callable,
stage: Optional[RunningStage] = None,
stage: RunningStage,
apply_per_sample_transform: bool = True,
on_device: bool = False
):
super().__init__()
self.preprocess = preprocess
self.collate_fn = convert_to_modules(collate_fn)
self.per_sample_transform = convert_to_modules(per_sample_transform)
self.per_batch_transform = convert_to_modules(per_batch_transform)
self.apply_per_sample_transform = apply_per_sample_transform
self.stage = stage
self.on_device = on_device

extension = f"{'on_device' if self.on_device else ''}"
self._current_stage_context = CurrentRunningStageContext(stage, preprocess)
self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform_{extension}", preprocess)
self._collate_context = CurrentFuncContext("collate", preprocess)
self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform_{extension}", preprocess)

def forward(self, samples: Sequence[Any]):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if self.apply_per_sample_transform:
samples = [self.per_sample_transform(sample) for sample in samples]
samples = type(samples)(samples)
samples = self.collate_fn(samples)
samples = self.per_batch_transform(samples)
return samples
with self._current_stage_context:

if self.apply_per_sample_transform:
with self._per_sample_transform_context:
samples = [self.per_sample_transform(sample) for sample in samples]
samples = type(samples)(samples)

with self._collate_context:
samples = self.collate_fn(samples)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

with self._per_batch_transform_context:
samples = self.per_batch_transform(samples)
return samples

def __str__(self) -> str:
# todo: define repr function which would take object and string attributes to be shown
Expand Down
1 change: 1 addition & 0 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def from_load_data_inputs(
)
else:
data_pipeline = cls(**kwargs).data_pipeline

train_dataset = cls._generate_dataset_if_possible(
train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline
)
Expand Down
48 changes: 29 additions & 19 deletions flash/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import weakref
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union

from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import imports
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data._utils.collate import default_collate, default_convert
from torch.utils.data.dataloader import DataLoader
Expand Down Expand Up @@ -108,11 +110,11 @@ def forward(self, samples: Sequence[Any]):
post_tensor_transform
┌────────────────┴───────────────────┐
(move Data to main worker) --> │ │
(move list to main worker) --> │ │
per_sample_transform_on_device collate
tchaton marked this conversation as resolved.
Show resolved Hide resolved
│ │
collate per_batch_transform
│ │ <-- (move Data to main worker)
│ │ <-- (move batch to main worker)
per_batch_transform_on_device per_batch_transform_on_device
│ │
└─────────────────┬──────────────────┘
Expand Down Expand Up @@ -181,7 +183,7 @@ def _is_overriden_recursive(
if not hasattr(process_obj, current_method_name):
return DataPipeline._is_overriden_recursive(method_name, process_obj, super_obj)

current_code = getattr(process_obj, current_method_name).__code__
current_code = inspect.unwrap(getattr(process_obj, current_method_name)).__code__
has_different_code = current_code != getattr(super_obj, method_name).__code__

if not prefix:
Expand Down Expand Up @@ -239,25 +241,29 @@ def _create_collate_preprocessors(
if collate_fn is None:
collate_fn = default_collate

preprocess = self._preprocess_pipeline
tchaton marked this conversation as resolved.
Show resolved Hide resolved

func_names = {
tchaton marked this conversation as resolved.
Show resolved Hide resolved
k: self._resolve_function_hierarchy(k, self._preprocess_pipeline, stage, Preprocess)
k: self._resolve_function_hierarchy(k, preprocess, stage, Preprocess)
for k in self.PREPROCESS_FUNCS
}

if self._is_overriden_recursive("collate", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage]):
collate_fn = getattr(self._preprocess_pipeline, func_names["collate"])
if self._is_overriden_recursive("collate", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage]):
collate_fn = getattr(preprocess, func_names["collate"])
tchaton marked this conversation as resolved.
Show resolved Hide resolved

per_batch_transform_overriden = self._is_overriden_recursive(
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"per_batch_transform", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage]
"per_batch_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage]
)

per_sample_transform_on_device_overriden = self._is_overriden_recursive(
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"per_sample_transform_on_device", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage]
"per_sample_transform_on_device", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage]
)

if per_batch_transform_overriden and per_sample_transform_on_device_overriden:
skip_mutual_check = getattr(preprocess, "skip_mutual_check", False)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

if (not skip_mutual_check and per_batch_transform_overriden and per_sample_transform_on_device_overriden):
raise MisconfigurationException(
f'{self.__class__.__name__}: `per_batch_transform` and `gpu_per_sample_transform` '
f'{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` '
f'are mutual exclusive for stage {stage}'
)

Expand All @@ -278,25 +284,29 @@ def _create_collate_preprocessors(
) else worker_collate_fn

assert_contains_tensor = self._is_overriden_recursive(
"to_tensor_transform", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage]
"to_tensor_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage]
)

worker_preprocessor = _PreProcessor(
worker_collate_fn,
preprocess, worker_collate_fn,
_Sequential(
getattr(self._preprocess_pipeline, func_names['pre_tensor_transform']),
getattr(self._preprocess_pipeline, func_names['to_tensor_transform']),
getattr(self._preprocess_pipeline, func_names['post_tensor_transform']),
preprocess,
getattr(preprocess, func_names['pre_tensor_transform']),
getattr(preprocess, func_names['to_tensor_transform']),
getattr(preprocess, func_names['post_tensor_transform']),
stage,
assert_contains_tensor=assert_contains_tensor,
), getattr(self._preprocess_pipeline, func_names['per_batch_transform']), stage
), getattr(preprocess, func_names['per_batch_transform']), stage
)
worker_preprocessor._original_collate_fn = original_collate_fn
device_preprocessor = _PreProcessor(
preprocess,
device_collate_fn,
getattr(self._preprocess_pipeline, func_names['per_sample_transform_on_device']),
getattr(self._preprocess_pipeline, func_names['per_batch_transform_on_device']),
getattr(preprocess, func_names['per_sample_transform_on_device']),
getattr(preprocess, func_names['per_batch_transform_on_device']),
stage,
apply_per_sample_transform=device_collate_fn != self._identity
apply_per_sample_transform=device_collate_fn != self._identity,
on_device=True,
)
return worker_preprocessor, device_preprocessor

Expand Down
Loading