Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add missing arguments and docs #245

Merged
merged 3 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions docs/source/general/predictions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ You can pass in a sample of data (image file path, a string of text, etc) to the

.. code-block:: python

from flash import Trainer
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier
from flash.vision import ImageClassifier


# 1. Download the data set
Expand Down Expand Up @@ -45,9 +44,83 @@ Predict on a csv file

# 2. Load the model from a checkpoint
model = TabularClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt"
"https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt"
)

# 3. Generate predictions from a csv file! Who would survive?
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)


Serializing predictions
=======================

To change how predictions are serialized you can attach a :class:`~flash.data.process.Serializer` to your
:class:`~flash.Task`. For example, you can choose to serialize outputs as probabilities (for more options see the API
reference below).


.. code-block:: python

from flash.core.classification import Probabilities
from flash.data.utils import download_data
from flash.vision import ImageClassifier


# 1. Download the data set
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 3. Attach the Serializer
model.serializer = Probabilities()

# 4. Predict whether the image contains an ant or a bee
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)
# out: [[0.5926494598388672, 0.40735048055648804]]


------


******************************************
Classification serializers - API reference
******************************************

.. _logits:

Logits
---------------

.. autoclass:: flash.core.classification.Logits
:members:
:exclude-members: serialize

.. _probabilities:

Probabilities
-----------------------

.. autoclass:: flash.core.classification.Probabilities
:members:
:exclude-members: serialize

.. _classes:

Classes
-----------------------

.. autoclass:: flash.core.classification.Classes
:members:
:exclude-members: serialize

.. _labels:

Labels
-----------------------

.. autoclass:: flash.core.classification.Labels
:members:
:exclude-members: serialize
24 changes: 22 additions & 2 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, List, Mapping, Optional, Union
from typing import Any, Callable, List, Mapping, Optional, Sequence, Union

import torch
import torch.nn.functional as F
import torchmetrics
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.model import Task
from flash.data.process import ProcessState, Serializer


def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Calls BCE with logits and cast the target one_hot (y) encoding to floating point precision."""
return F.binary_cross_entropy_with_logits(x, y.float())


@dataclass(unsafe_hash=True, frozen=True)
class ClassificationState(ProcessState):

Expand All @@ -33,10 +39,24 @@ class ClassificationTask(Task):
def __init__(
self,
*args,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs,
) -> None:
super().__init__(*args, serializer=serializer or Classes(), **kwargs)
if metrics is None:
metrics = torchmetrics.Accuracy(subset_accuracy=multi_label)

if loss_fn is None:
loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy
super().__init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
serializer=serializer or Classes(multi_label=multi_label),
**kwargs,
)

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
Expand Down
9 changes: 8 additions & 1 deletion flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Tuple, Type
from typing import Any, Callable, List, Mapping, Optional, Tuple, Type, Union

import torch
from torch.nn import functional as F
from torchmetrics import Metric

from flash.core.classification import ClassificationTask
from flash.data.process import Serializer
from flash.utils.imports import _TABNET_AVAILABLE

if _TABNET_AVAILABLE:
Expand All @@ -35,6 +36,8 @@ class TabularClassifier(ClassificationTask):
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to `1e-3`
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.data.process.Serializer` to use when serializing prediction outputs.
**tabnet_kwargs: Optional additional arguments for the TabNet model, see
`pytorch_tabnet <https://dreamquark-ai.github.io/tabnet/_modules/pytorch_tabnet/tab_network.html#TabNet>`_.
"""
Expand All @@ -48,6 +51,8 @@ def __init__(
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: List[Metric] = None,
learning_rate: float = 1e-3,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**tabnet_kwargs,
):
self.save_hyperparameters()
Expand All @@ -68,6 +73,8 @@ def __init__(
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
multi_label=multi_label,
serializer=serializer,
)

