diff --git a/test/test_ops.py b/test/test_ops.py index 9858450f76c..a6f161051fc 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -357,6 +357,20 @@ def _test_boxes_shape(self): self._helper_boxes_shape(ops.ps_roi_align) +class MultiScaleRoIAlignTester(unittest.TestCase): + def test_msroialign_repr(self): + fmap_names = ['0'] + output_size = (7, 7) + sampling_ratio = 2 + # Pass mock feature map names + t = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio) + + # Check integrity of object __repr__ attribute + expected_string = (f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, " + f"sampling_ratio={sampling_ratio})") + self.assertEqual(t.__repr__(), expected_string) + + class NMSTester(unittest.TestCase): def reference_nms(self, boxes, scores, iou_threshold): """ diff --git a/torchvision/ops/poolers.py b/torchvision/ops/poolers.py index 463ce7e5ddc..02dbf3904bb 100644 --- a/torchvision/ops/poolers.py +++ b/torchvision/ops/poolers.py @@ -258,3 +258,7 @@ def forward( result = _onnx_merge_levels(levels, tracing_results) return result + + def __repr__(self) -> str: + return (f"{self.__class__.__name__}(featmap_names={self.featmap_names}, " + f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})")