Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Support from_dicts for Tabular Classification and Regression #1331

Merged
merged 17 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for `from_dict` for Tabular Classification ([#1331](https://github.com/PyTorchLightning/lightning-flash/pull/1331))

- Added support for `from_dict` for Tabular Regression ([#1331](https://github.com/PyTorchLightning/lightning-flash/pull/1331))

krshrimali marked this conversation as resolved.
Show resolved Hide resolved
- Added support for using the `ImageEmbedder` SSL training for all image classifier backbones ([#1264](https://github.com/PyTorchLightning/lightning-flash/pull/1264))

- Added support for audio file formats to `AudioClassificationData` ([#1085](https://github.com/PyTorchLightning/lightning-flash/pull/1085))
Expand Down
125 changes: 124 additions & 1 deletion flash/tabular/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from flash.core.data.utilities.classification import TargetFormatter
from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING
from flash.core.utilities.stages import RunningStage
from flash.tabular.classification.input import TabularClassificationCSVInput, TabularClassificationDataFrameInput
from flash.tabular.classification.input import (
TabularClassificationCSVInput,
TabularClassificationDataFrameInput,
TabularClassificationDictInput,
)
from flash.tabular.data import TabularData

if _PANDAS_AVAILABLE:
Expand Down Expand Up @@ -311,3 +315,122 @@ def from_csv(
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)

@classmethod
def from_dicts(
cls,
categorical_fields: Optional[Union[str, List[str]]] = None,
numerical_fields: Optional[Union[str, List[str]]] = None,
target_fields: Optional[Union[str, List[str]]] = None,
parameters: Optional[Dict[str, Any]] = None,
train_dict: Optional[Dict[str, List[Any]]] = None,
val_dict: Optional[Dict[str, List[Any]]] = None,
test_dict: Optional[Dict[str, List[Any]]] = None,
predict_dict: Optional[Dict[str, List[Any]]] = None,
target_formatter: Optional[TargetFormatter] = None,
input_cls: Type[Input] = TabularClassificationDictInput,
transform: INPUT_TRANSFORM_TYPE = InputTransform,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "TabularClassificationData":
"""Creates a :class:`~flash.tabular.classification.data.TabularClassificationData` object from the given
dictionary.

.. note::
The ``categorical_fields``, ``numerical_fields``, and ``target_fields`` do not need to be provided if
``parameters`` are passed instead. These can be obtained from the
:attr:`~flash.tabular.data.TabularData.parameters` attribute of the
:class:`~flash.tabular.data.TabularData` object that contains your training data.

The targets will be extracted from the ``target_fields`` in the dict and can be in any of our
:ref:`supported classification target formats <formatting_classification_targets>`.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.

Args:
categorical_fields: The fields (column names) in the dictionary containing categorical data.
numerical_fields: The fields (column names) in the dictionary containing numerical data.
target_fields: The field (column name) or list of fields in the dictionary containing the targets.
parameters: Parameters to use if ``categorical_fields``, ``numerical_fields``, and ``target_fields`` are not
provided (e.g. when loading data for inference or validation).
train_dict: The data to use when training.
val_dict: The data to use when validating.
test_dict: The data to use when testing.
predict_dict: The data to use when predicting.
target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to
control how targets are handled. See :ref:`formatting_classification_targets` for more details.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.

Returns:
The constructed :class:`~flash.tabular.classification.data.TabularClassificationData`.

Examples
________

.. testsetup::

>>> train_data = {
... "animal": ["cat", "dog", "cat"],
... "friendly": ["yes", "yes", "no"],
... "weight": [6, 10, 5],
... }
>>> predict_data = {
... "friendly": ["yes", "no", "yes"],
... "weight": [7, 12, 5],
... }

We have dictionaries ``train_data`` and ``predict_data``.

.. doctest::

>>> from flash import Trainer
>>> from flash.tabular import TabularClassifier, TabularClassificationData
>>> datamodule = TabularClassificationData.from_dicts(
... "friendly",
... "weight",
... "animal",
... train_dict=train_data,
... predict_dict=predict_data,
... batch_size=4,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = TabularClassifier.from_data(datamodule, backbone="tabnet")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...

.. testcleanup::

>>> del train_data
>>> del predict_data
"""
ds_kw = dict(
target_formatter=target_formatter,
categorical_fields=categorical_fields,
numerical_fields=numerical_fields,
target_fields=target_fields,
parameters=parameters,
)

train_input = input_cls(RunningStage.TRAINING, train_dict, **ds_kw)
ds_kw["parameters"] = train_input.parameters if train_input else parameters
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
train_input,
input_cls(RunningStage.VALIDATING, val_dict, **ds_kw),
input_cls(RunningStage.TESTING, test_dict, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_dict, **ds_kw),
transform=transform,
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)
18 changes: 18 additions & 0 deletions flash/tabular/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,21 @@ def load_data(
return super().load_data(
read_csv(file), categorical_fields, numerical_fields, target_fields, parameters, target_formatter
)


class TabularClassificationDictInput(TabularClassificationDataFrameInput):
def load_data(
self,
data: Dict[str, Union[Any, List[Any]]],
categorical_fields: Optional[Union[str, List[str]]] = None,
numerical_fields: Optional[Union[str, List[str]]] = None,
target_fields: Optional[Union[str, List[str]]] = None,
parameters: Dict[str, Any] = None,
target_formatter: Optional[TargetFormatter] = None,
):
# Convert the data (dict) to a Pandas DataFrame
data_frame = DataFrame.from_dict(data)

return super().load_data(
data_frame, categorical_fields, numerical_fields, target_fields, parameters, target_formatter
)
116 changes: 115 additions & 1 deletion flash/tabular/regression/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING
from flash.core.utilities.stages import RunningStage
from flash.tabular.data import TabularData
from flash.tabular.regression.input import TabularRegressionCSVInput, TabularRegressionDataFrameInput
from flash.tabular.regression.input import (
TabularRegressionCSVInput,
TabularRegressionDataFrameInput,
TabularRegressionDictInput,
)

if _PANDAS_AVAILABLE:
from pandas.core.frame import DataFrame
Expand Down Expand Up @@ -289,3 +293,113 @@ def from_csv(
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)

@classmethod
def from_dicts(
cls,
categorical_fields: Optional[Union[str, List[str]]] = None,
numerical_fields: Optional[Union[str, List[str]]] = None,
target_field: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None,
train_dict: Optional[Dict[str, List[Any]]] = None,
val_dict: Optional[Dict[str, List[Any]]] = None,
test_dict: Optional[Dict[str, List[Any]]] = None,
predict_dict: Optional[Dict[str, List[Any]]] = None,
input_cls: Type[Input] = TabularRegressionDictInput,
transform: INPUT_TRANSFORM_TYPE = InputTransform,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "TabularRegressionData":
"""Creates a :class:`~flash.tabular.regression.data.TabularRegressionData` object from the given
dictionary.

.. note::

The ``categorical_fields``, ``numerical_fields``, and ``target_field`` do not need to be provided if
``parameters`` are passed instead. These can be obtained from the
:attr:`~flash.tabular.data.TabularData.parameters` attribute of the
:class:`~flash.tabular.data.TabularData` object that contains your training data.

The targets will be extracted from the ``target_field`` in the data frames.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.

Args:
categorical_fields: The fields (column names) in the dictionary containing categorical data.
numerical_fields: The fields (column names) in the dictionary containing numerical data.
target_field: The field (column name) in the dictionary containing the targets.
parameters: Parameters to use if ``categorical_fields``, ``numerical_fields``, and ``target_field`` are not
provided (e.g. when loading data for inference or validation).
train_dict: The dictionary to use when training.
val_dict: The dictionary to use when validating.
test_dict: The dictionary to use when testing.
predict_dict: The dictionary to use when predicting.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.

Returns:
The constructed :class:`~flash.tabular.regression.data.TabularRegressionData`.

Examples
________

.. testsetup::

>>> train_data = {
... "age": [2, 4, 1],
... "animal": ["cat", "dog", "cat"],
... "weight": [6, 10, 5],
... }
>>> predict_data = {
... "animal": ["dog", "dog", "cat"],
... "weight": [7, 12, 5],
... }

We have dictionaries ``train_data`` and ``predict_data``.

.. doctest::

>>> from flash import Trainer
>>> from flash.tabular import TabularRegressor, TabularRegressionData
>>> datamodule = TabularRegressionData.from_dicts(
... "animal",
... "weight",
... "age",
... train_dict=train_data,
... predict_dict=predict_data,
... batch_size=4,
... )
>>> model = TabularRegressor.from_data(datamodule, backbone="tabnet")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...

.. testcleanup::

>>> del train_data
>>> del predict_data
"""
ds_kw = dict(
categorical_fields=categorical_fields,
numerical_fields=numerical_fields,
target_field=target_field,
parameters=parameters,
)

train_input = input_cls(RunningStage.TRAINING, train_dict, **ds_kw)
ds_kw["parameters"] = train_input.parameters if train_input else parameters

return cls(
train_input,
input_cls(RunningStage.VALIDATING, val_dict, **ds_kw),
input_cls(RunningStage.TESTING, test_dict, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_dict, **ds_kw),
transform=transform,
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)
14 changes: 14 additions & 0 deletions flash/tabular/regression/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,17 @@ def load_data(
):
if file is not None:
return super().load_data(read_csv(file), categorical_fields, numerical_fields, target_field, parameters)


class TabularRegressionDictInput(TabularRegressionDataFrameInput):
def load_data(
self,
data: Dict[str, List[Any]],
categorical_fields: Optional[Union[str, List[str]]] = None,
numerical_fields: Optional[Union[str, List[str]]] = None,
target_field: Optional[str] = None,
parameters: Dict[str, Any] = None,
):
data_frame = DataFrame.from_dict(data)

return super().load_data(data_frame, categorical_fields, numerical_fields, target_field, parameters)
52 changes: 36 additions & 16 deletions tests/tabular/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,22 @@
from flash.tabular import TabularClassificationData
from flash.tabular.classification.utils import _categorize, _compute_normalization, _generate_codes, _normalize

TEST_DF_1 = pd.DataFrame(
data={
"category": ["a", "b", "c", "a", None, "c"],
"scalar_a": [0.0, 1.0, 2.0, 3.0, None, 5.0],
"scalar_b": [5.0, 4.0, 3.0, 2.0, None, 1.0],
"label": [0, 1, 0, 1, 0, 1],
}
)
TEST_DICT_1 = {
"category": ["a", "b", "c", "a", None, "c"],
"scalar_a": [0.0, 1.0, 2.0, 3.0, None, 5.0],
"scalar_b": [5.0, 4.0, 3.0, 2.0, None, 1.0],
"label": [0, 1, 0, 1, 0, 1],
}

TEST_DF_2 = pd.DataFrame(
data={
"category": ["d", "e", "f"],
"scalar_a": [0.0, 1.0, 2.0],
"scalar_b": [0.0, 4.0, 2.0],
"label": [0, 1, 0],
}
)
TEST_DICT_2 = {
"category": ["d", "e", "f"],
"scalar_a": [0.0, 1.0, 2.0],
"scalar_b": [0.0, 4.0, 2.0],
"label": [0, 1, 0],
}

TEST_DF_1 = pd.DataFrame(data=TEST_DICT_1)
TEST_DF_2 = pd.DataFrame(data=TEST_DICT_2)


@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required")
Expand Down Expand Up @@ -164,6 +163,27 @@ def test_from_csv(tmpdir):
assert target.shape == (1,)


@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required")
def test_from_dicts():
dm = TabularClassificationData.from_dicts(
categorical_fields=["category"],
numerical_fields=["scalar_a", "scalar_b"],
target_fields="label",
train_dict=TEST_DICT_1,
val_dict=TEST_DICT_2,
test_dict=TEST_DICT_2,
num_workers=0,
batch_size=1,
)
for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]:
data = next(iter(dl))
(cat, num) = data[DataKeys.INPUT]
target = data[DataKeys.TARGET]
assert cat.shape == (1, 1)
assert num.shape == (1, 2)
assert target.shape == (1,)


@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required")
def test_empty_inputs():
train_data_frame = TEST_DF_1.copy()
Expand Down
Loading