def forward(self, x_in) -> torch.Tensor:
Expand Down
6 changes: 1 addition & 5 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,7 @@ def __init__(
max_length: Maximum number of tokens within a single sentence.
target: The field storing the class id of the associated text.
filetype: .csv or .json format type.
label_to_class_mapping: Dictionnary mapping target labels to class indexes.

Returns:
TextClassificationPreprocess: The constructed preprocess objects.

label_to_class_mapping: Dictionary mapping target labels to class indexes.
"""

super().__init__()
Expand Down
9 changes: 6 additions & 3 deletions flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import Callable, Mapping, Optional, Sequence, Type, Union

import torch
from torchmetrics import Accuracy
from transformers import BertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput

Expand All @@ -33,16 +32,19 @@ class TextClassifier(ClassificationTask):
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to `1e-3`
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.data.process.Serializer` to use when serializing prediction outputs.
"""

def __init__(
self,
num_classes: int,
backbone: str = "prajjwal1/bert-tiny",
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[Callable, Mapping, Sequence, None] = [Accuracy()],
metrics: Union[Callable, Mapping, Sequence, None] = None,
learning_rate: float = 1e-3,
serializer: Optional[Serializer] = None,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
):
self.save_hyperparameters()

Expand All @@ -58,6 +60,7 @@ def __init__(
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
multi_label=multi_label,
serializer=serializer,
)
self.model = BertForSequenceClassification.from_pretrained(backbone, num_labels=num_classes)
Expand Down
33 changes: 10 additions & 23 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from types import FunctionType
from typing import Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union, Any
import torchmetrics
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
import torchmetrics
from torch import nn
from torch.nn import functional as F
from torchmetrics import Accuracy
from torch.optim.lr_scheduler import _LRScheduler

from flash.core.classification import Classes, ClassificationTask
from flash.core.classification import ClassificationTask
from flash.core.registry import FlashRegistry
from flash.data.process import Preprocess, Serializer
from flash.data.process import Serializer
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES
from flash.vision.classification.data import ImageClassificationPreprocess


def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Calls BCE with logits and cast the target one_hot (y) encoding to floating point precision."""
return F.binary_cross_entropy_with_logits(x, y.float())


class ImageClassifier(ClassificationTask):
Expand Down Expand Up @@ -62,10 +55,10 @@ class ImageClassifier(ClassificationTask):
pretrained: Use a pretrained backbone, defaults to ``True``.
loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`.
optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`.
metrics: Metrics to compute for training and evaluation,
defaults to :class:`torchmetrics.Accuracy`.
metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
multi_label: Whether the labels are multi labels or not.
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.data.process.Serializer` to use when serializing prediction outputs.
"""

backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES
Expand All @@ -87,13 +80,6 @@ def __init__(
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
):

if metrics is None:
metrics = Accuracy(subset_accuracy=multi_label)

if loss_fn is None:
loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy

super().__init__(
model=None,
loss_fn=loss_fn,
Expand All @@ -103,7 +89,8 @@ def __init__(
scheduler_kwargs=scheduler_kwargs,
metrics=metrics,
learning_rate=learning_rate,
serializer=serializer or Classes(multi_label=multi_label),
multi_label=multi_label,
serializer=serializer,
)

self.save_hyperparameters()
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_classificationtask_train(tmpdir: str, metrics: Any):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
train_dl = torch.utils.data.DataLoader(DummyDataset())
val_dl = torch.utils.data.DataLoader(DummyDataset())
task = ClassificationTask(model, F.nll_loss, metrics=metrics)
task = ClassificationTask(model, loss_fn=F.nll_loss, metrics=metrics)
trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
result = trainer.fit(task, train_dl, val_dl)
assert result
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_classification_task_trainer_predict(tmpdir):
def test_task_datapipeline_save(tmpdir):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
train_dl = torch.utils.data.DataLoader(DummyDataset())
task = ClassificationTask(model, F.nll_loss, postprocess=DummyPostprocess())
task = ClassificationTask(model, loss_fn=F.nll_loss, postprocess=DummyPostprocess())

# to check later
task.postprocess.test = True
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_task_fit(tmpdir: str):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax())
train_dl = torch.utils.data.DataLoader(DummyDataset())
val_dl = torch.utils.data.DataLoader(DummyDataset())
task = ClassificationTask(model, F.nll_loss)
task = ClassificationTask(model, loss_fn=F.nll_loss)
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
result = trainer.fit(task, train_dl, val_dl)
assert result
Expand All @@ -63,7 +63,7 @@ def test_task_finetune(tmpdir: str):
model = DummyClassifier()
train_dl = torch.utils.data.DataLoader(DummyDataset())
val_dl = torch.utils.data.DataLoader(DummyDataset())
task = ClassificationTask(model, F.nll_loss)
task = ClassificationTask(model, loss_fn=F.nll_loss)
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
result = trainer.finetune(task, train_dl, val_dl, strategy=NoFreeze())
assert result