Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

fix: set inputs as optional #109

Merged
merged 18 commits into from
Feb 13, 2021
Merged
19 changes: 14 additions & 5 deletions flash/tabular/classification/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ class TabularData(DataModule):
def __init__(
self,
train_df: DataFrame,
target: str,
categorical_input: List,
numerical_input: List,
target: str,
valid_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
batch_size: int = 2,
Expand All @@ -82,6 +82,15 @@ def __init__(
dfs = [train_df]
self._test_df = None

if not categorical_input and not numerical_input:
raise RuntimeError('Both `categorical_input` and `numerical_input` are None!')

if categorical_input is None:
categorical_input = []

if numerical_input is None:
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
numerical_input = []

if valid_df is not None:
dfs.append(valid_df)

Expand Down Expand Up @@ -133,8 +142,8 @@ def from_df(
cls,
train_df: DataFrame,
target: str,
categorical_input: List,
numerical_input: List,
categorical_input: List = (),
numerical_input: List = (),
valid_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
batch_size: int = 8,
Expand Down Expand Up @@ -194,8 +203,8 @@ def from_csv(
cls,
train_csv: str,
target: str,
categorical_input: List,
numerical_input: List,
categorical_input: List = (),
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
numerical_input: List = (),
valid_csv: Optional[str] = None,
test_csv: Optional[str] = None,
batch_size: int = 8,
Expand Down
9 changes: 9 additions & 0 deletions tests/tabular/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
import pandas as pd
import pytest

from flash.tabular import TabularData
from flash.tabular.classification.data.dataset import _categorize, _normalize
Expand Down Expand Up @@ -169,3 +170,11 @@ def test_from_csv(tmpdir):
assert cat.shape == (1, 1)
assert num.shape == (1, 2)
assert target.shape == (1, )


def test_empty_inputs():
train_df = TEST_DF_1.copy()
with pytest.raises(RuntimeError):
TabularData.from_df(
train_df, categorical_input=[], numerical_input=[], target="label", num_workers=0, batch_size=1
)