Skip to content

Commit

Permalink
Add Visualizer.imshow()
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Mar 3, 2022
1 parent 9386c16 commit 88b830a
Showing 1 changed file with 27 additions and 11 deletions.
38 changes: 27 additions & 11 deletions yolort/utils/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch
from PIL import Image
from torch import Tensor
from yolort.v5.utils.plots import Colors

Expand Down Expand Up @@ -37,14 +38,7 @@ class Visualizer:
instances on an image. Default: None
"""

def __init__(
self,
image: Union[Tensor, np.ndarray],
*,
metalabels: Optional[str] = None,
scale: float = 1.0,
line_width: Optional[int] = None,
) -> None:
def __init__(self, image: Union[Tensor, np.ndarray], *, metalabels: Optional[str] = None) -> None:

if isinstance(image, torch.Tensor):
if image.dtype != torch.uint8:
Expand Down Expand Up @@ -73,9 +67,7 @@ def __init__(
if metalabels is not None:
self.metadata = np.loadtxt(metalabels, dtype="str", delimiter="\n")

self.scale = scale
self.cpu_device = torch.device("cpu")
self.line_width = line_width or max(round(sum(self.img.shape) / 2 * 0.003), 2)
self.line_width = max(round(sum(self.img.shape) / 2 * 0.003), 2)
self.assigned_colors = Colors()
self.output = self.img

Expand Down Expand Up @@ -103,6 +95,30 @@ def draw_instance_predictions(self, predictions: Dict[str, Tensor]):
self.overlay_instances(boxes=boxes, labels=labels, colors=colors)
return self.output

def imshow(self, scale: Optional[float] = None):
"""
A replacement of cv2.imshow() for using in Jupyter notebooks.
Args:
scale (float, optional): zoom ratio to show the image. Default: None
"""
from IPython.display import display

img = self.output

img = img.clip(0, 255).astype("uint8")
# cv2 stores colors as BGR; convert to RGB
if self.is_bgr and img.ndim == 3:
if img.shape[2] == 4:
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

if scale is not None:
img = cv2.resize(img, None, fx=scale, fy=scale)

display(Image.fromarray(img))

def overlay_instances(
self,
boxes: np.ndarray,
Expand Down

0 comments on commit 88b830a

Please sign in to comment.