diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b4c0f0dc1759..c761cdb6912bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added SSIM metrics ([#2671](https://github.com/PyTorchLightning/pytorch-lightning/pull/2671)) - Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535)) ### Changed diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 860d2fd5c7335..a32362af23e5c 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -234,6 +234,12 @@ RMSLE .. autoclass:: pytorch_lightning.metrics.regression.RMSLE :noindex: +SSIM +^^^^ + +.. autoclass:: pytorch_lightning.metrics.regression.SSIM + :noindex: + ---------------- Functional Metrics @@ -403,6 +409,12 @@ psnr (F) .. autofunction:: pytorch_lightning.metrics.functional.psnr :noindex: +ssim (F) +^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.ssim + :noindex: + stat_scores_multiple_classes (F) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 2cfbbfa01f6b6..2a107c639bd3f 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -6,6 +6,7 @@ PSNR, RMSE, RMSLE, + SSIM ) from pytorch_lightning.metrics.classification import ( Accuracy, @@ -54,6 +55,7 @@ "PSNR", "RMSE", "RMSLE", + "SSIM" ] __sequence_metrics = ["BLEUScore"] __all__ = __regression_metrics + __classification_metrics + ["SklearnMetric"] + __sequence_metrics diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index eb92cabf8e5e7..4d940ad18bd6a 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -26,5 +26,6 @@ psnr, rmse, rmsle, + ssim ) from pytorch_lightning.metrics.functional.nlp import bleu_score diff --git a/pytorch_lightning/metrics/functional/regression.py b/pytorch_lightning/metrics/functional/regression.py index c13dfc80b2066..68f7bef93f7ea 100644 --- a/pytorch_lightning/metrics/functional/regression.py +++ b/pytorch_lightning/metrics/functional/regression.py @@ -1,3 +1,5 @@ +from typing import Sequence + import torch from torch.nn import functional as F @@ -182,3 +184,116 @@ def psnr( psnr_base_e = 2 * torch.log(data_range) - torch.log(mse_score) psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) return psnr + + +def _gaussian_kernel(channel, kernel_size, sigma, device): + def gaussian(kernel_size, sigma, device): + gauss = torch.arange( + start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32, device=device + ) + gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2))) + return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) + + gaussian_kernel_x = gaussian(kernel_size[0], sigma[0], device) + gaussian_kernel_y = gaussian(kernel_size[1], sigma[1], device) + kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) + + return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) + + +def ssim( + pred: torch.Tensor, + target: torch.Tensor, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: float = None, + k1: float = 0.01, + k2: float = 0.03 +) -> torch.Tensor: + """ + Computes Structual Similarity Index Measure + + Args: + pred: Estimated image + target: Ground truth image + kernel_size: Size of the gaussian kernel. Default: (11, 11) + sigma: Standard deviation of the gaussian kernel. Default: (1.5, 1.5) + reduction: A method for reducing ssim over all elements in the ``pred`` tensor. Default: ``elementwise_mean`` + + Available reduction methods: + - elementwise_mean: takes the mean + - none: pass away + - sum: add elements + + data_range: Range of the image. If ``None``, it is determined from the image (max - min) + k1: Parameter of SSIM. Default: 0.01 + k2: Parameter of SSIM. Default: 0.03 + + Returns: + A Tensor with SSIM + + Example: + + >>> pred = torch.rand([16, 1, 16, 16]) + >>> target = pred * 1.25 + >>> ssim(pred, target) + tensor(0.9520) + """ + + if pred.dtype != target.dtype: + raise TypeError( + "Expected `pred` and `target` to have the same data type." + f" Got pred: {pred.dtype} and target: {target.dtype}." + ) + + if pred.shape != target.shape: + raise ValueError( + "Expected `pred` and `target` to have the same shape." + f" Got pred: {pred.shape} and target: {target.shape}." + ) + + if len(pred.shape) != 4 or len(target.shape) != 4: + raise ValueError( + "Expected `pred` and `target` to have BxCxHxW shape." + f" Got pred: {pred.shape} and target: {target.shape}." + ) + + if len(kernel_size) != 2 or len(sigma) != 2: + raise ValueError( + "Expected `kernel_size` and `sigma` to have the length of two." + f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}." + ) + + if any(x % 2 == 0 or x <= 0 for x in kernel_size): + raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.") + + if any(y <= 0 for y in sigma): + raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.") + + if data_range is None: + data_range = max(pred.max() - pred.min(), target.max() - target.min()) + + C1 = pow(k1 * data_range, 2) + C2 = pow(k2 * data_range, 2) + device = pred.device + + channel = pred.size(1) + kernel = _gaussian_kernel(channel, kernel_size, sigma, device) + mu_pred = F.conv2d(pred, kernel, groups=channel) + mu_target = F.conv2d(target, kernel, groups=channel) + + mu_pred_sq = mu_pred.pow(2) + mu_target_sq = mu_target.pow(2) + mu_pred_target = mu_pred * mu_target + + sigma_pred_sq = F.conv2d(pred * pred, kernel, groups=channel) - mu_pred_sq + sigma_target_sq = F.conv2d(target * target, kernel, groups=channel) - mu_target_sq + sigma_pred_target = F.conv2d(pred * target, kernel, groups=channel) - mu_pred_target + + UPPER = 2 * sigma_pred_target + C2 + LOWER = sigma_pred_sq + sigma_target_sq + C2 + + ssim_idx = ((2 * mu_pred_target + C1) * UPPER) / ((mu_pred_sq + mu_target_sq + C1) * LOWER) + + return reduce(ssim_idx, reduction) diff --git a/pytorch_lightning/metrics/regression.py b/pytorch_lightning/metrics/regression.py index e94c0bd60c7dd..5b69868e1f776 100644 --- a/pytorch_lightning/metrics/regression.py +++ b/pytorch_lightning/metrics/regression.py @@ -1,3 +1,5 @@ +from typing import Sequence + import torch from pytorch_lightning.metrics.functional.regression import ( @@ -5,7 +7,8 @@ mse, psnr, rmse, - rmsle + rmsle, + ssim ) from pytorch_lightning.metrics.metric import Metric @@ -229,3 +232,62 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: A Tensor with psnr score. """ return psnr(pred, target, self.data_range, self.base, self.reduction) + + +class SSIM(Metric): + """ + Computes Structual Similarity Index Measure + + Example: + + >>> pred = torch.rand([16, 1, 16, 16]) + >>> target = pred * 1.25 + >>> metric = SSIM() + >>> metric(pred, target) + tensor(0.9520) + """ + + def __init__( + self, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: float = None, + k1: float = 0.01, + k2: float = 0.03 + ): + """ + Args: + kernel_size: Size of the gaussian kernel. Default: (11, 11) + sigma: Standard deviation of the gaussian kernel. Default: (1.5, 1.5) + reduction: A method for reducing ssim. Default: ``elementwise_mean`` + + Available reduction methods: + - elementwise_mean: takes the mean + - none: pass away + - sum: add elements + + data_range: Range of the image. If ``None``, it is determined from the image (max - min) + k1: Parameter of SSIM. Default: 0.01 + k2: Parameter of SSIM. Default: 0.03 + """ + super().__init__(name="ssim") + self.kernel_size = kernel_size + self.sigma = sigma + self.reduction = reduction + self.data_range = data_range + self.k1 = k1 + self.k2 = k2 + + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Actual metric computation + + Args: + pred: Estimated image + target: Ground truth image + + Return: + torch.Tensor: SSIM Score + """ + return ssim(pred, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2) diff --git a/tests/metrics/functional/test_regression.py b/tests/metrics/functional/test_regression.py index 1434b86dff238..c9df4f1ba3b9e 100644 --- a/tests/metrics/functional/test_regression.py +++ b/tests/metrics/functional/test_regression.py @@ -2,13 +2,15 @@ import torch import numpy as np from skimage.metrics import peak_signal_noise_ratio as ski_psnr +from skimage.metrics import structural_similarity as ski_ssim from pytorch_lightning.metrics.functional import ( mae, mse, psnr, rmse, - rmsle + rmsle, + ssim ) @@ -86,10 +88,57 @@ def test_psnr_against_sklearn(sklearn_metric, torch_metric): for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]: pred = torch.randint(n_cls_pred, (500,), device=device, dtype=torch.float) target = torch.randint(n_cls_target, (500,), device=device, dtype=torch.float) - + sk_score = sklearn_metric(target.cpu().detach().numpy(), pred.cpu().detach().numpy(), data_range=n_cls_target) sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) pl_score = torch_metric(pred, target, data_range=n_cls_target) assert torch.allclose(sk_score, pl_score) + + +@pytest.mark.parametrize(['size', 'channel', 'plus', 'multichannel'], [ + pytest.param(16, 1, 0.125, False), + pytest.param(32, 1, 0.25, False), + pytest.param(48, 3, 0.5, True), + pytest.param(64, 4, 0.75, True), + pytest.param(128, 5, 1, True) +]) +def test_ssim(size, channel, plus, multichannel): + device = "cuda" if torch.cuda.is_available() else "cpu" + pred = torch.rand(1, channel, size, size, device=device) + target = pred + plus + ssim_idx = ssim(pred, target) + np_pred = np.random.rand(size, size, channel) + if multichannel is False: + np_pred = np_pred[:, :, 0] + np_target = np.add(np_pred, plus) + sk_ssim_idx = ski_ssim(np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True) + assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-2, rtol=1e-2) + + ssim_idx = ssim(pred, pred) + assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device)) + + +@pytest.mark.parametrize(['pred', 'target', 'kernel', 'sigma'], [ + pytest.param([1, 1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # shape + pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape) + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma) + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma) + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma) + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input +]) +def test_ssim_invalid_inputs(pred, target, kernel, sigma): + pred_t = torch.rand(pred) + target_t = torch.rand(target, dtype=torch.float64) + with pytest.raises(TypeError): + ssim(pred_t, target_t) + + pred = torch.rand(pred) + target = torch.rand(target) + with pytest.raises(ValueError): + ssim(pred, target, kernel, sigma) diff --git a/tests/metrics/test_regression.py b/tests/metrics/test_regression.py index dfdfd2a8d0ca2..955e6253e3225 100644 --- a/tests/metrics/test_regression.py +++ b/tests/metrics/test_regression.py @@ -6,7 +6,7 @@ from skimage.metrics import peak_signal_noise_ratio as ski_psnr from pytorch_lightning.metrics.regression import ( - MAE, MSE, RMSE, RMSLE, PSNR + MAE, MSE, RMSE, RMSLE, PSNR, SSIM ) @@ -58,3 +58,13 @@ def test_psnr(): target = torch.tensor([0., 1, 2, 2]) score = psnr(pred, target) assert isinstance(score, torch.Tensor) + + +def test_ssim(): + ssim = SSIM() + assert ssim.name == 'ssim' + + pred = torch.rand([16, 1, 16, 16]) + target = pred * 1.25 + score = ssim(pred, target) + assert isinstance(score, torch.Tensor)