Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 14, 2021
1 parent cecfebb commit 4c480f1
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 73 deletions.
4 changes: 2 additions & 2 deletions flash/tabular/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from flash.tabular.classification import ( # noqa: F401
TabularClassificationData,
TabularClassificationPreprocess,
TabularClassificationDataFrameDataSource,
TabularClassificationPreprocess,
TabularClassifier,
)
from flash.tabular.data import TabularData # noqa: F401
from flash.tabular.regression import ( # noqa: F401
TabularRegressionData,
TabularRegressionPreprocess,
TabularRegressionDataFrameDataSource,
TabularRegressionPreprocess,
)
3 changes: 1 addition & 2 deletions flash/tabular/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from flash.tabular.classification.data import ( # noqa: F401
TabularClassificationData,
TabularClassificationDataFrameDataSource,
TabularClassificationPreprocess,
TabularClassificationDataFrameDataSource
)

from flash.tabular.classification.model import TabularClassifier # noqa: F401
21 changes: 11 additions & 10 deletions flash/tabular/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@
else:
DataFrame = object

from flash.tabular.data import TabularData, TabularPreprocess, TabularDataFrameDataSource
from flash.tabular.data import TabularData, TabularDataFrameDataSource, TabularPreprocess


class TabularClassificationDataFrameDataSource(TabularDataFrameDataSource):

