Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

fix: set inputs as optional #109

Merged
merged 18 commits into from
Feb 13, 2021
Merged
18 changes: 12 additions & 6 deletions flash/tabular/classification/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ class TabularData(DataModule):
def __init__(
self,
train_df: DataFrame,
categorical_input: List,
numerical_input: List,
target: str,
categorical_input: Optional[List] = None,
numerical_input: Optional[List] = None,
valid_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
batch_size: int = 2,
Expand All @@ -82,6 +82,12 @@ def __init__(
dfs = [train_df]
self._test_df = None

if categorical_input is None and numerical_input is None:
raise RuntimeError('Both `categorical_input` and `numerical_input` are None!')

categorical_input = categorical_input if categorical_input is not None else []
numerical_input = numerical_input if numerical_input is not None else []

if valid_df is not None:
dfs.append(valid_df)

Expand Down Expand Up @@ -133,8 +139,8 @@ def from_df(
cls,
train_df: DataFrame,
target: str,
categorical_input: List,
numerical_input: List,
categorical_input: Optional[List] = None,
numerical_input: Optional[List] = None,
valid_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
batch_size: int = 8,
Expand Down Expand Up @@ -194,8 +200,8 @@ def from_csv(
cls,
train_csv: str,
target: str,
categorical_input: List,
numerical_input: List,
categorical_input: Optional[List] = None,
numerical_input: Optional[List] = None,
valid_csv: Optional[str] = None,
test_csv: Optional[str] = None,
batch_size: int = 8,
Expand Down
9 changes: 9 additions & 0 deletions tests/tabular/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
import pandas as pd
import pytest

from flash.tabular import TabularData
from flash.tabular.classification.data.dataset import _categorize, _normalize
Expand Down Expand Up @@ -169,3 +170,11 @@ def test_from_csv(tmpdir):
assert cat.shape == (1, 1)
assert num.shape == (1, 2)
assert target.shape == (1, )


def test_empty_inputs():
train_df = TEST_DF_1.copy()
with pytest.raises(RuntimeError):
TabularData.from_df(
train_df, categorical_input=None, numerical_input=None, target="label", num_workers=0, batch_size=1
)