Skip to content

Commit

Permalink
Replacing all torch.jit.annotations with typing (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang authored Dec 15, 2020
1 parent 20291c9 commit 74376d6
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 9 deletions.
3 changes: 2 additions & 1 deletion models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import torch
from torch import nn, Tensor
from torch.jit.annotations import Tuple, List
from torchvision.ops import box_convert

from typing import Tuple, List


class BalancedPositiveNegativeSampler(object):
"""
Expand Down
2 changes: 1 addition & 1 deletion models/anchor_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved.
import torch
from torch import nn, Tensor
from torch.jit.annotations import Tuple, List
from typing import Tuple, List


class AnchorGenerator(nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from torch import nn, Tensor
from torch.jit.annotations import List, Dict, Optional
from typing import List, Dict, Optional

from .common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat
from .experimental import MixConv2d, CrossConv, C3
Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(self, layers, save_list):

def forward(self, x: Tensor) -> Tensor:
out = x
y = torch.jit.annotate(List[Tensor], [])
y: List[Tensor] = []

for i, m in enumerate(self.model):
if m.f > 0: # Concat layer
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(self, model, return_layers, save_list):

def forward(self, x):
out = OrderedDict()
y = torch.jit.annotate(List[Tensor], [])
y: List[Tensor] = []

for i, (name, module) in enumerate(self.items()):
if module.f > 0: # Concat layer
Expand Down
7 changes: 4 additions & 3 deletions models/box_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import torch
from torch import nn, Tensor

from torch.jit.annotations import Tuple, List, Dict, Optional
from torchvision.ops import batched_nms

from . import _utils as det_utils
from ._utils import FocalLoss
from utils.box_ops import bbox_iou

from typing import Tuple, List, Dict, Optional


class YoloHead(nn.Module):
def __init__(self, in_channels: List[int], num_anchors: int, num_classes: int):
Expand Down Expand Up @@ -38,7 +39,7 @@ def get_result_from_head(self, features: Tensor, idx: int) -> Tensor:
return out

def forward(self, x: List[Tensor]) -> Tensor:
all_pred_logits = torch.jit.annotate(List[Tensor], []) # inference output
all_pred_logits: List[Tensor] = [] # inference output

for i, features in enumerate(x):
pred_logits = self.get_result_from_head(features, i)
Expand Down Expand Up @@ -327,7 +328,7 @@ def forward(
For visualization, this should be the image size after data augment, but before padding
"""
num_images = len(image_shapes)
detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
detections: List[Dict[str, Tensor]] = []

for index in range(num_images): # image index, image inference
pred_logits = torch.sigmoid(head_outputs[index])
Expand Down
2 changes: 1 addition & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def forward(
like `scores` and `labels`.
"""
# get the original image sizes
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
original_image_sizes: List[Tuple[int, int]] = []
for img in images:
val = img.shape[-2:]
assert len(val) == 2
Expand Down

0 comments on commit 74376d6

Please sign in to comment.