Skip to content

Commit

Permalink
Fix bounding boxes rescales bug
Browse files Browse the repository at this point in the history
  • Loading branch information
laugh12321 committed Mar 2, 2024
1 parent 653aaa3 commit 0f90cd0
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 33 deletions.
4 changes: 2 additions & 2 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def parse_opt() -> argparse.Namespace:
total_images = 0
print(f"Infering data in {opt.input}")
batcher = ImageBatcher(opt.input, *model.input_spec())
for batch, images, batch_ratio_pad in batcher:
for batch, images, batch_shape in batcher:
start_time_ns = time.perf_counter_ns()
detections = model.infer(batch, batch_ratio_pad)
detections = model.infer(batch, batch_shape)
end_time_ns = time.perf_counter_ns()
elapsed_time_ms = (end_time_ns - start_time_ns) / 1e6
total_time += elapsed_time_ms
Expand Down
24 changes: 12 additions & 12 deletions python/infer/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# ==============================================================================
import random
from pathlib import Path
from typing import Tuple, List, Iterator
from typing import Union, Tuple, List, Iterator

import cv2
import numpy as np
Expand Down Expand Up @@ -80,24 +80,24 @@ def __init__(
self.num_batches = 1 + int((self.num_images - 1) / self.batch_size)
self.batches = [self.images[i * self.batch_size: (i + 1) * self.batch_size] for i in range(self.num_batches)]

def __iter__(self) -> Iterator[Tuple[np.ndarray, List, List, List]]:
def __iter__(self) -> Iterator[Tuple[np.ndarray, List[Union[str, Path]], List[Tuple[int, int]]]]:
"""
Iterator function to yield batches of preprocessed images.
Yields:
Tuple[np.ndarray, List[str], List[Tuple[float, float]]]: Batch data, image paths, and ratio/padding information.
Tuple[np.ndarray, List[Union[str, Path]], List[Tuple[int, int]]]: Batch data, image paths, and image shape information.
"""
for batch_images in self.batches:
batch_ratio_pad = []
batch_shape = []
batch_data = np.zeros(self.shape, dtype=self.dtype)
with ThreadPoolExecutor(max_workers=len(batch_images)) as executor:
results = list(executor.map(self._preprocess_image, batch_images))

for idx, (im, *ratio_pad) in enumerate(results):
for idx, (im, shape) in enumerate(results):
batch_data[idx] = im
batch_ratio_pad.append(ratio_pad)
batch_shape.append(shape)

yield np.ascontiguousarray(batch_data), batch_images, batch_ratio_pad
yield np.ascontiguousarray(batch_data), batch_images, batch_shape

def _find_images(self, input_path: Path, shuffle_files: bool) -> None:
"""
Expand Down Expand Up @@ -160,26 +160,26 @@ def _handle_tensor_shape(self) -> Tuple[np.int32, np.int32]:

return width, height

def _preprocess_image(self, image_path) -> Tuple[np.ndarray, float, Tuple[float, float]]:
def _preprocess_image(self, image_path: Union[str, Path]) -> Tuple[np.ndarray, Tuple[int, int]]:
"""
Preprocesses an image by reading, resizing, and normalizing.
Args:
image_path (str): The path to the input image file.
Returns:
Tuple[np.ndarray, float, Tuple[float, float]]: Preprocessed image, scale ratio, padding.
Tuple[np.ndarray, Tuple[int, int]]: Preprocessed image, image shape.
"""
# Read the image
image = cv2.imread(str(image_path))

# Resize and pad the image
image, ratio, pad = letterbox(image, (self.height, self.width))
image, shape = letterbox(image, (self.height, self.width))

# Convert color format and normalize pixel values
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(self.dtype) / 255.0

# Transpose the image to CHW format
image = np.transpose(image, (2, 0, 1))

return image, ratio, pad
return image, shape
14 changes: 7 additions & 7 deletions python/infer/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ def _infer(self) -> None:

self.stream.synchronize()

def _filter(self, outputs, ratio_pad, idx: int = 0) -> DetectInfo:
def _filter(self, outputs, output_shape: Tuple[int, int], idx: int = 0) -> DetectInfo:
"""
Filter and process the inference results.
Args:
outputs (dict): Dictionary containing output tensors.
ratio_pad (_type_): _description_
output_shape (Tuple[int, int]): output image shape.
idx (int, optional): Index of the batch. Defaults to 0.
Returns:
Expand All @@ -140,7 +140,7 @@ def _filter(self, outputs, ratio_pad, idx: int = 0) -> DetectInfo:
detection_scores = outputs['detection_scores'][idx, :num_detections]
detection_classes = outputs['detection_classes'][idx, :num_detections]

detection_boxes = scale_boxes(detection_boxes, (self.height, self.width), ratio_pad)
detection_boxes = scale_boxes(detection_boxes, (self.height, self.width), output_shape)

return DetectInfo(
num=num_detections,
Expand Down Expand Up @@ -172,13 +172,13 @@ def input_spec(self) -> Tuple[Tuple[int, int, int, int], np.dtype]:
"""
return self.inputs[0].shape, self.inputs[0].dtype

def infer(self, batch, batch_ratio_pad) -> List[DetectInfo]:
def infer(self, batch: np.ndarray, batch_shape: List[Tuple[int, int]]) -> List[DetectInfo]:
"""
Run inference on the TensorRT engine.
Args:
batch (_type_): Input batch for inference.
batch_ratio_pad (_type_): Ratio_pad for each batch item.
batch (np.ndarray): Input batch for inference.
batch_shape (List[Tuple[int, int]]): image shape for each batch item.
Returns:
List[DetectInfo]: List of processed detection information for each batch item.
Expand All @@ -190,4 +190,4 @@ def infer(self, batch, batch_ratio_pad) -> List[DetectInfo]:

# Process the results
outputs = {tensor.name: tensor.host.reshape(tensor.shape) for tensor in self.outputs}
return [self._filter(outputs, ratio_pad, idx) for idx, ratio_pad in enumerate(batch_ratio_pad)]
return [self._filter(outputs, ratio_pad, idx) for idx, ratio_pad in enumerate(batch_shape)]
24 changes: 12 additions & 12 deletions python/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def generate_random_rgb() -> Tuple[int, int, int]:
return [(label.strip(), generate_random_rgb()) for label in f]


def letterbox(image: np.ndarray, new_shape: Union[Tuple[int, int], int], color: Tuple[int, int, int] = (114, 114, 114)) -> Tuple[np.ndarray, float, Tuple[float, float]]:
def letterbox(image: np.ndarray, new_shape: Union[Tuple[int, int], int], color: Tuple[int, int, int] = (114, 114, 114)) -> Tuple[np.ndarray, Tuple[int, int]]:
"""
Resizes and pads the input image to the specified new shape.
Expand All @@ -66,7 +66,7 @@ def letterbox(image: np.ndarray, new_shape: Union[Tuple[int, int], int], color:
color (Tuple[int, int, int], optional): The color used for padding. Defaults to (114, 114, 114).
Returns:
Tuple[np.ndarray, float, Tuple[float, float]]: Resized image, scale ratio, padding.
Tuple[np.ndarray, Tuple[int, int]]: Resized image, image origal shape.
"""
shape = image.shape[:2] # Current shape [height, width]

Expand All @@ -91,34 +91,34 @@ def letterbox(image: np.ndarray, new_shape: Union[Tuple[int, int], int], color:
# Add border to the image for padding
image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)

return image, r, (dw, dh)
return image, shape


def scale_boxes(boxes: np.ndarray, shape: Tuple[int, int], ratio_pad: Tuple[float, Tuple[float, float]]) -> np.ndarray:
def scale_boxes(boxes: np.ndarray, input_shape: Tuple[int, int], output_shape: Tuple[int, int]) -> np.ndarray:
"""
Rescales (xyxy) bounding boxes to the target shape using the provided `ratio_pad`.
Rescales (xyxy) bounding boxes from input_shape to output_shape.
Args:
boxes (np.ndarray): Input bounding boxes in (xyxy) format.
shape (Tuple[int, int]): Target shape (height, width).
ratio_pad (Tuple[float, Tuple[float, float]]): Tuple containing the scaling ratio
and padding values used during preprocessing.
input_shape (Tuple[int, int]): Source shape (height, width).
output_shape (Tuple[int, int]): Target shape (height, width).
Returns:
np.ndarray: Rescaled bounding boxes.
"""
ratio, pad = ratio_pad
gain = min(input_shape[0] / output_shape[0], input_shape[1] / output_shape[1]) # gain = old / new
pad = (input_shape[1] - output_shape[1] * gain) / 2, (input_shape[0] - output_shape[0] * gain) / 2 # wh padding

# Adjust for padding
boxes[..., [0, 2]] -= pad[0] # x padding
boxes[..., [1, 3]] -= pad[1] # y padding

# Rescale using the ratio
boxes /= ratio
boxes[..., :4] /= gain

# Clip coordinates to be within the target shape
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, output_shape[1]) # x1, x2
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, output_shape[0]) # y1, y2

return boxes

Expand Down

0 comments on commit 0f90cd0

Please sign in to comment.