Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Feature/integration pytorch tabular (#1098)
Browse files Browse the repository at this point in the history
Co-authored-by: Luca Actis Grosso <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
4 people authored Jan 10, 2022
1 parent 21a9d7f commit b208689
Show file tree
Hide file tree
Showing 46 changed files with 851 additions and 292 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for multi-label, space delimited, targets ([#1076](https://github.com/PyTorchLightning/lightning-flash/pull/1076))

- Added support for tabular classification / regression backbones from PyTorch Tabular ([#1098](https://github.com/PyTorchLightning/lightning-flash/pull/1098))

- Added Flash zero support for tabular regression ([#1098](https://github.com/PyTorchLightning/lightning-flash/pull/1098))

### Changed

- Changed `Wav2Vec2Processor` to `AutoProcessor` and seperate it from backbone [optional] ([#1075](https://github.com/PyTorchLightning/lightning-flash/pull/1075))
Expand Down
20 changes: 20 additions & 0 deletions docs/source/general/finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ The freeze strategy keeps the backbone frozen throughout.

trainer.finetune(model, datamodule, strategy="freeze")

.. testoutput:: strategies
:hide:

...

The pseudocode looks like:

.. code-block:: python
Expand Down Expand Up @@ -135,6 +140,11 @@ For example, to unfreeze after epoch 7:

trainer.finetune(model, datamodule, strategy=("freeze_unfreeze", 7))

.. testoutput:: strategies
:hide:

...

Under the hood, the pseudocode looks like:

.. code-block:: python
Expand Down Expand Up @@ -169,6 +179,11 @@ Here's an example where:

trainer.finetune(model, datamodule, strategy=("unfreeze_milestones", ((3, 8), 2)))

.. testoutput:: strategies
:hide:

...

Under the hood, the pseudocode looks like:

.. code-block:: python
Expand Down Expand Up @@ -220,3 +235,8 @@ For even more customization, create your own finetuning callback. Learn more abo

# Pass the callback to trainer.finetune
trainer.finetune(model, datamodule, strategy=FeatureExtractorFreezeUnfreeze(unfreeze_epoch=5))

.. testoutput:: strategies
:hide:

...
1 change: 1 addition & 0 deletions flash/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def wrapper(cli_args):
"flash.pointcloud.detection",
"flash.pointcloud.segmentation",
"flash.tabular.classification",
"flash.tabular.regression",
"flash.tabular.forecasting",
"flash.text.classification",
"flash.text.question_answering",
Expand Down
9 changes: 5 additions & 4 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union

import torch
import torch.nn.functional as F
Expand All @@ -25,13 +25,14 @@
from flash.core.model import Task
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires

Classification, Classifications = None, None
if _FIFTYONE_AVAILABLE:
fol = lazy_import("fiftyone.core.labels")
if TYPE_CHECKING:
from fiftyone.core.labels import Classification, Classifications
Classification = "fiftyone.core.labels.Classification"
Classifications = "fiftyone.core.labels.Classifications"
else:
fol = None
Classification = None
Classifications = None


def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Expand Down
8 changes: 1 addition & 7 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union

import numpy as np
import pytorch_lightning as pl
Expand All @@ -33,14 +33,8 @@
from flash.core.data.splits import SplitDataset
from flash.core.data.utils import _STAGES_PREFIX
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE
from flash.core.utilities.stages import RunningStage

if _FIFTYONE_AVAILABLE and TYPE_CHECKING:
pass
else:
SampleCollection = None


class DatasetInput(Input):
"""The ``DatasetInput`` implements default behaviours for data sources which expect the input to
Expand Down
12 changes: 7 additions & 5 deletions flash/core/integrations/fiftyone/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from itertools import chain
from typing import Dict, List, Optional, Type, TYPE_CHECKING, Union
from typing import Dict, List, Optional, Type, Union

import flash
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires

Label, Session, SampleCollection = object, object, object
if _FIFTYONE_AVAILABLE:
fo = lazy_import("fiftyone")
fol = lazy_import("fiftyone.core.labels")
if TYPE_CHECKING:
from fiftyone import Label, Session
from fiftyone.core.collections import SampleCollection
Label = "fiftyone.Label"
Session = "fiftyone.Session"
SampleCollection = "fiftyone.core.collections.SampleCollection"
else:
fo = None
fol = None
Label = object
Session = object
SampleCollection = object


@requires("fiftyone")
Expand Down
Empty file.
85 changes: 85 additions & 0 deletions flash/core/integrations/pytorch_tabular/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Union

import torchmetrics

from flash.core.adapter import Adapter
from flash.core.data.io.input import DataKeys
from flash.core.model import Task


class PytorchTabularAdapter(Adapter):
def __init__(self, backbone):
super().__init__()

self.backbone = backbone

@classmethod
def from_task(
cls,
task: Task,
task_type,
embedding_sizes: list,
categorical_fields: list,
cat_dims: list,
num_features: int,
output_dim: int,
backbone: str,
loss_fn: Optional[Callable],
metrics: Optional[Union[torchmetrics.Metric, List[torchmetrics.Metric]]],
backbone_kwargs: Optional[Dict[str, Any]] = None,
) -> Adapter:

backbone_kwargs = backbone_kwargs or {}
parameters = {
"embedding_dims": embedding_sizes,
"categorical_cols": categorical_fields,
"categorical_cardinality": cat_dims,
"categorical_dim": len(categorical_fields),
"continuous_dim": num_features - len(categorical_fields),
"output_dim": output_dim,
}
adapter = cls(
task.backbones.get(backbone)(
task_type=task_type, parameters=parameters, loss_fn=loss_fn, metrics=metrics, **backbone_kwargs
)
)

return adapter

@staticmethod
def convert_batch(batch):
new_batch = {
"continuous": batch[DataKeys.INPUT][1],
"categorical": batch[DataKeys.INPUT][0],
}
if DataKeys.TARGET in batch:
new_batch["target"] = batch[DataKeys.TARGET].reshape(-1, 1)
return new_batch

def training_step(self, batch, batch_idx) -> Any:
return self.backbone.training_step(self.convert_batch(batch), batch_idx)

def validation_step(self, batch, batch_idx):
return self.backbone.validation_step(self.convert_batch(batch), batch_idx)

def test_step(self, batch, batch_idx):
return self.backbone.validation_step(self.convert_batch(batch), batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self(self.convert_batch(batch))

def forward(self, batch: Any) -> Any:
return self.backbone(batch)["logits"]
96 changes: 96 additions & 0 deletions flash/core/integrations/pytorch_tabular/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
from typing import Callable, List, Optional, Union

import torchmetrics

from flash.core.integrations.pytorch_tabular.adapter import PytorchTabularAdapter
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _PYTORCHTABULAR_AVAILABLE
from flash.core.utilities.providers import _PYTORCH_TABULAR

if _PYTORCHTABULAR_AVAILABLE:
import pytorch_tabular.models as models
from omegaconf import DictConfig, OmegaConf
from pytorch_tabular.config import ModelConfig
from pytorch_tabular.models import (
AutoIntConfig,
CategoryEmbeddingModelConfig,
FTTransformerConfig,
NodeConfig,
TabNetModelConfig,
TabTransformerConfig,
)


PYTORCH_TABULAR_BACKBONES = FlashRegistry("backbones")


if _PYTORCHTABULAR_AVAILABLE:

def _read_parse_config(config, cls):
if isinstance(config, str):
if os.path.exists(config):
_config = OmegaConf.load(config)
if cls == ModelConfig:
cls = getattr(getattr(models, _config._module_src), _config._config_name)
config = cls(
**{
k: v
for k, v in _config.items()
if (k in cls.__dataclass_fields__.keys()) and (cls.__dataclass_fields__[k].init)
}
)
else:
raise ValueError(f"{config} is not a valid path")
config = OmegaConf.structured(config)
return config

def load_pytorch_tabular(
model_config_class,
task_type,
parameters: DictConfig,
loss_fn: Callable,
metrics: Optional[Union[torchmetrics.Metric, List[torchmetrics.Metric]]],
**model_kwargs,
):
model_config = model_config_class(task=task_type, embedding_dims=parameters["embedding_dims"], **model_kwargs)
model_config = _read_parse_config(model_config, ModelConfig)
model_callable = getattr(getattr(models, model_config._module_src), model_config._model_name)
config = OmegaConf.merge(
OmegaConf.create(parameters),
OmegaConf.to_container(model_config),
)
model = model_callable(config=config, custom_loss=loss_fn, custom_metrics=metrics)
return model

for model_config_class, name in zip(
[
TabNetModelConfig,
TabTransformerConfig,
FTTransformerConfig,
AutoIntConfig,
NodeConfig,
CategoryEmbeddingModelConfig,
],
["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"],
):
PYTORCH_TABULAR_BACKBONES(
functools.partial(load_pytorch_tabular, model_config_class),
name=name,
providers=_PYTORCH_TABULAR,
adapter=PytorchTabularAdapter,
)
2 changes: 2 additions & 0 deletions flash/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ def get(
if len(external_matches) == 1:
return external_matches[0]

if len(matches) == 0 and len(external_matches) == 0:
raise KeyError("No matches found in registry.")
raise KeyError("Multiple matches from external registries, a strict lookup is not possible.")

def remove(self, key: str) -> None:
Expand Down
21 changes: 21 additions & 0 deletions flash/core/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.nn.functional as F
import torchmetrics

from flash.core.adapter import AdapterTask
from flash.core.model import Task
from flash.core.utilities.types import OUTPUT_TYPE

Expand Down Expand Up @@ -55,3 +56,23 @@ def __init__(
output=output,
**kwargs,
)


class RegressionAdapterTask(AdapterTask, RegressionMixin):
def __init__(
self,
*args,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
output: OUTPUT_TYPE = None,
**kwargs,
) -> None:
metrics, loss_fn = RegressionMixin._build(loss_fn, metrics)

super().__init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
output=output,
**kwargs,
)
Loading

0 comments on commit b208689

Please sign in to comment.