Skip to content

Commit

Permalink
Fix anchors in AnchorGenerator._generate_shifts
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Dec 20, 2021
1 parent 8aaec51 commit a0ef9e7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
11 changes: 8 additions & 3 deletions yolort/models/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def __init__(self, strides: List[int], anchor_grids: List[List[float]]):
super().__init__()
assert len(strides) == len(anchor_grids)
self.strides = strides
self.anchor_grids = anchor_grids
self.num_layers = len(anchor_grids)
self.num_anchors = len(anchor_grids[0]) // 2
self.register_buffer("anchors", torch.tensor(anchor_grids).float().view(self.num_layers, -1, 2))

def _generate_grids(
self,
Expand All @@ -39,12 +39,17 @@ def _generate_shifts(
self,
grid_sizes: List[List[int]],
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
) -> List[Tensor]:

shifts = []
anchors = torch.tensor(self.anchor_grids, dtype=dtype, device=device)
strides = torch.tensor(self.strides, dtype=dtype, device=device)
anchors = anchors.view(self.num_layers, -1, 2) / strides.view(-1, 1, 1)

for i, (height, width) in enumerate(grid_sizes):
shift = (
(self.anchors[i].clone() * self.strides[i])
(anchors[i].clone() * self.strides[i])
.view((1, self.num_anchors, 1, 1, 2))
.expand((1, self.num_anchors, height, width, 2))
.to(dtype=dtype)
Expand All @@ -56,5 +61,5 @@ def forward(self, feature_maps: List[Tensor]) -> Tuple[List[Tensor], List[Tensor
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
dtype, device = feature_maps[0].dtype, feature_maps[0].device
grids = self._generate_grids(grid_sizes, dtype=dtype, device=device)
shifts = self._generate_shifts(grid_sizes, dtype=dtype)
shifts = self._generate_shifts(grid_sizes, dtype=dtype, device=device)
return grids, shifts
4 changes: 2 additions & 2 deletions yolort/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
criterion = SetCriterion(
anchor_generator.num_anchors,
anchor_generator.strides,
anchor_generator.anchors,
anchor_generator.anchor_grids,
num_classes,
)
self.compute_loss = criterion
Expand Down Expand Up @@ -270,7 +270,7 @@ def build_model(
if model_urls.get(weights_name, None) is None:
raise ValueError(f"No checkpoint is available for model {weights_name}")
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
model.load_state_dict(state_dict, strict=False)
model.load_state_dict(state_dict)

return model

Expand Down

0 comments on commit a0ef9e7

Please sign in to comment.