diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py index 7460f3e3e8..9430b3f5fc 100644 --- a/flash/tabular/classification/data.py +++ b/flash/tabular/classification/data.py @@ -323,10 +323,10 @@ def from_dict( numerical_fields: Optional[Union[str, List[str]]] = None, target_fields: Optional[Union[str, List[str]]] = None, parameters: Optional[Dict[str, Any]] = None, - train_data: Optional[DataFrame] = None, - val_data: Optional[DataFrame] = None, - test_data: Optional[DataFrame] = None, - predict_data: Optional[DataFrame] = None, + train_dict: Optional[Dict[str, List[Any]]] = None, + val_dict: Optional[Dict[str, List[Any]]] = None, + test_dict: Optional[Dict[str, List[Any]]] = None, + predict_dict: Optional[Dict[str, List[Any]]] = None, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = TabularClassificationDictInput, transform: INPUT_TRANSFORM_TYPE = InputTransform, @@ -353,10 +353,10 @@ def from_dict( target_fields: The field (column name) or list of fields in the dictionary containing the targets. parameters: Parameters to use if ``categorical_fields``, ``numerical_fields``, and ``target_fields`` are not provided (e.g. when loading data for inference or validation). - train_data: The data to use when training. - val_data: The data to use when validating. - test_data: The data to use when testing. - predict_data: The data to use when predicting. + train_dict: The data to use when training. + val_dict: The data to use when validating. + test_dict: The data to use when testing. + predict_dict: The data to use when predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. @@ -393,8 +393,8 @@ def from_dict( ... "friendly", ... "weight", ... "animal", - ... train_data=train_data, - ... predict_data=predict_data, + ... train_dict=train_data, + ... predict_dict=predict_data, ... batch_size=4, ... ) >>> datamodule.num_classes @@ -421,15 +421,15 @@ def from_dict( parameters=parameters, ) - train_input = input_cls(RunningStage.TRAINING, train_data, **ds_kw) + train_input = input_cls(RunningStage.TRAINING, train_dict, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, - input_cls(RunningStage.VALIDATING, val_data, **ds_kw), - input_cls(RunningStage.TESTING, test_data, **ds_kw), - input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + input_cls(RunningStage.VALIDATING, val_dict, **ds_kw), + input_cls(RunningStage.TESTING, test_dict, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_dict, **ds_kw), transform=transform, transform_kwargs=transform_kwargs, **data_module_kwargs, diff --git a/flash/tabular/classification/input.py b/flash/tabular/classification/input.py index b66303d428..1b1314d38b 100644 --- a/flash/tabular/classification/input.py +++ b/flash/tabular/classification/input.py @@ -81,10 +81,5 @@ def load_data( data_frame = DataFrame.from_dict(data) return super().load_data( - data=data_frame, - categorical_fields=categorical_fields, - numerical_fields=numerical_fields, - target_fields=target_fields, - parameters=parameters, - target_formatter=target_formatter, + data_frame, categorical_fields, numerical_fields, target_fields, parameters, target_formatter ) diff --git a/flash/tabular/regression/data.py b/flash/tabular/regression/data.py index 67acee7559..b6f23bb5e2 100644 --- a/flash/tabular/regression/data.py +++ b/flash/tabular/regression/data.py @@ -18,7 +18,11 @@ from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING from flash.core.utilities.stages import RunningStage from flash.tabular.data import TabularData -from flash.tabular.regression.input import TabularRegressionCSVInput, TabularRegressionDataFrameInput +from flash.tabular.regression.input import ( + TabularRegressionCSVInput, + TabularRegressionDataFrameInput, + TabularRegressionDictInput, +) if _PANDAS_AVAILABLE: from pandas.core.frame import DataFrame @@ -289,3 +293,113 @@ def from_csv( transform_kwargs=transform_kwargs, **data_module_kwargs, ) + + @classmethod + def from_dict( + cls, + categorical_fields: Optional[Union[str, List[str]]] = None, + numerical_fields: Optional[Union[str, List[str]]] = None, + target_field: Optional[str] = None, + parameters: Optional[Dict[str, Any]] = None, + train_dict: Optional[Dict[str, List[Any]]] = None, + val_dict: Optional[Dict[str, List[Any]]] = None, + test_dict: Optional[Dict[str, List[Any]]] = None, + predict_dict: Optional[Dict[str, List[Any]]] = None, + input_cls: Type[Input] = TabularRegressionDictInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, + transform_kwargs: Optional[Dict] = None, + **data_module_kwargs: Any, + ) -> "TabularRegressionData": + """Creates a :class:`~flash.tabular.regression.data.TabularRegressionData` object from the given + dictionary. + + .. note:: + + The ``categorical_fields``, ``numerical_fields``, and ``target_field`` do not need to be provided if + ``parameters`` are passed instead. These can be obtained from the + :attr:`~flash.tabular.data.TabularData.parameters` attribute of the + :class:`~flash.tabular.data.TabularData` object that contains your training data. + + The targets will be extracted from the ``target_field`` in the data frames. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + categorical_fields: The fields (column names) in the dictionary containing categorical data. + numerical_fields: The fields (column names) in the dictionary containing numerical data. + target_field: The field (column name) in the dictionary containing the targets. + parameters: Parameters to use if ``categorical_fields``, ``numerical_fields``, and ``target_field`` are not + provided (e.g. when loading data for inference or validation). + train_dict: The dictionary to use when training. + val_dict: The dictionary to use when validating. + test_dict: The dictionary to use when testing. + predict_dict: The dictionary to use when predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.tabular.regression.data.TabularRegressionData`. + + Examples + ________ + + .. testsetup:: + + >>> train_data = { + ... "age": [2, 4, 1], + ... "animal": ["cat", "dog", "cat"], + ... "weight": [6, 10, 5], + ... } + >>> predict_data = { + ... "animal": ["dog", "dog", "cat"], + ... "weight": [7, 12, 5], + ... } + + We have dictionaries ``train_data`` and ``predict_data``. + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.tabular import TabularRegressor, TabularRegressionData + >>> datamodule = TabularRegressionData.from_dict( + ... "animal", + ... "weight", + ... "age", + ... train_dict=train_data, + ... predict_dict=predict_data, + ... batch_size=4, + ... ) + >>> model = TabularRegressor.from_data(datamodule, backbone="tabnet") + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> del train_data + >>> del predict_data + """ + ds_kw = dict( + categorical_fields=categorical_fields, + numerical_fields=numerical_fields, + target_field=target_field, + parameters=parameters, + ) + + train_input = input_cls(RunningStage.TRAINING, train_dict, **ds_kw) + ds_kw["parameters"] = train_input.parameters if train_input else parameters + + return cls( + train_input, + input_cls(RunningStage.VALIDATING, val_dict, **ds_kw), + input_cls(RunningStage.TESTING, test_dict, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_dict, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, + **data_module_kwargs, + ) diff --git a/flash/tabular/regression/input.py b/flash/tabular/regression/input.py index 1ff15ca4c1..9d601b8313 100644 --- a/flash/tabular/regression/input.py +++ b/flash/tabular/regression/input.py @@ -55,3 +55,17 @@ def load_data( ): if file is not None: return super().load_data(read_csv(file), categorical_fields, numerical_fields, target_field, parameters) + + +class TabularRegressionDictInput(TabularRegressionDataFrameInput): + def load_data( + self, + data: Dict[str, List[Any]], + categorical_fields: Optional[Union[str, List[str]]] = None, + numerical_fields: Optional[Union[str, List[str]]] = None, + target_field: Optional[str] = None, + parameters: Dict[str, Any] = None, + ): + data_frame = DataFrame.from_dict(data) + + return super().load_data(data_frame, categorical_fields, numerical_fields, target_field, parameters) diff --git a/tests/tabular/classification/test_data.py b/tests/tabular/classification/test_data.py index c235f1a9dc..c91ac25387 100644 --- a/tests/tabular/classification/test_data.py +++ b/tests/tabular/classification/test_data.py @@ -169,9 +169,9 @@ def test_from_dict(): categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", - train_data=TEST_DICT_1, - val_data=TEST_DICT_2, - test_data=TEST_DICT_2, + train_dict=TEST_DICT_1, + val_dict=TEST_DICT_2, + test_dict=TEST_DICT_2, num_workers=0, batch_size=1, ) diff --git a/tests/tabular/regression/test_data_model_integration.py b/tests/tabular/regression/test_data_model_integration.py index 8aedf75ed2..e8a2341276 100644 --- a/tests/tabular/regression/test_data_model_integration.py +++ b/tests/tabular/regression/test_data_model_integration.py @@ -20,14 +20,14 @@ if _TABULAR_AVAILABLE: import pandas as pd - TEST_DF_1 = pd.DataFrame( - data={ - "category": ["a", "b", "c", "a", None, "c"], - "scalar_a": [0.0, 1.0, 2.0, 3.0, None, 5.0], - "scalar_b": [5.0, 4.0, 3.0, 2.0, None, 1.0], - "label": [0.0, 1.0, 2.0, 1.0, 0.0, 1.0], - } - ) + TEST_DICT = { + "category": ["a", "b", "c", "a", None, "c"], + "scalar_a": [0.0, 1.0, 2.0, 3.0, None, 5.0], + "scalar_b": [5.0, 4.0, 3.0, 2.0, None, 1.0], + "label": [0.0, 1.0, 2.0, 1.0, 0.0, 1.0], + } + + TEST_DF = pd.DataFrame(data=TEST_DICT) @pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") @@ -45,10 +45,10 @@ ("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}), ], ) -def test_regression(backbone, fields, tmpdir): - train_data_frame = TEST_DF_1.copy() - val_data_frame = TEST_DF_1.copy() - test_data_frame = TEST_DF_1.copy() +def test_regression_data_frame(backbone, fields, tmpdir): + train_data_frame = TEST_DF.copy() + val_data_frame = TEST_DF.copy() + test_data_frame = TEST_DF.copy() data = TabularRegressionData.from_data_frame( **fields, target_field="label", @@ -61,3 +61,33 @@ def test_regression(backbone, fields, tmpdir): model = TabularRegressor.from_data(datamodule=data, backbone=backbone) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, data) + + +@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") +@pytest.mark.parametrize( + "backbone,fields", + [ + ("tabnet", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), + ("tabtransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), + ("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), + ("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), + ("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), + ("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), + # No categorical / numerical fields + ("tabnet", {"categorical_fields": ["category"]}), + ("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}), + ], +) +def test_regression_dict(backbone, fields, tmpdir): + data = TabularRegressionData.from_dict( + **fields, + target_field="label", + train_dict=TEST_DICT, + val_dict=TEST_DICT, + test_dict=TEST_DICT, + num_workers=0, + batch_size=2, + ) + model = TabularRegressor.from_data(datamodule=data, backbone=backbone) + trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) + trainer.fit(model, data)