Skip to content

Commit

Permalink
Prune metrics: other classification 7/n (#6584)
Browse files Browse the repository at this point in the history
* confusion_matrix

* iou

* f_beta

* hamming_distance

* stat_scores

* tests

* flake8

* chlog
  • Loading branch information
Borda authored Mar 19, 2021
1 parent 3b72bcc commit 3a56a60
Show file tree
Hide file tree
Showing 20 changed files with 155 additions and 2,421 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

[#6573](https://github.com/PyTorchLightning/pytorch-lightning/pull/6573),

[#6584](https://github.com/PyTorchLightning/pytorch-lightning/pull/6584),

)


Expand Down
90 changes: 7 additions & 83 deletions pytorch_lightning/metrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,64 +13,14 @@
# limitations under the License.
from typing import Any, Optional

import torch
from torchmetrics import Metric
from torchmetrics import ConfusionMatrix as _ConfusionMatrix

from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update
from pytorch_lightning.utilities.deprecation import deprecated


class ConfusionMatrix(Metric):
"""
Computes the `confusion matrix
<https://scikit-learn.org/stable/modules/model_evaluation.html#confusion-matrix>`_. Works with binary,
multiclass, and multilabel data. Accepts probabilities from a model output or
integer class values in prediction. Works with multi-dimensional preds and
target.
Note:
This metric produces a multi-dimensional output, so it can not be directly logged.
Forward accepts
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
- ``target`` (long tensor): ``(N, ...)``
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
to convert into integer labels. This is the case for binary and multi-label probabilities.
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
Args:
num_classes: Number of classes in the dataset.
normalize: Normalization mode for confusion matrix. Choose from
- ``None`` or ``'none'``: no normalization (default)
- ``'true'``: normalization over the targets (most commonly used)
- ``'pred'``: normalization over the predictions
- ``'all'``: normalization over the whole matrix
threshold:
Threshold value for binary or multi-label probabilites. default: 0.5
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example:
>>> from pytorch_lightning.metrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)
tensor([[2., 0.],
[1., 1.]])
"""
class ConfusionMatrix(_ConfusionMatrix):

@deprecated(target=_ConfusionMatrix, ver_deprecate="1.3.0", ver_remove="1.5.0")
def __init__(
self,
num_classes: int,
Expand All @@ -80,35 +30,9 @@ def __init__(
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):

super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)
self.num_classes = num_classes
self.normalize = normalize
self.threshold = threshold

allowed_normalize = ('true', 'pred', 'all', 'none', None)
assert self.normalize in allowed_normalize, \
f"Argument average needs to one of the following: {allowed_normalize}"

self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold)
self.confmat += confmat
This implementation refers to :class:`~torchmetrics.ConfusionMatrix`.
def compute(self) -> torch.Tensor:
"""
Computes confusion matrix
.. deprecated::
Use :class:`~torchmetrics.ConfusionMatrix`. Will be removed in v1.5.0.
"""
return _confusion_matrix_compute(self.confmat, self.normalize)
180 changes: 15 additions & 165 deletions pytorch_lightning/metrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,72 +13,15 @@
# limitations under the License.
from typing import Any, Optional

import torch
from torchmetrics import Metric
from torchmetrics import F1 as _F1
from torchmetrics import FBeta as _FBeta

from pytorch_lightning.metrics.functional.f_beta import _fbeta_compute, _fbeta_update
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.deprecation import deprecated


class FBeta(Metric):
r"""
Computes `F-score <https://en.wikipedia.org/wiki/F-score>`_, specifically:
.. math::
F_\beta = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
{(\beta^2 * \text{precision}) + \text{recall}}
Where :math:`\beta` is some positive real factor. Works with binary, multiclass, and multilabel data.
Accepts probabilities from a model output or integer class values in prediction.
Works with multi-dimensional preds and target.
Forward accepts
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
- ``target`` (long tensor): ``(N, ...)``
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
to convert into integer labels. This is the case for binary and multi-label probabilities.
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
Args:
num_classes: Number of classes in the dataset.
beta: Beta coefficient in the F measure.
threshold:
Threshold value for binary or multi-label probabilities. default: 0.5
average:
- ``'micro'`` computes metric globally
- ``'macro'`` computes metric for each class and uniformly averages them
- ``'weighted'`` computes metric for each class and does a weighted-average,
where each class is weighted by their support (accounts for class imbalance)
- ``'none'`` or ``None`` computes and returns the metric per class
multilabel: If predictions are from multilabel classification.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Raises:
ValueError:
If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``, ``None``.
Example:
>>> from pytorch_lightning.metrics import FBeta
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
>>> f_beta = FBeta(num_classes=3, beta=0.5)
>>> f_beta(preds, target)
tensor(0.3333)
"""
class FBeta(_FBeta):

@deprecated(target=_FBeta, ver_deprecate="1.3.0", ver_remove="1.5.0")
def __init__(
self,
num_classes: int,
Expand All @@ -90,103 +33,17 @@ def __init__(
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)

self.num_classes = num_classes
self.beta = beta
self.threshold = threshold
self.average = average
self.multilabel = multilabel

allowed_average = ("micro", "macro", "weighted", "none", None)
if self.average not in allowed_average:
raise ValueError(
'Argument `average` expected to be one of the following:'
f' {allowed_average} but got {self.average}'
)

self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
true_positives, predicted_positives, actual_positives = _fbeta_update(
preds, target, self.num_classes, self.threshold, self.multilabel
)

self.true_positives += true_positives
self.predicted_positives += predicted_positives
self.actual_positives += actual_positives
This implementation refers to :class:`~torchmetrics.FBeta`.
def compute(self) -> torch.Tensor:
.. deprecated::
Use :class:`~torchmetrics.FBeta`. Will be removed in v1.5.0.
"""
Computes fbeta over state.
"""
return _fbeta_compute(
self.true_positives, self.predicted_positives, self.actual_positives, self.beta, self.average
)


class F1(FBeta):
"""
Computes F1 metric. F1 metrics correspond to a harmonic mean of the
precision and recall scores.
Works with binary, multiclass, and multilabel data.
Accepts logits from a model output or integer class values in prediction.
Works with multi-dimensional preds and target.

Forward accepts

- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
- ``target`` (long tensor): ``(N, ...)``
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
This is the case for binary and multi-label logits.
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
Args:
num_classes: Number of classes in the dataset.
threshold:
Threshold value for binary or multi-label logits. default: 0.5
average:
- ``'micro'`` computes metric globally
- ``'macro'`` computes metric for each class and uniformly averages them
- ``'weighted'`` computes metric for each class and does a weighted-average,
where each class is weighted by their support (accounts for class imbalance)
- ``'none'`` or ``None`` computes and returns the metric per class
multilabel: If predictions are from multilabel classification.
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example:
>>> from pytorch_lightning.metrics import F1
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
>>> f1 = F1(num_classes=3)
>>> f1(preds, target)
tensor(0.3333)
"""
class F1(_F1):

@deprecated(target=_F1, ver_deprecate="1.3.0", ver_remove="1.5.0")
def __init__(
self,
num_classes: int,
Expand All @@ -197,16 +54,9 @@ def __init__(
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
if multilabel is not False:
rank_zero_warn(f'The `multilabel={multilabel}` parameter is unused and will not have any effect.')
"""
This implementation refers to :class:`~torchmetrics.F1`.
super().__init__(
num_classes=num_classes,
beta=1.0,
threshold=threshold,
average=average,
multilabel=multilabel,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)
.. deprecated::
Use :class:`~torchmetrics.F1`. Will be removed in v1.5.0.
"""
Loading

0 comments on commit 3a56a60

Please sign in to comment.