Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

DataPipeline PoC #141

Merged
merged 191 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from 185 commits
Commits
Show all changes
191 commits
Select commit Hold shift + click to select a range
45691cd
add prototype of DataPipeline
justusschock Feb 18, 2021
135eb17
Add Prototype of PostProcessingPipeline
justusschock Feb 18, 2021
535353c
isort + pep8
justusschock Feb 18, 2021
f66f223
update post_processing_pipeline
justusschock Feb 20, 2021
67de76f
update data pipline
justusschock Feb 20, 2021
3be12a3
add new prediction part
justusschock Feb 20, 2021
17cecb8
change loader name
justusschock Feb 22, 2021
be4f505
update
tchaton Feb 22, 2021
2e2fa54
uypdate new datapipeline
justusschock Feb 23, 2021
fc34775
update model with new pipeline
justusschock Feb 23, 2021
b417683
update
tchaton Feb 23, 2021
307b210
update gitignore
justusschock Feb 23, 2021
9dc842a
add autodataset
justusschock Feb 23, 2021
77f935c
add batch processing
justusschock Feb 23, 2021
7606f93
Merge branch 'datapipeline_poc_1' of github.com:PyTorchLightning/ligh…
justusschock Feb 23, 2021
dd68bf3
update
tchaton Feb 24, 2021
69c5bc1
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Feb 24, 2021
b5b3ad0
update
tchaton Feb 24, 2021
040c3be
update
tchaton Feb 25, 2021
4edee9c
add process file
justusschock Feb 27, 2021
327f19c
make datapipeline attaching and detaching more robust
justusschock Feb 27, 2021
95e809c
resolve flake8
tchaton Feb 28, 2021
3780522
update
tchaton Feb 28, 2021
966b1a9
push curr state
justusschock Mar 2, 2021
41a5e71
Update flash/data/batch.py
justusschock Mar 4, 2021
d6b7347
resolve some bugs
tchaton Mar 8, 2021
8e96e7e
tests
justusschock Mar 8, 2021
4067356
tests
justusschock Mar 8, 2021
d40b8c9
update
tchaton Mar 8, 2021
4cbcff8
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 8, 2021
f8a3580
make everything nn.Module and check serialization
justusschock Mar 8, 2021
e388659
resolve kornia example
tchaton Mar 10, 2021
92df45a
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 10, 2021
38d2574
add prototype of DataPipeline
justusschock Feb 18, 2021
6ff8c1c
Add Prototype of PostProcessingPipeline
justusschock Feb 18, 2021
f73f45b
isort + pep8
justusschock Feb 18, 2021
66f6562
update post_processing_pipeline
justusschock Feb 20, 2021
07ab337
update data pipline
justusschock Feb 20, 2021
f92b3cb
add new prediction part
justusschock Feb 20, 2021
0e7ee40
change loader name
justusschock Feb 22, 2021
e03b0b1
update
tchaton Feb 22, 2021
e3c1582
update
tchaton Feb 23, 2021
1c915ca
update
tchaton Feb 24, 2021
9906ad4
uypdate new datapipeline
justusschock Feb 23, 2021
fba408e
update model with new pipeline
justusschock Feb 23, 2021
99ebec8
update gitignore
justusschock Feb 23, 2021
a68419c
add autodataset
justusschock Feb 23, 2021
ac25999
add batch processing
justusschock Feb 23, 2021
70ba492
update
tchaton Feb 24, 2021
0bb5fdc
update
tchaton Feb 25, 2021
ac50910
add process file
justusschock Feb 27, 2021
97a8e4e
make datapipeline attaching and detaching more robust
justusschock Feb 27, 2021
5ef3f7d
resolve flake8
tchaton Feb 28, 2021
b7275de
update
tchaton Feb 28, 2021
f7a1966
push curr state
justusschock Mar 2, 2021
d84e53f
Update flash/data/batch.py
justusschock Mar 4, 2021
31d9b6d
resolve some bugs
tchaton Mar 8, 2021
d3a7cd7
update
tchaton Mar 8, 2021
eaee810
tests
justusschock Mar 8, 2021
f7f8642
resolve kornia example
tchaton Mar 10, 2021
e290159
make everything nn.Module and check serialization
justusschock Mar 8, 2021
6133ef8
rebase_fixes
justusschock Mar 10, 2021
a5289d4
add more tests
tchaton Mar 10, 2021
9f144f4
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 10, 2021
f962401
update tabular
tchaton Mar 11, 2021
59bbe09
add new hooks
tchaton Mar 11, 2021
5604042
update tabular
tchaton Mar 12, 2021
fba7e96
update
tchaton Mar 12, 2021
ef6a9fd
Move func to data module
justusschock Mar 13, 2021
08bab33
fix vision to current version
justusschock Mar 13, 2021
07fb5e6
transfer text classification to new API
justusschock Mar 13, 2021
b744741
add more tests
tchaton Mar 13, 2021
7b782e1
update
justusschock Mar 13, 2021
1abee8a
resolve most bugs
tchaton Mar 14, 2021
0b00b22
address most comments
tchaton Mar 14, 2021
4d15e94
remove kornia example
tchaton Mar 14, 2021
a598e99
add support for summurization example
tchaton Mar 14, 2021
e8968a7
work with ObjectDetection
tchaton Mar 14, 2021
1ea587c
Update gitignore
kaushikb11 Mar 14, 2021
0ae1729
updates
tchaton Mar 14, 2021
d1da35f
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 14, 2021
3c2c08b
resolve bug
tchaton Mar 14, 2021
ab887be
update
justusschock Mar 15, 2021
ef91f81
resolve image embedder
tchaton Mar 15, 2021
c5bfc41
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 15, 2021
b2b6b54
Update Image Classifer
kaushikb11 Mar 15, 2021
382feb5
Renaming
kaushikb11 Mar 15, 2021
59365f4
fix recursion
justusschock Mar 16, 2021
b1951c8
resolve bug
tchaton Mar 16, 2021
214dd58
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 16, 2021
e8a7842
Merge branch 'datapipeline_poc_1' of github.com:PyTorchLightning/ligh…
justusschock Mar 16, 2021
a18d745
Fix DataPipeline function resolution
justusschock Mar 16, 2021
1903fa7
put back properties instead of attributes
justusschock Mar 16, 2021
832663e
fix import
justusschock Mar 16, 2021
0187b13
fix examples
justusschock Mar 16, 2021
cc4b0d5
add checks for loading
justusschock Mar 16, 2021
e55899f
fix recursion
justusschock Mar 16, 2021
2443720
fix seq2seq dataset
justusschock Mar 16, 2021
f67b209
fix dm init in tests
justusschock Mar 16, 2021
27aa8b4
fix data parts
justusschock Mar 16, 2021
3969aa0
resolve tests and flake8
tchaton Mar 17, 2021
af7beef
Merge branch 'master' into datapipeline_poc_1
tchaton Mar 17, 2021
8b73caa
update on comments
tchaton Mar 18, 2021
23ba639
update notebooks
tchaton Mar 18, 2021
214ada8
devel
Borda Mar 18, 2021
a656fc9
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
Borda Mar 18, 2021
ca017da
update
tchaton Mar 18, 2021
8f276b2
update
tchaton Mar 18, 2021
44ffd16
update
tchaton Mar 18, 2021
3c1e433
resolve the doc
tchaton Mar 18, 2021
1471fb0
update
tchaton Mar 18, 2021
d0e599c
don't apply flake8 on notebook
tchaton Mar 18, 2021
9c24add
resolve tests
tchaton Mar 18, 2021
d16b9fd
comment a notebook
tchaton Mar 18, 2021
7baf1cb
update
tchaton Mar 18, 2021
7bd5700
update ci
tchaton Mar 18, 2021
c324c34
add fixes
tchaton Mar 18, 2021
6eabc01
updaet
tchaton Mar 18, 2021
dd35707
Merge branch 'master' into datapipeline_poc_1
tchaton Mar 18, 2021
e4917ed
update with lightning
tchaton Mar 22, 2021
0b170b3
add a test for flash_special_arguments
tchaton Mar 22, 2021
7083679
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 22, 2021
2f381ef
add data_pipeline
tchaton Mar 22, 2021
465522d
update ci
tchaton Mar 22, 2021
819c018
delete generate .py file
tchaton Mar 22, 2021
2b4756d
update bolts
tchaton Mar 22, 2021
d291f12
udpate ci
tchaton Mar 22, 2021
ffdd258
update
tchaton Mar 22, 2021
f183f19
Merge branch 'master' into data_pipeline_1_n
tchaton Mar 22, 2021
2e7bc4b
Update flash/data/auto_dataset.py
tchaton Mar 22, 2021
2c1e412
update
tchaton Mar 22, 2021
b8d2abc
Merge branch 'data_pipeline_1_n' of https://github.com/PyTorchLightni…
tchaton Mar 22, 2021
d278382
Update tests/data/test_data_pipeline.py
tchaton Mar 22, 2021
0e32fa1
update
tchaton Mar 22, 2021
eba35f6
Merge branch 'data_pipeline_1_n' of https://github.com/PyTorchLightni…
tchaton Mar 22, 2021
8bea3dd
update
tchaton Mar 22, 2021
2990b0b
add some docstring
tchaton Mar 23, 2021
276cf40
update docstring
tchaton Mar 23, 2021
06e5a09
update on comments
tchaton Mar 23, 2021
913bb45
Fixes
carmocca Mar 24, 2021
98aa56d
Docs
carmocca Mar 24, 2021
58c147f
Docs
carmocca Mar 24, 2021
98f75c4
Merge branch 'master' into data_pipeline_1_n
carmocca Mar 24, 2021
84ce3b1
update ci
tchaton Mar 24, 2021
86669c6
update on comments
tchaton Mar 25, 2021
54d0fc3
Update flash/data/batch.py
tchaton Mar 25, 2021
637ff25
Update flash/data/data_module.py
kaushikb11 Mar 25, 2021
dd3dfdb
Update flash/data/process.py
kaushikb11 Mar 25, 2021
4c487a9
Apply suggestions from code review
Borda Mar 25, 2021
ab96ac7
cleaning
Borda Mar 25, 2021
fddd6b1
Merge branch 'data_pipeline_1_n' into datapipeline_poc_1
tchaton Mar 25, 2021
51ea5d9
add pip install
tchaton Mar 25, 2021
0f1f15f
switch back to master
tchaton Mar 25, 2021
23aaebf
update requierements
tchaton Mar 25, 2021
41dd86c
try
Borda Mar 25, 2021
7d8c955
try
Borda Mar 25, 2021
8451011
try
Borda Mar 25, 2021
40a6b33
update
tchaton Mar 25, 2021
82fcace
prune legacy
tchaton Mar 25, 2021
fc045b1
Merge branch 'data_pipeline_1_n' into datapipeline_poc_1
tchaton Mar 25, 2021
13db92e
Merge branch 'master' into datapipeline_poc_1
tchaton Mar 25, 2021
176f14b
update
tchaton Mar 25, 2021
f4410de
Merge branch 'master' into datapipeline_poc_1
tchaton Mar 26, 2021
d8367a2
update
tchaton Mar 26, 2021
e89037f
update to latest
tchaton Mar 26, 2021
095bdbe
delete extra files
tchaton Mar 26, 2021
a659c3d
updates to Task class
kaushikb11 Mar 26, 2021
b17b413
Update Datamodule
kaushikb11 Mar 26, 2021
1877733
resolve comments
tchaton Mar 26, 2021
a434a68
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 26, 2021
bda0ca3
update
tchaton Mar 26, 2021
b81a204
update
tchaton Mar 26, 2021
8043153
update
tchaton Mar 26, 2021
9297c5b
update
tchaton Mar 26, 2021
4f1e87d
try
tchaton Mar 26, 2021
d00382e
update
tchaton Mar 26, 2021
f11002c
udpate
tchaton Mar 26, 2021
6e05051
update
tchaton Mar 26, 2021
e72c1c3
update
tchaton Mar 26, 2021
9289632
update
tchaton Mar 26, 2021
77e3e0e
formatting
Borda Mar 28, 2021
d72d1b7
update on comments
tchaton Mar 29, 2021
503b62a
update on comments
tchaton Mar 29, 2021
25bd225
Merge branch 'master' into datapipeline_poc_1
tchaton Mar 29, 2021
2c58006
General changes
carmocca Mar 29, 2021
816c010
General changes
carmocca Mar 29, 2021
fbfb71f
update
tchaton Mar 29, 2021
d3d4c78
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 29, 2021
0245d17
update
tchaton Mar 29, 2021
e2f24dc
add _data_pipeline back
tchaton Mar 29, 2021
de3327b
update
tchaton Mar 29, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/ci-notebook.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ jobs:
# Look to see if there is a cache hit for the corresponding requirements file
key: flash-datasets_predict

