Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tschuprow's T and Pearson's Contingency Coefficient #1334

Merged
merged 22 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `CramersV` to the new nominal package ([#1298](https://github.com/Lightning-AI/metrics/pull/1298))


- Added `PearsonsContingencyCoefficient` and `TschuprowsT` to nominal package ([#1334](https://github.com/Lightning-AI/metrics/pull/1334))


Borda marked this conversation as resolved.
Show resolved Hide resolved
### Changed

- Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259))
Expand Down
2 changes: 2 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,5 @@
.. _Kendall Rank Correlation Coefficient: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient
.. _The Treatment of Ties in Ranking Problems: https://www.jstor.org/stable/2332303
.. _LogCosh Error: https://arxiv.org/pdf/2101.10427.pdf
.. _Tschuprow's T: https://en.wikipedia.org/wiki/Tschuprow%27s_T
.. _Pearson's Contingency Coefficient: https://www.itl.nist.gov/div898/software/dataplot/refman2/auxillar/pearcont.htm
26 changes: 26 additions & 0 deletions docs/source/nominal/pearsons_contingency_coefficient.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
.. customcarditem::
:header: Pearson's Contingency Coefficient
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Nominal

#################################
Pearson's Contingency Coefficient
#################################

Module Interface
________________

.. autoclass:: torchmetrics.PearsonsContingencyCoefficient
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.pearsons_contingency_coefficient
:noindex:

pearsons_contingency_coefficient_matrix
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.nominal.pearsons_contingency_coefficient_matrix
:noindex:
26 changes: 26 additions & 0 deletions docs/source/nominal/tschuprows_t.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
.. customcarditem::
:header: Tschuprow's T
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Nominal

#############
Tschuprow's T
#############

Module Interface
________________

.. autoclass:: torchmetrics.TschuprowsT
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.tschuprows_t
:noindex:

tschuprows_t_matrix
^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.nominal.tschuprows_t_matrix
:noindex:
1 change: 1 addition & 0 deletions requirements/nominal_test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pandas # cannot pin version due to numpy version incompatibility
dython # todo: pin version, but some version resolution issue
scipy # cannot pin version due to some version conflicts with `oldest` CI configuration
4 changes: 3 additions & 1 deletion src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
UniversalImageQualityIndex,
)
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.nominal import CramersV # noqa: E402
from torchmetrics.nominal import CramersV, PearsonsContingencyCoefficient, TschuprowsT # noqa: E402
from torchmetrics.regression import ( # noqa: E402
ConcordanceCorrCoef,
CosineSimilarity,
Expand Down Expand Up @@ -152,6 +152,7 @@
"MultioutputWrapper",
"MultiScaleStructuralSimilarityIndexMeasure",
"PearsonCorrCoef",
"PearsonsContingencyCoefficient",
"PermutationInvariantTraining",
"Perplexity",
"Precision",
Expand Down Expand Up @@ -186,6 +187,7 @@
"SymmetricMeanAbsolutePercentageError",
"TotalVariation",
"TranslationEditRate",
"TschuprowsT",
"UniversalImageQualityIndex",
"WeightedMeanAbsolutePercentageError",
"WordErrorRate",
Expand Down
12 changes: 11 additions & 1 deletion src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@
)
from torchmetrics.functional.image.tv import total_variation
from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.functional.nominal.cramers import cramers_v
from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix
from torchmetrics.functional.nominal.pearson import (
pearsons_contingency_coefficient,
pearsons_contingency_coefficient_matrix,
)
from torchmetrics.functional.nominal.tschuprows import tschuprows_t, tschuprows_t_matrix
from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance
from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity
Expand Down Expand Up @@ -104,6 +109,7 @@
"confusion_matrix",
"cosine_similarity",
"cramers_v",
"cramers_v_matrix",
"tweedie_deviance_score",
"dice_score",
"dice",
Expand Down Expand Up @@ -131,6 +137,8 @@
"pairwise_linear_similarity",
"pairwise_manhattan_distance",
"pearson_corrcoef",
"pearsons_contingency_coefficient",
"pearsons_contingency_coefficient_matrix",
"permutation_invariant_training",
"perplexity",
"pit_permutate",
Expand Down Expand Up @@ -165,6 +173,8 @@
"symmetric_mean_absolute_percentage_error",
"total_variation",
"translation_edit_rate",
"tschuprows_t",
"tschuprows_t_matrix",
"universal_image_quality_index",
"spectral_angle_mapper",
"weighted_mean_absolute_percentage_error",
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/nominal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix # noqa: F401
from torchmetrics.functional.nominal.pearson import ( # noqa: F401
pearsons_contingency_coefficient,
pearsons_contingency_coefficient_matrix,
)
from torchmetrics.functional.nominal.tschuprows import tschuprows_t, tschuprows_t_matrix # noqa: F401
82 changes: 20 additions & 62 deletions src/torchmetrics/functional/nominal/cramers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,14 @@
from typing_extensions import Literal

from torchmetrics.functional.classification.confusion_matrix import _multiclass_confusion_matrix_update
from torchmetrics.functional.nominal.utils import _handle_nan_in_data
from torchmetrics.utilities.prints import rank_zero_warn


def _cramers_input_validation(nan_strategy: str, nan_replace_value: Optional[Union[int, float]]) -> None:
if nan_strategy not in ["replace", "drop"]:
raise ValueError(
f"Argument `nan_strategy` is expected to be one of `['replace', 'drop']`, but got {nan_strategy}"
)
if nan_strategy == "replace" and not isinstance(nan_replace_value, (int, float)):
raise ValueError(
"Argument `nan_replace` is expected to be of a type `int` or `float` when `nan_strategy = 'replace`, "
f"but got {nan_replace_value}"
)


def _compute_expected_freqs(confmat: Tensor) -> Tensor:
"""Compute the expected frequenceis from the provided confusion matrix."""
margin_sum_rows, margin_sum_cols = confmat.sum(1), confmat.sum(0)
expected_freqs = torch.einsum("r, c -> rc", margin_sum_rows, margin_sum_cols) / confmat.sum()
return expected_freqs


def _compute_chi_squared(confmat: Tensor, bias_correction: bool) -> Tensor:
"""Chi-square test of independenc of variables in a confusion matrix table.

Adapted from: https://github.com/scipy/scipy/blob/v1.9.2/scipy/stats/contingency.py.
"""
expected_freqs = _compute_expected_freqs(confmat)
# Get degrees of freedom
df = expected_freqs.numel() - sum(expected_freqs.shape) + expected_freqs.ndim - 1
if df == 0:
return torch.tensor(0.0, device=confmat.device)

if df == 1 and bias_correction:
diff = expected_freqs - confmat
direction = diff.sign()
confmat += direction * torch.minimum(0.5 * torch.ones_like(direction), direction.abs())

return torch.sum((confmat - expected_freqs) ** 2 / expected_freqs)


def _drop_empty_rows_and_cols(confmat: Tensor) -> Tensor:
"""Drop all rows and columns containing only zeros."""
confmat = confmat[confmat.sum(1) != 0]
confmat = confmat[:, confmat.sum(0) != 0]
return confmat
from torchmetrics.functional.nominal.utils import (
_compute_bias_corrected_values,
_compute_chi_squared,
_drop_empty_rows_and_cols,
_handle_nan_in_data,
_nominal_input_validation,
_unable_to_use_bias_correction_warning,
)


def _cramers_v_update(
Expand Down Expand Up @@ -110,17 +71,13 @@ def _cramers_v_compute(confmat: Tensor, bias_correction: bool) -> Tensor:
n_rows, n_cols = confmat.shape

if bias_correction:
phi_squared_corrected = torch.max(
torch.tensor(0.0, device=confmat.device), phi_squared - ((n_rows - 1) * (n_cols - 1)) / (cm_sum - 1)
phi_squared_corrected, rows_corrected, cols_corrected = _compute_bias_corrected_values(
phi_squared, n_rows, n_cols, cm_sum
)
rows_corrected = n_rows - (n_rows - 1) ** 2 / (cm_sum - 1)
cols_corrected = n_cols - (n_cols - 1) ** 2 / (cm_sum - 1)
if min(rows_corrected, cols_corrected) == 1:
rank_zero_warn(
"Unable to compute Cramer's V using bias correction. Please consider to set `bias_correction=False`."
)
if torch.min(rows_corrected, cols_corrected) == 1:
_unable_to_use_bias_correction_warning(metric_name="Cramer's V")
return torch.tensor(float("nan"), device=confmat.device)
cramers_v_value = torch.sqrt(phi_squared_corrected / min(rows_corrected - 1, cols_corrected - 1))
cramers_v_value = torch.sqrt(phi_squared_corrected / torch.min(rows_corrected - 1, cols_corrected - 1))
else:
cramers_v_value = torch.sqrt(phi_squared / min(n_rows - 1, n_cols - 1))
return cramers_v_value.clamp(0.0, 1.0)
Expand All @@ -136,19 +93,19 @@ def cramers_v(
r"""Compute `Cramer's V`_ statistic measuring the association between two categorical (nominal) data series.

.. math::
V = \sqrt{\frac{\chi^2 / 2}{\min(r - 1, k - 1)}}
V = \sqrt{\frac{\chi^2 / n}{\min(r - 1, k - 1)}}

where

.. math::
\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}

Cramer's V is a symmetric coefficient, i.e.
where :math:`n_{ij}` denotes the number of times the values :math:`(A_i, B_j)` are observed with :math:`A_i, B_j`
represent frequencies of values in ``preds`` and ``target``, respectively.

.. math::
V(preds, target) = V(target, preds)
Cramer's V is a symmetric coefficient, i.e. :math:`V(preds, target) = V(target, preds)`.

The output values lies in [0, 1].
The output values lies in [0, 1] with 1 meaning the perfect association.

Args:
preds: 1D or 2D tensor of categorical (nominal) data
Expand All @@ -172,6 +129,7 @@ def cramers_v(
>>> cramers_v(preds, target)
tensor(0.5284)
"""
_nominal_input_validation(nan_strategy, nan_replace_value)
num_classes = len(torch.cat([preds, target]).unique())
confmat = _cramers_v_update(preds, target, num_classes, nan_strategy, nan_replace_value)
return _cramers_v_compute(confmat, bias_correction)
Expand Down Expand Up @@ -210,7 +168,7 @@ def cramers_v_matrix(
[0.0542, 0.0000, 0.0000, 1.0000, 0.1100],
[0.1337, 0.0000, 0.0649, 0.1100, 1.0000]])
"""
_cramers_input_validation(nan_strategy, nan_replace_value)
_nominal_input_validation(nan_strategy, nan_replace_value)
num_variables = matrix.shape[1]
cramers_v_matrix_value = torch.ones(num_variables, num_variables, device=matrix.device)
for i, j in itertools.combinations(range(num_variables), 2):
Expand Down
Loading