Skip to content

Commit

Permalink
Fixing shifts in AnchorGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Oct 20, 2021
1 parent 503cb7e commit 299fac4
Showing 1 changed file with 3 additions and 13 deletions.
16 changes: 3 additions & 13 deletions yolort/models/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,12 @@ def _generate_grids(

return grids

def _generate_shifts(self, grid_sizes: List[List[int]]) -> List[Tensor]:

shifts = []
for i, (height, width) in enumerate(grid_sizes):
shift = (
(self.anchors[i].clone() * self.strides[i])
.view((1, self.num_anchors, 1, 1, 2))
.expand((1, self.num_anchors, height, width, 2))
).float()
shifts.append(shift)

return shifts
def _generate_shifts(self) -> List[Tensor]:
return self.anchors.clone().view(self.num_layers, 1, -1, 1, 1, 2)

def forward(self, feature_maps: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
dtype, device = feature_maps[0].dtype, feature_maps[0].device
grids = self._generate_grids(grid_sizes, dtype, device)
shifts = self._generate_shifts(grid_sizes)
shifts = self._generate_shifts()
return grids, shifts

0 comments on commit 299fac4

Please sign in to comment.