From 273e8b6cdd12a39eb46ea31ba7986089dc344b26 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Fri, 14 May 2021 22:59:13 +0800 Subject: [PATCH] Unify onnx and JIT resize implementations (#105) * Unify onnx and JIT resize implementations * Minor refactoring * Fix numerical error of division between PT and ONNX * Fixing onnx inference with different inputs * Remove unused codes --- yolort/models/transform.py | 67 ++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 39 deletions(-) diff --git a/yolort/models/transform.py b/yolort/models/transform.py index b6c5681f..2d3b9878 100644 --- a/yolort/models/transform.py +++ b/yolort/models/transform.py @@ -130,10 +130,8 @@ def resize( else: # FIXME assume for now that testing uses the largest scale size = float(self.min_size[-1]) - if torchvision._is_tracing(): - image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target) - else: - image, target = _resize_image_and_masks(image, size, float(self.max_size), target) + + image, target = _resize_image_and_masks(image, size, float(self.max_size), target) if target is None: return image, target @@ -211,7 +209,7 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor], size_divisib for img in tensor_list: padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] - padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_img = F.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) padded_imgs.append(padded_img) tensor = torch.stack(padded_imgs) @@ -220,57 +218,48 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor], size_divisib @torch.jit.unused -def _resize_image_and_masks_onnx( - image: Tensor, - self_min_size: float, - self_max_size: float, - target: Optional[Dict[str, Tensor]], -) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: - +def _get_shape_onnx(image: Tensor) -> Tensor: from torch.onnx import operators + return operators.shape_as_tensor(image)[-2:] - im_shape = operators.shape_as_tensor(image)[-2:] - min_size = torch.min(im_shape).to(dtype=torch.float32) - max_size = torch.max(im_shape).to(dtype=torch.float32) - scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size) - - image = torch.nn.functional.interpolate( - image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True, - align_corners=False)[0] - - if target is None: - return image, target - if "masks" in target: - mask = target["masks"] - mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, recompute_scale_factor=True)[:, 0].byte() - target["masks"] = mask - return image, target +@torch.jit.unused +def _fake_cast_onnx(v: Tensor) -> float: + # ONNX requires a tensor but here we fake its type for JIT. + return v def _resize_image_and_masks( image: Tensor, self_min_size: float, self_max_size: float, - target: Optional[Dict[str, Tensor]], + target: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: - im_shape = torch.tensor(image.shape[-2:]) - min_size = float(torch.min(im_shape)) - max_size = float(torch.max(im_shape)) - scale_factor = self_min_size / min_size - if max_size * scale_factor > self_max_size: - scale_factor = self_max_size / max_size - image = torch.nn.functional.interpolate( - image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True, - align_corners=False)[0] + if torchvision._is_tracing(): + im_shape = _get_shape_onnx(image).to(dtype=torch.float32) + else: + im_shape = torch.tensor(image.shape[-2:], dtype=torch.float64) + + min_size = self_min_size / torch.min(im_shape) + max_size = self_max_size / torch.max(im_shape) + scale = torch.min(min_size, max_size) + + if torchvision._is_tracing(): + scale_factor = _fake_cast_onnx(scale) + else: + scale_factor = scale.item() + + image = F.interpolate(image[None], size=None, scale_factor=scale_factor, mode='bilinear', + recompute_scale_factor=True, align_corners=False)[0] if target is None: return image, target if "masks" in target: mask = target["masks"] - mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, recompute_scale_factor=True)[:, 0].byte() + mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, + recompute_scale_factor=True)[:, 0].byte() target["masks"] = mask return image, target