Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify Dice, Jaccard and Tversky losses #8138

Merged
merged 15 commits into from
Dec 3, 2024

Conversation

zifuwanggg
Copy link
Contributor

@zifuwanggg zifuwanggg commented Oct 10, 2024

Fixes #8094.

Description

The Dice, Jaccard and Tversky losses in monai.losses.dice and monai.losses.tversky are modified based on JDTLoss and segmentation_models.pytorch.

In the original versions, when squared_pred=False, the loss functions are incompatible with soft labels. For example, with a ground truth value of 0.5 for a single pixel, the Dice loss is minimized when the predicted value is 1, which is clearly erroneous. To address this, the intersection term is rewritten as $\frac{|x|_p^p + |y|_p^p - |x-y|_p^p}{2}$. When $p$ is 2 (squared_pred=True), this reformulation becomes the classical inner product: $\langle x,y \rangle$. When $p$ is 1 (squared_pred=False), the reformulation has been proven to retain equivalence with the original versions when the ground truth is binary (i.e. one-hot hard labels). Moreover, since the new versions are minimized if and only if the prediction is identical to the ground truth, even when the ground truth include fractional numbers, they resolves the issue with soft labels [1, 2].

In summary, there are three scenarios:

  • [Scenario 1] $x$ is nonnegative and $y$ is binary: The new versions are the same as the original versions.
  • [Scenario 2] Both $x$ and $y$ are nonnegative: The new versions differ from the original versions. The new versions are minimized if and only if $x=y$, while the original versions may not, making them incorrect.
  • [Scenario 3] Either $x$ or $y$ is negative: The new versions differ from the original versions. The new versions are minimized if and only if $x=y$, while the original versions may not, making them incorrect.

Due to these differences, particularly in Scenarios 2 and 3, some tests fail with the new versions:

  • The target is non-binary: test_multi_scale
  • The input is negative: test_dice_loss, test_tversky_loss, test_generalized_dice_loss, test_masked_loss, test_seg_loss_integration

The failures in test_multi_scale are expected since the original versions are incorrectly defined for non-binary targets. Furthermore, because Dice, Jaccard, and Tversky losses are fundamentally defined over probabilities—which should be nonnegative—the new versions should not be tested against negative input or target values.

Example

import torch
import torch.linalg as LA
import torch.nn.functional as F

torch.manual_seed(0)

b, c, h, w = 4, 3, 32, 32
dims = (0, 2, 3)

pred = torch.rand(b, c, h, w).softmax(dim=1)
soft_label = torch.rand(b, c, h, w).softmax(dim=1)
hard_label = torch.randint(low=0, high=c, size=(b, h, w))
one_hot_label = F.one_hot(hard_label, c).permute(0, 3, 1, 2).float()

def dice_old(x, y, ord, dims):
    cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord
    intersection = torch.sum(x * y, dim=dims)
    return 2 * intersection / cardinality

def dice_new(x, y, ord, dims):
    cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord
    difference = LA.vector_norm(x - y, ord=ord, dim=dims) ** ord
    intersection = (cardinality - difference) / 2
    return 2 * intersection / cardinality

print(dice_old(pred, one_hot_label, 1, dims), dice_new(pred, one_hot_label, 1, dims))
print(dice_old(pred, soft_label, 1, dims), dice_new(pred, soft_label, 1, dims))
print(dice_old(pred, pred, 1, dims), dice_new(pred, pred, 1, dims))

print(dice_old(pred, one_hot_label, 2, dims), dice_new(pred, one_hot_label, 2, dims))
print(dice_old(pred, soft_label, 2, dims), dice_new(pred, soft_label, 2, dims))
print(dice_old(pred, pred, 2, dims), dice_new(pred, pred, 2, dims))

# tensor([0.3345, 0.3310, 0.3317]) tensor([0.3345, 0.3310, 0.3317])
# tensor([0.3321, 0.3333, 0.3350]) tensor([0.8680, 0.8690, 0.8700])
# tensor([0.3487, 0.3502, 0.3544]) tensor([1., 1., 1.])

# tensor([0.4921, 0.4904, 0.4935]) tensor([0.4921, 0.4904, 0.4935])
# tensor([0.9489, 0.9499, 0.9503]) tensor([0.9489, 0.9499, 0.9503])
# tensor([1., 1., 1.]) tensor([1., 1., 1.])

References

