Skip to content

Commit

Permalink
Replacing all torch.jit.annotations with typing (#3174)
Browse files Browse the repository at this point in the history
Summary:
* Replacing all torch.jit.annotations with typing

* Replacing remaining typing

Reviewed By: fmassa

Differential Revision: D25679213

fbshipit-source-id: 297d52d7ed1322d350619e298a9c2bbaa771d2a2
  • Loading branch information
datumbox authored and facebook-github-bot committed Dec 23, 2020
1 parent aade017 commit c3879ec
Show file tree
Hide file tree
Showing 27 changed files with 40 additions and 56 deletions.
2 changes: 1 addition & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import torch
from torch import Tensor
from torch.autograd import gradcheck
from torch.jit.annotations import Tuple
from torch.nn.modules.utils import _pair
from torchvision import ops
from typing import Tuple


class OpTester(object):
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import math

import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor
from typing import List, Tuple

from torchvision.ops.misc import FrozenBatchNorm2d

Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/detection/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch import nn, Tensor

from torch.jit.annotations import List, Optional, Dict
from typing import List, Optional, Dict
from .image_list import ImageList


Expand Down Expand Up @@ -148,7 +148,7 @@ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Ten
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
anchors: List[List[torch.Tensor]] = []
for i in range(len(image_list.image_sizes)):
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
anchors.append(anchors_in_image)
Expand Down
8 changes: 3 additions & 5 deletions torchvision/models/detection/generalized_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
"""

from collections import OrderedDict
from typing import Union
import torch
from torch import nn
from torch import nn, Tensor
import warnings
from torch.jit.annotations import Tuple, List, Dict, Optional
from torch import Tensor
from typing import Tuple, List, Dict, Optional, Union


class GeneralizedRCNN(nn.Module):
Expand Down Expand Up @@ -71,7 +69,7 @@ def forward(self, images, targets=None):
raise ValueError("Expected target boxes to be of type "
"Tensor, got {:}.".format(type(boxes)))

original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
original_image_sizes: List[Tuple[int, int]] = []
for img in images:
val = img.shape[-2:]
assert len(val) == 2
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/detection/image_list.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor
from typing import List, Tuple


class ImageList(object):
Expand Down
11 changes: 5 additions & 6 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import warnings

import torch
import torch.nn as nn
from torch import Tensor
from torch.jit.annotations import Dict, List, Tuple, Optional
from torch import nn, Tensor
from typing import Dict, List, Tuple, Optional

from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url
Expand Down Expand Up @@ -402,7 +401,7 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):

num_images = len(image_shapes)

detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
detections: List[Dict[str, Tensor]] = []

for index in range(num_images):
box_regression_per_image = [br[index] for br in box_regression]
Expand Down Expand Up @@ -486,7 +485,7 @@ def forward(self, images, targets=None):
"Tensor, got {:}.".format(type(boxes)))

# get the original image sizes
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
original_image_sizes: List[Tuple[int, int]] = []
for img in images:
val = img.shape[-2:]
assert len(val) == 2
Expand Down Expand Up @@ -524,7 +523,7 @@ def forward(self, images, targets=None):
anchors = self.anchor_generator(images, features)

losses = {}
detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
detections: List[Dict[str, Tensor]] = []
if self.training:
assert targets is not None

Expand Down
8 changes: 4 additions & 4 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from . import _utils as det_utils

from torch.jit.annotations import Optional, List, Dict, Tuple
from typing import Optional, List, Dict, Tuple


def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
Expand Down Expand Up @@ -379,7 +379,7 @@ def expand_masks(mask, padding):
scale = expand_masks_tracing_scale(M, padding)
else:
scale = float(M + 2 * padding) / M
padded_mask = torch.nn.functional.pad(mask, (padding,) * 4)
padded_mask = F.pad(mask, (padding,) * 4)
return padded_mask, scale


Expand Down Expand Up @@ -482,7 +482,7 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1):
return ret


class RoIHeads(torch.nn.Module):
class RoIHeads(nn.Module):
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
Expand Down Expand Up @@ -753,7 +753,7 @@ def forward(self,
box_features = self.box_head(box_features)
class_logits, box_regression = self.box_predictor(box_features)

result = torch.jit.annotate(List[Dict[str, torch.Tensor]], [])
result: List[Dict[str, torch.Tensor]] = []
losses = {}
if self.training:
assert labels is not None and regression_targets is not None
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from . import _utils as det_utils
from .image_list import ImageList

from torch.jit.annotations import List, Optional, Dict, Tuple
from typing import List, Optional, Dict, Tuple

# Import AnchorGenerator to keep compatibility.
from .anchor_utils import AnchorGenerator
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn, Tensor
from torch.nn import functional as F
import torchvision
from torch.jit.annotations import List, Tuple, Dict, Optional
from typing import List, Tuple, Dict, Optional

from .image_list import ImageList
from .roi_heads import paste_masks_in_image
Expand Down Expand Up @@ -109,7 +109,7 @@ def forward(self,

image_sizes = [img.shape[-2:] for img in images]
images = self.batch_images(images)
image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], [])
image_sizes_list: List[Tuple[int, int]] = []
for image_size in image_sizes:
assert len(image_size) == 2
image_sizes_list.append((image_size[0], image_size[1]))
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor
# N x 480 x 14 x 14
x = self.inception4a(x)
# N x 512 x 14 x 14
aux1 = torch.jit.annotate(Optional[Tensor], None)
aux1: Optional[Tensor] = None
if self.aux1 is not None:
if self.training:
aux1 = self.aux1(x)
Expand All @@ -173,7 +173,7 @@ def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor
# N x 512 x 14 x 14
x = self.inception4d(x)
# N x 528 x 14 x 14
aux2 = torch.jit.annotate(Optional[Tensor], None)
aux2: Optional[Tensor] = None
if self.aux2 is not None:
if self.training:
aux2 = self.aux2(x)
Expand Down
7 changes: 3 additions & 4 deletions torchvision/models/inception.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from collections import namedtuple
import warnings
import torch
import torch.nn as nn
from torch import nn, Tensor
import torch.nn.functional as F
from torch import Tensor
from .utils import load_state_dict_from_url
from typing import Callable, Any, Optional, Tuple, List

Expand All @@ -17,7 +16,7 @@
}

InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
InceptionOutputs.__annotations__ = {'logits': torch.Tensor, 'aux_logits': Optional[torch.Tensor]}
InceptionOutputs.__annotations__ = {'logits': Tensor, 'aux_logits': Optional[Tensor]}

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
Expand Down Expand Up @@ -171,7 +170,7 @@ def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
aux = torch.jit.annotate(Optional[Tensor], None)
aux: Optional[Tensor] = None
if self.AuxLogits is not None:
if self.training:
aux = self.AuxLogits(x)
Expand Down
1 change: 0 additions & 1 deletion torchvision/models/quantization/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.jit.annotations import Optional

from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.googlenet import (
Expand Down
1 change: 0 additions & 1 deletion torchvision/models/quantization/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch.nn.functional as F
from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs
from torch.jit.annotations import Optional
from torchvision.models.utils import load_state_dict_from_url
from .utils import _replace_relu, quantize_model

Expand Down
2 changes: 0 additions & 2 deletions torchvision/ops/_box_convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import torch
from torch.jit.annotations import Tuple
from torch import Tensor
import torchvision


def _box_cxcywh_to_xyxy(boxes: Tensor) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion torchvision/ops/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import Tensor
from torch.jit.annotations import List
from typing import List


def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch.jit.annotations import Tuple
from torch import Tensor
from typing import Tuple
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
import torchvision
from torchvision.extension import _assert_has_ops
Expand Down
2 changes: 1 addition & 1 deletion torchvision/ops/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair
from torch.jit.annotations import Optional, Tuple
from typing import Optional, Tuple
from torchvision.extension import _assert_has_ops


Expand Down
3 changes: 1 addition & 2 deletions torchvision/ops/feature_pyramid_network.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from collections import OrderedDict

import torch
import torch.nn.functional as F
from torch import nn, Tensor

from torch.jit.annotations import Tuple, List, Dict, Optional
from typing import Tuple, List, Dict, Optional


class ExtraFPNBlock(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions torchvision/ops/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import warnings
import torch
from torch import Tensor, Size
from torch.jit.annotations import List, Optional, Tuple
from torch import Tensor
from typing import List, Optional


class Conv2d(torch.nn.Conv2d):
Expand Down
9 changes: 3 additions & 6 deletions torchvision/ops/poolers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from typing import Union

import torch
import torch.nn.functional as F
from torch import nn, Tensor

import torchvision
from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area

from torch.jit.annotations import Optional, List, Dict, Tuple
import torchvision
from typing import Optional, List, Dict, Tuple, Union


# copying result_idx_in_level to a specific index in result[]
Expand Down Expand Up @@ -149,7 +146,7 @@ def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor:
def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
# assumption: the scale is of the form 2 ** (-k), with k integer
size = feature.shape[-2:]
possible_scales = torch.jit.annotate(List[float], [])
possible_scales: List[float] = []
for s1, s2 in zip(size, original_size):
approx_scale = float(s1) / float(s2)
scale = 2 ** float(torch.tensor(approx_scale).log2().round())
Expand Down
1 change: 0 additions & 1 deletion torchvision/ops/ps_roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from torch import nn, Tensor

from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, Tuple

from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
Expand Down
1 change: 0 additions & 1 deletion torchvision/ops/ps_roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from torch import nn, Tensor

from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, Tuple

from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
Expand Down
2 changes: 1 addition & 1 deletion torchvision/ops/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import nn, Tensor

from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, BroadcastingList2
from torch.jit.annotations import BroadcastingList2

from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
Expand Down
2 changes: 1 addition & 1 deletion torchvision/ops/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import nn, Tensor

from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, BroadcastingList2
from torch.jit.annotations import BroadcastingList2

from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
Expand Down
3 changes: 1 addition & 2 deletions torchvision/transforms/autoaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from enum import Enum
from torch import Tensor
from torch.jit.annotations import List, Tuple
from typing import Optional
from typing import List, Tuple, Optional

from . import functional as F, InterpolationMode

Expand Down
3 changes: 1 addition & 2 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import numbers
import warnings
from enum import Enum
from typing import Any, Optional

import numpy as np
from PIL import Image

import torch
from torch import Tensor
from torch.jit.annotations import List, Tuple
from typing import List, Tuple, Any, Optional

try:
import accimage
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import warnings
from typing import Optional, Tuple

import torch
from torch import Tensor
from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad
from torch.jit.annotations import List, BroadcastingList2
from torch.jit.annotations import BroadcastingList2
from typing import Optional, Tuple, List


def _is_tensor_a_torch_image(x: Tensor) -> bool:
Expand Down

0 comments on commit c3879ec

Please sign in to comment.