Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Support PL 1.5.0 #933

Merged
merged 17 commits into from
Nov 5, 2021
35 changes: 26 additions & 9 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,40 @@
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union

import torch
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from torch.utils.data import DataLoader, IterableDataset

import flash
from flash.core.data.auto_dataset import IterableAutoDataset
from flash.core.data.batch import _DeserializeProcessor, _Postprocessor, _Preprocessor, _Sequential, _SerializeProcessor
from flash.core.data.data_source import DataSource
from flash.core.data.process import DefaultPreprocess, Deserializer, Postprocess, Preprocess, Serializer
from flash.core.data.properties import ProcessState
from flash.core.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3, _PL_GREATER_EQUAL_1_5_0
from flash.core.utilities.stages import _RUNNING_STAGE_MAPPING, RunningStage

if not _PL_GREATER_EQUAL_1_5_0:
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader

if TYPE_CHECKING:
from flash.core.model import Task


class DataLoaderGetter:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still wraps and patches the dataloader, right? Is there no way around this?
Will this patch be assigned to back to the loader or just used internally?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, unfortunately this was the only solution for now, but the upcoming data pipeline refactor should remove this patching entirely

"""A utility class to be used when patching the ``{stage}_dataloader`` attribute of a LightningModule."""

def __init__(self, dataloader):
self.dataloader = dataloader

# Dummy `__code__` attribute to trick is_overridden
self.__code__ = self.__call__.__code__

def __call__(self):
return self.dataloader


class DataPipelineState:
"""A class to store and share all process states once a :class:`.DataPipeline` has been initialized."""

Expand Down Expand Up @@ -316,15 +332,18 @@ def _get_dataloader(model: "Task", loader_name: str) -> Tuple[DataLoader, str]:
attr_name = loader_name

elif model.trainer and hasattr(model.trainer, "datamodule") and model.trainer.datamodule:
dataloader = getattr(model, f"trainer.datamodule.{loader_name}", None)
attr_name = f"trainer.datamodule.{loader_name}"
if is_overridden(loader_name, model.trainer.datamodule, flash.DataModule):
dataloader = getattr(model.trainer.datamodule, loader_name, None)
attr_name = f"trainer.datamodule.{loader_name}"

return dataloader, attr_name