#- name: Run Notebooks
# run: |
# jupyter nbconvert --to script flash_notebooks/image_classification.ipynb
# jupyter nbconvert --to script flash_notebooks/tabular_classification.ipynb
#
# ipython flash_notebooks/image_classification.py
# ipython flash_notebooks/tabular_classification.py
- name: Run Notebooks
run: |
# jupyter nbconvert --to script flash_notebooks/image_classification.ipynb
jupyter nbconvert --to script flash_notebooks/tabular_classification.ipynb

# ipython flash_notebooks/image_classification.py
ipython flash_notebooks/tabular_classification.py
2 changes: 1 addition & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ jobs:
run: |
python --version
pip --version
pip install -e . --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install --requirement requirements/devel.txt --pre --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install -e . --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip list
shell: bash

Expand Down
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ flash_notebooks/*.py
flash_notebooks/data
MNIST*
titanic
coco128
hymenoptera_data
xsum
imdb
xsum
coco128
wmt_en_ro
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')
datamodule = TabularData.from_csv(
"./data/titanic/titanic.csv",
test_csv="./data/titanic/test.csv",
categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
numerical_input=["Fare"],
cat_cols=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
num_cols=["Fare"],
tchaton marked this conversation as resolved.
Show resolved Hide resolved
target="Survived",
val_size=0.25,
)
Expand Down
27 changes: 0 additions & 27 deletions docs/source/custom_task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,33 +67,6 @@ for the prediction of diabetes disease progression. We can create this
``DataModule`` below, wrapping the scikit-learn `Diabetes
dataset <https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset>`__.

.. testcode::

class DiabetesPipeline(flash.core.data.TaskDataPipeline):
def after_uncollate(self, samples):
return [f"disease progression: {float(s):.2f}" for s in samples]

class DiabetesData(flash.DataModule):
def __init__(self, batch_size=64, num_workers=0):
x, y = datasets.load_diabetes(return_X_y=True)
x = torch.from_numpy(x).float()
y = torch.from_numpy(y).float().unsqueeze(1)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0)

train_ds = TensorDataset(x_train, y_train)
test_ds = TensorDataset(x_test, y_test)

super().__init__(
train_ds=train_ds,
test_ds=test_ds,
batch_size=batch_size,
num_workers=num_workers
)
self.num_inputs = x.shape[1]

@staticmethod
def default_pipeline():
return DiabetesPipeline()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

You’ll notice we added a ``DataPipeline``, which will be used when we
call ``.predict()`` on our model. In this case we want to nicely format
Expand Down
50 changes: 3 additions & 47 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,50 +7,6 @@ Data
DataPipeline
------------

To make tasks work for inference, one must create a ``DataPipeline``.
The ``flash.core.data.DataPipeline`` exposes 6 hooks to override:

.. code:: python

class DataPipeline:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
This class purpose is to facilitate the conversion of raw data to processed or batched data and back.
Several hooks are provided for maximum flexibility.

collate_fn:
- before_collate
- collate
- after_collate

uncollate_fn:
- before_uncollate
- uncollate
- after_uncollate
"""

def before_collate(self, samples: Any) -> Any:
"""Override to apply transformations to samples"""
return samples

def collate(self, samples: Any) -> Any:
"""Override to convert a set of samples to a batch"""
if not isinstance(samples, Tensor):
return default_collate(samples)
return samples

def after_collate(self, batch: Any) -> Any:
"""Override to apply transformations to the batch"""
return batch

def before_uncollate(self, batch: Any) -> Any:
"""Override to apply transformations to the batch"""
return batch

def uncollate(self, batch: Any) -> ny:
"""Override to convert a batch to a set of samples"""
samples = batch
return samples

def after_uncollate(self, samples: Any) -> Any:
"""Override to apply transformations to samples"""
return samplesA
To make tasks work for inference, one must create a ``Preprocess`` and ``PostProcess``.
The ``flash.data.process.Preprocess`` exposes 9 hooks to override which can specifialzed for each stage using
``train``, ``val``, ``test``, ``predict`` prefixes.
4 changes: 1 addition & 3 deletions docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Use the :class:`~flash.vision.ImageClassifier` pretrained model for inference on
print(predictions)

# 3b. Or generate predictions with a whole folder!
datamodule = ImageClassificationData.from_folder(folder="data/hymenoptera_data/predict/")
datamodule = ImageClassificationData.from_folders(predict_folder="data/hymenoptera_data/predict/")
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)

Expand Down Expand Up @@ -185,5 +185,3 @@ ImageClassificationData
.. automethod:: flash.vision.ImageClassificationData.from_filepaths

.. automethod:: flash.vision.ImageClassificationData.from_folders

.. automethod:: flash.vision.ImageClassificationData.from_folder
13 changes: 6 additions & 7 deletions docs/source/reference/tabular_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ We can use the Flash Tabular classification task to predict the probability a pa
We can create :class:`~flash.tabular.TabularData` from csv files using the :func:`~flash.tabular.TabularData.from_csv` method. We will pass in:

* **train_csv**- csv file containing the training data converted to a Pandas DataFrame
* **categorical_input**- a list of the names of columns that contain categorical data (strings or integers)
* **numerical_input**- a list of the names of columns that contain numerical continuous data (floats)
* **cat_cols**- a list of the names of columns that contain categorical data (strings or integers)
* **num_cols**- a list of the names of columns that contain numerical continuous data (floats)
Comment on lines +38 to +39
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not keep categorical_ and numerical_ when we have extremely long names elsewhere

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to start reducing it, and there were inconsistency inside the code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does it mean "Better to start reducing it"? to make it shorter? well, it was you who added new long names in the very last PR...

* **target**- the name of the column we want to predict


Expand All @@ -56,8 +56,8 @@ Next, we create the :class:`~flash.tabular.TabularClassifier` task, using the Da
datamodule = TabularData.from_csv(
"./data/titanic/titanic.csv",
test_csv="./data/titanic/test.csv",
categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
numerical_input=["Fare"],
cat_cols=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
num_cols=["Fare"],
tchaton marked this conversation as resolved.
Show resolved Hide resolved
target="Survived",
val_size=0.25,
)
Expand Down Expand Up @@ -120,8 +120,8 @@ Or you can finetune your own model and use that for prediction:
datamodule = TabularData.from_csv(
"my_data_file.csv",
test_csv="./data/titanic/test.csv",
categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
numerical_input=["Fare"],
cat_cols=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
num_cols=["Fare"],
target="Survived",
val_size=0.25,
)
Expand Down Expand Up @@ -166,4 +166,3 @@ TabularData
.. automethod:: flash.tabular.TabularData.from_csv

.. automethod:: flash.tabular.TabularData.from_df

7 changes: 2 additions & 5 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from flash import tabular, text, vision # noqa: E402
from flash.core import data, utils # noqa: E402
from flash.core.classification import ClassificationTask # noqa: E402
from flash.core.data import DataModule # noqa: E402
from flash.core.data.utils import download_data # noqa: E402
from flash.core.model import Task # noqa: E402
from flash.core.trainer import Trainer # noqa: E402
from flash.data.data_module import DataModule # noqa: E402
from flash.data.utils import download_data # noqa: E402

__all__ = [
"Task",
Expand All @@ -42,7 +41,5 @@
"vision",
"text",
"tabular",
"data",
"utils",
"download_data",
]
17 changes: 6 additions & 11 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,18 @@
import torch
from torch import Tensor

from flash.core.data import TaskDataPipeline
from flash.core.model import Task
from flash.data.process import Postprocess


class ClassificationDataPipeline(TaskDataPipeline):
class ClassificationPostprocess(Postprocess):

def before_uncollate(self, batch: Union[Tensor, tuple]) -> Tensor:
if isinstance(batch, tuple):
batch = batch[0]
return torch.softmax(batch, -1)

def after_uncollate(self, samples: Any) -> Any:
def per_sample_transform(self, samples: Any) -> Any:
return torch.argmax(samples, -1).tolist()
tchaton marked this conversation as resolved.
Show resolved Hide resolved


class ClassificationTask(Task):

@staticmethod
def default_pipeline() -> ClassificationDataPipeline:
return ClassificationDataPipeline()
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._postprocess = ClassificationPostprocess()
3 changes: 0 additions & 3 deletions flash/core/data/__init__.py

This file was deleted.

117 changes: 0 additions & 117 deletions flash/core/data/datamodule.py

This file was deleted.

Loading