diff --git a/tests/unit/algo/visual_prompting/test_sam.py b/tests/unit/algo/visual_prompting/test_sam.py
index c88aa279d5d..94f4f1fc3d9 100644
--- a/tests/unit/algo/visual_prompting/test_sam.py
+++ b/tests/unit/algo/visual_prompting/test_sam.py
@@ -94,6 +94,36 @@ def __init__(self):
         for param in mock_model.mask_decoder.parameters():
             assert param.requires_grad != freeze_mask_decoder
 
+    def test_forward_for_tracing(self, mocker) -> None:
+        mixin = CommonSettingMixin()
+        mixin.model = mock.Mock()
+        mock_forward_for_tracing = mocker.patch.object(mixin.model, "forward_for_tracing")
+
+        image_embeddings = torch.zeros((1, 256, 64, 64))
+        point_coords = torch.zeros((1, 10, 2))
+        point_labels = torch.zeros((1, 10))
+        mask_input = torch.zeros((1, 1, 256, 256))
+        has_mask_input = torch.zeros((1, 1))
+        ori_shape = torch.zeros((1, 2))
+
+        mixin.forward_for_tracing(
+            image_embeddings=image_embeddings,
+            point_coords=point_coords,
+            point_labels=point_labels,
+            mask_input=mask_input,
+            has_mask_input=has_mask_input,
+            ori_shape=ori_shape,
+        )
+
+        mock_forward_for_tracing.assert_called_once_with(
+            image_embeddings=image_embeddings,
+            point_coords=point_coords,
+            point_labels=point_labels,
+            mask_input=mask_input,
+            has_mask_input=has_mask_input,
+            ori_shape=ori_shape,
+        )
+
 
 class TestSAM:
     @pytest.fixture()
@@ -128,34 +158,6 @@ def test_build_model(self, sam: SAM) -> None:
         assert isinstance(segment_anything.mask_decoder, SAMMaskDecoder)
         assert isinstance(segment_anything.criterion, SAMCriterion)
 
-    def test_forward_for_tracing(self, mocker, sam) -> None:
-        mock_forward_for_tracing = mocker.patch.object(sam.model, "forward_for_tracing")
-
-        image_embeddings = torch.zeros((1, 256, 64, 64))
-        point_coords = torch.zeros((1, 10, 2))
-        point_labels = torch.zeros((1, 10))
-        mask_input = torch.zeros((1, 1, 256, 256))
-        has_mask_input = torch.zeros((1, 1))
-        ori_shape = torch.zeros((1, 2))
-
-        sam.forward_for_tracing(
-            image_embeddings=image_embeddings,
-            point_coords=point_coords,
-            point_labels=point_labels,
-            mask_input=mask_input,
-            has_mask_input=has_mask_input,
-            ori_shape=ori_shape,
-        )
-
-        mock_forward_for_tracing.assert_called_once_with(
-            image_embeddings=image_embeddings,
-            point_coords=point_coords,
-            point_labels=point_labels,
-            mask_input=mask_input,
-            has_mask_input=has_mask_input,
-            ori_shape=ori_shape,
-        )
-
 
 class TestZeroShotSAM:
     @pytest.fixture()