Skip to content

Commit

Permalink
Fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
sungchul2 committed Aug 23, 2024
1 parent 8a73422 commit 0f71e01
Showing 1 changed file with 30 additions and 28 deletions.
58 changes: 30 additions & 28 deletions tests/unit/algo/visual_prompting/test_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,36 @@ def __init__(self):
for param in mock_model.mask_decoder.parameters():
assert param.requires_grad != freeze_mask_decoder

def test_forward_for_tracing(self, mocker) -> None:
mixin = CommonSettingMixin()
mixin.model = mock.Mock()
mock_forward_for_tracing = mocker.patch.object(mixin.model, "forward_for_tracing")

image_embeddings = torch.zeros((1, 256, 64, 64))
point_coords = torch.zeros((1, 10, 2))
point_labels = torch.zeros((1, 10))
mask_input = torch.zeros((1, 1, 256, 256))
has_mask_input = torch.zeros((1, 1))
ori_shape = torch.zeros((1, 2))

mixin.forward_for_tracing(
image_embeddings=image_embeddings,
point_coords=point_coords,
point_labels=point_labels,
mask_input=mask_input,
has_mask_input=has_mask_input,
ori_shape=ori_shape,
)

mock_forward_for_tracing.assert_called_once_with(
image_embeddings=image_embeddings,
point_coords=point_coords,
point_labels=point_labels,
mask_input=mask_input,
has_mask_input=has_mask_input,
ori_shape=ori_shape,
)


class TestSAM:
@pytest.fixture()
Expand Down Expand Up @@ -128,34 +158,6 @@ def test_build_model(self, sam: SAM) -> None:
assert isinstance(segment_anything.mask_decoder, SAMMaskDecoder)
assert isinstance(segment_anything.criterion, SAMCriterion)

def test_forward_for_tracing(self, mocker, sam) -> None:
mock_forward_for_tracing = mocker.patch.object(sam.model, "forward_for_tracing")

image_embeddings = torch.zeros((1, 256, 64, 64))
point_coords = torch.zeros((1, 10, 2))
point_labels = torch.zeros((1, 10))
mask_input = torch.zeros((1, 1, 256, 256))
has_mask_input = torch.zeros((1, 1))
ori_shape = torch.zeros((1, 2))

sam.forward_for_tracing(
image_embeddings=image_embeddings,
point_coords=point_coords,
point_labels=point_labels,
mask_input=mask_input,
has_mask_input=has_mask_input,
ori_shape=ori_shape,
)

mock_forward_for_tracing.assert_called_once_with(
image_embeddings=image_embeddings,
point_coords=point_coords,
point_labels=point_labels,
mask_input=mask_input,
has_mask_input=has_mask_input,
ori_shape=ori_shape,
)


class TestZeroShotSAM:
@pytest.fixture()
Expand Down

0 comments on commit 0f71e01

Please sign in to comment.