Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Thiel's U Statistic (Uncertainty) Metric #1337

Merged
merged 35 commits into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6655755
Updated metrics with thiels_u.py functional code
Nov 15, 2022
75322ae
Updated __init__ blocks
Nov 15, 2022
172deca
Added a thiel's u metric
Nov 15, 2022
7ad92a1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 15, 2022
0dfa393
removed dependence on scipy
Nov 15, 2022
2602db1
Merge branch 'shenoy/theils_u' of github.com:shenoynikhil/metrics int…
Nov 15, 2022
3fcac2d
Renamed Class
Nov 15, 2022
cdc1052
Merge branch 'master' into shenoy/theils_u
Borda Nov 16, 2022
ed34d58
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 16, 2022
f33b69b
Modified files to tensorize and changed spell error
Nov 17, 2022
21f7f6f
Updated documentation
Nov 17, 2022
f5b7343
Added theil's u matrix
Nov 17, 2022
213a370
Updated documentation
Nov 17, 2022
04055d5
Updated init with nofa
Nov 17, 2022
93029b9
Added unit test based on crammer's v test
Nov 17, 2022
58c4b1e
Merge branch 'master' into shenoy/theils_u
shenoynikhil Nov 17, 2022
b8b2a22
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
1cd9002
Updated return type and made return consistent
Nov 17, 2022
a366386
Merge branch 'master' into shenoy/theils_u
shenoynikhil Nov 17, 2022
62800f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
64b0b4c
Updated with mypy fix
Nov 18, 2022
3fb2c62
Merge branch 'shenoy/theils_u' of github.com:shenoynikhil/metrics int…
Nov 18, 2022
00ceb0c
changelog
SkafteNicki Nov 18, 2022
ffdf438
doc page
SkafteNicki Nov 18, 2022
afa5718
docs
Borda Nov 18, 2022
52da5a6
tensor
Borda Nov 18, 2022
5bd1223
import
Borda Nov 18, 2022
9ed3373
Merge branch 'shenoy/theils_u' of https://github.com/shenoynikhil/met…
Borda Nov 18, 2022
53bb33f
Merge branch 'master' into shenoy/theils_u
SkafteNicki Nov 18, 2022
80b9786
Apply suggestions from code review
Borda Nov 18, 2022
f085aaf
Merge branch 'master' into shenoy/theils_u
mergify[bot] Nov 18, 2022
e1a6bbf
move compute
SkafteNicki Nov 18, 2022
1c0c511
update docstrings
SkafteNicki Nov 18, 2022
92d9a3b
remove unused bias term
SkafteNicki Nov 18, 2022
51e21d3
Increase atol to 1e-6 for matrix test
stancld Nov 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `CramersV` ([#1298](https://github.com/Lightning-AI/metrics/pull/1298))
* `PearsonsContingencyCoefficient` ([#1334](https://github.com/Lightning-AI/metrics/pull/1334))
* `TschuprowsT` ([#1334](https://github.com/Lightning-AI/metrics/pull/1334))
* `TheilsU` ([#1337](https://github.com/Lightning-AI/metrics/pull/1334))


- Added option to pass `distributed_available_fn` to metrics to allow checks for custom communication backend for making `dist_sync_fn` actually useful ([#1301](https://github.com/Lightning-AI/metrics/pull/1301))
Expand Down Expand Up @@ -52,7 +53,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
-


## [0.10.3] - 2022-11-16
Expand Down
26 changes: 26 additions & 0 deletions docs/source/nominal/theils_u.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
.. customcarditem::
:header: Theil's U
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Nominal

#########
Theil's U
#########

Module Interface
________________

.. autoclass:: torchmetrics.TheilsU
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.theils_u
:noindex:

theils_u_matrix
^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.nominal.theils_u_matrix
:noindex:
3 changes: 2 additions & 1 deletion src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
UniversalImageQualityIndex,
)
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.nominal import CramersV, PearsonsContingencyCoefficient, TschuprowsT # noqa: E402
from torchmetrics.nominal import CramersV, PearsonsContingencyCoefficient, TheilsU, TschuprowsT # noqa: E402
from torchmetrics.regression import ( # noqa: E402
ConcordanceCorrCoef,
CosineSimilarity,
Expand Down Expand Up @@ -185,6 +185,7 @@
"StatScores",
"SumMetric",
"SymmetricMeanAbsolutePercentageError",
"TheilsU",
"TotalVariation",
"TranslationEditRate",
"TschuprowsT",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
pearsons_contingency_coefficient,
pearsons_contingency_coefficient_matrix,
)
from torchmetrics.functional.nominal.theils_u import theils_u
Borda marked this conversation as resolved.
Show resolved Hide resolved
from torchmetrics.functional.nominal.tschuprows import tschuprows_t, tschuprows_t_matrix
from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance
Expand Down Expand Up @@ -171,6 +172,7 @@
"structural_similarity_index_measure",
"stat_scores",
"symmetric_mean_absolute_percentage_error",
"theils_u",
"total_variation",
Borda marked this conversation as resolved.
Show resolved Hide resolved
"translation_edit_rate",
"tschuprows_t",
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/functional/nominal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
pearsons_contingency_coefficient,
pearsons_contingency_coefficient_matrix,
)
from torchmetrics.functional.nominal.theils_u import theils_u, theils_u_matrix # noqa: F401
from torchmetrics.functional.nominal.tschuprows import tschuprows_t, tschuprows_t_matrix # noqa: F401
181 changes: 181 additions & 0 deletions src/torchmetrics/functional/nominal/theils_u.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Optional, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.classification.confusion_matrix import _multiclass_confusion_matrix_update
from torchmetrics.functional.nominal.utils import (
_drop_empty_rows_and_cols,
_handle_nan_in_data,
_nominal_input_validation,
)


def _conditional_entropy_compute(confmat: Tensor) -> Tensor:
r"""Compute Conditional Entropy Statistic based on a pre-computed confusion matrix.

.. math::
H(X|Y) = \sum_{x, y ~ (X, Y)} p(x, y)\frac{p(y)}{p(x, y)}

Args:
confmat: Confusion matrix for observed data

Returns:
Conditional Entropy Value
"""
confmat = _drop_empty_rows_and_cols(confmat)
total_occurrences = confmat.sum()
# iterate over all i, j combinations
p_xy_m = confmat / total_occurrences
# get p_y by summing over x dim (=1)
p_y = confmat.sum(1) / total_occurrences
# repeat over rows (shape = p_xy_m.shape[1]) for tensor multiplication
p_y_m = p_y.unsqueeze(1).repeat(1, p_xy_m.shape[1])

# entropy calculated as p_xy * log (p_xy / p_y)
return torch.nansum(p_xy_m * torch.log(p_y_m / p_xy_m))


def _theils_u_compute(confmat: Tensor) -> Tensor:
"""Compute Theil's U statistic based on a pre-computed confusion matrix.

Args:
confmat: Confusion matrix for observed data

Returns:
Theil's U statistic
"""
confmat = _drop_empty_rows_and_cols(confmat)

# compute conditional entropy
s_xy = _conditional_entropy_compute(confmat)

# compute H(x)
total_occurrences = confmat.sum()
p_x = confmat.sum(0) / total_occurrences
s_x = -torch.sum(p_x * torch.log(p_x))

# compute u statistic
if s_x == 0:
return torch.tensor(0, device=confmat.device)

return (s_x - s_xy) / s_x


def _theils_u_update(
preds: Tensor,
target: Tensor,
num_classes: int,
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[Union[int, float]] = 0.0,
) -> Tensor:
"""Computes the bins to update the confusion matrix with for Theil's U calculation.

Args:
preds: 1D or 2D tensor of categorical (nominal) data
target: 1D or 2D tensor of categorical (nominal) data
num_classes: Integer specifing the number of classes
nan_strategy: Indication of whether to replace or drop ``NaN`` values
nan_replace_value: Value to replace ``NaN`s when ``nan_strategy = 'replace```

Returns:
Non-reduced confusion matrix
"""
preds = preds.argmax(1) if preds.ndim == 2 else preds
target = target.argmax(1) if target.ndim == 2 else target
preds, target = _handle_nan_in_data(preds, target, nan_strategy, nan_replace_value)
return _multiclass_confusion_matrix_update(preds, target, num_classes)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved


def theils_u(
preds: Tensor,
target: Tensor,
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[Union[int, float]] = 0.0,
) -> Tensor:
r"""Computes Theil's U Statistic (Uncertainty Coefficient).

The value is between 0 and 1, i.e. 0 means y has no information about x while value 1
means y has complete information about x.

Args:
preds: 1D or 2D tensor of categorical (nominal) data
- 1D shape: (batch_size,)
- 2D shape: (batch_size, num_classes)
target: 1D or 2D tensor of categorical (nominal) data
- 1D shape: (batch_size,)
- 2D shape: (batch_size, num_classes)
nan_strategy: Indication of whether to replace or drop ``NaN`` values
nan_replace_value: Value to replace ``NaN``s when ``nan_strategy = 'replace'``
Returns:
Theil's U Statistic: Tensor

Example:
>>> from torchmetrics.functional import theils_u
>>> _ = torch.manual_seed(42)
>>> preds = torch.randint(10, (10,))
>>> target = torch.randint(10, (10,))
>>> theils_u(preds, target)
tensor(0.8530)
"""
num_classes = len(torch.cat([preds, target]).unique())
confmat = _theils_u_update(preds, target, num_classes, nan_strategy, nan_replace_value)
return _theils_u_compute(confmat)


def theils_u_matrix(
matrix: Tensor,
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[Union[int, float]] = 0.0,
) -> Tensor:
r"""Compute `Theil's U`_ statistic between a set of multiple variables.

This can serve as a convenient tool to compute Theil's U statistic for analyses of correlation between categorical
variables in your dataset.

Args:
matrix: A tensor of categorical (nominal) data, where:
- rows represent a number of data points
- columns represent a number of categorical (nominal) features
nan_strategy: Indication of whether to replace or drop ``NaN`` values
nan_replace_value: Value to replace ``NaN``s when ``nan_strategy = 'replace'``

Returns:
Theil's U statistic for a dataset of categorical variables

Example:
>>> from torchmetrics.functional.nominal import theils_u_matrix
>>> _ = torch.manual_seed(42)
>>> matrix = torch.randint(0, 4, (200, 5))
>>> theils_u_matrix(matrix)
tensor([[1.0000, 0.0202, 0.0142, 0.0196, 0.0353],
[0.0202, 1.0000, 0.0070, 0.0136, 0.0065],
[0.0143, 0.0070, 1.0000, 0.0125, 0.0206],
[0.0198, 0.0137, 0.0125, 1.0000, 0.0312],
[0.0352, 0.0065, 0.0204, 0.0308, 1.0000]])
"""
_nominal_input_validation(nan_strategy, nan_replace_value)
num_variables = matrix.shape[1]
theils_u_matrix_value = torch.ones(num_variables, num_variables, device=matrix.device)
for i, j in itertools.combinations(range(num_variables), 2):
x, y = matrix[:, i], matrix[:, j]
num_classes = len(torch.cat([x, y]).unique())
confmat = _theils_u_update(x, y, num_classes, nan_strategy, nan_replace_value)
theils_u_matrix_value[i, j] = _theils_u_compute(confmat)
theils_u_matrix_value[j, i] = _theils_u_compute(confmat.T)
return theils_u_matrix_value
1 change: 1 addition & 0 deletions src/torchmetrics/nominal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.
from torchmetrics.nominal.cramers import CramersV # noqa: F401
from torchmetrics.nominal.pearson import PearsonsContingencyCoefficient # noqa: F401
from torchmetrics.nominal.theils_u import TheilsU # noqa: F401
from torchmetrics.nominal.tschuprows import TschuprowsT # noqa: F401
98 changes: 98 additions & 0 deletions src/torchmetrics/nominal/theils_u.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.nominal.theils_u import _theils_u_compute, _theils_u_update
from torchmetrics.functional.nominal.utils import _nominal_input_validation
from torchmetrics.metric import Metric


class TheilsU(Metric):
"""Compute `TheilsU` statistic measuring the association between two categorical (nominal) data series.

.. math::
U(X|Y) = \frac{H(X) - H(X|Y)}{H(X)}

where H(X) is entropy of variable X while H(X|Y) is the conditional entropy of X given Y

Theils's U is an asymmetric coefficient, i.e.

.. math::
V(preds, target) != V(target, preds)

The output values lies in [0, 1]. 0 means y has no information about x while value 1 means y has complete
information about x.

Article: https://en.wikipedia.org/wiki/Uncertainty_coefficient

Args:
num_classes: Integer specifing the number of classes
nan_strategy: Indication of whether to replace or drop ``NaN`` values
nan_replace_value: Value to replace ``NaN``s when ``nan_strategy = 'replace'``
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Returns:
Theil's U Statistic: Tensor

Example:
>>> from torchmetrics import TheilsU
>>> _ = torch.manual_seed(42)
>>> preds = torch.randint(10, (10,))
>>> target = torch.randint(10, (10,))
>>> TheilsU(num_classes=10)(preds, target)
tensor(0.8530)
"""

full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
confmat: Tensor

def __init__(
self,
num_classes: int,
nan_strategy: Literal["replace", "drop"] = "replace",
nan_replace_value: Optional[Union[int, float]] = 0.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.num_classes = num_classes

_nominal_input_validation(nan_strategy, nan_replace_value)
self.nan_strategy = nan_strategy
self.nan_replace_value = nan_replace_value

self.add_state("confmat", torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets.

Args:
preds: 1D or 2D tensor of categorical (nominal) data
- 1D shape: (batch_size,)
- 2D shape: (batch_size, num_classes)
target: 1D or 2D tensor of categorical (nominal) data
- 1D shape: (batch_size,)
- 2D shape: (batch_size, num_classes)
"""
confmat = _theils_u_update(preds, target, self.num_classes, self.nan_strategy, self.nan_replace_value)
self.confmat += confmat

def compute(self) -> Tensor:
"""Computer Theil's U statistic."""
return _theils_u_compute(self.confmat)
Loading