Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Tests added, minor fix in class
Browse files Browse the repository at this point in the history
  • Loading branch information
krshrimali committed May 9, 2022
1 parent 337e977 commit e5699aa
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
2 changes: 1 addition & 1 deletion flash/tabular/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def load_data(
)


class TabularClassificationDictInput(TabularDataFrameInput):
class TabularClassificationDictInput(TabularClassificationDataFrameInput):
def load_data(
self,
data: Dict[str, Union[Any, List[Any]]],
Expand Down
52 changes: 36 additions & 16 deletions tests/tabular/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,22 @@
from flash.tabular import TabularClassificationData
from flash.tabular.classification.utils import _categorize, _compute_normalization, _generate_codes, _normalize

TEST_DF_1 = pd.DataFrame(
data={
"category": ["a", "b", "c", "a", None, "c"],
"scalar_a": [0.0, 1.0, 2.0, 3.0, None, 5.0],
"scalar_b": [5.0, 4.0, 3.0, 2.0, None, 1.0],
"label": [0, 1, 0, 1, 0, 1],
}
)
TEST_DICT_1 = {
"category": ["a", "b", "c", "a", None, "c"],
"scalar_a": [0.0, 1.0, 2.0, 3.0, None, 5.0],
"scalar_b": [5.0, 4.0, 3.0, 2.0, None, 1.0],
"label": [0, 1, 0, 1, 0, 1],
}

TEST_DF_2 = pd.DataFrame(
data={
"category": ["d", "e", "f"],
"scalar_a": [0.0, 1.0, 2.0],
"scalar_b": [0.0, 4.0, 2.0],
"label": [0, 1, 0],
}
)
TEST_DICT_2 = {
"category": ["d", "e", "f"],
"scalar_a": [0.0, 1.0, 2.0],
"scalar_b": [0.0, 4.0, 2.0],
"label": [0, 1, 0],
}

TEST_DF_1 = pd.DataFrame(data=TEST_DICT_1)
TEST_DF_2 = pd.DataFrame(data=TEST_DICT_2)


@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required")
Expand Down Expand Up @@ -164,6 +163,27 @@ def test_from_csv(tmpdir):
assert target.shape == (1,)


@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required")
def test_from_dict():
dm = TabularClassificationData.from_dict(
categorical_fields=["category"],
numerical_fields=["scalar_a", "scalar_b"],
target_fields="label",
train_data=TEST_DICT_1,
val_data=TEST_DICT_2,
test_data=TEST_DICT_2,
num_workers=0,
batch_size=1,
)
for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]:
data = next(iter(dl))
(cat, num) = data[DataKeys.INPUT]
target = data[DataKeys.TARGET]
assert cat.shape == (1, 1)
assert num.shape == (1, 2)
assert target.shape == (1,)


@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required")
def test_empty_inputs():
train_data_frame = TEST_DF_1.copy()
Expand Down

0 comments on commit e5699aa

Please sign in to comment.