[1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew B. Blaschko. MICCAI 2023.

[2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. NeurIPS 2023.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@ericspod
Copy link
Member

Hi @zifuwanggg thanks for the contribution. I have an issue with this change in that the behaviour of the losses is very different now as seen in the CICD errors. I would instead suggest adding new loss functions in a "soft_losses.py" file or something like that instead of changing existing losses. Other uses may rely on existing behaviour, and in situations where non-binarises values are accidentally used due to incorrect postprocessing there is less feedback about the problem.

@ericspod ericspod requested review from Nic-Ma, KumoLiu and csudre October 12, 2024 21:52
@zifuwanggg
Copy link
Contributor Author

Hi @ericspod, thank you for reviewing my code. While adding new loss functions as separate .py files could be a workaround, my concern is that this approach would lead to a lot of duplicated code, as the core differences are only in 2-3 lines.

Would it make sense to add an attribute to the existing loss classes and create a new helper function, so that the default behavior remains unchanged? Something like the following.

class DiceLoss(_Loss):
    def __init__(
        ...
        binary_label: bool = True,
    ):
        ...
        self.binary_label = binary_label

    def forward(...):
        ...
        f = compute_score(self.binary_label)
        ...


class GeneralizedDiceLoss(_Loss):
    def __init__(
        ...
        binary_label: bool = True,
    ):
        ...
        self.binary_label = binary_label
    
    def forward(...):
        ...
        f = compute_score(self.binary_label)
        ...


def compute_score(binary_label):
    if binary_label == True:
        ...
    else:
        ...

@ericspod
Copy link
Member

Hi @ericspod, thank you for reviewing my code. While adding new loss functions as separate .py files could be a workaround, my concern is that this approach would lead to a lot of duplicated code, as the core differences are only in 2-3 lines.

Hi @zifuwanggg I appreciate wanting to reduce duplicate code, we have too much of that in these loss functions as it stands so yes adding more isn't great. I think we can try to parameterise the loss functions in some way, either a function as you suggest or some other way, so long as the default behaviour is preserved. If you want to have a go at refactoring to do that we can return to it, I think in the future we do need to refactor all these loss functions to reduce duplication anyway.

@zifuwanggg
Copy link
Contributor Author

zifuwanggg commented Oct 21, 2024

Hi @ericspod, I've created losses/utils.py and put a helper function that is shared by both dice.py and tversky.py.

Unit tests pass, but mypy tests fail. This seems related to #8149 and #8161.

@zifuwanggg
Copy link
Contributor Author

Hi @ericspod, all CICD tests pass. @KumoLiu, thanks for the commit.

Copy link
Member

@ericspod ericspod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a minor comment about the check in the helper function being expensive to compute, but otherwise we do also need tests for soft labels to ensure that formulation of the losses works. I do want to get others to review this as well to be doubly sure the changes are compatible. Thanks again.

monai/losses/utils.py Outdated Show resolved Hide resolved
monai/losses/dice.py Outdated Show resolved Hide resolved
@zifuwanggg
Copy link
Contributor Author

zifuwanggg commented Nov 30, 2024

Hi @ericspod, sorry for the late response.

I remove the costly check and modify the description of soft_label as you suggested. I also add some test cases. When input is the same as target, the loss value becomes zero when soft_label=True.

Hi @KumoLiu @csudre @Nic-Ma, could you please kindly review the changes?

Copy link
Contributor

@KumoLiu KumoLiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! Overall looks good to me.
Jus leave one comment and please sign off to fix the DCO error.
https://github.com/Project-MONAI/MONAI/pull/8138/checks?check_run_id=33736878625

monai/losses/utils.py Show resolved Hide resolved
zifuwanggg and others added 3 commits December 2, 2024 11:15
I, Zifu Wang <[email protected]>, hereby add my Signed-off-by to this commit: 3f74183
I, Zifu Wang <[email protected]>, hereby add my Signed-off-by to this commit: a778e58
I, Zifu Wang <[email protected]>, hereby add my Signed-off-by to this commit: aeef0af
I, Zifu Wang <[email protected]>, hereby add my Signed-off-by to this commit: 58c5396

Signed-off-by: Zifu Wang <[email protected]>
@KumoLiu
Copy link
Contributor

KumoLiu commented Dec 2, 2024

/build

@KumoLiu KumoLiu merged commit 9808ce2 into Project-MONAI:dev Dec 3, 2024
28 checks passed
@zifuwanggg zifuwanggg deleted the 8094-modify-dice-loss branch December 3, 2024 10:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Jaccard, Dice and Tversky losses are incompatible with soft labels
3 participants