Skip to content

Commit

Permalink
Merge similar test components with parameterized (#7663)
Browse files Browse the repository at this point in the history
### Description

I noticed some test cases contain same duplicated asserts. Having
multiple asserts in one test cases can cause potential issues like when
the first assert fails, the test case stops and won't check the second
assert. By using @parameterized.expand, this issue can be resolved and
the caching also saves execution time.

Added sign-offs from  #7648 

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Han Wang <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
freddiewanah and KumoLiu authored Apr 23, 2024
1 parent a59676f commit ec6aa33
Show file tree
Hide file tree
Showing 15 changed files with 243 additions and 473 deletions.
27 changes: 8 additions & 19 deletions tests/test_affine_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,28 +133,17 @@ def test_to_norm_affine_ill(self, affine, src_size, dst_size, align_corners):

class TestAffineTransform(unittest.TestCase):

def test_affine_shift(self):
affine = torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]])
image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])
out = AffineTransform(align_corners=False)(image, affine)
out = out.detach().cpu().numpy()
expected = [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]]
np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)

def test_affine_shift_1(self):
affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]])
image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])
out = AffineTransform(align_corners=False)(image, affine)
out = out.detach().cpu().numpy()
expected = [[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]]
np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)

def test_affine_shift_2(self):
affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]])
@parameterized.expand(
[
(torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]]), [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]]),
(torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]]), [[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]]),
(torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]]), [[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]]),
]
)
def test_affine_transforms(self, affine, expected):
image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])
out = AffineTransform(align_corners=False)(image, affine)
out = out.detach().cpu().numpy()
expected = [[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]]
np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)

def test_zoom(self):
Expand Down
36 changes: 16 additions & 20 deletions tests/test_compute_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np
import torch
from parameterized import parameterized

from monai.metrics import FBetaScore
from tests.utils import assert_allclose
Expand All @@ -33,26 +34,21 @@ def test_expecting_success_and_device(self):
assert_allclose(result, torch.Tensor([0.714286]), atol=1e-6, rtol=1e-6)
np.testing.assert_equal(result.device, y_pred.device)

def test_expecting_success2(self):
metric = FBetaScore(beta=0.5)
metric(
y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
)
assert_allclose(metric.aggregate()[0], torch.Tensor([0.609756]), atol=1e-6, rtol=1e-6)

def test_expecting_success3(self):
metric = FBetaScore(beta=2)
metric(
y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
)
assert_allclose(metric.aggregate()[0], torch.Tensor([0.862069]), atol=1e-6, rtol=1e-6)

def test_denominator_is_zero(self):
metric = FBetaScore(beta=2)
metric(
y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
)
assert_allclose(metric.aggregate()[0], torch.Tensor([0.0]), atol=1e-6, rtol=1e-6)
@parameterized.expand(
[
(0.5, torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]), torch.Tensor([0.609756])), # success_beta_0_5
(2, torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]), torch.Tensor([0.862069])), # success_beta_2
(
2, # success_beta_2, denominator_zero
torch.Tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]]),
torch.Tensor([0.0]),
),
]
)
def test_success_and_zero(self, beta, y, expected_score):
metric = FBetaScore(beta=beta)
metric(y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=y)
assert_allclose(metric.aggregate()[0], expected_score, atol=1e-6, rtol=1e-6)

def test_number_of_dimensions_less_than_2_should_raise_error(self):
metric = FBetaScore()
Expand Down
40 changes: 25 additions & 15 deletions tests/test_global_mutual_information_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np
import torch
from parameterized import parameterized

from monai import transforms
from monai.losses.image_dissimilarity import GlobalMutualInformationLoss
Expand Down Expand Up @@ -116,24 +117,33 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.

class TestGlobalMutualInformationLossIll(unittest.TestCase):

def test_ill_shape(self):
@parameterized.expand(
[
(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)), # mismatched_simple_dims
(
torch.ones((1, 3, 3), dtype=torch.float),
torch.ones((1, 3), dtype=torch.float),
), # mismatched_advanced_dims
]
)
def test_ill_shape(self, input1, input2):
loss = GlobalMutualInformationLoss()
with self.assertRaisesRegex(ValueError, ""):
loss.forward(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float, device=device))
with self.assertRaisesRegex(ValueError, ""):
loss.forward(torch.ones((1, 3, 3), dtype=torch.float), torch.ones((1, 3), dtype=torch.float, device=device))

