Skip to content

Commit

Permalink
Added TabularRegressionData extending TabularData (Lightning-Universe…
Browse files Browse the repository at this point in the history
…#574)

* added TabularClassificationData,TabularRegressionData extending TabularData

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update flash/tabular/regression/data.py

Co-authored-by: thomas chaton <[email protected]>

* Update flash/tabular/classification/data.py

Co-authored-by: thomas chaton <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added TabularClassificationData,TabularRegressionData extending TabularData

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* PEP8 fix

* modified tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas chaton <[email protected]>
  • Loading branch information
3 people authored Jul 13, 2021
1 parent f7a86ea commit c318e4a
Show file tree
Hide file tree
Showing 11 changed files with 553 additions and 520 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,13 @@ To illustrate, say we want to build a model to predict if a passenger survived o
from torchmetrics.classification import Accuracy, Precision, Recall
import flash
from flash.core.data.utils import download_data
from flash.tabular import TabularClassifier, TabularData
from flash.tabular import TabularClassifier, TabularClassificationData

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')

# 2. Load the data
datamodule = TabularData.from_csv(
datamodule = TabularClassificationData.from_csv(
["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
"Fare",
target_fields="Survived",
Expand Down
4 changes: 3 additions & 1 deletion flash/tabular/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from flash.tabular.classification import TabularClassifier, TabularData # noqa: F401
from flash.tabular.classification import TabularClassificationData, TabularClassifier # noqa: F401
from flash.tabular.data import TabularData # noqa: F401
from flash.tabular.regression import TabularRegressionData # noqa: F401
2 changes: 1 addition & 1 deletion flash/tabular/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from flash.tabular.classification.data import TabularData # noqa: F401
from flash.tabular.classification.data import TabularClassificationData # noqa: F401
from flash.tabular.classification.model import TabularClassifier # noqa: F401
Loading

0 comments on commit c318e4a

Please sign in to comment.