Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add TabularForecaster task based on PyTorch Forecasting #647

Merged
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
2716347
Revert "Added TabularRegressionData extending TabularData (#574)"
sumanmichael Jul 15, 2021
a34be7d
added DataModule, PreProcess, DataSource for TabularForecasting
sumanmichael Jul 15, 2021
42aa6ce
added TABULAR_FORECASTING_BACKBONES
sumanmichael Jul 16, 2021
00b43aa
[WIP] added model.py in tabular forecasting
sumanmichael Aug 10, 2021
20eb7ec
Merge branch 'master' into feature/pytorch_forecasting
ethanwharris Aug 10, 2021
c3c4282
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2021
75cc620
Updates
ethanwharris Aug 10, 2021
eec7fab
Merge branch 'feature/pytorch_forecasting' of https://github.com/suma…
ethanwharris Aug 10, 2021
5c554d4
Updates
ethanwharris Aug 10, 2021
3f02252
Try fix
ethanwharris Aug 10, 2021
f6ac528
Updates
ethanwharris Aug 10, 2021
3db1966
Rename to TabularClassificationData
ethanwharris Aug 10, 2021
f2a8cc1
Updates
ethanwharris Aug 10, 2021
e72d441
Fix embedding sizes
ethanwharris Aug 10, 2021
739d9a8
Fixes and add example
ethanwharris Aug 10, 2021
353f1f9
Merge branch 'master' into feature/pytorch_forecasting
ethanwharris Aug 23, 2021
672bf9e
Merge branch 'master' into feature/pytorch_forecasting
ethanwharris Sep 20, 2021
a3aafd0
Updates
ethanwharris Sep 20, 2021
157ef3f
Switch to an adapter
ethanwharris Sep 22, 2021
3872b34
Merge branch 'master' into feature/pytorch_forecasting
ethanwharris Sep 22, 2021
e7bca8e
Small fixes
ethanwharris Sep 23, 2021
3bad7a8
Merge branch 'feature/pytorch_forecasting' of https://github.com/suma…
ethanwharris Sep 23, 2021
fc10908
Merge branch 'master' into feature/pytorch_forecasting
ethanwharris Oct 28, 2021
86f3bf9
Add inference error
ethanwharris Oct 28, 2021
c7967ca
Add inference and refactor
ethanwharris Oct 28, 2021
7fb852f
Add interpertation example
ethanwharris Oct 28, 2021
3a9c9ab
Fix broken tests
ethanwharris Oct 28, 2021
69c4467
Merge branch 'master' into feature/pytorch_forecasting
ethanwharris Oct 28, 2021
b7846a3
Small fixes and add some tests
ethanwharris Oct 29, 2021
b7756b0
Updates
ethanwharris Oct 29, 2021
6313ffe
Update CHANGELOG.md
ethanwharris Oct 29, 2021
8976a90
Add provider
ethanwharris Oct 29, 2021
fb4a598
Update flash/core/integrations/pytorch_forecasting/adapter.py
ethanwharris Oct 29, 2021
9c213d6
Update flash/core/integrations/pytorch_forecasting/adapter.py
ethanwharris Oct 29, 2021
3cbd13e
Update on comments
ethanwharris Oct 29, 2021
c71bff7
Merge branch 'master' into feature/pytorch_forecasting
ethanwharris Oct 29, 2021
d005bb1
Merge branch 'master' into feature/pytorch_forecasting
ethanwharris Oct 29, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added a `TabularForecaster` task based on PyTorch Forecasting ([#647](https://github.com/PyTorchLightning/lightning-flash/pull/647))

### Changed

### Fixed
Expand Down
19 changes: 3 additions & 16 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import torch
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from torch.utils.data import DataLoader, IterableDataset
Expand All @@ -42,29 +41,19 @@ class DataPipelineState:

def __init__(self):
self._state: Dict[Type[ProcessState], ProcessState] = {}
self._initialized = False

def set_state(self, state: ProcessState):
"""Add the given :class:`.ProcessState` to the :class:`.DataPipelineState`."""

if not self._initialized:
self._state[type(state)] = state
else:
rank_zero_warn(
f"Attempted to add a state ({state}) after the data pipeline has already been initialized. This will"
" only have an effect when a new data pipeline is created.",
UserWarning,
)
self._state[type(state)] = state

def get_state(self, state_type: Type[ProcessState]) -> Optional[ProcessState]:
"""Get the :class:`.ProcessState` of the given type from the :class:`.DataPipelineState`."""

if state_type in self._state:
return self._state[state_type]
return None
return self._state.get(state_type, None)

def __str__(self) -> str:
return f"{self.__class__.__name__}(initialized={self._initialized}, state={self._state})"
return f"{self.__class__.__name__}(state={self._state})"


class DataPipeline:
Expand Down Expand Up @@ -113,13 +102,11 @@ def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) ->
:class:`.Postprocess`, and :class:`.Serializer`. Once this has been called, any attempt to add new state will
give a warning."""
data_pipeline_state = data_pipeline_state or DataPipelineState()
data_pipeline_state._initialized = False
if self.data_source is not None:
self.data_source.attach_data_pipeline_state(data_pipeline_state)
self._preprocess_pipeline.attach_data_pipeline_state(data_pipeline_state)
self._postprocess_pipeline.attach_data_pipeline_state(data_pipeline_state)
self._serializer.attach_data_pipeline_state(data_pipeline_state)
data_pipeline_state._initialized = True # TODO: Not sure we need this
return data_pipeline_state

@property
Expand Down
1 change: 1 addition & 0 deletions flash/core/integrations/pytorch_forecasting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.core.integrations.pytorch_forecasting.transforms import convert_predictions # noqa: F401
117 changes: 117 additions & 0 deletions flash/core/integrations/pytorch_forecasting/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import copy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torchmetrics

from flash import Task
from flash.core.adapter import Adapter
from flash.core.data.batch import default_uncollate
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.states import CollateFn
from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE

if _PANDAS_AVAILABLE:
from pandas import DataFrame
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

if _FORECASTING_AVAILABLE:
from pytorch_forecasting import TimeSeriesDataSet
else:
TimeSeriesDataSet = object


class PatchTimeSeriesDataSet(TimeSeriesDataSet):
"""Hack to prevent index construction or data validation / conversion when instantiating model.

This enables the ``TimeSeriesDataSet`` to be created from a single row of data.
"""

def _construct_index(self, data: DataFrame, predict_mode: bool) -> DataFrame:
return DataFrame()

def _data_to_tensors(self, data: DataFrame) -> Dict[str, torch.Tensor]:
return {}


class PyTorchForecastingAdapter(Adapter):
"""The ``PyTorchForecastingAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with PyTorch
Forecasting."""

def __init__(self, backbone):
super().__init__()

self.backbone = backbone

@staticmethod
def _collate_fn(collate_fn, samples):
samples = [(sample[DefaultDataKeys.INPUT], sample[DefaultDataKeys.TARGET]) for sample in samples]
batch = collate_fn(samples)
return {DefaultDataKeys.INPUT: batch[0], DefaultDataKeys.TARGET: batch[1]}

@classmethod
def from_task(
cls,
task: Task,
parameters: Dict[str, Any],
backbone: str,
loss_fn: Optional[Callable] = None,
metrics: Optional[Union[torchmetrics.Metric, List[torchmetrics.Metric]]] = None,
**backbone_kwargs,
) -> Adapter:
parameters = copy(parameters)
data = parameters.pop("data_sample")
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
time_series_dataset = PatchTimeSeriesDataSet.from_parameters(parameters, data)

backbone_kwargs["loss"] = loss_fn

if metrics is not None and not isinstance(metrics, list):
metrics = [metrics]
backbone_kwargs["logging_metrics"] = metrics

if not backbone_kwargs:
backbone_kwargs = {}
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

adapter = cls(task.backbones.get(backbone)(time_series_dataset=time_series_dataset, **backbone_kwargs))

# Attach the required collate function
adapter.set_state(CollateFn(partial(PyTorchForecastingAdapter._collate_fn, time_series_dataset._collate_fn)))

return adapter

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
return self.backbone.training_step(batch, batch_idx)

def validation_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
return self.backbone.validation_step(batch, batch_idx)

def test_step(self, batch: Any, batch_idx: int) -> None:
raise NotImplementedError(
"Backbones provided by PyTorch Forecasting don't support testing. Use validation instead."
)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
result = dict(self.backbone(batch[DefaultDataKeys.INPUT]))
result[DefaultDataKeys.INPUT] = default_uncollate(batch[DefaultDataKeys.INPUT])
return default_uncollate(result)

def training_epoch_end(self, outputs) -> None:
self.backbone.training_epoch_end(outputs)

def validation_epoch_end(self, outputs) -> None:
self.backbone.validation_epoch_end(outputs)
49 changes: 49 additions & 0 deletions flash/core/integrations/pytorch_forecasting/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools

from flash.core.integrations.pytorch_forecasting.adapter import PyTorchForecastingAdapter
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FORECASTING_AVAILABLE
from flash.core.utilities.providers import _PYTORCH_FORECASTING

if _FORECASTING_AVAILABLE:
from pytorch_forecasting import (
DecoderMLP,
DeepAR,
NBeats,
RecurrentNetwork,
TemporalFusionTransformer,
TimeSeriesDataSet,
)


PYTORCH_FORECASTING_BACKBONES = FlashRegistry("backbones")


if _FORECASTING_AVAILABLE:

def load_torch_forecasting(model, time_series_dataset: TimeSeriesDataSet, **kwargs):
return model.from_dataset(time_series_dataset, **kwargs)

for model, name in zip(
[TemporalFusionTransformer, NBeats, RecurrentNetwork, DeepAR, DecoderMLP],
["temporal_fusion_transformer", "n_beats", "recurrent_network", "deep_ar", "decoder_mlp"],
):
PYTORCH_FORECASTING_BACKBONES(
functools.partial(load_torch_forecasting, model),
name=name,
providers=_PYTORCH_FORECASTING,
adapter=PyTorchForecastingAdapter,
)
30 changes: 30 additions & 0 deletions flash/core/integrations/pytorch_forecasting/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Tuple

from torch.utils.data._utils.collate import default_collate

from flash.core.data.data_source import DefaultDataKeys


def convert_predictions(predictions: List[Dict[str, Any]]) -> Tuple[Dict[str, Any], List]:
# Flatten list if batches were used
if all(isinstance(fl, list) for fl in predictions):
unrolled_predictions = []
for prediction_batch in predictions:
unrolled_predictions.extend(prediction_batch)
predictions = unrolled_predictions
result = default_collate(predictions)
inputs = result.pop(DefaultDataKeys.INPUT)
return result, inputs
3 changes: 2 additions & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _compare_version(package: str, op, version) -> bool:
_PANDAS_AVAILABLE = _module_available("pandas")
_SKLEARN_AVAILABLE = _module_available("sklearn")
_TABNET_AVAILABLE = _module_available("pytorch_tabnet")
_FORECASTING_AVAILABLE = _module_available("pytorch_forecasting")
_KORNIA_AVAILABLE = _module_available("kornia")
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
Expand Down Expand Up @@ -126,7 +127,7 @@ class Image:
_DATASETS_AVAILABLE,
]
)
_TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE
_TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE and _FORECASTING_AVAILABLE
_VIDEO_AVAILABLE = _TORCHVISION_AVAILABLE and _PIL_AVAILABLE and _PYTORCHVIDEO_AVAILABLE and _KORNIA_AVAILABLE
_IMAGE_AVAILABLE = all(
[
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ def __str__(self):
_OPEN3D_ML = Provider("Intelligent Systems Lab Org/Open3D-ML", "https://github.com/isl-org/Open3D-ML")
_PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo")
_VISSL = Provider("Facebook Research/vissl", "https://github.com/facebookresearch/vissl")
_PYTORCH_FORECASTING = Provider("jdb78/PyTorch-Forecasting", "https://github.com/jdb78/pytorch-forecasting")
5 changes: 5 additions & 0 deletions flash/tabular/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from flash.tabular.classification import TabularClassificationData, TabularClassifier # noqa: F401
from flash.tabular.data import TabularData # noqa: F401
from flash.tabular.forecasting.data import ( # noqa: F401
TabularForecastingData,
TabularForecastingDataFrameDataSource,
TabularForecastingPreprocess,
)
from flash.tabular.regression import TabularRegressionData # noqa: F401
3 changes: 2 additions & 1 deletion flash/tabular/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from flash.core.data.utils import download_data
from flash.core.utilities.flash_cli import FlashCLI
from flash.tabular import TabularClassificationData, TabularClassifier
from flash.tabular.classification.data import TabularClassificationData
from flash.tabular.classification.model import TabularClassifier

__all__ = ["tabular_classification"]

Expand Down
2 changes: 2 additions & 0 deletions flash/tabular/forecasting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flash.tabular.forecasting.data import TabularForecastingData # noqa: F401
from flash.tabular.forecasting.model import TabularForecaster # noqa: F401
Loading