Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Feature/integration pytorch tabular #1098

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
c50e549
start integration with pytorch tabular
Nov 21, 2021
0b63d7e
adding adapter and backbones
Nov 22, 2021
60683e3
add index numeric and categorical columns and modify classification i…
Dec 30, 2021
215ebe1
first integration of pytorch tabular and flash using adapter
Jan 1, 2022
5797bce
dynamic create correct model class reading from config
Jan 1, 2022
a9e34b8
update backbone in order to manage different models and fix test for …
Jan 1, 2022
2c6abb0
add tabular regression integration with pytorch-tabular
Jan 1, 2022
acdc085
modify integration using AdapterTask
Jan 2, 2022
80cf104
update name parameter backbone_kwargs
Jan 2, 2022
56937a8
remove idx_cat and idx_num
Jan 2, 2022
365265a
add forward method in adapater and fix test for serving
Jan 3, 2022
f7bbda9
Merge branch 'master' into feature/integration-pytorch-tabular
Jan 3, 2022
6b8db80
add possibility to create model from datamodule
Jan 4, 2022
097d285
comment test_cli untile it will works
Jan 4, 2022
5268c53
update tabular models to work with cli
Jan 5, 2022
2a9237c
fix TabularRegressor and add test and cli
Jan 5, 2022
915c0b0
update docstring
Jan 5, 2022
65ed962
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2022
a87a472
fix flake8 problems
Jan 5, 2022
2f006d4
Merge remote-tracking branch 'origin/feature/integration-pytorch-tabu…
Jan 5, 2022
92bd2be
fix flake8 problems
Jan 5, 2022
8155716
remove useless comment and in model use only num_features instead usi…
Jan 6, 2022
6853fb8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2022
7ceda7b
refactoring tabular models in order to pass loss, metrics, optimizer …
Jan 6, 2022
3d02a8d
Merge remote-tracking branch 'origin/feature/integration-pytorch-tabu…
Jan 6, 2022
d9f60b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2022
6bdc514
fix flake8 problems
Jan 6, 2022
9bc516e
Merge remote-tracking branch 'origin/feature/integration-pytorch-tabu…
Jan 6, 2022
9eb1f1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2022
4115462
Merge branch 'master' into feature/integration-pytorch-tabular
ethanwharris Jan 7, 2022
23a5994
Merge branch 'master' into feature/integration-pytorch-tabular
ethanwharris Jan 7, 2022
808dcdf
Some fixes
ethanwharris Jan 7, 2022
942a72a
Move omegaconf import
ethanwharris Jan 7, 2022
811b23b
Fixes
ethanwharris Jan 10, 2022
a399ca0
Try fix
ethanwharris Jan 10, 2022
02aac44
Try fix
ethanwharris Jan 10, 2022
2d1d75c
Updates
ethanwharris Jan 10, 2022
af16e02
Updates
ethanwharris Jan 10, 2022
b11d838
Fixes
ethanwharris Jan 10, 2022
68f485e
Try something
ethanwharris Jan 10, 2022
707568d
Updates
ethanwharris Jan 10, 2022
615aa01
Fixes
ethanwharris Jan 10, 2022
e37d5b8
Fixes
ethanwharris Jan 10, 2022
9d87f6f
Fixes
ethanwharris Jan 10, 2022
ff6c2cc
Try fix
ethanwharris Jan 10, 2022
777985c
Fixes
ethanwharris Jan 10, 2022
6ea1eb4
Fixes
ethanwharris Jan 10, 2022
9f07ef4
Fix typing
ethanwharris Jan 10, 2022
07be49e
Copyright
ethanwharris Jan 10, 2022
a155609
Update CHANGELOG.md
ethanwharris Jan 10, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for multi-label, space delimited, targets ([#1076](https://github.com/PyTorchLightning/lightning-flash/pull/1076))

- Added support for tabular classification / regression backbones from PyTorch Tabular ([#1098](https://github.com/PyTorchLightning/lightning-flash/pull/1098))

- Added Flash zero support for tabular regression ([#1098](https://github.com/PyTorchLightning/lightning-flash/pull/1098))

### Changed

- Changed `Wav2Vec2Processor` to `AutoProcessor` and seperate it from backbone [optional] ([#1075](https://github.com/PyTorchLightning/lightning-flash/pull/1075))
Expand Down
20 changes: 20 additions & 0 deletions docs/source/general/finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ The freeze strategy keeps the backbone frozen throughout.

trainer.finetune(model, datamodule, strategy="freeze")

.. testoutput:: strategies
:hide:

...

The pseudocode looks like:

.. code-block:: python
Expand Down Expand Up @@ -135,6 +140,11 @@ For example, to unfreeze after epoch 7:

trainer.finetune(model, datamodule, strategy=("freeze_unfreeze", 7))

.. testoutput:: strategies
:hide:

...

Under the hood, the pseudocode looks like:

.. code-block:: python
Expand Down Expand Up @@ -169,6 +179,11 @@ Here's an example where:

trainer.finetune(model, datamodule, strategy=("unfreeze_milestones", ((3, 8), 2)))

.. testoutput:: strategies
:hide:

...

Under the hood, the pseudocode looks like:

.. code-block:: python
Expand Down Expand Up @@ -220,3 +235,8 @@ For even more customization, create your own finetuning callback. Learn more abo

# Pass the callback to trainer.finetune
trainer.finetune(model, datamodule, strategy=FeatureExtractorFreezeUnfreeze(unfreeze_epoch=5))

.. testoutput:: strategies
:hide:

...
1 change: 1 addition & 0 deletions flash/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def wrapper(cli_args):
"flash.pointcloud.detection",
"flash.pointcloud.segmentation",
"flash.tabular.classification",
"flash.tabular.regression",
"flash.tabular.forecasting",
"flash.text.classification",
"flash.text.question_answering",
Expand Down
9 changes: 5 additions & 4 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union

import torch
import torch.nn.functional as F
Expand All @@ -25,13 +25,14 @@
from flash.core.model import Task
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires

Classification, Classifications = None, None
if _FIFTYONE_AVAILABLE:
fol = lazy_import("fiftyone.core.labels")
if TYPE_CHECKING:
from fiftyone.core.labels import Classification, Classifications
Classification = "fiftyone.core.labels.Classification"
Classifications = "fiftyone.core.labels.Classifications"
else:
fol = None
Classification = None
Classifications = None


def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Expand Down
8 changes: 1 addition & 7 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union

import numpy as np
import pytorch_lightning as pl
Expand All @@ -33,14 +33,8 @@
from flash.core.data.splits import SplitDataset
from flash.core.data.utils import _STAGES_PREFIX
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE
from flash.core.utilities.stages import RunningStage

if _FIFTYONE_AVAILABLE and TYPE_CHECKING:
pass
else:
SampleCollection = None


class DatasetInput(Input):
"""The ``DatasetInput`` implements default behaviours for data sources which expect the input to
Expand Down
12 changes: 7 additions & 5 deletions flash/core/integrations/fiftyone/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from itertools import chain
from typing import Dict, List, Optional, Type, TYPE_CHECKING, Union
from typing import Dict, List, Optional, Type, Union

import flash
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires

Label, Session, SampleCollection = object, object, object
if _FIFTYONE_AVAILABLE:
fo = lazy_import("fiftyone")
fol = lazy_import("fiftyone.core.labels")
if TYPE_CHECKING:
from fiftyone import Label, Session
from fiftyone.core.collections import SampleCollection
Label = "fiftyone.Label"
Session = "fiftyone.Session"
SampleCollection = "fiftyone.core.collections.SampleCollection"
else:
fo = None
fol = None
Label = object
Session = object
SampleCollection = object


@requires("fiftyone")
Expand Down
Empty file.
85 changes: 85 additions & 0 deletions flash/core/integrations/pytorch_tabular/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Union

import torchmetrics

from flash.core.adapter import Adapter
from flash.core.data.io.input import DataKeys
from flash.core.model import Task


class PytorchTabularAdapter(Adapter):
def __init__(self, backbone):
super().__init__()

self.backbone = backbone

@classmethod
def from_task(
cls,
task: Task,
task_type,
embedding_sizes: list,
categorical_fields: list,
cat_dims: list,
num_features: int,
output_dim: int,
backbone: str,
loss_fn: Optional[Callable],
metrics: Optional[Union[torchmetrics.Metric, List[torchmetrics.Metric]]],
backbone_kwargs: Optional[Dict[str, Any]] = None,
) -> Adapter:

backbone_kwargs = backbone_kwargs or {}
parameters = {
"embedding_dims": embedding_sizes,
"categorical_cols": categorical_fields,
"categorical_cardinality": cat_dims,
"categorical_dim": len(categorical_fields),
"continuous_dim": num_features - len(categorical_fields),
"output_dim": output_dim,
}
adapter = cls(
task.backbones.get(backbone)(
task_type=task_type, parameters=parameters, loss_fn=loss_fn, metrics=metrics, **backbone_kwargs
)
)

return adapter

@staticmethod
def convert_batch(batch):
new_batch = {
"continuous": batch[DataKeys.INPUT][1],
"categorical": batch[DataKeys.INPUT][0],
}
if DataKeys.TARGET in batch:
new_batch["target"] = batch[DataKeys.TARGET].reshape(-1, 1)
return new_batch

def training_step(self, batch, batch_idx) -> Any:
return self.backbone.training_step(self.convert_batch(batch), batch_idx)

def validation_step(self, batch, batch_idx):
return self.backbone.validation_step(self.convert_batch(batch), batch_idx)

def test_step(self, batch, batch_idx):
return self.backbone.validation_step(self.convert_batch(batch), batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self(self.convert_batch(batch))

def forward(self, batch: Any) -> Any:
return self.backbone(batch)["logits"]
96 changes: 96 additions & 0 deletions flash/core/integrations/pytorch_tabular/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
from typing import Callable, List, Optional, Union

import torchmetrics

from flash.core.integrations.pytorch_tabular.adapter import PytorchTabularAdapter
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _PYTORCHTABULAR_AVAILABLE
from flash.core.utilities.providers import _PYTORCH_TABULAR

if _PYTORCHTABULAR_AVAILABLE:
import pytorch_tabular.models as models
from omegaconf import DictConfig, OmegaConf
from pytorch_tabular.config import ModelConfig
from pytorch_tabular.models import (
AutoIntConfig,
CategoryEmbeddingModelConfig,
FTTransformerConfig,
NodeConfig,
TabNetModelConfig,
TabTransformerConfig,
)


PYTORCH_TABULAR_BACKBONES = FlashRegistry("backbones")


if _PYTORCHTABULAR_AVAILABLE:

def _read_parse_config(config, cls):
if isinstance(config, str):
if os.path.exists(config):
_config = OmegaConf.load(config)
if cls == ModelConfig:
cls = getattr(getattr(models, _config._module_src), _config._config_name)
config = cls(
**{
k: v
for k, v in _config.items()
if (k in cls.__dataclass_fields__.keys()) and (cls.__dataclass_fields__[k].init)
}
)
else:
raise ValueError(f"{config} is not a valid path")
config = OmegaConf.structured(config)
return config

def load_pytorch_tabular(
model_config_class,
task_type,
parameters: DictConfig,
loss_fn: Callable,
metrics: Optional[Union[torchmetrics.Metric, List[torchmetrics.Metric]]],
**model_kwargs,
):
model_config = model_config_class(task=task_type, embedding_dims=parameters["embedding_dims"], **model_kwargs)
model_config = _read_parse_config(model_config, ModelConfig)
model_callable = getattr(getattr(models, model_config._module_src), model_config._model_name)
config = OmegaConf.merge(
OmegaConf.create(parameters),
OmegaConf.to_container(model_config),
)
model = model_callable(config=config, custom_loss=loss_fn, custom_metrics=metrics)
return model

for model_config_class, name in zip(
[
TabNetModelConfig,
TabTransformerConfig,
FTTransformerConfig,
AutoIntConfig,
NodeConfig,
CategoryEmbeddingModelConfig,
],
["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"],
):
PYTORCH_TABULAR_BACKBONES(
functools.partial(load_pytorch_tabular, model_config_class),
name=name,
providers=_PYTORCH_TABULAR,
adapter=PytorchTabularAdapter,
)
2 changes: 2 additions & 0 deletions flash/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ def get(
if len(external_matches) == 1:
return external_matches[0]

if len(matches) == 0 and len(external_matches) == 0:
raise KeyError("No matches found in registry.")
raise KeyError("Multiple matches from external registries, a strict lookup is not possible.")

def remove(self, key: str) -> None:
Expand Down
21 changes: 21 additions & 0 deletions flash/core/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.nn.functional as F
import torchmetrics

from flash.core.adapter import AdapterTask
from flash.core.model import Task
from flash.core.utilities.types import OUTPUT_TYPE

Expand Down Expand Up @@ -55,3 +56,23 @@ def __init__(
output=output,
**kwargs,
)


class RegressionAdapterTask(AdapterTask, RegressionMixin):
def __init__(
self,
*args,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
output: OUTPUT_TYPE = None,
**kwargs,
) -> None:
metrics, loss_fn = RegressionMixin._build(loss_fn, metrics)

super().__init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
output=output,
**kwargs,
)
Loading