Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
change running_stage (#872)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Oct 14, 2021
1 parent b22e786 commit 4f608b0
Show file tree
Hide file tree
Showing 34 changed files with 126 additions and 52 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,4 @@ timit/
urban8k_images/
__MACOSX
*-v2.0.json
cifar-10*
2 changes: 1 addition & 1 deletion flash/core/data/auto_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from inspect import signature
from typing import Any, Callable, Generic, Iterable, Optional, Sequence, TypeVar

from pytorch_lightning.trainer.states import RunningStage
from torch.utils.data import Dataset, IterableDataset

import flash
from flash.core.data.utils import CurrentRunningStageFuncContext
from flash.core.utilities.stages import RunningStage

DATA_TYPE = TypeVar("DATA_TYPE")

Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/base_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.
from typing import Any, Dict, List, Set

from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash.core.data.callback import BaseDataFetcher
from flash.core.data.utils import _CALLBACK_FUNCS
from flash.core.utilities.apply_func import _is_overriden
from flash.core.utilities.stages import RunningStage


class BaseVisualization(BaseDataFetcher):
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor

Expand All @@ -26,6 +25,7 @@
CurrentFuncContext,
CurrentRunningStageContext,
)
from flash.core.utilities.stages import RunningStage

if TYPE_CHECKING:
from flash.core.data.process import Deserializer, Preprocess, Serializer
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from typing import Any, List, Sequence

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage
from torch import Tensor

import flash
from flash.core.data.utils import _STAGES_PREFIX
from flash.core.utilities.stages import RunningStage


class FlashCallback(Callback):
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import IterableDataset, Subset
Expand All @@ -45,6 +44,7 @@
from flash.core.data.splits import SplitDataset
from flash.core.data.utils import _STAGES_PREFIX
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, requires
from flash.core.utilities.stages import RunningStage

if _FIFTYONE_AVAILABLE and TYPE_CHECKING:
from fiftyone.core.collections import SampleCollection
Expand Down
15 changes: 2 additions & 13 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import torch
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand All @@ -32,6 +31,7 @@
from flash.core.data.properties import ProcessState
from flash.core.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3
from flash.core.utilities.stages import _RUNNING_STAGE_MAPPING, RunningStage

if TYPE_CHECKING:
from flash.core.model import Task
Expand Down Expand Up @@ -582,17 +582,6 @@ def __str__(self) -> str:


class _StageOrchestrator:

# This is used to map ``SANITY_CHECKING`` to ``VALIDATING``
internal_mapping = {
RunningStage.TRAINING: RunningStage.TRAINING,
RunningStage.SANITY_CHECKING: RunningStage.VALIDATING,
RunningStage.VALIDATING: RunningStage.VALIDATING,
RunningStage.TESTING: RunningStage.TESTING,
RunningStage.PREDICTING: RunningStage.PREDICTING,
RunningStage.TUNING: RunningStage.TUNING,
}

def __init__(self, func_to_wrap: Callable, model: "Task") -> None:
self.func = func_to_wrap

Expand All @@ -609,7 +598,7 @@ def __call__(self, *args, **kwargs):
except AttributeError:
stage = self.model.trainer.state.stage

internal_running_state = self.internal_mapping[stage]
internal_running_state = _RUNNING_STAGE_MAPPING[stage]
additional_func = self._stage_mapping.get(internal_running_state, None)

if additional_func:
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import numpy as np
import pandas as pd
import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.enums import LightningEnum
from torch.nn import Module
from torch.utils.data.dataset import Dataset
Expand All @@ -49,6 +48,7 @@
from flash.core.data.properties import ProcessState, Properties
from flash.core.data.utils import CurrentRunningStageFuncContext
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires
from flash.core.utilities.stages import RunningStage

SampleCollection = None
if _FIFTYONE_AVAILABLE:
Expand Down
4 changes: 2 additions & 2 deletions flash/core/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
from functools import partial
from typing import Any, Callable, Iterable, Mapping, Optional, Type, Union

from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import Dataset, IterableDataset

from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.data.properties import Properties
from flash.core.registry import FlashRegistry
from flash.core.utilities.stages import RunningStage

