Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

harmonization and clarification of dice losses variants docs and associated tests #7587

Merged
merged 6 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,12 +778,22 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

Raises:
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 nor the same as input.
ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.

Returns:
torch.Tensor: value of the loss.

"""
if len(input.shape) != len(target.shape):
if input.dim() != target.dim():
raise ValueError(
"the number of dimensions for input and target should be the same, "
f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
"if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
)

if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
raise ValueError(
"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
f"got shape {input.shape} and {target.shape}."
)

Expand Down Expand Up @@ -899,14 +909,24 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

Raises:
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 nor the same as input.
ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.

Returns:
torch.Tensor: value of the loss.
"""
if len(input.shape) != len(target.shape):
if input.dim() != target.dim():
raise ValueError(
"the number of dimensions for input and target should be the same, "
f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
"if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
)

if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
raise ValueError(
"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
f"got shape {input.shape} and {target.shape}."
)

if self.to_onehot_y:
n_pred_ch = input.shape[1]
if n_pred_ch == 1:
Expand Down Expand Up @@ -1015,15 +1035,23 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
target (torch.Tensor): the shape should be BNH[WD] or B1H[WD].

Raises:
ValueError: When the input and target tensors have different numbers of dimensions, or the target
channel isn't either one-hot encoded or categorical with the same shape of the input.
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.

Returns:
torch.Tensor: value of the loss.
"""
if input.dim() != target.dim():
raise ValueError(
f"Input - {input.shape} - and target - {target.shape} - must have the same number of dimensions."
"the number of dimensions for input and target should be the same, "
f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
"if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
)

if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
raise ValueError(
"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
f"got shape {input.shape} and {target.shape}."
)

gdl_loss = self.generalized_dice(input, target)
Expand Down
18 changes: 14 additions & 4 deletions tests/test_dice_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,20 @@ def test_result(self, input_param, input_data, expected_val):
result = diceceloss(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)

# def test_ill_shape(self):
# loss = DiceCELoss()
# with self.assertRaisesRegex(ValueError, ""):
# loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
def test_ill_shape(self):
loss = DiceCELoss()
with self.assertRaises(AssertionError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))

def test_ill_shape2(self):
loss = DiceCELoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))

def test_ill_shape3(self):
loss = DiceCELoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))

# def test_ill_reduction(self):
# with self.assertRaisesRegex(ValueError, ""):
Expand Down
14 changes: 12 additions & 2 deletions tests/test_dice_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,18 @@ def test_result_no_onehot_no_bg(self, size, onehot):

def test_ill_shape(self):
loss = DiceFocalLoss()
with self.assertRaisesRegex(ValueError, ""):
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
with self.assertRaises(AssertionError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))

def test_ill_shape2(self):
loss = DiceFocalLoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))

def test_ill_shape3(self):
loss = DiceFocalLoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))

def test_ill_lambda(self):
with self.assertRaisesRegex(ValueError, ""):
Expand Down
14 changes: 12 additions & 2 deletions tests/test_generalized_dice_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,18 @@ def test_result_no_onehot_no_bg(self):

def test_ill_shape(self):
loss = GeneralizedDiceFocalLoss()
with self.assertRaisesRegex(ValueError, ""):
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
with self.assertRaises(AssertionError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))

def test_ill_shape2(self):
loss = GeneralizedDiceFocalLoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))

def test_ill_shape3(self):
loss = GeneralizedDiceFocalLoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))

def test_ill_lambda(self):
with self.assertRaisesRegex(ValueError, ""):
Expand Down
Loading