diff --git a/tests/unit/algo/visual_prompting/test_sam.py b/tests/unit/algo/visual_prompting/test_sam.py index c88aa279d5d..94f4f1fc3d9 100644 --- a/tests/unit/algo/visual_prompting/test_sam.py +++ b/tests/unit/algo/visual_prompting/test_sam.py @@ -94,6 +94,36 @@ def __init__(self): for param in mock_model.mask_decoder.parameters(): assert param.requires_grad != freeze_mask_decoder + def test_forward_for_tracing(self, mocker) -> None: + mixin = CommonSettingMixin() + mixin.model = mock.Mock() + mock_forward_for_tracing = mocker.patch.object(mixin.model, "forward_for_tracing") + + image_embeddings = torch.zeros((1, 256, 64, 64)) + point_coords = torch.zeros((1, 10, 2)) + point_labels = torch.zeros((1, 10)) + mask_input = torch.zeros((1, 1, 256, 256)) + has_mask_input = torch.zeros((1, 1)) + ori_shape = torch.zeros((1, 2)) + + mixin.forward_for_tracing( + image_embeddings=image_embeddings, + point_coords=point_coords, + point_labels=point_labels, + mask_input=mask_input, + has_mask_input=has_mask_input, + ori_shape=ori_shape, + ) + + mock_forward_for_tracing.assert_called_once_with( + image_embeddings=image_embeddings, + point_coords=point_coords, + point_labels=point_labels, + mask_input=mask_input, + has_mask_input=has_mask_input, + ori_shape=ori_shape, + ) + class TestSAM: @pytest.fixture() @@ -128,34 +158,6 @@ def test_build_model(self, sam: SAM) -> None: assert isinstance(segment_anything.mask_decoder, SAMMaskDecoder) assert isinstance(segment_anything.criterion, SAMCriterion) - def test_forward_for_tracing(self, mocker, sam) -> None: - mock_forward_for_tracing = mocker.patch.object(sam.model, "forward_for_tracing") - - image_embeddings = torch.zeros((1, 256, 64, 64)) - point_coords = torch.zeros((1, 10, 2)) - point_labels = torch.zeros((1, 10)) - mask_input = torch.zeros((1, 1, 256, 256)) - has_mask_input = torch.zeros((1, 1)) - ori_shape = torch.zeros((1, 2)) - - sam.forward_for_tracing( - image_embeddings=image_embeddings, - point_coords=point_coords, - point_labels=point_labels, - mask_input=mask_input, - has_mask_input=has_mask_input, - ori_shape=ori_shape, - ) - - mock_forward_for_tracing.assert_called_once_with( - image_embeddings=image_embeddings, - point_coords=point_coords, - point_labels=point_labels, - mask_input=mask_input, - has_mask_input=has_mask_input, - ori_shape=ori_shape, - ) - class TestZeroShotSAM: @pytest.fixture()