def test_ill_opts(self):
with self.assertRaises(ValueError):
loss.forward(input1, input2)

@parameterized.expand(
[
(0, "mean", ValueError, ""), # num_bins_zero
(-1, "mean", ValueError, ""), # num_bins_negative
(64, "unknown", ValueError, ""), # reduction_unknown
(64, None, ValueError, ""), # reduction_none
]
)
def test_ill_opts(self, num_bins, reduction, expected_exception, expected_message):
pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device)
target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device)
with self.assertRaisesRegex(ValueError, ""):
GlobalMutualInformationLoss(num_bins=0)(pred, target)
with self.assertRaisesRegex(ValueError, ""):
GlobalMutualInformationLoss(num_bins=-1)(pred, target)
with self.assertRaisesRegex(ValueError, ""):
GlobalMutualInformationLoss(reduction="unknown")(pred, target)
with self.assertRaisesRegex(ValueError, ""):
GlobalMutualInformationLoss(reduction=None)(pred, target)
with self.assertRaisesRegex(expected_exception, expected_message):
GlobalMutualInformationLoss(num_bins=num_bins, reduction=reduction)(pred, target)


if __name__ == "__main__":
Expand Down
22 changes: 6 additions & 16 deletions tests/test_hausdorff_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,12 @@ def test_ill_opts(self):
with self.assertRaisesRegex(ValueError, ""):
HausdorffDTLoss(reduction=None)(chn_input, chn_target)

def test_input_warnings(self):
@parameterized.expand([(False, False, False), (False, True, False), (False, False, True)])
def test_input_warnings(self, include_background, softmax, to_onehot_y):
chn_input = torch.ones((1, 1, 1, 3))
chn_target = torch.ones((1, 1, 1, 3))
with self.assertWarns(Warning):
loss = HausdorffDTLoss(include_background=False)
loss.forward(chn_input, chn_target)
with self.assertWarns(Warning):
loss = HausdorffDTLoss(softmax=True)
loss.forward(chn_input, chn_target)
with self.assertWarns(Warning):
loss = HausdorffDTLoss(to_onehot_y=True)
loss = HausdorffDTLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y)
loss.forward(chn_input, chn_target)


Expand All @@ -256,17 +251,12 @@ def test_ill_opts(self):
with self.assertRaisesRegex(ValueError, ""):
LogHausdorffDTLoss(reduction=None)(chn_input, chn_target)

def test_input_warnings(self):
@parameterized.expand([(False, False, False), (False, True, False), (False, False, True)])
def test_input_warnings(self, include_background, softmax, to_onehot_y):
chn_input = torch.ones((1, 1, 1, 3))
chn_target = torch.ones((1, 1, 1, 3))
with self.assertWarns(Warning):
loss = LogHausdorffDTLoss(include_background=False)
loss.forward(chn_input, chn_target)
with self.assertWarns(Warning):
loss = LogHausdorffDTLoss(softmax=True)
loss.forward(chn_input, chn_target)
with self.assertWarns(Warning):
loss = LogHausdorffDTLoss(to_onehot_y=True)
loss = LogHausdorffDTLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y)
loss.forward(chn_input, chn_target)


Expand Down
21 changes: 7 additions & 14 deletions tests/test_median_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,20 @@

import numpy as np
import torch
from parameterized import parameterized

from monai.networks.layers import MedianFilter


class MedianFilterTestCase(unittest.TestCase):
@parameterized.expand([(torch.ones(1, 1, 2, 3, 5), [1, 2, 4]), (torch.ones(1, 1, 4, 3, 4), 1)]) # 3d_big # 3d
def test_3d(self, input_tensor, radius):
filter = MedianFilter(radius).to(torch.device("cpu:0"))

def test_3d_big(self):
a = torch.ones(1, 1, 2, 3, 5)
g = MedianFilter([1, 2, 4]).to(torch.device("cpu:0"))
expected = input_tensor.numpy()
output = filter(input_tensor).cpu().numpy()

expected = a.numpy()
out = g(a).cpu().numpy()
np.testing.assert_allclose(out, expected, rtol=1e-5)

