Skip to content

Commit

Permalink
Merge branch 'master' into add-letterbox-loader
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed May 15, 2021
2 parents 6e85656 + 273e8b6 commit 2e9c143
Showing 1 changed file with 28 additions and 39 deletions.
67 changes: 28 additions & 39 deletions yolort/models/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,8 @@ def resize(
else:
# FIXME assume for now that testing uses the largest scale
size = float(self.min_size[-1])
if torchvision._is_tracing():
image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target)
else:
image, target = _resize_image_and_masks(image, size, float(self.max_size), target)

image, target = _resize_image_and_masks(image, size, float(self.max_size), target)

if target is None:
return image, target
Expand Down Expand Up @@ -216,7 +214,7 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor], size_divisib

for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_img = F.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)

tensor = torch.stack(padded_imgs)
Expand All @@ -225,57 +223,48 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor], size_divisib


@torch.jit.unused
def _resize_image_and_masks_onnx(
image: Tensor,
self_min_size: float,
self_max_size: float,
target: Optional[Dict[str, Tensor]],
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:

def _get_shape_onnx(image: Tensor) -> Tensor:
from torch.onnx import operators
return operators.shape_as_tensor(image)[-2:]

im_shape = operators.shape_as_tensor(image)[-2:]
min_size = torch.min(im_shape).to(dtype=torch.float32)
max_size = torch.max(im_shape).to(dtype=torch.float32)
scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size)

image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,
align_corners=False)[0]

if target is None:
return image, target

if "masks" in target:
mask = target["masks"]
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, recompute_scale_factor=True)[:, 0].byte()
target["masks"] = mask
return image, target
@torch.jit.unused
def _fake_cast_onnx(v: Tensor) -> float:
# ONNX requires a tensor but here we fake its type for JIT.
return v


def _resize_image_and_masks(
image: Tensor,
self_min_size: float,
self_max_size: float,
target: Optional[Dict[str, Tensor]],
target: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:

im_shape = torch.tensor(image.shape[-2:])
min_size = float(torch.min(im_shape))
max_size = float(torch.max(im_shape))
scale_factor = self_min_size / min_size
if max_size * scale_factor > self_max_size:
scale_factor = self_max_size / max_size
image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,
align_corners=False)[0]
if torchvision._is_tracing():
im_shape = _get_shape_onnx(image).to(dtype=torch.float32)
else:
im_shape = torch.tensor(image.shape[-2:], dtype=torch.float64)

min_size = self_min_size / torch.min(im_shape)
max_size = self_max_size / torch.max(im_shape)
scale = torch.min(min_size, max_size)

if torchvision._is_tracing():
scale_factor = _fake_cast_onnx(scale)
else:
scale_factor = scale.item()

image = F.interpolate(image[None], size=None, scale_factor=scale_factor, mode='bilinear',
recompute_scale_factor=True, align_corners=False)[0]

if target is None:
return image, target

if "masks" in target:
mask = target["masks"]
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, recompute_scale_factor=True)[:, 0].byte()
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor,
recompute_scale_factor=True)[:, 0].byte()
target["masks"] = mask
return image, target

Expand Down

0 comments on commit 2e9c143

Please sign in to comment.