From 2bbd1f7a15f74aee9a625ba24559d92c3e8027b8 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Sat, 28 Nov 2020 03:55:53 +0800 Subject: [PATCH] Correcting incorrect types (#3) torchscript support `torch.dtype` now, refer to https://github.com/pytorch/vision/pull/3032 --- models/anchor_utils.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/models/anchor_utils.py b/models/anchor_utils.py index fddb37f2..ebe15107 100644 --- a/models/anchor_utils.py +++ b/models/anchor_utils.py @@ -17,8 +17,13 @@ def __init__( self.strides = strides self.anchor_grids = anchor_grids - def set_wh_weights(self, grid_sizes, dtype, device): - # type: (List[List[int]], int, Device) -> Tensor # noqa: F821 + def set_wh_weights( + self, + grid_sizes: List[List[int]], + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ) -> Tensor: + wh_weights = [] for size, stride in zip(grid_sizes, self.strides): @@ -31,8 +36,13 @@ def set_wh_weights(self, grid_sizes, dtype, device): return torch.cat(wh_weights) - def set_xy_weights(self, grid_sizes, dtype, device): - # type: (List[List[int]], int, Device) -> Tensor # noqa: F821 + def set_xy_weights( + self, + grid_sizes: List[List[int]], + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ) -> Tensor: + xy_weights = [] for size, anchor_grid in zip(grid_sizes, self.anchor_grids): @@ -45,8 +55,12 @@ def set_xy_weights(self, grid_sizes, dtype, device): return torch.cat(xy_weights) - def grid_anchors(self, grid_sizes, device): - # type: (List[List[int]], Device) -> Tensor # noqa: F821 + def grid_anchors( + self, + grid_sizes: List[List[int]], + device: torch.device = torch.device("cpu"), + ) -> Tensor: + anchors = [] for size in grid_sizes: