Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add from_dict for regression as well, tests added
Browse files Browse the repository at this point in the history
  • Loading branch information
krshrimali committed May 9, 2022
1 parent e5699aa commit a42c9de
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 36 deletions.
28 changes: 14 additions & 14 deletions flash/tabular/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,10 @@ def from_dict(
numerical_fields: Optional[Union[str, List[str]]] = None,
target_fields: Optional[Union[str, List[str]]] = None,
parameters: Optional[Dict[str, Any]] = None,
train_data: Optional[DataFrame] = None,
val_data: Optional[DataFrame] = None,
test_data: Optional[DataFrame] = None,
predict_data: Optional[DataFrame] = None,
train_dict: Optional[Dict[str, List[Any]]] = None,
val_dict: Optional[Dict[str, List[Any]]] = None,
test_dict: Optional[Dict[str, List[Any]]] = None,
predict_dict: Optional[Dict[str, List[Any]]] = None,
target_formatter: Optional[TargetFormatter] = None,
input_cls: Type[Input] = TabularClassificationDictInput,
transform: INPUT_TRANSFORM_TYPE = InputTransform,
Expand All @@ -353,10 +353,10 @@ def from_dict(
target_fields: The field (column name) or list of fields in the dictionary containing the targets.
parameters: Parameters to use if ``categorical_fields``, ``numerical_fields``, and ``target_fields`` are not
provided (e.g. when loading data for inference or validation).
train_data: The data to use when training.
val_data: The data to use when validating.
test_data: The data to use when testing.
predict_data: The data to use when predicting.
train_dict: The data to use when training.
val_dict: The data to use when validating.
test_dict: The data to use when testing.
predict_dict: The data to use when predicting.
target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to
control how targets are handled. See :ref:`formatting_classification_targets` for more details.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
Expand Down Expand Up @@ -393,8 +393,8 @@ def from_dict(
... "friendly",
... "weight",
... "animal",
... train_data=train_data,
... predict_data=predict_data,
... train_dict=train_data,
... predict_dict=predict_data,
... batch_size=4,
... )
>>> datamodule.num_classes
Expand All @@ -421,15 +421,15 @@ def from_dict(
parameters=parameters,
)

train_input = input_cls(RunningStage.TRAINING, train_data, **ds_kw)
train_input = input_cls(RunningStage.TRAINING, train_dict, **ds_kw)
ds_kw["parameters"] = train_input.parameters if train_input else parameters
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
train_input,
input_cls(RunningStage.VALIDATING, val_data, **ds_kw),
input_cls(RunningStage.TESTING, test_data, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_data, **ds_kw),
input_cls(RunningStage.VALIDATING, val_dict, **ds_kw),
input_cls(RunningStage.TESTING, test_dict, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_dict, **ds_kw),
transform=transform,
transform_kwargs=transform_kwargs,
**data_module_kwargs,
Expand Down
7 changes: 1 addition & 6 deletions flash/tabular/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,5 @@ def load_data(
data_frame = DataFrame.from_dict(data)

return super().load_data(
data=data_frame,
categorical_fields=categorical_fields,
numerical_fields=numerical_fields,
target_fields=target_fields,
parameters=parameters,
target_formatter=target_formatter,
data_frame, categorical_fields, numerical_fields, target_fields, parameters, target_formatter
)
116 changes: 115 additions & 1 deletion flash/tabular/regression/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING
from flash.core.utilities.stages import RunningStage
from flash.tabular.data import TabularData
from flash.tabular.regression.input import TabularRegressionCSVInput, TabularRegressionDataFrameInput
from flash.tabular.regression.input import (
TabularRegressionCSVInput,
TabularRegressionDataFrameInput,
TabularRegressionDictInput,
)

if _PANDAS_AVAILABLE:
from pandas.core.frame import DataFrame
Expand Down Expand Up @@ -289,3 +293,113 @@ def from_csv(
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)

@classmethod
def from_dict(
cls,
categorical_fields: Optional[Union[str, List[str]]] = None,
numerical_fields: Optional[Union[str, List[str]]] = None,
target_field: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None,
train_dict: Optional[Dict[str, List[Any]]] = None,
val_dict: Optional[Dict[str, List[Any]]] = None,
test_dict: Optional[Dict[str, List[Any]]] = None,
predict_dict: Optional[Dict[str, List[Any]]] = None,
input_cls: Type[Input] = TabularRegressionDictInput,
transform: INPUT_TRANSFORM_TYPE = InputTransform,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "TabularRegressionData":
"""Creates a :class:`~flash.tabular.regression.data.TabularRegressionData` object from the given
dictionary.
.. note::
The ``categorical_fields``, ``numerical_fields``, and ``target_field`` do not need to be provided if
``parameters`` are passed instead. These can be obtained from the
:attr:`~flash.tabular.data.TabularData.parameters` attribute of the
:class:`~flash.tabular.data.TabularData` object that contains your training data.
The targets will be extracted from the ``target_field`` in the data frames.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
categorical_fields: The fields (column names) in the dictionary containing categorical data.
numerical_fields: The fields (column names) in the dictionary containing numerical data.
target_field: The field (column name) in the dictionary containing the targets.
parameters: Parameters to use if ``categorical_fields``, ``numerical_fields``, and ``target_field`` are not
provided (e.g. when loading data for inference or validation).
train_dict: The dictionary to use when training.
val_dict: The dictionary to use when validating.
test_dict: The dictionary to use when testing.
predict_dict: The dictionary to use when predicting.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.tabular.regression.data.TabularRegressionData`.
Examples
________
.. testsetup::
>>> train_data = {
... "age": [2, 4, 1],
... "animal": ["cat", "dog", "cat"],
... "weight": [6, 10, 5],
... }
>>> predict_data = {
... "animal": ["dog", "dog", "cat"],
... "weight": [7, 12, 5],
... }
We have dictionaries ``train_data`` and ``predict_data``.
.. doctest::
>>> from flash import Trainer
>>> from flash.tabular import TabularRegressor, TabularRegressionData
>>> datamodule = TabularRegressionData.from_dict(
... "animal",
... "weight",
... "age",
... train_dict=train_data,
... predict_dict=predict_data,
... batch_size=4,
... )
>>> model = TabularRegressor.from_data(datamodule, backbone="tabnet")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
.. testcleanup::
>>> del train_data
>>> del predict_data
"""
ds_kw = dict(
categorical_fields=categorical_fields,
numerical_fields=numerical_fields,
target_field=target_field,
parameters=parameters,
)

train_input = input_cls(RunningStage.TRAINING, train_dict, **ds_kw)
ds_kw["parameters"] = train_input.parameters if train_input else parameters

return cls(
train_input,
input_cls(RunningStage.VALIDATING, val_dict, **ds_kw),
input_cls(RunningStage.TESTING, test_dict, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_dict, **ds_kw),
transform=transform,
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)
14 changes: 14 additions & 0 deletions flash/tabular/regression/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,17 @@ def load_data(
):
if file is not None:
return super().load_data(read_csv(file), categorical_fields, numerical_fields, target_field, parameters)


class TabularRegressionDictInput(TabularRegressionDataFrameInput):
def load_data(
self,
data: Dict[str, List[Any]],
categorical_fields: Optional[Union[str, List[str]]] = None,
numerical_fields: Optional[Union[str, List[str]]] = None,
target_field: Optional[str] = None,
parameters: Dict[str, Any] = None,
):
data_frame = DataFrame.from_dict(data)

return super().load_data(data_frame, categorical_fields, numerical_fields, target_field, parameters)
6 changes: 3 additions & 3 deletions tests/tabular/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def test_from_dict():
categorical_fields=["category"],
numerical_fields=["scalar_a", "scalar_b"],
target_fields="label",
train_data=TEST_DICT_1,
val_data=TEST_DICT_2,
test_data=TEST_DICT_2,
train_dict=TEST_DICT_1,
val_dict=TEST_DICT_2,
test_dict=TEST_DICT_2,
num_workers=0,
batch_size=1,
)
Expand Down
54 changes: 42 additions & 12 deletions tests/tabular/regression/test_data_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
if _TABULAR_AVAILABLE:
import pandas as pd

TEST_DF_1 = pd.DataFrame(
data={
"category": ["a", "b", "c", "a", None, "c"],
"scalar_a": [0.0, 1.0, 2.0, 3.0, None, 5.0],
"scalar_b": [5.0, 4.0, 3.0, 2.0, None, 1.0],
"label": [0.0, 1.0, 2.0, 1.0, 0.0, 1.0],
}
)
TEST_DICT = {
"category": ["a", "b", "c", "a", None, "c"],
"scalar_a": [0.0, 1.0, 2.0, 3.0, None, 5.0],
"scalar_b": [5.0, 4.0, 3.0, 2.0, None, 1.0],
"label": [0.0, 1.0, 2.0, 1.0, 0.0, 1.0],
}

TEST_DF = pd.DataFrame(data=TEST_DICT)


@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.")
Expand All @@ -45,10 +45,10 @@
("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}),
],
)
def test_regression(backbone, fields, tmpdir):
train_data_frame = TEST_DF_1.copy()
val_data_frame = TEST_DF_1.copy()
test_data_frame = TEST_DF_1.copy()
def test_regression_data_frame(backbone, fields, tmpdir):
train_data_frame = TEST_DF.copy()
val_data_frame = TEST_DF.copy()
test_data_frame = TEST_DF.copy()
data = TabularRegressionData.from_data_frame(
**fields,
target_field="label",
Expand All @@ -61,3 +61,33 @@ def test_regression(backbone, fields, tmpdir):
model = TabularRegressor.from_data(datamodule=data, backbone=backbone)
trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model, data)


@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.")
@pytest.mark.parametrize(
"backbone,fields",
[
("tabnet", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("tabtransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# No categorical / numerical fields
("tabnet", {"categorical_fields": ["category"]}),
("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}),
],
)
def test_regression_dict(backbone, fields, tmpdir):
data = TabularRegressionData.from_dict(
**fields,
target_field="label",
train_dict=TEST_DICT,
val_dict=TEST_DICT,
test_dict=TEST_DICT,
num_workers=0,
batch_size=2,
)
model = TabularRegressor.from_data(datamodule=data, backbone=backbone)
trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model, data)

0 comments on commit a42c9de

Please sign in to comment.