@staticmethod
def _patch_dataloader(model: "Task", dataloader: Union[Callable, DataLoader], stage: RunningStage):
if isinstance(dataloader, DataLoader):
if _PL_GREATER_EQUAL_1_4_3:
if _PL_GREATER_EQUAL_1_5_0:
dataloader = DataLoaderGetter(dataloader)
elif _PL_GREATER_EQUAL_1_4_3:
dataloader = _PatchDataLoader(dataloader, _STAGES_PREFIX[stage])
dataloader.patch(model)
else:
Expand Down Expand Up @@ -369,7 +388,7 @@ def _attach_preprocess_to_model(
if not dataloader:
continue

if isinstance(dataloader, (_PatchDataLoader, Callable)):
if callable(dataloader):
dataloader = dataloader()

if dataloader is None:
Expand Down Expand Up @@ -504,9 +523,7 @@ def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[Runnin
if not dataloader:
continue

if isinstance(dataloader, _PatchDataLoader):
dataloader = dataloader()
elif isinstance(dataloader, Callable):
if callable(dataloader):
dataloader = dataloader()

if isinstance(dataloader, Sequence):
Expand Down
8 changes: 1 addition & 7 deletions flash/core/data/new_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Tuple, Type, TYPE_CHECKING, Union
from typing import Any, Optional, Tuple, Type, Union

import pytorch_lightning as pl
import torch
Expand All @@ -30,14 +30,8 @@
from flash.core.data.datasets import BaseDataset
from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE
from flash.core.utilities.stages import RunningStage

if _FIFTYONE_AVAILABLE and TYPE_CHECKING:
from fiftyone.core.collections import SampleCollection
else:
SampleCollection = None


class DataModule(DataModule):
"""A basic DataModule class for all Flash tasks. This class includes references to a
Expand Down
22 changes: 19 additions & 3 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
from torch.utils.data import DataLoader

import flash
from flash.core.data.data_module import DataModule
from flash.core.data.new_data_module import DataModule as NewDataModule
from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks
from flash.core.utilities.imports import _SERVE_AVAILABLE
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, _SERVE_AVAILABLE


def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
Expand Down Expand Up @@ -277,14 +279,28 @@ def request_dataloader(
The dataloader
"""
model, stage, is_legacy = self._parse_request_dataloader_args(args, kwargs)
dataloader = None
if is_legacy:
self.call_hook(f"on_{stage}_dataloader")
dataloader = getattr(model, f"{stage}_dataloader")()
else:
elif _PL_GREATER_EQUAL_1_5_0:
source = getattr(self._data_connector, f"_{stage.dataloader_prefix}_dataloader_source")
if (
not source.is_module()
or not isinstance(source.instance, DataModule)
or isinstance(source.instance, NewDataModule)
):
dataloader = source.dataloader()

if dataloader is None:
hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
dataloader = self.call_hook(hook, pl_module=model)

if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.accelerator.barrier("get_dataloaders")
if _PL_GREATER_EQUAL_1_5_0:
self.training_type_plugin.barrier("get_dataloaders")
else:
self.accelerator.barrier("get_dataloaders")
return dataloader
3 changes: 2 additions & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _compare_version(package: str, op, version) -> bool:
_PIL_AVAILABLE = _module_available("PIL")
_OPEN3D_AVAILABLE = _module_available("open3d")
_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch")
_FASTFACE_AVAILABLE = _module_available("fastface")
_FASTFACE_AVAILABLE = _module_available("fastface") and _compare_version("pytorch_lightning", operator.lt, "1.5.0")
_LIBROSA_AVAILABLE = _module_available("librosa")
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
Expand Down Expand Up @@ -118,6 +118,7 @@ class Image:
if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
_PL_GREATER_EQUAL_1_4_3 = _compare_version("pytorch_lightning", operator.ge, "1.4.3")
_PL_GREATER_EQUAL_1_5_0 = _compare_version("pytorch_lightning", operator.ge, "1.5.0")

_TEXT_AVAILABLE = all(
[
Expand Down
37 changes: 27 additions & 10 deletions flash/image/classification/integrations/baal/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@
from typing import Any, Dict, Optional

import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus

import flash
from flash.core.data.data_pipeline import DataLoaderGetter
from flash.core.data.utils import _STAGES_PREFIX
from flash.core.utilities.imports import requires
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, requires
from flash.core.utilities.stages import RunningStage
from flash.image.classification.integrations.baal.data import ActiveLearningDataModule
from flash.image.classification.integrations.baal.dropout import InferenceMCDropoutTask

if not _PL_GREATER_EQUAL_1_5_0:
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader


class ActiveLearningLoop(Loop):
@requires("baal")
Expand Down Expand Up @@ -133,35 +137,48 @@ def __getattr__(self, key):
return getattr(self.fit_loop, key)
return self.__dict__[key]

def _connect(self, model: LightningModule):
if _PL_GREATER_EQUAL_1_5_0:
self.trainer.training_type_plugin.connect(model)
else:
self.trainer.accelerator.connect(model)

def _reset_fitting(self):
self.trainer.state.fn = TrainerFn.FITTING
self.trainer.training = True
self.trainer.lightning_module.on_train_dataloader()
self.trainer.accelerator.connect(self._lightning_module)
self._connect(self._lightning_module)
self.fit_loop.epoch_progress = Progress()

def _reset_predicting(self):
self.trainer.state.fn = TrainerFn.PREDICTING
self.trainer.predicting = True
self.trainer.lightning_module.on_predict_dataloader()
self.trainer.accelerator.connect(self.inference_model)
self._connect(self.inference_model)

def _reset_testing(self):
self.trainer.state.fn = TrainerFn.TESTING
self.trainer.state.status = TrainerStatus.RUNNING
self.trainer.testing = True
self.trainer.lightning_module.on_test_dataloader()
self.trainer.accelerator.connect(self._lightning_module)
self._connect(self._lightning_module)

def _reset_dataloader_for_stage(self, running_state: RunningStage):
dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader"
# If the dataloader exists, we reset it.
dataloader = getattr(self.trainer.datamodule, dataloader_name, None)
if dataloader:
setattr(
self.trainer.lightning_module,
dataloader_name,
_PatchDataLoader(dataloader(), running_state),
)
if _PL_GREATER_EQUAL_1_5_0:
setattr(
self.trainer.lightning_module,
dataloader_name,
DataLoaderGetter(dataloader()),
)
else:
setattr(
self.trainer.lightning_module,
dataloader_name,
_PatchDataLoader(dataloader(), running_state),
)
setattr(self.trainer, dataloader_name, None)
getattr(self.trainer, f"reset_{dataloader_name}")(self.trainer.lightning_module)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ packaging
numpy
torch>=1.7.1
torchmetrics>=0.4.0,!=0.5.1
pytorch-lightning==1.4.9
pytorch-lightning>=1.4.0
pyDeprecate
pandas<1.3.0
jsonargparse[signatures]>=3.17.0
Expand Down
2 changes: 1 addition & 1 deletion tests/core/data/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
import numpy as np
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor, tensor
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate

from flash import Trainer
from flash.core.data.auto_dataset import IterableAutoDataset
from flash.core.data.batch import _Postprocessor, _Preprocessor
from flash.core.data.data_module import DataModule
Expand Down
5 changes: 0 additions & 5 deletions tests/image/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ def test_from_filepaths_smoke(tmpdir):
num_workers=0,
)
assert img_data.train_dataloader() is not None
assert img_data.val_dataloader() is None
assert img_data.test_dataloader() is None

data = next(iter(img_data.train_dataloader()))
imgs, labels = data["input"], data["target"]
Expand Down Expand Up @@ -275,9 +273,6 @@ def test_from_folders_only_train(tmpdir):
assert imgs.shape == (1, 3, 196, 196)
assert labels.shape == (1,)

assert img_data.val_dataloader() is None
assert img_data.test_dataloader() is None


@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_from_folders_train_val(tmpdir):
Expand Down