From 7a9be759e3af89d465c23efa78070e500af117cb Mon Sep 17 00:00:00 2001 From: ProtossDragoon Date: Wed, 7 Dec 2022 14:00:58 +0900 Subject: [PATCH] [Style] add `font_families` argument to fn --- mmocr/visualization/base_visualizer.py | 8 ++++++-- mmocr/visualization/kie_visualizer.py | 19 ++++++++++++++----- .../visualization/textspotting_visualizer.py | 5 ++++- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/mmocr/visualization/base_visualizer.py b/mmocr/visualization/base_visualizer.py index 7c2cbb628b..b53972cd1b 100644 --- a/mmocr/visualization/base_visualizer.py +++ b/mmocr/visualization/base_visualizer.py @@ -63,7 +63,9 @@ def get_labels_image(self, bboxes: Union[np.ndarray, torch.Tensor], colors: Union[str, Sequence[str]] = 'k', font_size: Union[int, float] = 10, - auto_font_size: bool = False) -> np.ndarray: + auto_font_size: bool = False, + font_families: Union[str, List[str]] = 'sans-serif' + ) -> np.ndarray: """Draw labels on image. Args: @@ -80,6 +82,8 @@ def get_labels_image(self, to 10. auto_font_size (bool): Whether to automatically adjust font size. Defaults to False. + font_families (Union[str, List[str]]): The font families of labels. + Defaults to 'sans-serif'. """ if colors is not None and isinstance(colors, (list, tuple)): size = math.ceil(len(labels) / len(colors)) @@ -96,7 +100,7 @@ def get_labels_image(self, horizontal_alignments='center', colors='k', font_sizes=font_size, - font_families=self.font_families) + font_families=font_families) return self.get_image() def get_polygons_image(self, diff --git a/mmocr/visualization/kie_visualizer.py b/mmocr/visualization/kie_visualizer.py index 72cf0cd458..f97bd3cd3f 100644 --- a/mmocr/visualization/kie_visualizer.py +++ b/mmocr/visualization/kie_visualizer.py @@ -152,12 +152,16 @@ def _draw_instances( empty_shape = (img_shape[0], img_shape[1], 3) text_image = np.full(empty_shape, 255, dtype=np.uint8) - text_image = self.get_labels_image(text_image, texts, bboxes) + text_image = self.get_labels_image( + text_image, texts, bboxes, font_families=self.font_families) classes_image = np.full(empty_shape, 255, dtype=np.uint8) bbox_classes = [class_names[int(i)]['name'] for i in bbox_labels] - classes_image = self.get_labels_image(classes_image, bbox_classes, - bboxes) + classes_image = self.get_labels_image( + classes_image, + bbox_classes, + bboxes, + font_families=self.font_families) if polygons: polygons = [polygon.reshape(-1, 2) for polygon in polygons] image = self.get_polygons_image( @@ -176,8 +180,13 @@ def _draw_instances( cat_image = [image, text_image, classes_image] if is_openset: edge_image = np.full(empty_shape, 255, dtype=np.uint8) - edge_image = self._draw_edge_label(edge_image, edge_labels, bboxes, - texts, arrow_colors) + edge_image = self._draw_edge_label( + edge_image, + edge_labels, + bboxes, + texts, + arrow_colors, + font_families=self.font_families) cat_image.append(edge_image) return self._cat_image(cat_image, axis=1) diff --git a/mmocr/visualization/textspotting_visualizer.py b/mmocr/visualization/textspotting_visualizer.py index 19a5e4ad3d..6371b063f4 100644 --- a/mmocr/visualization/textspotting_visualizer.py +++ b/mmocr/visualization/textspotting_visualizer.py @@ -45,7 +45,10 @@ def _draw_instances( empty_shape = (img_shape[0], img_shape[1], 3) text_image = np.full(empty_shape, 255, dtype=np.uint8) text_image = self.get_labels_image( - text_image, labels=texts, bboxes=bboxes) + text_image, + labels=texts, + bboxes=bboxes, + font_families=self.font_families) if polygons: polygons = [polygon.reshape(-1, 2) for polygon in polygons] image = self.get_polygons_image(