diff --git a/flash/tabular/__init__.py b/flash/tabular/__init__.py index 22698efc99..2253248c8b 100644 --- a/flash/tabular/__init__.py +++ b/flash/tabular/__init__.py @@ -1,3 +1,7 @@ -from flash.tabular.classification import TabularClassificationData, TabularClassifier # noqa: F401 +from flash.tabular.classification import ( # noqa: F401 + TabularClassificationData, + TabularClassificationPreprocess, + TabularClassifier, +) from flash.tabular.data import TabularData # noqa: F401 -from flash.tabular.regression import TabularRegressionData # noqa: F401 +from flash.tabular.regression import TabularRegressionData, TabularRegressionPreprocess # noqa: F401 diff --git a/flash/tabular/classification/__init__.py b/flash/tabular/classification/__init__.py index 6134277abf..4ae61a97dd 100644 --- a/flash/tabular/classification/__init__.py +++ b/flash/tabular/classification/__init__.py @@ -1,2 +1,2 @@ -from flash.tabular.classification.data import TabularClassificationData # noqa: F401 +from flash.tabular.classification.data import TabularClassificationData, TabularClassificationPreprocess # noqa: F401 from flash.tabular.classification.model import TabularClassifier # noqa: F401 diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py index 63cdda9ea2..52246752fc 100644 --- a/flash/tabular/classification/data.py +++ b/flash/tabular/classification/data.py @@ -11,7 +11,53 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from flash.tabular.data import TabularData +from typing import Any, Callable, Dict, List, Optional + +from flash.core.data.process import Deserializer +from flash.core.utilities.imports import _PANDAS_AVAILABLE + +if _PANDAS_AVAILABLE: + from pandas.core.frame import DataFrame +else: + DataFrame = object + +from flash.tabular.data import TabularData, TabularPreprocess + + +class TabularClassificationPreprocess(TabularPreprocess): + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + cat_cols: Optional[List[str]] = None, + num_cols: Optional[List[str]] = None, + target_col: Optional[str] = None, + mean: Optional[DataFrame] = None, + std: Optional[DataFrame] = None, + codes: Optional[Dict[str, Any]] = None, + target_codes: Optional[Dict[str, Any]] = None, + classes: Optional[List[str]] = None, + deserializer: Optional[Deserializer] = None + ): + super(TabularClassificationPreprocess, self).__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + cat_cols=cat_cols, + num_cols=num_cols, + target_col=target_col, + mean=mean, + std=std, + codes=codes, + target_codes=target_codes, + classes=classes, + is_regression=False, + deserializer=deserializer + ) class TabularClassificationData(TabularData): diff --git a/flash/tabular/regression/__init__.py b/flash/tabular/regression/__init__.py index a93e599ff0..b722fa05d7 100644 --- a/flash/tabular/regression/__init__.py +++ b/flash/tabular/regression/__init__.py @@ -1 +1 @@ -from flash.tabular.regression.data import TabularRegressionData # noqa: F401 +from flash.tabular.regression.data import TabularRegressionData, TabularRegressionPreprocess # noqa: F401 diff --git a/flash/tabular/regression/data.py b/flash/tabular/regression/data.py index 04dd8cd3b4..503153c8c6 100644 --- a/flash/tabular/regression/data.py +++ b/flash/tabular/regression/data.py @@ -11,8 +11,55 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from flash.tabular.data import TabularData +from typing import Any, Callable, Dict, List, Optional + +from flash.core.data.process import Deserializer +from flash.core.utilities.imports import _PANDAS_AVAILABLE + +if _PANDAS_AVAILABLE: + from pandas.core.frame import DataFrame +else: + DataFrame = object + +from flash.tabular.data import TabularData, TabularPreprocess + + +class TabularRegressionPreprocess(TabularPreprocess): + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + cat_cols: Optional[List[str]] = None, + num_cols: Optional[List[str]] = None, + target_col: Optional[str] = None, + mean: Optional[DataFrame] = None, + std: Optional[DataFrame] = None, + codes: Optional[Dict[str, Any]] = None, + target_codes: Optional[Dict[str, Any]] = None, + classes: Optional[List[str]] = None, + deserializer: Optional[Deserializer] = None + ): + super(TabularRegressionPreprocess, self).__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + cat_cols=cat_cols, + num_cols=num_cols, + target_col=target_col, + mean=mean, + std=std, + codes=codes, + target_codes=target_codes, + classes=classes, + is_regression=False, + deserializer=deserializer + ) class TabularRegressionData(TabularData): is_regression = True + preprocess_cls = TabularRegressionPreprocess