def __init__(
self,
cat_cols: Optional[List[str]] = None,
num_cols: Optional[List[str]] = None,
target_col: Optional[str] = None,
mean: Optional[DataFrame] = None,
std: Optional[DataFrame] = None,
codes: Optional[Dict[str, Any]] = None,
target_codes: Optional[Dict[str, Any]] = None,
classes: Optional[List[str]] = None
self,
cat_cols: Optional[List[str]] = None,
num_cols: Optional[List[str]] = None,
target_col: Optional[str] = None,
mean: Optional[DataFrame] = None,
std: Optional[DataFrame] = None,
codes: Optional[Dict[str, Any]] = None,
target_codes: Optional[Dict[str, Any]] = None,
classes: Optional[List[str]] = None
):
super(TabularClassificationDataFrameDataSource, self).__init__(
cat_cols=cat_cols,
Expand Down
2 changes: 1 addition & 1 deletion flash/tabular/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from flash.tabular.regression.data import ( # noqa: F401
TabularRegressionData,
TabularRegressionDataFrameDataSource,
TabularRegressionPreprocess,
TabularRegressionDataFrameDataSource
)
127 changes: 69 additions & 58 deletions flash/tabular/regression/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,27 @@
else:
DataFrame = object

from flash.tabular.data import TabularData, TabularPreprocess, TabularDataFrameDataSource
from pytorch_forecasting import TimeSeriesDataSet

from flash.tabular.data import TabularData, TabularDataFrameDataSource, TabularPreprocess


class TabularRegressionDataFrameDataSource(TabularDataFrameDataSource):

def __init__(
self,
time_idx: str,
target: Union[str, List[str]],
group_ids: List[str],
cat_cols: Optional[List[str]] = None,
num_cols: Optional[List[str]] = None,
target_col: Optional[str] = None,
mean: Optional[DataFrame] = None,
std: Optional[DataFrame] = None,
codes: Optional[Dict[str, Any]] = None,
target_codes: Optional[Dict[str, Any]] = None,
classes: Optional[List[str]] = None,
**data_source_kwargs: Any
self,
time_idx: str,
target: Union[str, List[str]],
group_ids: List[str],
cat_cols: Optional[List[str]] = None,
num_cols: Optional[List[str]] = None,
target_col: Optional[str] = None,
mean: Optional[DataFrame] = None,
std: Optional[DataFrame] = None,
codes: Optional[Dict[str, Any]] = None,
target_codes: Optional[Dict[str, Any]] = None,
classes: Optional[List[str]] = None,
**data_source_kwargs: Any
):
self.time_idx = time_idx
self.target = target
Expand All @@ -59,28 +61,29 @@ def __init__(
)

def load_data(self, data: DataFrame, dataset: Optional[Any] = None):
return TimeSeriesDataSet(data, time_idx=self.time_idx, group_ids=self.group_ids, target=self.target,
**self.data_source_kwargs)
return TimeSeriesDataSet(
data, time_idx=self.time_idx, group_ids=self.group_ids, target=self.target, **self.data_source_kwargs
)


class TabularRegressionPreprocess(TabularPreprocess):

def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
cat_cols: Optional[List[str]] = None,
num_cols: Optional[List[str]] = None,
target_col: Optional[str] = None,
mean: Optional[DataFrame] = None,
std: Optional[DataFrame] = None,
codes: Optional[Dict[str, Any]] = None,
target_codes: Optional[Dict[str, Any]] = None,
classes: Optional[List[str]] = None,
deserializer: Optional[Deserializer] = None,
**data_source_kwargs: Any
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
cat_cols: Optional[List[str]] = None,
num_cols: Optional[List[str]] = None,
target_col: Optional[str] = None,
mean: Optional[DataFrame] = None,
std: Optional[DataFrame] = None,
codes: Optional[Dict[str, Any]] = None,
target_codes: Optional[Dict[str, Any]] = None,
classes: Optional[List[str]] = None,
deserializer: Optional[Deserializer] = None,
**data_source_kwargs: Any
):
super(TabularRegressionPreprocess, self).__init__(
train_transform=train_transform,
Expand All @@ -89,7 +92,15 @@ def __init__(
predict_transform=predict_transform,
data_sources={
"data_frame": TabularRegressionDataFrameDataSource(
cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression=True,
cat_cols,
num_cols,
target_col,
mean,
std,
codes,
target_codes,
classes,
is_regression=True,
**data_source_kwargs
),
},
Expand All @@ -112,32 +123,32 @@ class TabularRegressionData(TabularData):

@classmethod
def from_data_frame(
cls,
group_ids: Optional[List[str]] = None,
target: Optional[str] = None,
time_idx: Optional[str] = None,
categorical_fields: Optional[Union[str, List[str]]] = None,
numerical_fields: Optional[Union[str, List[str]]] = None,
target_fields: Optional[str] = None,
train_data_frame: Optional[DataFrame] = None,
val_data_frame: Optional[DataFrame] = None,
test_data_frame: Optional[DataFrame] = None,
predict_data_frame: Optional[DataFrame] = None,
min_encoder_length: Optional[int] = None,
max_encoder_length: Optional[int] = None,
min_prediction_length: Optional[int] = None,
max_prediction_length: Optional[int] = None,
time_varying_unknown_reals: Optional[List[str]] = None,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[TabularRegressionPreprocess] = None,
val_split: Optional[float] = None,
batch_size: int = None,
num_workers: Optional[int] = None,
**preprocess_kwargs: Any,
cls,
group_ids: Optional[List[str]] = None,
target: Optional[str] = None,
time_idx: Optional[str] = None,
categorical_fields: Optional[Union[str, List[str]]] = None,
numerical_fields: Optional[Union[str, List[str]]] = None,
target_fields: Optional[str] = None,
train_data_frame: Optional[DataFrame] = None,
val_data_frame: Optional[DataFrame] = None,
test_data_frame: Optional[DataFrame] = None,
predict_data_frame: Optional[DataFrame] = None,
min_encoder_length: Optional[int] = None,
max_encoder_length: Optional[int] = None,
min_prediction_length: Optional[int] = None,
max_prediction_length: Optional[int] = None,
time_varying_unknown_reals: Optional[List[str]] = None,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[TabularRegressionPreprocess] = None,
val_split: Optional[float] = None,
batch_size: int = None,
num_workers: Optional[int] = None,
**preprocess_kwargs: Any,
):
super().from_data_frame(
time_idx=time_idx,
Expand Down

0 comments on commit 4c480f1

Please sign in to comment.