Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rescale to original scale after post-processor #47

Merged
merged 9 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion models/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Modified from ultralytics/yolov5 by Zhiqiang Wang
import math

import torch
from torch import nn, Tensor
from torch import Tensor
import torch.nn.functional as F
from torchvision.ops import box_convert

Expand Down
1 change: 1 addition & 0 deletions models/backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from typing import List, Optional
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
Expand Down
3 changes: 1 addition & 2 deletions models/darknet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import torch
from torch import nn, Tensor
from torch.hub import load_state_dict_from_url
from torch.nn.modules import conv
from torch.nn.modules.linear import Linear

from .common import Conv, SPP, Focus, BottleneckCSP
from .experimental import C3
Expand Down
2 changes: 1 addition & 1 deletion models/path_aggregation_network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import torch

from torch import nn, Tensor

from .common import Conv, BottleneckCSP
Expand Down
46 changes: 39 additions & 7 deletions models/pl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import argparse

import torch
from torch import nn, Tensor
from torchvision.models.utils import load_state_dict_from_url
from torch import Tensor

import pytorch_lightning as pl

from . import yolo
from .transform import nested_tensor_from_tensor_list
from .transform import GeneralizedYOLOTransform, nested_tensor_from_tensor_list

from typing import Any, List, Optional
from typing import Any, List, Dict, Tuple, Optional


class YOLOLitWrapper(pl.LightningModule):
Expand All @@ -24,6 +23,8 @@ def __init__(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 80,
min_size: int = 320,
max_size: int = 416,
**kwargs: Any,
):
"""
Expand All @@ -41,9 +42,40 @@ def __init__(
self.model = yolo.__dict__[arch](
pretrained=pretrained, progress=progress, num_classes=num_classes, **kwargs)

def forward(self, inputs: List[Tensor], targets: Optional[Tensor] = None):
sample = nested_tensor_from_tensor_list(inputs)
return self.model(sample.tensors, targets=targets)
self.transform = GeneralizedYOLOTransform(min_size, max_size)

def forward(
self,
inputs: List[Tensor],
targets: Optional[List[Dict[str, Tensor]]] = None,
) -> List[Dict[str, Tensor]]:
"""
Args:
images (list[Tensor]): images to be processed
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)

Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).

"""
# get the original image sizes
original_image_sizes: List[Tuple[int, int]] = []
for img in inputs:
val = img.shape[-2:]
assert len(val) == 2
original_image_sizes.append((val[0], val[1]))

# Transform the input
samples, targets = self.transform(inputs, targets)
# Compute the detections
detections = self.model(samples.tensors, targets=targets)
# Rescale coordinate
detections = self.transform.postprocess(detections, samples.image_sizes, original_image_sizes)

return detections

def training_step(self, batch, batch_idx):

Expand Down
222 changes: 216 additions & 6 deletions models/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import math
import torch
from torch import nn, Tensor
import torch.nn.functional as F

import torchvision

from typing import Optional, List
from typing import Dict, Optional, List, Tuple


class NestedTensor(object):
Expand All @@ -15,17 +17,153 @@ class NestedTensor(object):
This works by padding the images to the same size,
and storing in a field the original sizes of each image
"""
def __init__(self, tensors):
def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]):
"""
Args:
tensors (Tensor)
image_sizes (list[tuple[int, int]])
"""
self.tensors = tensors
self.image_sizes = image_sizes

def to(self, device) -> "NestedTensor":
cast_tensor = self.tensors.to(device)
return NestedTensor(cast_tensor)
return NestedTensor(cast_tensor, self.image_sizes)

def __repr__(self):
return str(self.tensors)


class GeneralizedYOLOTransform(nn.Module):
"""
Performs input / target transformation before feeding the data to a GeneralizedRCNN
model.

The transformations it perform are:
- input normalization (mean subtraction and std division)
- input / target resizing to match min_size / max_size

It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
"""
def __init__(self, min_size, max_size) -> None:
super().__init__()
if not isinstance(min_size, (list, tuple)):
min_size = (min_size,)
self.min_size = min_size
self.max_size = max_size

def forward(
self,
images: List[Tensor],
targets: Optional[List[Dict[str, Tensor]]],
) -> Tuple[NestedTensor, Optional[Tensor]]:

images = [img for img in images]
if targets is not None:
# make a copy of targets to avoid modifying it in-place
# once torchscript supports dict comprehension
# this can be simplified as as follows
# targets = [{k: v for k,v in t.items()} for t in targets]
targets_copy: List[Dict[str, Tensor]] = []
for t in targets:
data: Dict[str, Tensor] = {}
for k, v in t.items():
data[k] = v
targets_copy.append(data)
targets = targets_copy

for i in range(len(images)):
image = images[i]
target_index = targets[i] if targets is not None else None

if image.dim() != 3:
raise ValueError("images is expected to be a list of 3d tensors "
"of shape [C, H, W], got {}".format(image.shape))