def test_3d(self):
a = torch.ones(1, 1, 4, 3, 4)
g = MedianFilter(1).to(torch.device("cpu:0"))

expected = a.numpy()
out = g(a).cpu().numpy()
np.testing.assert_allclose(out, expected, rtol=1e-5)
np.testing.assert_allclose(output, expected, rtol=1e-5)

def test_3d_radii(self):
a = torch.ones(1, 1, 4, 3, 2)
Expand Down
29 changes: 18 additions & 11 deletions tests/test_multi_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,24 @@ def test_shape(self, input_param, input_data, expected_val):
result = MultiScaleLoss(**input_param).forward(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)

def test_ill_opts(self):
with self.assertRaisesRegex(ValueError, ""):
MultiScaleLoss(loss=dice_loss, kernel="none")
with self.assertRaisesRegex(ValueError, ""):
MultiScaleLoss(loss=dice_loss, scales=[-1])(
torch.ones((1, 1, 3), device=device), torch.ones((1, 1, 3), device=device)
)
with self.assertRaisesRegex(ValueError, ""):
MultiScaleLoss(loss=dice_loss, scales=[-1], reduction="none")(
torch.ones((1, 1, 3), device=device), torch.ones((1, 1, 3), device=device)
)
@parameterized.expand(
[
({"loss": dice_loss, "kernel": "none"}, None, None), # kernel_none
({"loss": dice_loss, "scales": [-1]}, torch.ones((1, 1, 3)), torch.ones((1, 1, 3))), # scales_negative
(
{"loss": dice_loss, "scales": [-1], "reduction": "none"},
torch.ones((1, 1, 3)),
torch.ones((1, 1, 3)),
), # scales_negative_reduction_none
]
)
def test_ill_opts(self, kwargs, input, target):
if input is None and target is None:
with self.assertRaisesRegex(ValueError, ""):
MultiScaleLoss(**kwargs)
else:
with self.assertRaisesRegex(ValueError, ""):
MultiScaleLoss(**kwargs)(input, target)

def test_script(self):
input_param, input_data, expected_val = TEST_CASES[0]
Expand Down
27 changes: 8 additions & 19 deletions tests/test_optional_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@

import unittest

from parameterized import parameterized

from monai.utils import OptionalImportError, exact_version, optional_import


class TestOptionalImport(unittest.TestCase):

def test_default(self):
my_module, flag = optional_import("not_a_module")
@parameterized.expand(["not_a_module", "torch.randint"])
def test_default(self, import_module):
my_module, flag = optional_import(import_module)
self.assertFalse(flag)
with self.assertRaises(OptionalImportError):
my_module.test

my_module, flag = optional_import("torch.randint")
with self.assertRaises(OptionalImportError):
self.assertFalse(flag)
print(my_module.test)

def test_import_valid(self):
my_module, flag = optional_import("torch")
self.assertTrue(flag)
Expand All @@ -47,18 +45,9 @@ def test_import_wrong_number(self):
self.assertTrue(flag)
print(my_module.randint(1, 2, (1, 2)))

def test_import_good_number(self):
my_module, flag = optional_import("torch", "0")
my_module.nn
self.assertTrue(flag)
print(my_module.randint(1, 2, (1, 2)))

my_module, flag = optional_import("torch", "0.0.0.1")
my_module.nn
self.assertTrue(flag)
print(my_module.randint(1, 2, (1, 2)))

my_module, flag = optional_import("torch", "1.1.0")
@parameterized.expand(["0", "0.0.0.1", "1.1.0"])
def test_import_good_number(self, version_number):
my_module, flag = optional_import("torch", version_number)
my_module.nn
self.assertTrue(flag)
print(my_module.randint(1, 2, (1, 2)))
Expand Down
8 changes: 3 additions & 5 deletions tests/test_perceptual_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,10 @@ def test_1d(self):
with self.assertRaises(NotImplementedError):
PerceptualLoss(spatial_dims=1)

def test_medicalnet_on_2d_data(self):
@parameterized.expand(["medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets"])
def test_medicalnet_on_2d_data(self, network_type):
with self.assertRaises(ValueError):
PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet10_23datasets")

with self.assertRaises(ValueError):
PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet50_23datasets")
PerceptualLoss(spatial_dims=2, network_type=network_type)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit ec6aa33

Please sign in to comment.