Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metrics: add SSIM #2671

Merged
merged 3 commits into from
Jul 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added SSIM metrics ([#2671](https://github.com/PyTorchLightning/pytorch-lightning/pull/2671))
- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))

### Changed
Expand Down
12 changes: 12 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ RMSLE
.. autoclass:: pytorch_lightning.metrics.regression.RMSLE
:noindex:

SSIM
^^^^

.. autoclass:: pytorch_lightning.metrics.regression.SSIM
:noindex:

----------------

Functional Metrics
Expand Down Expand Up @@ -403,6 +409,12 @@ psnr (F)
.. autofunction:: pytorch_lightning.metrics.functional.psnr
:noindex:

ssim (F)
^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.ssim
:noindex:

stat_scores_multiple_classes (F)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
PSNR,
RMSE,
RMSLE,
SSIM
)
from pytorch_lightning.metrics.classification import (
Accuracy,
Expand Down Expand Up @@ -54,6 +55,7 @@
"PSNR",
"RMSE",
"RMSLE",
"SSIM"
]
__sequence_metrics = ["BLEUScore"]
__all__ = __regression_metrics + __classification_metrics + ["SklearnMetric"] + __sequence_metrics
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
psnr,
rmse,
rmsle,
ssim
)
from pytorch_lightning.metrics.functional.nlp import bleu_score
115 changes: 115 additions & 0 deletions pytorch_lightning/metrics/functional/regression.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Sequence

import torch
from torch.nn import functional as F

Expand Down Expand Up @@ -182,3 +184,116 @@ def psnr(
psnr_base_e = 2 * torch.log(data_range) - torch.log(mse_score)
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
return psnr


def _gaussian_kernel(channel, kernel_size, sigma, device):
def gaussian(kernel_size, sigma, device):
gauss = torch.arange(
start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32, device=device
)
gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2)))
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)

gaussian_kernel_x = gaussian(kernel_size[0], sigma[0], device)
gaussian_kernel_y = gaussian(kernel_size[1], sigma[1], device)
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)

return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])


def ssim(
pred: torch.Tensor,
target: torch.Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
data_range: float = None,
k1: float = 0.01,
k2: float = 0.03
) -> torch.Tensor:
"""
Computes Structual Similarity Index Measure

Args:
pred: Estimated image
target: Ground truth image
kernel_size: Size of the gaussian kernel. Default: (11, 11)
sigma: Standard deviation of the gaussian kernel. Default: (1.5, 1.5)
reduction: A method for reducing ssim over all elements in the ``pred`` tensor. Default: ``elementwise_mean``

Available reduction methods:
- elementwise_mean: takes the mean
- none: pass away
- sum: add elements

data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of SSIM. Default: 0.01
k2: Parameter of SSIM. Default: 0.03

Returns:
A Tensor with SSIM

Example:

>>> pred = torch.rand([16, 1, 16, 16])
>>> target = pred * 1.25
>>> ssim(pred, target)
tensor(0.9520)
"""

if pred.dtype != target.dtype:
ydcjeff marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(
"Expected `pred` and `target` to have the same data type."
f" Got pred: {pred.dtype} and target: {target.dtype}."
)

if pred.shape != target.shape:
raise ValueError(
"Expected `pred` and `target` to have the same shape."
f" Got pred: {pred.shape} and target: {target.shape}."
)

if len(pred.shape) != 4 or len(target.shape) != 4:
raise ValueError(
"Expected `pred` and `target` to have BxCxHxW shape."
f" Got pred: {pred.shape} and target: {target.shape}."
)

if len(kernel_size) != 2 or len(sigma) != 2:
raise ValueError(
"Expected `kernel_size` and `sigma` to have the length of two."
f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}."
)

if any(x % 2 == 0 or x <= 0 for x in kernel_size):
raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.")

if any(y <= 0 for y in sigma):
raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.")

if data_range is None:
data_range = max(pred.max() - pred.min(), target.max() - target.min())

C1 = pow(k1 * data_range, 2)
C2 = pow(k2 * data_range, 2)
device = pred.device

channel = pred.size(1)
kernel = _gaussian_kernel(channel, kernel_size, sigma, device)
mu_pred = F.conv2d(pred, kernel, groups=channel)
mu_target = F.conv2d(target, kernel, groups=channel)

mu_pred_sq = mu_pred.pow(2)
mu_target_sq = mu_target.pow(2)
mu_pred_target = mu_pred * mu_target

sigma_pred_sq = F.conv2d(pred * pred, kernel, groups=channel) - mu_pred_sq
sigma_target_sq = F.conv2d(target * target, kernel, groups=channel) - mu_target_sq
sigma_pred_target = F.conv2d(pred * target, kernel, groups=channel) - mu_pred_target

UPPER = 2 * sigma_pred_target + C2
LOWER = sigma_pred_sq + sigma_target_sq + C2

