From 442d6fd5cfe24f893ab86e18cb20eca7f81ac40f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 22 Mar 2021 17:37:35 +0100 Subject: [PATCH 1/8] explained_variance --- .../metrics/functional/explained_variance.py | 66 +---------- .../metrics/regression/explained_variance.py | 106 ++---------------- tests/metrics/test_remove_1-5_metrics.py | 22 +++- 3 files changed, 29 insertions(+), 165 deletions(-) diff --git a/pytorch_lightning/metrics/functional/explained_variance.py b/pytorch_lightning/metrics/functional/explained_variance.py index fa8d43c06c7ef..534032024d5c0 100644 --- a/pytorch_lightning/metrics/functional/explained_variance.py +++ b/pytorch_lightning/metrics/functional/explained_variance.py @@ -14,74 +14,18 @@ from typing import Sequence, Tuple, Union import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import explained_variance as _explained_variance - -def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - _check_same_shape(preds, target) - return preds, target - - -def _explained_variance_compute( - preds: torch.Tensor, - target: torch.Tensor, - multioutput: str = 'uniform_average', -) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - diff_avg = torch.mean(target - preds, dim=0) - numerator = torch.mean((target - preds - diff_avg)**2, dim=0) - - target_avg = torch.mean(target, dim=0) - denominator = torch.mean((target - target_avg)**2, dim=0) - - # Take care of division by zero - nonzero_numerator = numerator != 0 - nonzero_denominator = denominator != 0 - valid_score = nonzero_numerator & nonzero_denominator - output_scores = torch.ones_like(diff_avg) - output_scores[valid_score] = 1.0 - (numerator[valid_score] / denominator[valid_score]) - output_scores[nonzero_numerator & ~nonzero_denominator] = 0. - - # Decide what to do in multioutput case - # Todo: allow user to pass in tensor with weights - if multioutput == 'raw_values': - return output_scores - if multioutput == 'uniform_average': - return torch.mean(output_scores) - if multioutput == 'variance_weighted': - denom_sum = torch.sum(denominator) - return torch.sum(denominator / denom_sum * output_scores) +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_explained_variance, ver_deprecate="1.3.0", ver_remove="1.5.0") def explained_variance( preds: torch.Tensor, target: torch.Tensor, multioutput: str = 'uniform_average', ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: """ - Computes explained variance. - - Args: - preds: estimated labels - target: ground truth labels - multioutput: Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is `'uniform_average'`.): - - * `'raw_values'` returns full set of scores - * `'uniform_average'` scores are uniformly averaged - * `'variance_weighted'` scores are weighted by their individual variances - - Example: - - >>> from pytorch_lightning.metrics.functional import explained_variance - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> explained_variance(preds, target) - tensor(0.9572) - - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> explained_variance(preds, target, multioutput='raw_values') - tensor([0.9677, 1.0000]) + .. deprecated:: + Use :func:`torchmetrics.functional.explained_variance`. Will be removed in v1.5.0. """ - preds, target = _explained_variance_update(preds, target) - return _explained_variance_compute(preds, target, multioutput) diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index 8b0259694ef4c..4f820718545cb 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -13,72 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import ExplainedVariance as _ExplainedVariance -from pytorch_lightning.metrics.functional.explained_variance import ( - _explained_variance_compute, - _explained_variance_update, -) -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.deprecation import deprecated -class ExplainedVariance(Metric): - r""" - Computes `explained variance - `_: - - .. math:: \text{ExplainedVariance} = 1 - \frac{\text{Var}(y - \hat{y})}{\text{Var}(y)} - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a - tensor of predictions. - - Forward accepts - - - ``preds`` (float tensor): ``(N,)`` or ``(N, ...)`` (multioutput) - - ``target`` (long tensor): ``(N,)`` or ``(N, ...)`` (multioutput) - - In the case of multioutput, as default the variances will be uniformly - averaged over the additional dimensions. Please see argument `multioutput` - for changing this behavior. - - Args: - multioutput: - Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is `'uniform_average'`.): - - * `'raw_values'` returns full set of scores - * `'uniform_average'` scores are uniformly averaged - * `'variance_weighted'` scores are weighted by their individual variances - - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Raises: - ValueError: - If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``. - - Example: - - >>> from pytorch_lightning.metrics import ExplainedVariance - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> explained_variance = ExplainedVariance() - >>> explained_variance(preds, target) - tensor(0.9572) - - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> explained_variance = ExplainedVariance(multioutput='raw_values') - >>> explained_variance(preds, target) - tensor([0.9677, 1.0000]) - """ +class ExplainedVariance(_ExplainedVariance): + @deprecated(target=_ExplainedVariance, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, multioutput: str = 'uniform_average', @@ -87,43 +29,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') - if multioutput not in allowed_multioutput: - raise ValueError( - f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}' - ) - self.multioutput = multioutput - self.add_state("y", default=[], dist_reduce_fx=None) - self.add_state("y_pred", default=[], dist_reduce_fx=None) - - rank_zero_warn( - 'Metric `ExplainedVariance` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' - ) - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - preds, target = _explained_variance_update(preds, target) - self.y_pred.append(preds) - self.y.append(target) + This implementation refers to :class:`~torchmetrics.ExplainedVariance`. - def compute(self): - """ - Computes explained variance over state. + .. deprecated:: + Use :class:`~torchmetrics.ExplainedVariance`. Will be removed in v1.5.0. """ - preds = torch.cat(self.y_pred, dim=0) - target = torch.cat(self.y, dim=0) - return _explained_variance_compute(preds, target, self.multioutput) diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py index 339d07b163632..ea869b8b944ca 100644 --- a/tests/metrics/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -31,7 +31,7 @@ PrecisionRecallCurve, Recall, ROC, - StatScores, + StatScores, ExplainedVariance, ) from pytorch_lightning.metrics.functional import ( auc, @@ -47,7 +47,7 @@ precision_recall_curve, recall, roc, - stat_scores, + stat_scores, explained_variance, ) from pytorch_lightning.metrics.functional.accuracy import accuracy from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot @@ -232,8 +232,20 @@ def test_v1_5_metric_detect(): IoU(num_classes=1) target = torch.randint(0, 2, (10, 25, 25)) - pred = torch.tensor(target) - pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] + preds = torch.tensor(target) + preds[2:5, 7:13, 9:15] = 1 - preds[2:5, 7:13, 9:15] iou.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert torch.allclose(iou(pred, target), torch.tensor(0.9660), atol=1e-4) + assert torch.allclose(iou(preds, target), torch.tensor(0.9660), atol=1e-4) + + +def test_v1_5_metric_regress(): + ExplainedVariance.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + ExplainedVariance() + + target = torch.tensor([3, -0.5, 2, 7]) + preds = torch.tensor([2.5, 0.0, 2, 8]) + explained_variance.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert torch.allclose(explained_variance(preds, target), torch.tensor(0.9572), atol=1e-4) From e825229f955b6adbae4155e72b5daaeeb22f8389 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 22 Mar 2021 17:38:21 +0100 Subject: [PATCH 2/8] tests --- tests/accelerators/test_cpu.py | 1 + .../regression/test_explained_variance.py | 77 ------------------- tests/metrics/test_remove_1-5_metrics.py | 6 +- tests/utilities/test_argparse.py | 4 +- 4 files changed, 7 insertions(+), 81 deletions(-) delete mode 100644 tests/metrics/regression/test_explained_variance.py diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 349e4175a7444..bcb351984a175 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -2,6 +2,7 @@ import pytest import torch + from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin diff --git a/tests/metrics/regression/test_explained_variance.py b/tests/metrics/regression/test_explained_variance.py deleted file mode 100644 index adab562ac6055..0000000000000 --- a/tests/metrics/regression/test_explained_variance.py +++ /dev/null @@ -1,77 +0,0 @@ -from collections import namedtuple -from functools import partial - -import pytest -import torch -from sklearn.metrics import explained_variance_score - -from pytorch_lightning.metrics.functional import explained_variance -from pytorch_lightning.metrics.regression import ExplainedVariance -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -num_targets = 5 - -Input = namedtuple('Input', ["preds", "target"]) - -_single_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.rand(NUM_BATCHES, BATCH_SIZE), -) - -_multi_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), -) - - -def _single_target_sk_metric(preds, target, sk_fn=explained_variance_score): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_fn(sk_target, sk_preds) - - -def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() - return sk_fn(sk_target, sk_preds) - - -@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted']) -@pytest.mark.parametrize( - "preds, target, sk_metric", - [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), - ], -) -class TestExplainedVariance(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_explained_variance(self, multioutput, preds, target, sk_metric, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp, - preds, - target, - ExplainedVariance, - partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)), - dist_sync_on_step, - metric_args=dict(multioutput=multioutput), - ) - - def test_explained_variance_functional(self, multioutput, preds, target, sk_metric): - self.run_functional_metric_test( - preds, - target, - explained_variance, - partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)), - metric_args=dict(multioutput=multioutput), - ) - - -def test_error_on_different_shape(metric_class=ExplainedVariance): - metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py index ea869b8b944ca..627a78825a345 100644 --- a/tests/metrics/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -22,6 +22,7 @@ AUROC, AveragePrecision, ConfusionMatrix, + ExplainedVariance, F1, FBeta, HammingDistance, @@ -31,13 +32,14 @@ PrecisionRecallCurve, Recall, ROC, - StatScores, ExplainedVariance, + StatScores, ) from pytorch_lightning.metrics.functional import ( auc, auroc, average_precision, confusion_matrix, + explained_variance, f1, fbeta, hamming_distance, @@ -47,7 +49,7 @@ precision_recall_curve, recall, roc, - stat_scores, explained_variance, + stat_scores, ) from pytorch_lightning.metrics.functional.accuracy import accuracy from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot diff --git a/tests/utilities/test_argparse.py b/tests/utilities/test_argparse.py index fdf5ae0cafe65..aef266d639b4a 100644 --- a/tests/utilities/test_argparse.py +++ b/tests/utilities/test_argparse.py @@ -7,13 +7,13 @@ from pytorch_lightning import Trainer from pytorch_lightning.utilities.argparse import ( + _gpus_arg_default, + _int_or_float_type, add_argparse_args, from_argparse_args, get_abbrev_qualified_cls_name, parse_argparser, parse_args_from_docstring, - _gpus_arg_default, - _int_or_float_type ) From 89423decf21f7eaaaa64ef45a16645886d14b4fe Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 22 Mar 2021 17:44:03 +0100 Subject: [PATCH 3/8] mean_absolute_error --- .../metrics/functional/mean_absolute_error.py | 34 ++-------- .../metrics/regression/mean_absolute_error.py | 64 ++----------------- tests/metrics/test_remove_1-5_metrics.py | 14 +++- 3 files changed, 24 insertions(+), 88 deletions(-) diff --git a/pytorch_lightning/metrics/functional/mean_absolute_error.py b/pytorch_lightning/metrics/functional/mean_absolute_error.py index 2bd8f125ecb9e..85aa07c802eca 100644 --- a/pytorch_lightning/metrics/functional/mean_absolute_error.py +++ b/pytorch_lightning/metrics/functional/mean_absolute_error.py @@ -11,40 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import mean_absolute_error as _mean_absolute_error - -def _mean_absolute_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - sum_abs_error = torch.sum(torch.abs(preds - target)) - n_obs = target.numel() - return sum_abs_error, n_obs - - -def _mean_absolute_error_compute(sum_abs_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_abs_error / n_obs +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_mean_absolute_error, ver_deprecate="1.3.0", ver_remove="1.5.0") def mean_absolute_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean absolute error - - Args: - pred: estimated labels - target: ground truth labels - - Return: - Tensor with MAE - - Example: - >>> from pytorch_lightning.metrics.functional import mean_absolute_error - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_absolute_error(x, y) - tensor(0.2500) + .. deprecated:: + Use :func:`torchmetrics.functional.mean_absolute_error`. Will be removed in v1.5.0. """ - sum_abs_error, n_obs = _mean_absolute_error_update(preds, target) - return _mean_absolute_error_compute(sum_abs_error, n_obs) diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index 484ccbe83284e..8510275c127d7 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -13,42 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import MeanAbsoluteError as _MeanAbsoluteError -from pytorch_lightning.metrics.functional.mean_absolute_error import ( - _mean_absolute_error_compute, - _mean_absolute_error_update, -) +from pytorch_lightning.utilities.deprecation import deprecated -class MeanAbsoluteError(Metric): - r""" - Computes `mean absolute error `_ (MAE): - - .. math:: \text{MAE} = \frac{1}{N}\sum_i^N | y_i - \hat{y_i} | - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - - Args: - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - - >>> from pytorch_lightning.metrics import MeanAbsoluteError - >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) - >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> mean_absolute_error = MeanAbsoluteError() - >>> mean_absolute_error(preds, target) - tensor(0.5000) - """ +class MeanAbsoluteError(_MeanAbsoluteError): + @deprecated(target=_MeanAbsoluteError, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, compute_on_step: bool = True, @@ -56,31 +28,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_abs_error, n_obs = _mean_absolute_error_update(preds, target) + This implementation refers to :class:`~torchmetrics.MeanAbsoluteError`. - self.sum_abs_error += sum_abs_error - self.total += n_obs - - def compute(self): - """ - Computes mean absolute error over state. + .. deprecated:: + Use :class:`~torchmetrics.MeanAbsoluteError`. Will be removed in v1.5.0. """ - return _mean_absolute_error_compute(self.sum_abs_error, self.total) diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py index 627a78825a345..1cf2d55744322 100644 --- a/tests/metrics/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -32,7 +32,7 @@ PrecisionRecallCurve, Recall, ROC, - StatScores, + StatScores, MeanAbsoluteError, ) from pytorch_lightning.metrics.functional import ( auc, @@ -49,7 +49,7 @@ precision_recall_curve, recall, roc, - stat_scores, + stat_scores, mean_absolute_error, ) from pytorch_lightning.metrics.functional.accuracy import accuracy from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot @@ -246,8 +246,18 @@ def test_v1_5_metric_regress(): with pytest.deprecated_call(match='It will be removed in v1.5.0'): ExplainedVariance() + MeanAbsoluteError.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanAbsoluteError() + target = torch.tensor([3, -0.5, 2, 7]) preds = torch.tensor([2.5, 0.0, 2, 8]) explained_variance.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert torch.allclose(explained_variance(preds, target), torch.tensor(0.9572), atol=1e-4) + + x = torch.tensor([0., 1, 2, 3]) + y = torch.tensor([0., 1, 2, 2]) + mean_absolute_error.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_absolute_error(x, y) == 0.25 From a150dfbf3e409596e09da62657b3d4e7b0877b5f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 22 Mar 2021 17:47:22 +0100 Subject: [PATCH 4/8] mean_squared_error --- .../metrics/functional/mean_squared_error.py | 33 ++-------- .../metrics/regression/mean_squared_error.py | 65 ++----------------- tests/metrics/test_remove_1-5_metrics.py | 12 +++- 3 files changed, 22 insertions(+), 88 deletions(-) diff --git a/pytorch_lightning/metrics/functional/mean_squared_error.py b/pytorch_lightning/metrics/functional/mean_squared_error.py index 66c0aadef0651..6801a0aa6e9e8 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_error.py @@ -14,37 +14,14 @@ from typing import Tuple import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import mean_squared_error as _mean_squared_error - -def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - sum_squared_error = torch.sum(torch.pow(preds - target, 2)) - n_obs = target.numel() - return sum_squared_error, n_obs - - -def _mean_squared_error_compute(sum_squared_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_squared_error / n_obs +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_mean_squared_error, ver_deprecate="1.3.0", ver_remove="1.5.0") def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean squared error - - Args: - preds: estimated labels - target: ground truth labels - - Return: - Tensor with MSE - - Example: - >>> from pytorch_lightning.metrics.functional import mean_squared_error - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_squared_error(x, y) - tensor(0.2500) + .. deprecated:: + Use :func:`torchmetrics.functional.mean_squared_error`. Will be removed in v1.5.0. """ - sum_squared_error, n_obs = _mean_squared_error_update(preds, target) - return _mean_squared_error_compute(sum_squared_error, n_obs) diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index c26371514e7cd..cbe09faf0046c 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -13,43 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import MeanSquaredError as _MeanSquaredError -from pytorch_lightning.metrics.functional.mean_squared_error import ( - _mean_squared_error_compute, - _mean_squared_error_update, -) +from pytorch_lightning.utilities.deprecation import deprecated -class MeanSquaredError(Metric): - r""" - Computes `mean squared error `_ (MSE): - - .. math:: \text{MSE} = \frac{1}{N}\sum_i^N(y_i - \hat{y_i})^2 - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - - Args: - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - - >>> from pytorch_lightning.metrics import MeanSquaredError - >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) - >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) - >>> mean_squared_error = MeanSquaredError() - >>> mean_squared_error(preds, target) - tensor(0.8750) - - """ +class MeanSquaredError(_MeanSquaredError): + @deprecated(target=_MeanSquaredError, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, compute_on_step: bool = True, @@ -57,31 +28,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_squared_error, n_obs = _mean_squared_error_update(preds, target) - - self.sum_squared_error += sum_squared_error - self.total += n_obs + This implementation refers to :class:`~torchmetrics.MeanSquaredError`. - def compute(self): - """ - Computes mean squared error over state. + .. deprecated:: + Use :class:`~torchmetrics.MeanSquaredError`. Will be removed in v1.5.0. """ - return _mean_squared_error_compute(self.sum_squared_error, self.total) diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py index 1cf2d55744322..75940a01020fa 100644 --- a/tests/metrics/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -32,7 +32,7 @@ PrecisionRecallCurve, Recall, ROC, - StatScores, MeanAbsoluteError, + StatScores, MeanAbsoluteError, MeanSquaredError, ) from pytorch_lightning.metrics.functional import ( auc, @@ -49,7 +49,7 @@ precision_recall_curve, recall, roc, - stat_scores, mean_absolute_error, + stat_scores, mean_absolute_error, mean_squared_error, ) from pytorch_lightning.metrics.functional.accuracy import accuracy from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot @@ -250,6 +250,10 @@ def test_v1_5_metric_regress(): with pytest.deprecated_call(match='It will be removed in v1.5.0'): MeanAbsoluteError() + MeanSquaredError.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanSquaredError() + target = torch.tensor([3, -0.5, 2, 7]) preds = torch.tensor([2.5, 0.0, 2, 8]) explained_variance.warned = False @@ -261,3 +265,7 @@ def test_v1_5_metric_regress(): mean_absolute_error.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert mean_absolute_error(x, y) == 0.25 + + mean_squared_error.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_squared_error(x, y) == 0.25 From 9dc7dd2c03ce9f57781b09f65c48a6662d8895f2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 22 Mar 2021 17:50:45 +0100 Subject: [PATCH 5/8] mean_relative_error --- .../metrics/functional/mean_relative_error.py | 36 +++---------------- tests/metrics/test_remove_1-5_metrics.py | 13 +++++-- 2 files changed, 16 insertions(+), 33 deletions(-) diff --git a/pytorch_lightning/metrics/functional/mean_relative_error.py b/pytorch_lightning/metrics/functional/mean_relative_error.py index bfe5eb6b847d7..8c6e10a17320b 100644 --- a/pytorch_lightning/metrics/functional/mean_relative_error.py +++ b/pytorch_lightning/metrics/functional/mean_relative_error.py @@ -14,40 +14,14 @@ from typing import Tuple import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional.regression.mean_relative_error import mean_relative_error as _mean_relative_error - -def _mean_relative_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - target_nz = target.clone() - target_nz[target == 0] = 1 - sum_rltv_error = torch.sum(torch.abs((preds - target) / target_nz)) - n_obs = target.numel() - return sum_rltv_error, n_obs - - -def _mean_relative_error_compute(sum_rltv_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_rltv_error / n_obs +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_mean_relative_error, ver_deprecate="1.3.0", ver_remove="1.5.0") def mean_relative_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean relative error - - Args: - pred: estimated labels - target: ground truth labels - - Return: - Tensor with mean relative error - - Example: - - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_relative_error(x, y) - tensor(0.1250) - + .. deprecated:: + Use :func:`torchmetrics.functional.regression.mean_relative_error`. Will be removed in v1.5.0. """ - sum_rltv_error, n_obs = _mean_relative_error_update(preds, target) - return _mean_relative_error_compute(sum_rltv_error, n_obs) diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py index 75940a01020fa..595e5d9d8cc41 100644 --- a/tests/metrics/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -27,12 +27,14 @@ FBeta, HammingDistance, IoU, + MeanAbsoluteError, + MeanSquaredError, MetricCollection, Precision, PrecisionRecallCurve, Recall, ROC, - StatScores, MeanAbsoluteError, MeanSquaredError, + StatScores, ) from pytorch_lightning.metrics.functional import ( auc, @@ -44,14 +46,17 @@ fbeta, hamming_distance, iou, + mean_absolute_error, + mean_squared_error, precision, precision_recall, precision_recall_curve, recall, roc, - stat_scores, mean_absolute_error, mean_squared_error, + stat_scores, ) from pytorch_lightning.metrics.functional.accuracy import accuracy +from pytorch_lightning.metrics.functional.mean_relative_error import mean_relative_error from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot @@ -266,6 +271,10 @@ def test_v1_5_metric_regress(): with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert mean_absolute_error(x, y) == 0.25 + mean_relative_error.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_relative_error(x, y) == 0.125 + mean_squared_error.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert mean_squared_error(x, y) == 0.25 From e1f2ca07de91aec2988022aede0985d3c19f6dda Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 22 Mar 2021 17:55:27 +0100 Subject: [PATCH 6/8] mean_squared_log_error --- .../functional/mean_squared_log_error.py | 33 ++----- .../regression/mean_squared_log_error.py | 67 ++------------ tests/metrics/regression/test_mean_error.py | 87 ------------------- tests/metrics/test_remove_1-5_metrics.py | 14 ++- 4 files changed, 25 insertions(+), 176 deletions(-) delete mode 100644 tests/metrics/regression/test_mean_error.py diff --git a/pytorch_lightning/metrics/functional/mean_squared_log_error.py b/pytorch_lightning/metrics/functional/mean_squared_log_error.py index baec63c7248f2..ac8154918f22f 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_log_error.py @@ -14,37 +14,14 @@ from typing import Tuple import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import mean_squared_log_error as _mean_squared_log_error - -def _mean_squared_log_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - sum_squared_log_error = torch.sum(torch.pow(torch.log1p(preds) - torch.log1p(target), 2)) - n_obs = target.numel() - return sum_squared_log_error, n_obs - - -def _mean_squared_log_error_compute(sum_squared_log_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_squared_log_error / n_obs +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_mean_squared_log_error, ver_deprecate="1.3.0", ver_remove="1.5.0") def mean_squared_log_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean squared log error - - Args: - preds: estimated labels - target: ground truth labels - - Return: - Tensor with RMSLE - - Example: - >>> from pytorch_lightning.metrics.functional import mean_squared_log_error - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_squared_log_error(x, y) - tensor(0.0207) + .. deprecated:: + Use :func:`torchmetrics.functional.mean_squared_log_error`. Will be removed in v1.5.0. """ - sum_squared_log_error, n_obs = _mean_squared_log_error_update(preds, target) - return _mean_squared_log_error_compute(sum_squared_log_error, n_obs) diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py index caaf09a3663ff..795d6f5409abf 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -13,45 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import MeanSquaredLogError as _MeanSquaredLogError -from pytorch_lightning.metrics.functional.mean_squared_log_error import ( - _mean_squared_log_error_compute, - _mean_squared_log_error_update, -) +from pytorch_lightning.utilities.deprecation import deprecated -class MeanSquaredLogError(Metric): - r""" - Computes `mean squared logarithmic error - `_ - (MSLE): - - .. math:: \text{MSLE} = \frac{1}{N}\sum_i^N (\log_e(1 + y_i) - \log_e(1 + \hat{y_i}))^2 - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - - Args: - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - - >>> from pytorch_lightning.metrics import MeanSquaredLogError - >>> target = torch.tensor([2.5, 5, 4, 8]) - >>> preds = torch.tensor([3, 5, 2.5, 7]) - >>> mean_squared_log_error = MeanSquaredLogError() - >>> mean_squared_log_error(preds, target) - tensor(0.0397) - - """ +class MeanSquaredLogError(_MeanSquaredLogError): + @deprecated(target=_MeanSquaredLogError, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, compute_on_step: bool = True, @@ -59,31 +28,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_squared_log_error, n_obs = _mean_squared_log_error_update(preds, target) - - self.sum_squared_log_error += sum_squared_log_error - self.total += n_obs + This implementation refers to :class:`~torchmetrics.MeanSquaredLogError`. - def compute(self): - """ - Compute mean squared logarithmic error over state. + .. deprecated:: + Use :class:`~torchmetrics.MeanSquaredLogError`. Will be removed in v1.5.0. """ - return _mean_squared_log_error_compute(self.sum_squared_log_error, self.total) diff --git a/tests/metrics/regression/test_mean_error.py b/tests/metrics/regression/test_mean_error.py deleted file mode 100644 index 041ce12f11164..0000000000000 --- a/tests/metrics/regression/test_mean_error.py +++ /dev/null @@ -1,87 +0,0 @@ -from collections import namedtuple -from functools import partial - -import pytest -import torch -from sklearn.metrics import mean_absolute_error as sk_mean_absolute_error -from sklearn.metrics import mean_squared_error as sk_mean_squared_error -from sklearn.metrics import mean_squared_log_error as sk_mean_squared_log_error - -from pytorch_lightning.metrics.functional import mean_absolute_error, mean_squared_error, mean_squared_log_error -from pytorch_lightning.metrics.regression import MeanAbsoluteError, MeanSquaredError, MeanSquaredLogError -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -num_targets = 5 - -Input = namedtuple('Input', ["preds", "target"]) - -_single_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.rand(NUM_BATCHES, BATCH_SIZE), -) - -_multi_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), -) - - -def _single_target_sk_metric(preds, target, sk_fn=mean_squared_error): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_fn(sk_preds, sk_target) - - -def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() - return sk_fn(sk_preds, sk_target) - - -@pytest.mark.parametrize( - "preds, target, sk_metric", - [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), - ], -) -@pytest.mark.parametrize( - "metric_class, metric_functional, sk_fn", - [ - (MeanSquaredError, mean_squared_error, sk_mean_squared_error), - (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error), - (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error), - ], -) -class TestMeanError(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_mean_error_class( - self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, ddp, dist_sync_on_step - ): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=metric_class, - sk_metric=partial(sk_metric, sk_fn=sk_fn), - dist_sync_on_step=dist_sync_on_step, - ) - - def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): - self.run_functional_metric_test( - preds=preds, - target=target, - metric_functional=metric_functional, - sk_metric=partial(sk_metric, sk_fn=sk_fn), - ) - - -@pytest.mark.parametrize("metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError]) -def test_error_on_different_shape(metric_class): - metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py index 595e5d9d8cc41..eaf17ec0792da 100644 --- a/tests/metrics/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -29,6 +29,7 @@ IoU, MeanAbsoluteError, MeanSquaredError, + MeanSquaredLogError, MetricCollection, Precision, PrecisionRecallCurve, @@ -48,6 +49,7 @@ iou, mean_absolute_error, mean_squared_error, + mean_squared_log_error, precision, precision_recall, precision_recall_curve, @@ -259,11 +261,16 @@ def test_v1_5_metric_regress(): with pytest.deprecated_call(match='It will be removed in v1.5.0'): MeanSquaredError() + MeanSquaredLogError.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanSquaredLogError() + target = torch.tensor([3, -0.5, 2, 7]) preds = torch.tensor([2.5, 0.0, 2, 8]) explained_variance.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert torch.allclose(explained_variance(preds, target), torch.tensor(0.9572), atol=1e-4) + res = explained_variance(preds, target) + assert torch.allclose(res, torch.tensor(0.9572), atol=1e-4) x = torch.tensor([0., 1, 2, 3]) y = torch.tensor([0., 1, 2, 2]) @@ -278,3 +285,8 @@ def test_v1_5_metric_regress(): mean_squared_error.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert mean_squared_error(x, y) == 0.25 + + mean_squared_log_error.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = mean_squared_log_error(x, y) + assert torch.allclose(res, torch.tensor(0.0207), atol=1e-4) From e8f23e5c10017b15598d45243300e92aa6ecaf10 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 22 Mar 2021 17:57:56 +0100 Subject: [PATCH 7/8] chlog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 51ad97decd867..1378fbaec80c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -90,6 +90,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#6584](https://github.com/PyTorchLightning/pytorch-lightning/pull/6584), + [#6636](https://github.com/PyTorchLightning/pytorch-lightning/pull/6636), + ) From 99c6b68906248f842a1111878f7fc79f3b5c92d1 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 22 Mar 2021 18:01:55 +0100 Subject: [PATCH 8/8] flake8 --- pytorch_lightning/metrics/functional/explained_variance.py | 2 +- pytorch_lightning/metrics/functional/mean_relative_error.py | 1 - pytorch_lightning/metrics/functional/mean_squared_error.py | 1 - pytorch_lightning/metrics/functional/mean_squared_log_error.py | 1 - 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/functional/explained_variance.py b/pytorch_lightning/metrics/functional/explained_variance.py index 534032024d5c0..bcfe698bf4c5e 100644 --- a/pytorch_lightning/metrics/functional/explained_variance.py +++ b/pytorch_lightning/metrics/functional/explained_variance.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple, Union +from typing import Sequence, Union import torch from torchmetrics.functional import explained_variance as _explained_variance diff --git a/pytorch_lightning/metrics/functional/mean_relative_error.py b/pytorch_lightning/metrics/functional/mean_relative_error.py index 8c6e10a17320b..be21371bdc91a 100644 --- a/pytorch_lightning/metrics/functional/mean_relative_error.py +++ b/pytorch_lightning/metrics/functional/mean_relative_error.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torchmetrics.functional.regression.mean_relative_error import mean_relative_error as _mean_relative_error diff --git a/pytorch_lightning/metrics/functional/mean_squared_error.py b/pytorch_lightning/metrics/functional/mean_squared_error.py index 6801a0aa6e9e8..9d1850dcd8689 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_error.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torchmetrics.functional import mean_squared_error as _mean_squared_error diff --git a/pytorch_lightning/metrics/functional/mean_squared_log_error.py b/pytorch_lightning/metrics/functional/mean_squared_log_error.py index ac8154918f22f..56654ea47daf2 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_log_error.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from torchmetrics.functional import mean_squared_log_error as _mean_squared_log_error