image, target_index = self.resize(image, target_index)
images[i] = image
if targets is not None and target_index is not None:
targets[i] = target_index

image_sizes = [img.shape[-2:] for img in images]
images = nested_tensor_from_tensor_list(images)
image_sizes_list: List[Tuple[int, int]] = []
for image_size in image_sizes:
assert len(image_size) == 2
image_sizes_list.append((image_size[0], image_size[1]))

image_list = NestedTensor(images, image_sizes_list)

if targets is not None:
targets_batched = []
for i, target in enumerate(targets):
num_objects = len(target['labels'])
if num_objects > 0:
targets_merged = torch.full((num_objects, 6), i, dtype=torch.float32)
targets_merged[:, 1] = target['labels']
targets_merged[:, 2:] = target['boxes']
targets_batched.append(targets_merged)
targets_batched = torch.cat(targets_batched, dim=0)
else:
targets_batched = None

return image_list, targets_batched

def torch_choice(self, k: List[int]) -> int:
"""
Implements `random.choice` via torch ops so it can be compiled with
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
is fixed.
"""
index = int(torch.empty(1).uniform_(0., float(len(k))).item())
return k[index]

def resize(
self,
image: Tensor,
target: Optional[Dict[str, Tensor]],
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:

h, w = image.shape[-2:]
if self.training:
size = float(self.torch_choice(self.min_size))
else:
# FIXME assume for now that testing uses the largest scale
size = float(self.min_size[-1])
if torchvision._is_tracing():
image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target)
else:
image, target = _resize_image_and_masks(image, size, float(self.max_size), target)

if target is None:
return image, target

bbox = target["boxes"]
bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
target["boxes"] = bbox

return image, target

def postprocess(
self,
result: Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]],
image_shapes: List[Tuple[int, int]],
original_image_sizes: List[Tuple[int, int]],
) -> List[Dict[str, Tensor]]:

if torch.jit.is_scripting():
predictions = result[1]
else:
predictions = result

for i, (pred, im_s, o_im_s) in enumerate(zip(predictions, image_shapes, original_image_sizes)):
boxes = pred["boxes"]
boxes = resize_boxes(boxes, im_s, o_im_s)
predictions[i]["boxes"] = boxes

return predictions


def nested_tensor_from_tensor_list(tensor_list: List[Tensor], size_divisible: int = 32):
# TODO make this more general
if tensor_list[0].ndim == 3:
Expand All @@ -46,7 +184,7 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor], size_divisible: in
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
else:
raise ValueError('not supported')
return NestedTensor(tensor_batched)
return tensor_batched


def _max_by_axis(the_list: List[List[int]]) -> List[int]:
Expand All @@ -60,7 +198,7 @@ def _max_by_axis(the_list: List[List[int]]) -> List[int]:
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor], size_divisible: int = 32) -> NestedTensor:
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor], size_divisible: int = 32) -> Tensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
Expand All @@ -83,4 +221,76 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor], size_divisib

tensor = torch.stack(padded_imgs)

return NestedTensor(tensor)
return tensor


@torch.jit.unused
def _resize_image_and_masks_onnx(
image: Tensor,
self_min_size: float,
self_max_size: float,
target: Optional[Dict[str, Tensor]],
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:

from torch.onnx import operators

im_shape = operators.shape_as_tensor(image)[-2:]
min_size = torch.min(im_shape).to(dtype=torch.float32)
max_size = torch.max(im_shape).to(dtype=torch.float32)
scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size)

image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,
align_corners=False)[0]

if target is None:
return image, target

if "masks" in target:
mask = target["masks"]
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, recompute_scale_factor=True)[:, 0].byte()
target["masks"] = mask
return image, target


def _resize_image_and_masks(
image: Tensor,
self_min_size: float,
self_max_size: float,
target: Optional[Dict[str, Tensor]],
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:

im_shape = torch.tensor(image.shape[-2:])
min_size = float(torch.min(im_shape))
max_size = float(torch.max(im_shape))
scale_factor = self_min_size / min_size
if max_size * scale_factor > self_max_size:
scale_factor = self_max_size / max_size
image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,
align_corners=False)[0]

if target is None:
return image, target

if "masks" in target:
mask = target["masks"]
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, recompute_scale_factor=True)[:, 0].byte()
target["masks"] = mask
return image, target


def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
ratios = [
torch.tensor(s, dtype=torch.float32, device=boxes.device) /
torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
for s, s_orig in zip(new_size, original_size)
]
ratio_height, ratio_width = ratios
xmin, ymin, xmax, ymax = boxes.unbind(1)

xmin = xmin * ratio_width
xmax = xmax * ratio_width
ymin = ymin * ratio_height
ymax = ymax * ratio_height
return torch.stack((xmin, ymin, xmax, ymax), dim=1)
1 change: 1 addition & 0 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def forward(
samples (NestedTensor): Expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)

Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
Expand Down
Loading