From 74376d61cb48e18fabc7572707b27b92548f83c7 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Wed, 16 Dec 2020 00:54:43 +0800 Subject: [PATCH] Replacing all torch.jit.annotations with typing (#22) --- models/_utils.py | 3 ++- models/anchor_utils.py | 2 +- models/backbone.py | 6 +++--- models/box_head.py | 7 ++++--- models/yolo.py | 2 +- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/models/_utils.py b/models/_utils.py index 0a0e4afc..1bc373c9 100644 --- a/models/_utils.py +++ b/models/_utils.py @@ -2,9 +2,10 @@ import torch from torch import nn, Tensor -from torch.jit.annotations import Tuple, List from torchvision.ops import box_convert +from typing import Tuple, List + class BalancedPositiveNegativeSampler(object): """ diff --git a/models/anchor_utils.py b/models/anchor_utils.py index ebe15107..dcb6aa21 100644 --- a/models/anchor_utils.py +++ b/models/anchor_utils.py @@ -1,7 +1,7 @@ # Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved. import torch from torch import nn, Tensor -from torch.jit.annotations import Tuple, List +from typing import Tuple, List class AnchorGenerator(nn.Module): diff --git a/models/backbone.py b/models/backbone.py index e4a28b07..a74cb3e5 100644 --- a/models/backbone.py +++ b/models/backbone.py @@ -6,7 +6,7 @@ import torch from torch import nn, Tensor -from torch.jit.annotations import List, Dict, Optional +from typing import List, Dict, Optional from .common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat from .experimental import MixConv2d, CrossConv, C3 @@ -54,7 +54,7 @@ def __init__(self, layers, save_list): def forward(self, x: Tensor) -> Tensor: out = x - y = torch.jit.annotate(List[Tensor], []) + y: List[Tensor] = [] for i, m in enumerate(self.model): if m.f > 0: # Concat layer @@ -172,7 +172,7 @@ def __init__(self, model, return_layers, save_list): def forward(self, x): out = OrderedDict() - y = torch.jit.annotate(List[Tensor], []) + y: List[Tensor] = [] for i, (name, module) in enumerate(self.items()): if module.f > 0: # Concat layer diff --git a/models/box_head.py b/models/box_head.py index 8fb1c96e..8fa77e08 100644 --- a/models/box_head.py +++ b/models/box_head.py @@ -2,13 +2,14 @@ import torch from torch import nn, Tensor -from torch.jit.annotations import Tuple, List, Dict, Optional from torchvision.ops import batched_nms from . import _utils as det_utils from ._utils import FocalLoss from utils.box_ops import bbox_iou +from typing import Tuple, List, Dict, Optional + class YoloHead(nn.Module): def __init__(self, in_channels: List[int], num_anchors: int, num_classes: int): @@ -38,7 +39,7 @@ def get_result_from_head(self, features: Tensor, idx: int) -> Tensor: return out def forward(self, x: List[Tensor]) -> Tensor: - all_pred_logits = torch.jit.annotate(List[Tensor], []) # inference output + all_pred_logits: List[Tensor] = [] # inference output for i, features in enumerate(x): pred_logits = self.get_result_from_head(features, i) @@ -327,7 +328,7 @@ def forward( For visualization, this should be the image size after data augment, but before padding """ num_images = len(image_shapes) - detections = torch.jit.annotate(List[Dict[str, Tensor]], []) + detections: List[Dict[str, Tensor]] = [] for index in range(num_images): # image index, image inference pred_logits = torch.sigmoid(head_outputs[index]) diff --git a/models/yolo.py b/models/yolo.py index e728d45e..b40b9e5d 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -122,7 +122,7 @@ def forward( like `scores` and `labels`. """ # get the original image sizes - original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) + original_image_sizes: List[Tuple[int, int]] = [] for img in images: val = img.shape[-2:] assert len(val) == 2