Skip to content

Commit

Permalink
Replaced to_tensor() with pil_to_tensor() + convert_image_dtype() (#298)
Browse files Browse the repository at this point in the history
* Replaced to_tensor() with pil_to_tensor() + convert_image_dtype()

* Revert collate_fn and default_val_transforms

* Minor fixes for type annotations
  • Loading branch information
zhiqwang authored Feb 1, 2022
1 parent 61df774 commit c6131f4
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 12 deletions.
2 changes: 1 addition & 1 deletion yolort/data/coco_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def compute(self):
def derive_coco_results(self, class_names: Optional[List[str]] = None):
"""
Derive the desired score numbers from summarized COCOeval. Modified from
https://github.com/facebookresearch/detectron2/blob/7205996/detectron2/evaluation/coco_evaluation.py#L291
https://github.com/facebookresearch/detectron2/blob/main/detectron2/evaluation/coco_evaluation.py
Args:
coco_eval (None or COCOEval): None represents no predictions from model.
Expand Down
59 changes: 48 additions & 11 deletions yolort/data/transforms.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
"""
Transforms for Data Augmentation
Mostly copy-paste from https://github.com/pytorch/vision/blob/0013d93/references/detection/transforms.py
"""
from typing import List, Tuple, Dict, Optional
from typing import List, Dict, Tuple, Optional

import torch
import torchvision
from torch import nn, Tensor
from torchvision.ops import boxes as box_ops
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T

Expand All @@ -31,7 +27,8 @@ def default_train_transforms(hflip_prob=0.5):
RandomZoomOut(),
RandomIoUCrop(),
RandomHorizontalFlip(p=hflip_prob),
ToTensor(),
PILToTensor(),
ConvertImageDtype(torch.float),
]
)

Expand All @@ -40,7 +37,17 @@ def default_val_transforms():
return ToTensor()


class Compose(object):
def _flip_coco_person_keypoints(kps, width):
flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
flipped_data = kps[:, flip_inds]
flipped_data[..., 0] = width - flipped_data[..., 0]
# Maintain COCO convention that if visibility == 0, then x, y = 0
inds = flipped_data[..., 2] == 0
flipped_data[inds] = 0
return flipped_data


class Compose:
def __init__(self, transforms):
self.transforms = transforms

Expand All @@ -63,6 +70,10 @@ def forward(
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
if "keypoints" in target:
keypoints = target["keypoints"]
keypoints = _flip_coco_person_keypoints(keypoints, width)
target["keypoints"] = keypoints
return image, target


Expand All @@ -72,7 +83,32 @@ def forward(
image: Tensor,
target: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.to_tensor(image)
image = F.pil_to_tensor(image)
image = F.convert_image_dtype(image)
return image, target


class PILToTensor(nn.Module):
def forward(
self,
image: Tensor,
target: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.pil_to_tensor(image)
return image, target


class ConvertImageDtype(nn.Module):
def __init__(self, dtype: torch.dtype) -> None:
super().__init__()
self.dtype = dtype

def forward(
self,
image: Tensor,
target: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.convert_image_dtype(image, self.dtype)
return image, target


Expand Down Expand Up @@ -148,7 +184,7 @@ def forward(

# check at least 1 box with jaccard limitations
boxes = target["boxes"][is_within_crop_area]
ious = box_ops.box_iou(
ious = torchvision.ops.boxes.box_iou(
boxes,
torch.tensor(
[[left, top, right, bottom]],
Expand Down Expand Up @@ -291,7 +327,8 @@ def forward(

is_pil = F._is_pil_image(image)
if is_pil:
image = F.to_tensor(image)
image = F.pil_to_tensor(image)
image = F.convert_image_dtype(image)
image = image[..., permutation, :, :]
if is_pil:
image = F.to_pil_image(image)
Expand Down

0 comments on commit c6131f4

Please sign in to comment.