diff --git a/mmseg/models/losses/__init__.py b/mmseg/models/losses/__init__.py index fbc5b2d1b9..d7e019747d 100644 --- a/mmseg/models/losses/__init__.py +++ b/mmseg/models/losses/__init__.py @@ -5,11 +5,12 @@ from .dice_loss import DiceLoss from .focal_loss import FocalLoss from .lovasz_loss import LovaszLoss +from .tversky_loss import TverskyLoss from .utils import reduce_loss, weight_reduce_loss, weighted_loss __all__ = [ 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', - 'FocalLoss' + 'FocalLoss', 'TverskyLoss' ] diff --git a/mmseg/models/losses/tversky_loss.py b/mmseg/models/losses/tversky_loss.py new file mode 100644 index 0000000000..4ad14f783c --- /dev/null +++ b/mmseg/models/losses/tversky_loss.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from +https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/dice_loss.py#L333 +(Apache-2.0 License)""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import get_class_weight, weighted_loss + + +@weighted_loss +def tversky_loss(pred, + target, + valid_mask, + alpha=0.3, + beta=0.7, + smooth=1, + class_weight=None, + ignore_index=255): + assert pred.shape[0] == target.shape[0] + total_loss = 0 + num_classes = pred.shape[1] + for i in range(num_classes): + if i != ignore_index: + tversky_loss = binary_tversky_loss( + pred[:, i], + target[..., i], + valid_mask=valid_mask, + alpha=alpha, + beta=beta, + smooth=smooth) + if class_weight is not None: + tversky_loss *= class_weight[i] + total_loss += tversky_loss + return total_loss / num_classes + + +@weighted_loss +def binary_tversky_loss(pred, + target, + valid_mask, + alpha=0.3, + beta=0.7, + smooth=1): + assert pred.shape[0] == target.shape[0] + pred = pred.reshape(pred.shape[0], -1) + target = target.reshape(target.shape[0], -1) + valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) + + TP = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) + FP = torch.sum(torch.mul(pred, 1 - target) * valid_mask, dim=1) + FN = torch.sum(torch.mul(1 - pred, target) * valid_mask, dim=1) + tversky = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth) + + return 1 - tversky + + +@LOSSES.register_module() +class TverskyLoss(nn.Module): + """TverskyLoss. This loss is proposed in `Tversky loss function for image + segmentation using 3D fully convolutional deep networks. + + `_. + Args: + smooth (float): A float number to smooth loss, and avoid NaN error. + Default: 1. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Default to 1.0. + ignore_index (int | None): The label index to be ignored. Default: 255. + alpha(float, in [0, 1]): + The coefficient of false positives. Default: 0.3. + beta (float, in [0, 1]): + The coefficient of false negatives. Default: 0.7. + Note: alpha + beta = 1. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_tversky'. + """ + + def __init__(self, + smooth=1, + class_weight=None, + loss_weight=1.0, + ignore_index=255, + alpha=0.3, + beta=0.7, + loss_name='loss_tversky'): + super(TverskyLoss, self).__init__() + self.smooth = smooth + self.class_weight = get_class_weight(class_weight) + self.loss_weight = loss_weight + self.ignore_index = ignore_index + assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!' + self.alpha = alpha + self.beta = beta + self._loss_name = loss_name + + def forward(self, pred, target, **kwargs): + if self.class_weight is not None: + class_weight = pred.new_tensor(self.class_weight) + else: + class_weight = None + + pred = F.softmax(pred, dim=1) + num_classes = pred.shape[1] + one_hot_target = F.one_hot( + torch.clamp(target.long(), 0, num_classes - 1), + num_classes=num_classes) + valid_mask = (target != self.ignore_index).long() + + loss = self.loss_weight * tversky_loss( + pred, + one_hot_target, + valid_mask=valid_mask, + alpha=self.alpha, + beta=self.beta, + smooth=self.smooth, + class_weight=class_weight, + ignore_index=self.ignore_index) + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/tests/test_models/test_losses/test_tversky_loss.py b/tests/test_models/test_losses/test_tversky_loss.py new file mode 100644 index 0000000000..24a4b57e9b --- /dev/null +++ b/tests/test_models/test_losses/test_tversky_loss.py @@ -0,0 +1,76 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + + +def test_tversky_lose(): + from mmseg.models import build_loss + + # test alpha + beta != 1 + with pytest.raises(AssertionError): + loss_cfg = dict( + type='TverskyLoss', + class_weight=[1.0, 2.0, 3.0], + loss_weight=1.0, + alpha=0.4, + beta=0.7, + loss_name='loss_tversky') + tversky_loss = build_loss(loss_cfg) + logits = torch.rand(8, 3, 4, 4) + labels = (torch.rand(8, 4, 4) * 3).long() + tversky_loss(logits, labels, ignore_index=1) + + # test tversky loss + loss_cfg = dict( + type='TverskyLoss', + class_weight=[1.0, 2.0, 3.0], + loss_weight=1.0, + ignore_index=1, + loss_name='loss_tversky') + tversky_loss = build_loss(loss_cfg) + logits = torch.rand(8, 3, 4, 4) + labels = (torch.rand(8, 4, 4) * 3).long() + tversky_loss(logits, labels) + + # test loss with class weights from file + import os + import tempfile + + import mmcv + import numpy as np + tmp_file = tempfile.NamedTemporaryFile() + + mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file + loss_cfg = dict( + type='TverskyLoss', + class_weight=f'{tmp_file.name}.pkl', + loss_weight=1.0, + ignore_index=1, + loss_name='loss_tversky') + tversky_loss = build_loss(loss_cfg) + tversky_loss(logits, labels) + + np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file + loss_cfg = dict( + type='TverskyLoss', + class_weight=f'{tmp_file.name}.pkl', + loss_weight=1.0, + ignore_index=1, + loss_name='loss_tversky') + tversky_loss = build_loss(loss_cfg) + tversky_loss(logits, labels) + tmp_file.close() + os.remove(f'{tmp_file.name}.pkl') + os.remove(f'{tmp_file.name}.npy') + + # test tversky loss has name `loss_tversky` + loss_cfg = dict( + type='TverskyLoss', + smooth=2, + loss_weight=1.0, + ignore_index=1, + alpha=0.3, + beta=0.7, + loss_name='loss_tversky') + tversky_loss = build_loss(loss_cfg) + assert tversky_loss.loss_name == 'loss_tversky'