ssim_idx = ((2 * mu_pred_target + C1) * UPPER) / ((mu_pred_sq + mu_target_sq + C1) * LOWER)

return reduce(ssim_idx, reduction)
64 changes: 63 additions & 1 deletion pytorch_lightning/metrics/regression.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Sequence

import torch

from pytorch_lightning.metrics.functional.regression import (
mae,
mse,
psnr,
rmse,
rmsle
rmsle,
ssim
)
from pytorch_lightning.metrics.metric import Metric

Expand Down Expand Up @@ -229,3 +232,62 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
A Tensor with psnr score.
"""
return psnr(pred, target, self.data_range, self.base, self.reduction)


class SSIM(Metric):
"""
Computes Structual Similarity Index Measure

Example:

>>> pred = torch.rand([16, 1, 16, 16])
>>> target = pred * 1.25
>>> metric = SSIM()
>>> metric(pred, target)
tensor(0.9520)
"""

def __init__(
self,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
data_range: float = None,
k1: float = 0.01,
k2: float = 0.03
):
"""
Args:
kernel_size: Size of the gaussian kernel. Default: (11, 11)
sigma: Standard deviation of the gaussian kernel. Default: (1.5, 1.5)
reduction: A method for reducing ssim. Default: ``elementwise_mean``

Available reduction methods:
- elementwise_mean: takes the mean
- none: pass away
- sum: add elements

data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of SSIM. Default: 0.01
k2: Parameter of SSIM. Default: 0.03
"""
super().__init__(name="ssim")
self.kernel_size = kernel_size
self.sigma = sigma
self.reduction = reduction
self.data_range = data_range
self.k1 = k1
self.k2 = k2

def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation

Args:
pred: Estimated image
target: Ground truth image

Return:
torch.Tensor: SSIM Score
"""
return ssim(pred, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2)
53 changes: 51 additions & 2 deletions tests/metrics/functional/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import torch
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as ski_psnr
from skimage.metrics import structural_similarity as ski_ssim

from pytorch_lightning.metrics.functional import (
mae,
mse,
psnr,
rmse,
rmsle
rmsle,
ssim
)


Expand Down Expand Up @@ -86,10 +88,57 @@ def test_psnr_against_sklearn(sklearn_metric, torch_metric):
for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]:
pred = torch.randint(n_cls_pred, (500,), device=device, dtype=torch.float)
target = torch.randint(n_cls_target, (500,), device=device, dtype=torch.float)

sk_score = sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy(),
data_range=n_cls_target)
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
pl_score = torch_metric(pred, target, data_range=n_cls_target)
assert torch.allclose(sk_score, pl_score)


@pytest.mark.parametrize(['size', 'channel', 'plus', 'multichannel'], [
pytest.param(16, 1, 0.125, False),
pytest.param(32, 1, 0.25, False),
pytest.param(48, 3, 0.5, True),
pytest.param(64, 4, 0.75, True),
pytest.param(128, 5, 1, True)
])
def test_ssim(size, channel, plus, multichannel):
device = "cuda" if torch.cuda.is_available() else "cpu"
pred = torch.rand(1, channel, size, size, device=device)
target = pred + plus
ssim_idx = ssim(pred, target)
np_pred = np.random.rand(size, size, channel)
if multichannel is False:
np_pred = np_pred[:, :, 0]
np_target = np.add(np_pred, plus)
sk_ssim_idx = ski_ssim(np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True)
assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-2, rtol=1e-2)

ssim_idx = ssim(pred, pred)
assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device))


@pytest.mark.parametrize(['pred', 'target', 'kernel', 'sigma'], [
pytest.param([1, 1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # shape
pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input
])
def test_ssim_invalid_inputs(pred, target, kernel, sigma):
pred_t = torch.rand(pred)
target_t = torch.rand(target, dtype=torch.float64)
with pytest.raises(TypeError):
ssim(pred_t, target_t)

pred = torch.rand(pred)
target = torch.rand(target)
with pytest.raises(ValueError):
ssim(pred, target, kernel, sigma)
12 changes: 11 additions & 1 deletion tests/metrics/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from skimage.metrics import peak_signal_noise_ratio as ski_psnr

from pytorch_lightning.metrics.regression import (
MAE, MSE, RMSE, RMSLE, PSNR
MAE, MSE, RMSE, RMSLE, PSNR, SSIM
)


Expand Down Expand Up @@ -58,3 +58,13 @@ def test_psnr():
target = torch.tensor([0., 1, 2, 2])
score = psnr(pred, target)
assert isinstance(score, torch.Tensor)


def test_ssim():
ssim = SSIM()
assert ssim.name == 'ssim'

pred = torch.rand([16, 1, 16, 16])
target = pred * 1.25
score = ssim(pred, target)
assert isinstance(score, torch.Tensor)