diff --git a/mmocr/visualization/base_visualizer.py b/mmocr/visualization/base_visualizer.py index 7c2cbb628..b53972cd1 100644 --- a/mmocr/visualization/base_visualizer.py +++ b/mmocr/visualization/base_visualizer.py @@ -63,7 +63,9 @@ def get_labels_image(self, bboxes: Union[np.ndarray, torch.Tensor], colors: Union[str, Sequence[str]] = 'k', font_size: Union[int, float] = 10, - auto_font_size: bool = False) -> np.ndarray: + auto_font_size: bool = False, + font_families: Union[str, List[str]] = 'sans-serif' + ) -> np.ndarray: """Draw labels on image. Args: @@ -80,6 +82,8 @@ def get_labels_image(self, to 10. auto_font_size (bool): Whether to automatically adjust font size. Defaults to False. + font_families (Union[str, List[str]]): The font families of labels. + Defaults to 'sans-serif'. """ if colors is not None and isinstance(colors, (list, tuple)): size = math.ceil(len(labels) / len(colors)) @@ -96,7 +100,7 @@ def get_labels_image(self, horizontal_alignments='center', colors='k', font_sizes=font_size, - font_families=self.font_families) + font_families=font_families) return self.get_image() def get_polygons_image(self, diff --git a/mmocr/visualization/kie_visualizer.py b/mmocr/visualization/kie_visualizer.py index 72cf0cd45..fb69e3440 100644 --- a/mmocr/visualization/kie_visualizer.py +++ b/mmocr/visualization/kie_visualizer.py @@ -152,12 +152,16 @@ def _draw_instances( empty_shape = (img_shape[0], img_shape[1], 3) text_image = np.full(empty_shape, 255, dtype=np.uint8) - text_image = self.get_labels_image(text_image, texts, bboxes) + text_image = self.get_labels_image( + text_image, texts, bboxes, font_families=self.font_families) classes_image = np.full(empty_shape, 255, dtype=np.uint8) bbox_classes = [class_names[int(i)]['name'] for i in bbox_labels] - classes_image = self.get_labels_image(classes_image, bbox_classes, - bboxes) + classes_image = self.get_labels_image( + classes_image, + bbox_classes, + bboxes, + font_families=self.font_families) if polygons: polygons = [polygon.reshape(-1, 2) for polygon in polygons] image = self.get_polygons_image( diff --git a/mmocr/visualization/textspotting_visualizer.py b/mmocr/visualization/textspotting_visualizer.py index 19a5e4ad3..6371b063f 100644 --- a/mmocr/visualization/textspotting_visualizer.py +++ b/mmocr/visualization/textspotting_visualizer.py @@ -45,7 +45,10 @@ def _draw_instances( empty_shape = (img_shape[0], img_shape[1], 3) text_image = np.full(empty_shape, 255, dtype=np.uint8) text_image = self.get_labels_image( - text_image, labels=texts, bboxes=bboxes) + text_image, + labels=texts, + bboxes=bboxes, + font_families=self.font_families) if polygons: polygons = [polygon.reshape(-1, 2) for polygon in polygons] image = self.get_polygons_image(