From d070b945fee26678cdff1383d22c22b3b5d404d8 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 30 Jun 2021 19:14:21 +0200 Subject: [PATCH 1/6] overload update --- torchmetrics/audio/si_sdr.py | 5 +++-- torchmetrics/audio/si_snr.py | 3 ++- torchmetrics/audio/snr.py | 5 +++-- torchmetrics/classification/accuracy.py | 5 +++-- torchmetrics/classification/auc.py | 5 +++-- torchmetrics/classification/auroc.py | 5 +++-- torchmetrics/classification/average_precision.py | 5 +++-- torchmetrics/classification/binned_precision_recall.py | 7 ++++--- torchmetrics/classification/cohen_kappa.py | 5 +++-- torchmetrics/classification/confusion_matrix.py | 5 +++-- torchmetrics/classification/hamming_distance.py | 5 +++-- torchmetrics/classification/hinge.py | 5 +++-- torchmetrics/classification/kldivergence.py | 5 +++-- torchmetrics/classification/matthews_corrcoef.py | 5 +++-- torchmetrics/classification/precision_recall_curve.py | 5 +++-- torchmetrics/classification/roc.py | 5 +++-- torchmetrics/classification/stat_scores.py | 5 +++-- torchmetrics/image/fid.py | 5 +++-- torchmetrics/image/inception.py | 5 +++-- torchmetrics/image/kid.py | 5 +++-- torchmetrics/image/psnr.py | 5 +++-- torchmetrics/image/ssim.py | 5 +++-- torchmetrics/metric.py | 3 ++- torchmetrics/regression/cosine_similarity.py | 5 +++-- torchmetrics/regression/explained_variance.py | 5 +++-- torchmetrics/regression/mean_absolute_error.py | 5 +++-- torchmetrics/regression/mean_absolute_percentage_error.py | 5 +++-- torchmetrics/regression/mean_squared_error.py | 5 +++-- torchmetrics/regression/mean_squared_log_error.py | 5 +++-- torchmetrics/regression/pearson.py | 5 +++-- torchmetrics/regression/r2score.py | 5 +++-- torchmetrics/regression/spearman.py | 5 +++-- torchmetrics/retrieval/retrieval_metric.py | 5 +++-- 33 files changed, 98 insertions(+), 65 deletions(-) diff --git a/torchmetrics/audio/si_sdr.py b/torchmetrics/audio/si_sdr.py index 1d24f4d9832..91062019aa6 100644 --- a/torchmetrics/audio/si_sdr.py +++ b/torchmetrics/audio/si_sdr.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload from torch import Tensor, tensor @@ -84,7 +84,8 @@ def __init__( self.add_state("sum_si_sdr", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/audio/si_snr.py b/torchmetrics/audio/si_snr.py index 4d9092cd21d..1966b96e3df 100644 --- a/torchmetrics/audio/si_snr.py +++ b/torchmetrics/audio/si_snr.py @@ -79,7 +79,8 @@ def __init__( self.add_state("sum_si_snr", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/audio/snr.py b/torchmetrics/audio/snr.py index 2f81e667703..9141c08537b 100644 --- a/torchmetrics/audio/snr.py +++ b/torchmetrics/audio/snr.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload from torch import Tensor, tensor @@ -89,7 +89,8 @@ def __init__( self.add_state("sum_snr", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index e23bf057d76..295b0e556b7 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload from torch import Tensor, tensor @@ -216,7 +216,8 @@ def __init__( self.mode: DataType = None # type: ignore self.multiclass = multiclass - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. See :ref:`references/modules:input types` for more information on input types. diff --git a/torchmetrics/classification/auc.py b/torchmetrics/classification/auc.py index 65449cae84e..a8247f249ce 100644 --- a/torchmetrics/classification/auc.py +++ b/torchmetrics/classification/auc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload from torch import Tensor @@ -69,7 +69,8 @@ def __init__( ' For large datasets this may lead to large memory footprint.' ) - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/classification/auroc.py b/torchmetrics/classification/auroc.py index bddf674acc2..96bb1d866fd 100644 --- a/torchmetrics/classification/auroc.py +++ b/torchmetrics/classification/auroc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload import torch from torch import Tensor @@ -148,7 +148,8 @@ def __init__( ' For large datasets this may lead to large memory footprint.' ) - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/classification/average_precision.py b/torchmetrics/classification/average_precision.py index 397faaac91d..42c4cb8fbbf 100644 --- a/torchmetrics/classification/average_precision.py +++ b/torchmetrics/classification/average_precision.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Union, overload import torch from torch import Tensor @@ -98,7 +98,8 @@ def __init__( ' For large datasets this may lead to large memory footprint.' ) - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 29182d945db..c3aacf0a1bb 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, overload from warnings import warn import torch @@ -23,7 +23,7 @@ def _recall_at_precision(precision: Tensor, recall: Tensor, thresholds: Tensor, - min_precision: float) -> Tuple[Tensor, Tensor]: + min_precision: float,) -> Tuple[Tensor, Tensor]: try: max_recall, _, best_threshold = max((r, p, t) for p, r, t in zip(precision, recall, thresholds) if p >= min_precision) @@ -155,7 +155,8 @@ def __init__( dist_reduce_fx="sum", ) - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Args preds: (n_samples, n_classes) tensor diff --git a/torchmetrics/classification/cohen_kappa.py b/torchmetrics/classification/cohen_kappa.py index 56c4e504b99..784bcba71f3 100644 --- a/torchmetrics/classification/cohen_kappa.py +++ b/torchmetrics/classification/cohen_kappa.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, overload import torch from torch import Tensor @@ -100,7 +100,8 @@ def __init__( self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/classification/confusion_matrix.py b/torchmetrics/classification/confusion_matrix.py index 01d97555974..26d910ad656 100644 --- a/torchmetrics/classification/confusion_matrix.py +++ b/torchmetrics/classification/confusion_matrix.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, overload import torch from torch import Tensor @@ -120,7 +120,8 @@ def __init__( default = torch.zeros(num_classes, 2, 2) if multilabel else torch.zeros(num_classes, num_classes) self.add_state("confmat", default=default, dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/classification/hamming_distance.py b/torchmetrics/classification/hamming_distance.py index 40d201dbc21..a6f8dd5d433 100644 --- a/torchmetrics/classification/hamming_distance.py +++ b/torchmetrics/classification/hamming_distance.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload import torch from torch import Tensor, tensor @@ -90,7 +90,8 @@ def __init__( raise ValueError("The `threshold` should lie in the (0,1) interval.") self.threshold = threshold - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. See :ref:`references/modules:input types` for more information on input types. diff --git a/torchmetrics/classification/hinge.py b/torchmetrics/classification/hinge.py index 31d9abc524e..d11e28040f3 100644 --- a/torchmetrics/classification/hinge.py +++ b/torchmetrics/classification/hinge.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, overload from torch import Tensor, tensor @@ -115,7 +115,8 @@ def __init__( self.squared = squared self.multiclass_mode = multiclass_mode - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: measure, total = _hinge_update(preds, target, squared=self.squared, multiclass_mode=self.multiclass_mode) self.measure = measure + self.measure diff --git a/torchmetrics/classification/kldivergence.py b/torchmetrics/classification/kldivergence.py index 970daefdd77..e87bbf4ec77 100644 --- a/torchmetrics/classification/kldivergence.py +++ b/torchmetrics/classification/kldivergence.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload import torch from torch import Tensor @@ -91,7 +91,8 @@ def __init__( self.add_state('measures', [], dist_reduce_fx='cat') self.add_state('total', torch.zeros(1), dist_reduce_fx='sum') - def update(self, p: Tensor, q: Tensor) -> None: # type: ignore + @overload + def update(self, p: Tensor, q: Tensor) -> None: measures, total = _kld_update(p, q, self.log_prob) if self.reduction is None or self.reduction == 'none': self.measures.append(measures) diff --git a/torchmetrics/classification/matthews_corrcoef.py b/torchmetrics/classification/matthews_corrcoef.py index cf113c35e1a..54a7ff803fc 100644 --- a/torchmetrics/classification/matthews_corrcoef.py +++ b/torchmetrics/classification/matthews_corrcoef.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload import torch from torch import Tensor @@ -95,7 +95,8 @@ def __init__( self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/classification/precision_recall_curve.py b/torchmetrics/classification/precision_recall_curve.py index e5e0edf4b04..5f11a74520d 100644 --- a/torchmetrics/classification/precision_recall_curve.py +++ b/torchmetrics/classification/precision_recall_curve.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, overload import torch from torch import Tensor @@ -109,7 +109,8 @@ def __init__( ' For large datasets this may lead to large memory footprint.' ) - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/classification/roc.py b/torchmetrics/classification/roc.py index 41793e9f3f4..dd59c2630da 100644 --- a/torchmetrics/classification/roc.py +++ b/torchmetrics/classification/roc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union, overload import torch from torch import Tensor @@ -133,7 +133,8 @@ def __init__( ' For large datasets this may lead to large memory footprint.' ) - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index d9bf9dc47a0..4f489c31a4b 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, overload import torch from torch import Tensor @@ -189,7 +189,8 @@ def __init__( for s in ("tp", "fp", "tn", "fn"): self.add_state(s, default=default(), dist_reduce_fx=reduce_fn) - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. See :ref:`references/modules:input types` for more information on input types. diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 98ef3e503e8..b9c27b98295 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Union, overload import numpy as np import torch @@ -247,7 +247,8 @@ def __init__( self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) - def update(self, imgs: Tensor, real: bool) -> None: # type: ignore + @overload + def update(self, imgs: Tensor, real: bool) -> None: """ Update the state with extracted features Args: diff --git a/torchmetrics/image/inception.py b/torchmetrics/image/inception.py index c65aa52298e..22303d590ee 100644 --- a/torchmetrics/image/inception.py +++ b/torchmetrics/image/inception.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union, overload import torch from torch import Tensor @@ -145,7 +145,8 @@ def __init__( self.splits = splits self.add_state("features", [], dist_reduce_fx=None) - def update(self, imgs: Tensor) -> None: # type: ignore + @overload + def update(self, imgs: Tensor) -> None: """ Update the state with extracted features Args: diff --git a/torchmetrics/image/kid.py b/torchmetrics/image/kid.py index 9b930edca4f..e3492a7c3c1 100644 --- a/torchmetrics/image/kid.py +++ b/torchmetrics/image/kid.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union, overload import torch from torch import Tensor @@ -238,7 +238,8 @@ def __init__( self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) - def update(self, imgs: Tensor, real: bool) -> None: # type: ignore + @overload + def update(self, imgs: Tensor, real: bool) -> None: """ Update the state with extracted features Args: diff --git a/torchmetrics/image/psnr.py b/torchmetrics/image/psnr.py index 1f3501034f8..891ef4f09d2 100644 --- a/torchmetrics/image/psnr.py +++ b/torchmetrics/image/psnr.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, Union, overload import torch from torch import Tensor, tensor @@ -112,7 +112,8 @@ def __init__( self.reduction = reduction self.dim = tuple(dim) if isinstance(dim, Sequence) else dim - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/image/ssim.py b/torchmetrics/image/ssim.py index df9c1973453..4c539b5f082 100644 --- a/torchmetrics/image/ssim.py +++ b/torchmetrics/image/ssim.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, overload import torch from torch import Tensor @@ -84,7 +84,8 @@ def __init__( self.k2 = k2 self.reduction = reduction - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 7151643be55..cfa55346994 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -19,7 +19,7 @@ from collections.abc import Sequence from contextlib import contextmanager from copy import deepcopy -from typing import Any, Callable, Dict, Generator, List, Optional, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Union, overload import torch from torch import Tensor, nn @@ -334,6 +334,7 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Any: return wrapped_func + @overload @abstractmethod def update(self, *_: Any, **__: Any) -> None: """ diff --git a/torchmetrics/regression/cosine_similarity.py b/torchmetrics/regression/cosine_similarity.py index 3f2536694cf..130e9c8c38b 100644 --- a/torchmetrics/regression/cosine_similarity.py +++ b/torchmetrics/regression/cosine_similarity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload import torch from torch import Tensor @@ -80,7 +80,8 @@ def __init__( self.add_state("target", [], dist_reduce_fx="cat") self.reduction = reduction - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update metric states with predictions and targets. diff --git a/torchmetrics/regression/explained_variance.py b/torchmetrics/regression/explained_variance.py index 633035468f6..4e3435cbcb6 100644 --- a/torchmetrics/regression/explained_variance.py +++ b/torchmetrics/regression/explained_variance.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union, overload import torch from torch import Tensor, tensor @@ -109,7 +109,8 @@ def __init__( self.add_state("sum_squared_target", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("n_obs", default=tensor(0.0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/regression/mean_absolute_error.py b/torchmetrics/regression/mean_absolute_error.py index 78b7cceddfa..52cbafa37ce 100644 --- a/torchmetrics/regression/mean_absolute_error.py +++ b/torchmetrics/regression/mean_absolute_error.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload import torch from torch import Tensor, tensor @@ -66,7 +66,8 @@ def __init__( self.add_state("sum_abs_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/regression/mean_absolute_percentage_error.py b/torchmetrics/regression/mean_absolute_percentage_error.py index ade6c81100f..8890133ef2c 100644 --- a/torchmetrics/regression/mean_absolute_percentage_error.py +++ b/torchmetrics/regression/mean_absolute_percentage_error.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload import torch from torch import Tensor, tensor @@ -77,7 +77,8 @@ def __init__( self.add_state("sum_abs_per_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/regression/mean_squared_error.py b/torchmetrics/regression/mean_squared_error.py index 94ead23e732..e558d281766 100644 --- a/torchmetrics/regression/mean_squared_error.py +++ b/torchmetrics/regression/mean_squared_error.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload import torch from torch import Tensor, tensor @@ -71,7 +71,8 @@ def __init__( self.add_state("total", default=tensor(0), dist_reduce_fx="sum") self.squared = squared - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/regression/mean_squared_log_error.py b/torchmetrics/regression/mean_squared_log_error.py index 69ef426b7b6..d6f4a22b41d 100644 --- a/torchmetrics/regression/mean_squared_log_error.py +++ b/torchmetrics/regression/mean_squared_log_error.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload import torch from torch import Tensor, tensor @@ -72,7 +72,8 @@ def __init__( self.add_state("sum_squared_log_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py index 6028e023e96..f25f66f57e7 100644 --- a/torchmetrics/regression/pearson.py +++ b/torchmetrics/regression/pearson.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, overload import torch from torch import Tensor @@ -77,7 +77,8 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat") - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/regression/r2score.py b/torchmetrics/regression/r2score.py index fbb6371beb4..78783916b80 100644 --- a/torchmetrics/regression/r2score.py +++ b/torchmetrics/regression/r2score.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload import torch from torch import Tensor, tensor @@ -123,7 +123,8 @@ def __init__( self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/regression/spearman.py b/torchmetrics/regression/spearman.py index 3778e64125c..9e496a20e69 100644 --- a/torchmetrics/regression/spearman.py +++ b/torchmetrics/regression/spearman.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload import torch from torch import Tensor @@ -75,7 +75,8 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat") - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor) -> None: """ Update state with predictions and targets. diff --git a/torchmetrics/retrieval/retrieval_metric.py b/torchmetrics/retrieval/retrieval_metric.py index 1ee9e5f656c..dfee7a5702a 100644 --- a/torchmetrics/retrieval/retrieval_metric.py +++ b/torchmetrics/retrieval/retrieval_metric.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload import torch from torch import Tensor, tensor @@ -90,7 +90,8 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx=None) self.add_state("target", default=[], dist_reduce_fx=None) - def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: # type: ignore + @overload + def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: """ Check shape, check and convert dtypes, flatten and add to accumulators. """ if indexes is None: raise ValueError("Argument `indexes` cannot be None") From 7a194cf4b79717051bc02e14057e9dfbb7080923 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 30 Jun 2021 19:15:21 +0200 Subject: [PATCH 2/6] fmt --- torchmetrics/audio/si_snr.py | 2 +- torchmetrics/classification/binned_precision_recall.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/torchmetrics/audio/si_snr.py b/torchmetrics/audio/si_snr.py index 1966b96e3df..977e03089aa 100644 --- a/torchmetrics/audio/si_snr.py +++ b/torchmetrics/audio/si_snr.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, overload from torch import Tensor, tensor diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index c3aacf0a1bb..2464a127cef 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -22,8 +22,12 @@ from torchmetrics.utilities.data import METRIC_EPS, to_onehot -def _recall_at_precision(precision: Tensor, recall: Tensor, thresholds: Tensor, - min_precision: float,) -> Tuple[Tensor, Tensor]: +def _recall_at_precision( + precision: Tensor, + recall: Tensor, + thresholds: Tensor, + min_precision: float, +) -> Tuple[Tensor, Tensor]: try: max_recall, _, best_threshold = max((r, p, t) for p, r, t in zip(precision, recall, thresholds) if p >= min_precision) From 3710a0659869d7c0d8121124a14e440732236ef4 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 30 Jun 2021 19:21:20 +0200 Subject: [PATCH 3/6] ... --- torchmetrics/metric.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index cfa55346994..7d2982cb822 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -19,7 +19,7 @@ from collections.abc import Sequence from contextlib import contextmanager from copy import deepcopy -from typing import Any, Callable, Dict, Generator, List, Optional, Union, overload +from typing import Any, Callable, Dict, Generator, List, Optional, Union import torch from torch import Tensor, nn @@ -334,9 +334,8 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Any: return wrapped_func - @overload @abstractmethod - def update(self, *_: Any, **__: Any) -> None: + def update(self, *_, **__) -> None: """ Override this method to update the state variables of your metric class. """ From 008b45a417857eb8b7d1c925c992ac26a6713063 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 30 Jun 2021 19:23:24 +0200 Subject: [PATCH 4/6] ... --- torchmetrics/retrieval/retrieval_metric.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchmetrics/retrieval/retrieval_metric.py b/torchmetrics/retrieval/retrieval_metric.py index dfee7a5702a..1782708c34f 100644 --- a/torchmetrics/retrieval/retrieval_metric.py +++ b/torchmetrics/retrieval/retrieval_metric.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, overload +from typing import Any, Callable, Optional import torch from torch import Tensor, tensor @@ -90,7 +90,6 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx=None) self.add_state("target", default=[], dist_reduce_fx=None) - @overload def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: """ Check shape, check and convert dtypes, flatten and add to accumulators. """ if indexes is None: From 4c75dc3b9ef7bf4fd09a8a84b9d966d2baf32dd2 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 30 Jun 2021 19:25:18 +0200 Subject: [PATCH 5/6] . --- torchmetrics/retrieval/retrieval_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/retrieval/retrieval_metric.py b/torchmetrics/retrieval/retrieval_metric.py index 1782708c34f..6ff944ce682 100644 --- a/torchmetrics/retrieval/retrieval_metric.py +++ b/torchmetrics/retrieval/retrieval_metric.py @@ -90,7 +90,7 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx=None) self.add_state("target", default=[], dist_reduce_fx=None) - def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor, indexes: Tensor, *_, **__) -> None: """ Check shape, check and convert dtypes, flatten and add to accumulators. """ if indexes is None: raise ValueError("Argument `indexes` cannot be None") From 2a1516a2dfffc8571f0bde4c2d97b4c58c157ad6 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 30 Jun 2021 19:28:16 +0200 Subject: [PATCH 6/6] . --- torchmetrics/retrieval/retrieval_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/retrieval/retrieval_metric.py b/torchmetrics/retrieval/retrieval_metric.py index 6ff944ce682..8a65049ea38 100644 --- a/torchmetrics/retrieval/retrieval_metric.py +++ b/torchmetrics/retrieval/retrieval_metric.py @@ -90,7 +90,7 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx=None) self.add_state("target", default=[], dist_reduce_fx=None) - def update(self, preds: Tensor, target: Tensor, indexes: Tensor, *_, **__) -> None: + def update(self, *, preds: Tensor, target: Tensor, indexes: Tensor, **__) -> None: """ Check shape, check and convert dtypes, flatten and add to accumulators. """ if indexes is None: raise ValueError("Argument `indexes` cannot be None")