Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Support PL 1.5.0 #933

Merged
merged 17 commits into from
Nov 5, 2021
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Fixed a bug where validation metrics could be aggregated together with test metrics in some cases ([#900](https://github.com/PyTorchLightning/lightning-flash/pull/900))


- Fixed a bug where the latest versions of torchmetrics and Lightning Flash could not be installed together ([#902](https://github.com/PyTorchLightning/lightning-flash/pull/902))


- Fixed compatibility with PyTorch-Lightning 1.5 ([#933](https://github.com/PyTorchLightning/lightning-flash/pull/933))


## [0.5.1] - 2021-10-26

### Added
Expand Down
50 changes: 41 additions & 9 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,40 @@
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union

import torch
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from torch.utils.data import DataLoader, IterableDataset

import flash
from flash.core.data.auto_dataset import IterableAutoDataset
from flash.core.data.batch import _DeserializeProcessor, _Postprocessor, _Preprocessor, _Sequential, _SerializeProcessor
from flash.core.data.data_source import DataSource
from flash.core.data.process import DefaultPreprocess, Deserializer, Postprocess, Preprocess, Serializer
from flash.core.data.properties import ProcessState
from flash.core.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3, _PL_GREATER_EQUAL_1_5_0
from flash.core.utilities.stages import _RUNNING_STAGE_MAPPING, RunningStage

if not _PL_GREATER_EQUAL_1_5_0:
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader

if TYPE_CHECKING:
from flash.core.model import Task


class DataLoaderGetter:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still wraps and patches the dataloader, right? Is there no way around this?
Will this patch be assigned to back to the loader or just used internally?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, unfortunately this was the only solution for now, but the upcoming data pipeline refactor should remove this patching entirely

"""A utility class to be used when patching the ``{stage}_dataloader`` attribute of a LightningModule."""

def __init__(self, dataloader):
self.dataloader = dataloader

# Dummy `__code__` attribute to trick is_overridden
self.__code__ = self.__call__.__code__

def __call__(self):
return self.dataloader


class DataPipelineState:
"""A class to store and share all process states once a :class:`.DataPipeline` has been initialized."""

Expand Down Expand Up @@ -315,16 +331,34 @@ def _get_dataloader(model: "Task", loader_name: str) -> Tuple[DataLoader, str]:
dataloader = getattr(model, loader_name)
attr_name = loader_name

elif model.trainer and hasattr(model.trainer, "datamodule") and model.trainer.datamodule:
dataloader = getattr(model, f"trainer.datamodule.{loader_name}", None)
elif (
model.trainer
and hasattr(model.trainer, "datamodule")
and model.trainer.datamodule
and is_overridden(loader_name, model.trainer.datamodule, flash.DataModule)
):
dataloader = getattr(model.trainer.datamodule, loader_name, None)
attr_name = f"trainer.datamodule.{loader_name}"

elif _PL_GREATER_EQUAL_1_5_0 and model.trainer is not None:
source = getattr(model.trainer._data_connector, f"_{loader_name}_source")
if not source.is_module():
dataloader = source.dataloader()
attr_name = loader_name

if dataloader is not None:
# Update source as wrapped loader will be attached to model
source.instance = model
source.name = loader_name

return dataloader, attr_name

@staticmethod
def _patch_dataloader(model: "Task", dataloader: Union[Callable, DataLoader], stage: RunningStage):
if isinstance(dataloader, DataLoader):
if _PL_GREATER_EQUAL_1_4_3:
if _PL_GREATER_EQUAL_1_5_0:
dataloader = DataLoaderGetter(dataloader)
elif _PL_GREATER_EQUAL_1_4_3:
dataloader = _PatchDataLoader(dataloader, _STAGES_PREFIX[stage])
dataloader.patch(model)
else:
Expand Down Expand Up @@ -369,7 +403,7 @@ def _attach_preprocess_to_model(
if not dataloader:
continue

if isinstance(dataloader, (_PatchDataLoader, Callable)):
if callable(dataloader):
dataloader = dataloader()

if dataloader is None:
Expand Down Expand Up @@ -504,9 +538,7 @@ def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[Runnin
if not dataloader:
continue

if isinstance(dataloader, _PatchDataLoader):
dataloader = dataloader()
elif isinstance(dataloader, Callable):
if callable(dataloader):
dataloader = dataloader()

if isinstance(dataloader, Sequence):
Expand Down
8 changes: 1 addition & 7 deletions flash/core/data/new_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Tuple, Type, TYPE_CHECKING, Union
from typing import Any, Optional, Tuple, Type, Union

import pytorch_lightning as pl
import torch
Expand All @@ -30,14 +30,8 @@
from flash.core.data.datasets import BaseDataset
from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE
from flash.core.utilities.stages import RunningStage

if _FIFTYONE_AVAILABLE and TYPE_CHECKING:
from fiftyone.core.collections import SampleCollection
else:
SampleCollection = None


class DataModule(DataModule):
"""A basic DataModule class for all Flash tasks. This class includes references to a
Expand Down
17 changes: 14 additions & 3 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.argparse import add_argparse_args, get_init_arguments_and_types, parse_env_variables
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from torch.utils.data import DataLoader

import flash
from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks
from flash.core.utilities.imports import _SERVE_AVAILABLE
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, _SERVE_AVAILABLE


def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
Expand Down Expand Up @@ -277,14 +278,24 @@ def request_dataloader(
The dataloader
"""
model, stage, is_legacy = self._parse_request_dataloader_args(args, kwargs)

if is_legacy:
self.call_hook(f"on_{stage}_dataloader")
dataloader = getattr(model, f"{stage}_dataloader")()
else:
hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
dataloader = self.call_hook(hook, pl_module=model)

if is_overridden(hook, model):
dataloader = self.call_hook(hook, pl_module=model)
elif _PL_GREATER_EQUAL_1_5_0:
source = getattr(self._data_connector, f"_{stage.dataloader_prefix}_dataloader_source")
dataloader = source.dataloader()

if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.accelerator.barrier("get_dataloaders")
if _PL_GREATER_EQUAL_1_5_0:
self.training_type_plugin.barrier("get_dataloaders")
else:
self.accelerator.barrier("get_dataloaders")
return dataloader
20 changes: 20 additions & 0 deletions flash/core/utilities/compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Trainer


def accelerator_connector(trainer: Trainer):
if hasattr(trainer, "_accelerator_connector"):
return trainer._accelerator_connector
return trainer.accelerator_connector
3 changes: 2 additions & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _compare_version(package: str, op, version) -> bool:
_PIL_AVAILABLE = _module_available("PIL")
_OPEN3D_AVAILABLE = _module_available("open3d")
_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch")
_FASTFACE_AVAILABLE = _module_available("fastface")
_FASTFACE_AVAILABLE = _module_available("fastface") and _compare_version("pytorch_lightning", operator.lt, "1.5.0")
_LIBROSA_AVAILABLE = _module_available("librosa")
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
Expand Down Expand Up @@ -118,6 +118,7 @@ class Image:
if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
_PL_GREATER_EQUAL_1_4_3 = _compare_version("pytorch_lightning", operator.ge, "1.4.3")
_PL_GREATER_EQUAL_1_5_0 = _compare_version("pytorch_lightning", operator.ge, "1.5.0")

_TEXT_AVAILABLE = all(
[
Expand Down
12 changes: 11 additions & 1 deletion flash/image/classification/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import os
from collections import defaultdict
Expand All @@ -31,6 +32,7 @@
from flash.core.data.data_source import DefaultDataKeys
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.compatibility import accelerator_connector
from flash.core.utilities.imports import _LEARN2LEARN_AVAILABLE
from flash.core.utilities.providers import _LEARN2LEARN
from flash.core.utilities.url_error import catch_url_error
Expand Down Expand Up @@ -183,9 +185,17 @@ def __init__(

self.model = self.algorithm_cls(**algorithm_kwargs)

# Patch log to avoid error with learn2learn and PL 1.5
self.model.log = functools.partial(self._patch_log, self.model.log)

# this algorithm requires a special treatment
self._algorithm_has_validated = self.algorithm_cls != l2l.algorithms.LightningPrototypicalNetworks

def _patch_log(self, log, *args, on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, **kwargs):
if not on_step and not on_epoch:
on_epoch = True
return log(*args, on_step=on_step, on_epoch=on_epoch, **kwargs)

def _default_transform(self, dataset, ways: int, shots: int, queries) -> List[Callable]:
return [
l2l.data.transforms.FusedNWaysKShots(dataset, n=ways, k=shots + queries),
Expand Down Expand Up @@ -268,7 +278,7 @@ def _convert_dataset(
devices = 1
if isinstance(trainer.training_type_plugin, DataParallelPlugin):
# when using DP, we need to sample n tasks, so it can splitted across multiple devices.
devices = trainer.accelerator_connector.devices
devices = accelerator_connector(trainer).devices
dataset = TaskDataParallel(taskset, epoch_length=epoch_length, devices=devices, collate_fn=None)
self.trainer.accumulated_grad_batches = self.meta_batch_size / devices

Expand Down
44 changes: 33 additions & 11 deletions flash/image/classification/integrations/baal/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,24 @@
from typing import Any, Dict, Optional

import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus
from pytorch_lightning.utilities.model_helpers import is_overridden

import flash
from flash.core.data.data_pipeline import DataLoaderGetter
from flash.core.data.utils import _STAGES_PREFIX
from flash.core.utilities.imports import requires
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, requires
from flash.core.utilities.stages import RunningStage
from flash.image.classification.integrations.baal.data import ActiveLearningDataModule
from flash.image.classification.integrations.baal.dropout import InferenceMCDropoutTask

if not _PL_GREATER_EQUAL_1_5_0:
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader


class ActiveLearningLoop(Loop):
@requires("baal")
Expand Down Expand Up @@ -133,35 +138,52 @@ def __getattr__(self, key):
return getattr(self.fit_loop, key)
return self.__dict__[key]

def _connect(self, model: LightningModule):
if _PL_GREATER_EQUAL_1_5_0:
self.trainer.training_type_plugin.connect(model)
else:
self.trainer.accelerator.connect(model)

def _reset_fitting(self):
self.trainer.state.fn = TrainerFn.FITTING
self.trainer.training = True
self.trainer.lightning_module.on_train_dataloader()
self.trainer.accelerator.connect(self._lightning_module)
self._connect(self._lightning_module)
self.fit_loop.epoch_progress = Progress()

def _reset_predicting(self):
self.trainer.state.fn = TrainerFn.PREDICTING
self.trainer.predicting = True
self.trainer.lightning_module.on_predict_dataloader()
self.trainer.accelerator.connect(self.inference_model)
self._connect(self.inference_model)

def _reset_testing(self):
self.trainer.state.fn = TrainerFn.TESTING
self.trainer.state.status = TrainerStatus.RUNNING
self.trainer.testing = True
self.trainer.lightning_module.on_test_dataloader()
self.trainer.accelerator.connect(self._lightning_module)
self._connect(self._lightning_module)

def _reset_dataloader_for_stage(self, running_state: RunningStage):
dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader"
# If the dataloader exists, we reset it.
dataloader = getattr(self.trainer.datamodule, dataloader_name, None)
dataloader = (
getattr(self.trainer.datamodule, dataloader_name)
if is_overridden(dataloader_name, self.trainer.datamodule)
else None
)
if dataloader:
setattr(
self.trainer.lightning_module,
dataloader_name,
_PatchDataLoader(dataloader(), running_state),
)
if _PL_GREATER_EQUAL_1_5_0:
setattr(
self.trainer.lightning_module,
dataloader_name,
DataLoaderGetter(dataloader()),
)
else:
setattr(
self.trainer.lightning_module,
dataloader_name,
_PatchDataLoader(dataloader(), running_state),
)
setattr(self.trainer, dataloader_name, None)
getattr(self.trainer, f"reset_{dataloader_name}")(self.trainer.lightning_module)
3 changes: 2 additions & 1 deletion flash/image/embedding/vissl/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pytorch_lightning.core.hooks import ModelHooks

import flash
from flash.core.utilities.compatibility import accelerator_connector
from flash.core.utilities.imports import _VISSL_AVAILABLE

if _VISSL_AVAILABLE:
Expand Down Expand Up @@ -48,7 +49,7 @@ def on_start(self, task: "flash.image.embedding.vissl.adapter.MockVISSLTask") ->

# get around vissl distributed training by setting MockTask flags
num_nodes = lightning_module.trainer.num_nodes
accelerators_ids = lightning_module.trainer.accelerator_connector.parallel_device_ids
accelerators_ids = accelerator_connector(lightning_module.trainer).parallel_device_ids
accelerator_per_node = len(accelerators_ids) if accelerators_ids is not None else 1
task.world_size = num_nodes * accelerator_per_node

Expand Down
5 changes: 3 additions & 2 deletions flash/video/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from flash.core.classification import ClassificationTask, Labels
from flash.core.data.data_source import DefaultDataKeys
from flash.core.registry import FlashRegistry
from flash.core.utilities.compatibility import accelerator_connector
from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE
from flash.core.utilities.providers import _PYTORCHVIDEO
from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE
Expand Down Expand Up @@ -146,13 +147,13 @@ def __init__(
)

def on_train_start(self) -> None:
if self.trainer.accelerator_connector.is_distributed:
if accelerator_connector(self.trainer).is_distributed:
encoded_dataset = self.trainer.train_dataloader.loaders.dataset.dataset
encoded_dataset._video_sampler = DistributedSampler(encoded_dataset._labeled_videos)
super().on_train_start()

def on_train_epoch_start(self) -> None:
if self.trainer.accelerator_connector.is_distributed:
if accelerator_connector(self.trainer).is_distributed:
encoded_dataset = self.trainer.train_dataloader.loaders.dataset.dataset
encoded_dataset._video_sampler.set_epoch(self.trainer.current_epoch)
super().on_train_epoch_start()
Expand Down
Loading