Skip to content

Commit

Permalink
Correcting incorrect types (#3)
Browse files Browse the repository at this point in the history
torchscript support `torch.dtype` now, refer to pytorch/vision#3032
  • Loading branch information
zhiqwang authored Nov 27, 2020
1 parent 288ddd0 commit 2bbd1f7
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions models/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@ def __init__(
self.strides = strides
self.anchor_grids = anchor_grids

def set_wh_weights(self, grid_sizes, dtype, device):
# type: (List[List[int]], int, Device) -> Tensor # noqa: F821
def set_wh_weights(
self,
grid_sizes: List[List[int]],
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
) -> Tensor:

wh_weights = []

for size, stride in zip(grid_sizes, self.strides):
Expand All @@ -31,8 +36,13 @@ def set_wh_weights(self, grid_sizes, dtype, device):

return torch.cat(wh_weights)

def set_xy_weights(self, grid_sizes, dtype, device):
# type: (List[List[int]], int, Device) -> Tensor # noqa: F821
def set_xy_weights(
self,
grid_sizes: List[List[int]],
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
) -> Tensor:

xy_weights = []

for size, anchor_grid in zip(grid_sizes, self.anchor_grids):
Expand All @@ -45,8 +55,12 @@ def set_xy_weights(self, grid_sizes, dtype, device):

return torch.cat(xy_weights)

def grid_anchors(self, grid_sizes, device):
# type: (List[List[int]], Device) -> Tensor # noqa: F821
def grid_anchors(
self,
grid_sizes: List[List[int]],
device: torch.device = torch.device("cpu"),
) -> Tensor:

anchors = []

for size in grid_sizes:
Expand Down

0 comments on commit 2bbd1f7

Please sign in to comment.