diff --git a/flash/tabular/classification/input.py b/flash/tabular/classification/input.py index 4b7b57b575..b66303d428 100644 --- a/flash/tabular/classification/input.py +++ b/flash/tabular/classification/input.py @@ -67,7 +67,7 @@ def load_data( ) -class TabularClassificationDictInput(TabularDataFrameInput): +class TabularClassificationDictInput(TabularClassificationDataFrameInput): def load_data( self, data: Dict[str, Union[Any, List[Any]]], diff --git a/tests/tabular/classification/test_data.py b/tests/tabular/classification/test_data.py index d4a9c9a164..c235f1a9dc 100644 --- a/tests/tabular/classification/test_data.py +++ b/tests/tabular/classification/test_data.py @@ -26,23 +26,22 @@ from flash.tabular import TabularClassificationData from flash.tabular.classification.utils import _categorize, _compute_normalization, _generate_codes, _normalize - TEST_DF_1 = pd.DataFrame( - data={ - "category": ["a", "b", "c", "a", None, "c"], - "scalar_a": [0.0, 1.0, 2.0, 3.0, None, 5.0], - "scalar_b": [5.0, 4.0, 3.0, 2.0, None, 1.0], - "label": [0, 1, 0, 1, 0, 1], - } - ) + TEST_DICT_1 = { + "category": ["a", "b", "c", "a", None, "c"], + "scalar_a": [0.0, 1.0, 2.0, 3.0, None, 5.0], + "scalar_b": [5.0, 4.0, 3.0, 2.0, None, 1.0], + "label": [0, 1, 0, 1, 0, 1], + } - TEST_DF_2 = pd.DataFrame( - data={ - "category": ["d", "e", "f"], - "scalar_a": [0.0, 1.0, 2.0], - "scalar_b": [0.0, 4.0, 2.0], - "label": [0, 1, 0], - } - ) + TEST_DICT_2 = { + "category": ["d", "e", "f"], + "scalar_a": [0.0, 1.0, 2.0], + "scalar_b": [0.0, 4.0, 2.0], + "label": [0, 1, 0], + } + + TEST_DF_1 = pd.DataFrame(data=TEST_DICT_1) + TEST_DF_2 = pd.DataFrame(data=TEST_DICT_2) @pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required") @@ -164,6 +163,27 @@ def test_from_csv(tmpdir): assert target.shape == (1,) +@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required") +def test_from_dict(): + dm = TabularClassificationData.from_dict( + categorical_fields=["category"], + numerical_fields=["scalar_a", "scalar_b"], + target_fields="label", + train_data=TEST_DICT_1, + val_data=TEST_DICT_2, + test_data=TEST_DICT_2, + num_workers=0, + batch_size=1, + ) + for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: + data = next(iter(dl)) + (cat, num) = data[DataKeys.INPUT] + target = data[DataKeys.TARGET] + assert cat.shape == (1, 1) + assert num.shape == (1, 2) + assert target.shape == (1,) + + @pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required") def test_empty_inputs(): train_data_frame = TEST_DF_1.copy()