diff --git a/monai/losses/dice.py b/monai/losses/dice.py index b3c0f57c6e..f1c357d31f 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -778,12 +778,22 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When number of dimensions for input and target are different. - ValueError: When number of channels for target is neither 1 nor the same as input. + ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input. + + Returns: + torch.Tensor: value of the loss. """ - if len(input.shape) != len(target.shape): + if input.dim() != target.dim(): raise ValueError( "the number of dimensions for input and target should be the same, " + f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). " + "if target is not one-hot encoded, please provide a tensor with shape B1H[WD]." + ) + + if target.shape[1] != 1 and target.shape[1] != input.shape[1]: + raise ValueError( + "number of channels for target is neither 1 (without one-hot encoding) nor the same as input, " f"got shape {input.shape} and {target.shape}." ) @@ -899,14 +909,24 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When number of dimensions for input and target are different. - ValueError: When number of channels for target is neither 1 nor the same as input. + ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input. + Returns: + torch.Tensor: value of the loss. """ - if len(input.shape) != len(target.shape): + if input.dim() != target.dim(): raise ValueError( "the number of dimensions for input and target should be the same, " + f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). " + "if target is not one-hot encoded, please provide a tensor with shape B1H[WD]." + ) + + if target.shape[1] != 1 and target.shape[1] != input.shape[1]: + raise ValueError( + "number of channels for target is neither 1 (without one-hot encoding) nor the same as input, " f"got shape {input.shape} and {target.shape}." ) + if self.to_onehot_y: n_pred_ch = input.shape[1] if n_pred_ch == 1: @@ -1015,15 +1035,23 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target (torch.Tensor): the shape should be BNH[WD] or B1H[WD]. Raises: - ValueError: When the input and target tensors have different numbers of dimensions, or the target - channel isn't either one-hot encoded or categorical with the same shape of the input. + ValueError: When number of dimensions for input and target are different. + ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input. Returns: torch.Tensor: value of the loss. """ if input.dim() != target.dim(): raise ValueError( - f"Input - {input.shape} - and target - {target.shape} - must have the same number of dimensions." + "the number of dimensions for input and target should be the same, " + f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). " + "if target is not one-hot encoded, please provide a tensor with shape B1H[WD]." + ) + + if target.shape[1] != 1 and target.shape[1] != input.shape[1]: + raise ValueError( + "number of channels for target is neither 1 (without one-hot encoding) nor the same as input, " + f"got shape {input.shape} and {target.shape}." ) gdl_loss = self.generalized_dice(input, target) diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 225618ed2c..97c7ae5050 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -93,10 +93,20 @@ def test_result(self, input_param, input_data, expected_val): result = diceceloss(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) - # def test_ill_shape(self): - # loss = DiceCELoss() - # with self.assertRaisesRegex(ValueError, ""): - # loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + def test_ill_shape(self): + loss = DiceCELoss() + with self.assertRaises(AssertionError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5))) + + def test_ill_shape2(self): + loss = DiceCELoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_ill_shape3(self): + loss = DiceCELoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4))) # def test_ill_reduction(self): # with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py index 13899da003..814a174762 100644 --- a/tests/test_dice_focal_loss.py +++ b/tests/test_dice_focal_loss.py @@ -69,8 +69,18 @@ def test_result_no_onehot_no_bg(self, size, onehot): def test_ill_shape(self): loss = DiceFocalLoss() - with self.assertRaisesRegex(ValueError, ""): - loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + with self.assertRaises(AssertionError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5))) + + def test_ill_shape2(self): + loss = DiceFocalLoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_ill_shape3(self): + loss = DiceFocalLoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4))) def test_ill_lambda(self): with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_generalized_dice_focal_loss.py b/tests/test_generalized_dice_focal_loss.py index 8a0a80865e..65252611ca 100644 --- a/tests/test_generalized_dice_focal_loss.py +++ b/tests/test_generalized_dice_focal_loss.py @@ -59,8 +59,18 @@ def test_result_no_onehot_no_bg(self): def test_ill_shape(self): loss = GeneralizedDiceFocalLoss() - with self.assertRaisesRegex(ValueError, ""): - loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + with self.assertRaises(AssertionError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5))) + + def test_ill_shape2(self): + loss = GeneralizedDiceFocalLoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_ill_shape3(self): + loss = GeneralizedDiceFocalLoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4))) def test_ill_lambda(self): with self.assertRaisesRegex(ValueError, ""):