This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 211
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/integration pytorch tabular (#1098)
Co-authored-by: Luca Actis Grosso <[email protected]> Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: Ethan Harris <[email protected]>
- Loading branch information
1 parent
21a9d7f
commit b208689
Showing
46 changed files
with
851 additions
and
292 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, Dict, List, Optional, Union | ||
|
||
import torchmetrics | ||
|
||
from flash.core.adapter import Adapter | ||
from flash.core.data.io.input import DataKeys | ||
from flash.core.model import Task | ||
|
||
|
||
class PytorchTabularAdapter(Adapter): | ||
def __init__(self, backbone): | ||
super().__init__() | ||
|
||
self.backbone = backbone | ||
|
||
@classmethod | ||
def from_task( | ||
cls, | ||
task: Task, | ||
task_type, | ||
embedding_sizes: list, | ||
categorical_fields: list, | ||
cat_dims: list, | ||
num_features: int, | ||
output_dim: int, | ||
backbone: str, | ||
loss_fn: Optional[Callable], | ||
metrics: Optional[Union[torchmetrics.Metric, List[torchmetrics.Metric]]], | ||
backbone_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> Adapter: | ||
|
||
backbone_kwargs = backbone_kwargs or {} | ||
parameters = { | ||
"embedding_dims": embedding_sizes, | ||
"categorical_cols": categorical_fields, | ||
"categorical_cardinality": cat_dims, | ||
"categorical_dim": len(categorical_fields), | ||
"continuous_dim": num_features - len(categorical_fields), | ||
"output_dim": output_dim, | ||
} | ||
adapter = cls( | ||
task.backbones.get(backbone)( | ||
task_type=task_type, parameters=parameters, loss_fn=loss_fn, metrics=metrics, **backbone_kwargs | ||
) | ||
) | ||
|
||
return adapter | ||
|
||
@staticmethod | ||
def convert_batch(batch): | ||
new_batch = { | ||
"continuous": batch[DataKeys.INPUT][1], | ||
"categorical": batch[DataKeys.INPUT][0], | ||
} | ||
if DataKeys.TARGET in batch: | ||
new_batch["target"] = batch[DataKeys.TARGET].reshape(-1, 1) | ||
return new_batch | ||
|
||
def training_step(self, batch, batch_idx) -> Any: | ||
return self.backbone.training_step(self.convert_batch(batch), batch_idx) | ||
|
||
def validation_step(self, batch, batch_idx): | ||
return self.backbone.validation_step(self.convert_batch(batch), batch_idx) | ||
|
||
def test_step(self, batch, batch_idx): | ||
return self.backbone.validation_step(self.convert_batch(batch), batch_idx) | ||
|
||
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: | ||
return self(self.convert_batch(batch)) | ||
|
||
def forward(self, batch: Any) -> Any: | ||
return self.backbone(batch)["logits"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import functools | ||
import os | ||
from typing import Callable, List, Optional, Union | ||
|
||
import torchmetrics | ||
|
||
from flash.core.integrations.pytorch_tabular.adapter import PytorchTabularAdapter | ||
from flash.core.registry import FlashRegistry | ||
from flash.core.utilities.imports import _PYTORCHTABULAR_AVAILABLE | ||
from flash.core.utilities.providers import _PYTORCH_TABULAR | ||
|
||
if _PYTORCHTABULAR_AVAILABLE: | ||
import pytorch_tabular.models as models | ||
from omegaconf import DictConfig, OmegaConf | ||
from pytorch_tabular.config import ModelConfig | ||
from pytorch_tabular.models import ( | ||
AutoIntConfig, | ||
CategoryEmbeddingModelConfig, | ||
FTTransformerConfig, | ||
NodeConfig, | ||
TabNetModelConfig, | ||
TabTransformerConfig, | ||
) | ||
|
||
|
||
PYTORCH_TABULAR_BACKBONES = FlashRegistry("backbones") | ||
|
||
|
||
if _PYTORCHTABULAR_AVAILABLE: | ||
|
||
def _read_parse_config(config, cls): | ||
if isinstance(config, str): | ||
if os.path.exists(config): | ||
_config = OmegaConf.load(config) | ||
if cls == ModelConfig: | ||
cls = getattr(getattr(models, _config._module_src), _config._config_name) | ||
config = cls( | ||
**{ | ||
k: v | ||
for k, v in _config.items() | ||
if (k in cls.__dataclass_fields__.keys()) and (cls.__dataclass_fields__[k].init) | ||
} | ||
) | ||
else: | ||
raise ValueError(f"{config} is not a valid path") | ||
config = OmegaConf.structured(config) | ||
return config | ||
|
||
def load_pytorch_tabular( | ||
model_config_class, | ||
task_type, | ||
parameters: DictConfig, | ||
loss_fn: Callable, | ||
metrics: Optional[Union[torchmetrics.Metric, List[torchmetrics.Metric]]], | ||
**model_kwargs, | ||
): | ||
model_config = model_config_class(task=task_type, embedding_dims=parameters["embedding_dims"], **model_kwargs) | ||
model_config = _read_parse_config(model_config, ModelConfig) | ||
model_callable = getattr(getattr(models, model_config._module_src), model_config._model_name) | ||
config = OmegaConf.merge( | ||
OmegaConf.create(parameters), | ||
OmegaConf.to_container(model_config), | ||
) | ||
model = model_callable(config=config, custom_loss=loss_fn, custom_metrics=metrics) | ||
return model | ||
|
||
for model_config_class, name in zip( | ||
[ | ||
TabNetModelConfig, | ||
TabTransformerConfig, | ||
FTTransformerConfig, | ||
AutoIntConfig, | ||
NodeConfig, | ||
CategoryEmbeddingModelConfig, | ||
], | ||
["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"], | ||
): | ||
PYTORCH_TABULAR_BACKBONES( | ||
functools.partial(load_pytorch_tabular, model_config_class), | ||
name=name, | ||
providers=_PYTORCH_TABULAR, | ||
adapter=PytorchTabularAdapter, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.