__all__ = [
"BaseDataset",
Expand Down Expand Up @@ -116,7 +116,7 @@ def from_data(
if not running_stage:
raise MisconfigurationException(
"You should provide a running_stage to your dataset"
" `from pytorch_lightning.trainer.states import RunningStage`."
" `from flash.core.utilities.stages import RunningStage`."
)
flash_dataset = cls(**dataset_kwargs, running_stage=running_stage, transform=transform)
flash_dataset.pass_args_to_load_data(*load_data_args)
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data._utils.collate import default_collate
Expand All @@ -26,6 +25,7 @@
from flash.core.data.states import CollateFn
from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX
from flash.core.registry import FlashRegistry
from flash.core.utilities.stages import RunningStage

INPUT_TRANSFORM_TYPE = Optional[
Union["InputTransform", Callable, Tuple[Union[LightningEnum, str], Dict[str, Any]], Union[LightningEnum, str]]
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/new_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import pytorch_lightning as pl
import torch
from pytorch_lightning import LightningDataModule
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader
Expand All @@ -32,6 +31,7 @@
from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE
from flash.core.utilities.stages import RunningStage

if _FIFTYONE_AVAILABLE and TYPE_CHECKING:
from fiftyone.core.collections import SampleCollection
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union

import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor
from torch.utils.data._utils.collate import default_collate
Expand All @@ -30,6 +29,7 @@
from flash.core.data.states import CollateFn
from flash.core.data.transforms import ApplyToKeys
from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext
from flash.core.utilities.stages import RunningStage


class BasePreprocess(ABC):
Expand Down
26 changes: 18 additions & 8 deletions flash/core/data/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
from dataclasses import dataclass
from typing import Dict, Optional, Type, TypeVar

from pytorch_lightning.trainer.states import RunningStage

import flash
from flash.core.utilities.stages import RunningStage


@dataclass(unsafe_hash=True, frozen=True)
Expand Down Expand Up @@ -80,6 +79,17 @@ def training(self, val: bool) -> None:
elif self.training:
self._running_stage = None

@property
def validating(self) -> bool:
return self._running_stage == RunningStage.VALIDATING

@validating.setter
def validating(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.VALIDATING
elif self.validating:
self._running_stage = None

@property
def testing(self) -> bool:
return self._running_stage == RunningStage.TESTING
Expand All @@ -103,12 +113,12 @@ def predicting(self, val: bool) -> None:
self._running_stage = None

@property
def validating(self) -> bool:
return self._running_stage == RunningStage.VALIDATING
def serving(self) -> bool:
return self._running_stage == RunningStage.SERVING

@validating.setter
def validating(self, val: bool) -> None:
@serving.setter
def serving(self, val: bool) -> None:
if val:
self._running_stage = RunningStage.VALIDATING
elif self.validating:
self._running_stage = RunningStage.SERVING
elif self.serving:
self._running_stage = None
3 changes: 2 additions & 1 deletion flash/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

import requests
import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.apply_func import apply_to_collection
from torch import Tensor
from tqdm.auto import tqdm as tq

from flash.core.utilities.stages import RunningStage

_STAGES_PREFIX = {
RunningStage.TRAINING: "train",
RunningStage.TESTING: "test",
Expand Down
2 changes: 1 addition & 1 deletion flash/core/integrations/labelstudio/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union

import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.cloud_io import get_filesystem

from flash import DataSource
from flash.core.data.auto_dataset import AutoDataset, IterableAutoDataset
from flash.core.data.data_source import DefaultDataKeys, has_len
from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE, _TEXT_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.stages import RunningStage

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader
Expand Down
2 changes: 1 addition & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import torchmetrics
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -51,6 +50,7 @@
from flash.core.serve import Composition
from flash.core.utilities.apply_func import get_callable_dict
from flash.core.utilities.imports import requires
from flash.core.utilities.stages import RunningStage


class ModuleWrapperBase:
Expand Down
2 changes: 1 addition & 1 deletion flash/core/serve/flash_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from typing import Any, Callable, Mapping

import torch
from pytorch_lightning.trainer.states import RunningStage

from flash.core.data.data_source import DefaultDataKeys
from flash.core.serve import expose, ModelComponent
from flash.core.serve.types.base import BaseType
from flash.core.utilities.stages import RunningStage


class FlashInputs(BaseType):
Expand Down
72 changes: 72 additions & 0 deletions flash/core/utilities/stages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from pytorch_lightning.utilities.enums import LightningEnum


class RunningStage(LightningEnum):
"""Enum for the current running stage.
This stage complements :class:`TrainerFn` by specifying the current running stage for each function.
More than one running stage value can be set while a :class:`TrainerFn` is running:
- ``TrainerFn.FITTING`` - ``RunningStage.{SANITY_CHECKING,TRAINING,VALIDATING}``
- ``TrainerFn.VALIDATING`` - ``RunningStage.VALIDATING``
- ``TrainerFn.TESTING`` - ``RunningStage.TESTING``
- ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING``
- ``TrainerFn.SERVING`` - ``RunningStage.SERVING``
- ``TrainerFn.TUNING`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}``
"""

TRAINING = "train"
SANITY_CHECKING = "sanity_check"
VALIDATING = "validate"
TESTING = "test"
PREDICTING = "predict"
SERVING = "serve"
TUNING = "tune"

@property
def evaluating(self) -> bool:
return self in (self.VALIDATING, self.TESTING)

@property
def dataloader_prefix(self) -> Optional[str]:
if self in (self.SANITY_CHECKING, self.TUNING):
return None
if self == self.VALIDATING:
return "val"
return self.value


_STAGES_PREFIX = {
RunningStage.TRAINING: "train",
RunningStage.TESTING: "test",
RunningStage.VALIDATING: "val",
RunningStage.PREDICTING: "predict",
RunningStage.SERVING: "serve",
}

_STAGES_PREFIX_VALUES = {"train", "test", "val", "predict", "serve"}

_RUNNING_STAGE_MAPPING = {
RunningStage.TRAINING: RunningStage.TRAINING,
RunningStage.SANITY_CHECKING: RunningStage.VALIDATING,
RunningStage.VALIDATING: RunningStage.VALIDATING,
RunningStage.TESTING: RunningStage.TESTING,
RunningStage.PREDICTING: RunningStage.PREDICTING,
RunningStage.SERVING: RunningStage.SERVING,
RunningStage.TUNING: RunningStage.TUNING,
}
Loading

0 comments on commit 4f608b0

Please sign in to comment.