Skip to content

Commit

Permalink
harmonization and clarification of dice losses variants docs and asso…
Browse files Browse the repository at this point in the history
…ciated tests (#7587)

### Description

This PR aims to clarify and harmonise the code for the DiceLoss variants
in the `monai/losses/dice.py` file. With the `to_onehot_y` `softmax` and
`sigmoid` arguments, I didn't necessarily understand the ValueError that
occurred when I passed a target of size NH[WD]. I had a bit of trouble
reading the documentation and understanding it. I thought that they had
to be the same shape as they are displayed, unlike the number of
dimensions in the input, so I added that.
Besides, in the documentation is written:
```python
"""
raises:
      ValueError: When number of channels for target is neither 1 nor the same as input.

"""
```
Trying to reproduce this, we give an input with a number of channels $N$
and target a number of channels of $M$, with $M \neq N$ and $M > 1$.
```python
loss = DiceCELoss()
input = torch.rand(1, 4, 3, 3)
target = torch.randn(1, 2, 3, 3)
loss(input, target)
>: AssertionError: ground truth has different shape (torch.Size([1, 2, 3, 3])) from input (torch.Size([1, 4, 3, 3]))
```
This error in the Dice is an `AssertionError` and not a `ValueError` as
expected and the explanation can be confusing and doesn't give a clear
idea of the error here. The classes concerned and harmonised are
`DiceFocalLoss`, `DiceCELoss` and `GeneralizedDiceFocalLoss` with the
addition of tests that behave correctly and handle this harmonisation.

Also, feel free to modify or make suggestions regarding the changes made
in the docstring to make them more understandable (in my opinion, but
other readers and users will probably have a different view).

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 5, 2024
1 parent 195d7dd commit 625967c
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 15 deletions.
42 changes: 35 additions & 7 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,12 +778,22 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Raises:
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 nor the same as input.
ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.
Returns:
torch.Tensor: value of the loss.
"""
if len(input.shape) != len(target.shape):
if input.dim() != target.dim():
raise ValueError(
"the number of dimensions for input and target should be the same, "
f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
"if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
)

if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
raise ValueError(
"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
f"got shape {input.shape} and {target.shape}."
)

Expand Down Expand Up @@ -899,14 +909,24 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Raises:
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 nor the same as input.
ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.
Returns:
torch.Tensor: value of the loss.
"""
if len(input.shape) != len(target.shape):
if input.dim() != target.dim():
raise ValueError(
"the number of dimensions for input and target should be the same, "
f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
"if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
)

if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
raise ValueError(
"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
f"got shape {input.shape} and {target.shape}."
)

if self.to_onehot_y:
n_pred_ch = input.shape[1]
if n_pred_ch == 1:
Expand Down Expand Up @@ -1015,15 +1035,23 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
target (torch.Tensor): the shape should be BNH[WD] or B1H[WD].
Raises:
ValueError: When the input and target tensors have different numbers of dimensions, or the target
channel isn't either one-hot encoded or categorical with the same shape of the input.
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.
Returns:
torch.Tensor: value of the loss.
"""
if input.dim() != target.dim():
raise ValueError(
f"Input - {input.shape} - and target - {target.shape} - must have the same number of dimensions."
"the number of dimensions for input and target should be the same, "
f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
"if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
)

if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
raise ValueError(
"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
f"got shape {input.shape} and {target.shape}."
)

gdl_loss = self.generalized_dice(input, target)
Expand Down
18 changes: 14 additions & 4 deletions tests/test_dice_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,20 @@ def test_result(self, input_param, input_data, expected_val):
result = diceceloss(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)

# def test_ill_shape(self):
# loss = DiceCELoss()
# with self.assertRaisesRegex(ValueError, ""):
# loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
def test_ill_shape(self):
loss = DiceCELoss()
with self.assertRaises(AssertionError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))

def test_ill_shape2(self):
loss = DiceCELoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))

def test_ill_shape3(self):
loss = DiceCELoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))

# def test_ill_reduction(self):
# with self.assertRaisesRegex(ValueError, ""):
Expand Down
14 changes: 12 additions & 2 deletions tests/test_dice_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,18 @@ def test_result_no_onehot_no_bg(self, size, onehot):

def test_ill_shape(self):
loss = DiceFocalLoss()
with self.assertRaisesRegex(ValueError, ""):
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
with self.assertRaises(AssertionError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))

def test_ill_shape2(self):
loss = DiceFocalLoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))

def test_ill_shape3(self):
loss = DiceFocalLoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))

def test_ill_lambda(self):
with self.assertRaisesRegex(ValueError, ""):
Expand Down
14 changes: 12 additions & 2 deletions tests/test_generalized_dice_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,18 @@ def test_result_no_onehot_no_bg(self):

def test_ill_shape(self):
loss = GeneralizedDiceFocalLoss()
with self.assertRaisesRegex(ValueError, ""):
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
with self.assertRaises(AssertionError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))

def test_ill_shape2(self):
loss = GeneralizedDiceFocalLoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))

def test_ill_shape3(self):
loss = GeneralizedDiceFocalLoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))

def test_ill_lambda(self):
with self.assertRaisesRegex(ValueError, ""):
Expand Down

0 comments on commit 625967c

Please sign in to comment.