Skip to content

Commit

Permalink
Fix missing outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Dec 20, 2021
1 parent 595d758 commit d18a54f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
4 changes: 2 additions & 2 deletions yolort/models/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Tuple, Optional
from typing import Tuple, Optional

import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -54,7 +54,7 @@ def decode_single(
rel_codes: Tensor,
grid: Tensor,
shift: Tensor,
stride: List[int],
stride: int,
) -> Tuple[Tensor, Tensor]:
"""
From a set of original boxes and encoded relative box offsets,
Expand Down
12 changes: 8 additions & 4 deletions yolort/models/box_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,15 +344,15 @@ def _concat_pred_logits(
all_pred_logits = []

for i, head_output in enumerate(head_outputs):
head_feature = head_output.sigmoid()
head_feature = torch.sigmoid(head_output)
pred_xy, pred_wh = det_utils.decode_single(
head_feature[..., :4],
grids[i],
shifts[i],
self.strides[i],
)
pred_logits = torch.cat((pred_xy, pred_wh, head_feature[..., 4:]), dim=-1)
all_pred_logits.append(pred_logits.reshape(batch_size, -1, K))
all_pred_logits.append(pred_logits.view(batch_size, -1, K))

all_pred_logits = torch.cat(all_pred_logits, dim=1)

Expand Down Expand Up @@ -394,7 +394,8 @@ def forward(
pred_scores = []

for idx in range(batch_size): # image idx, image inference
boxes, scores = self._decode_pred_logits(all_pred_logits[idx])
pred_logits = all_pred_logits[idx]
boxes, scores = self._decode_pred_logits(pred_logits)
bbox_regression.append(boxes)
pred_scores.append(scores)

Expand Down Expand Up @@ -449,7 +450,8 @@ def forward(
detections: List[Dict[str, Tensor]] = []

for idx in range(batch_size): # image idx, image inference
boxes, scores = self._decode_pred_logits(all_pred_logits[idx])
pred_logits = all_pred_logits[idx]
boxes, scores = self._decode_pred_logits(pred_logits)
# remove low scoring boxes
inds, labels = torch.where(scores > self.score_thresh)
boxes, scores = boxes[inds], scores[inds, labels]
Expand All @@ -461,3 +463,5 @@ def forward(
boxes, scores, labels = boxes[keep], scores[keep], labels[keep]

detections.append({"scores": scores, "labels": labels, "boxes": boxes})

return detections
2 changes: 1 addition & 1 deletion yolort/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def build_model(
if model_urls.get(weights_name, None) is None:
raise ValueError(f"No checkpoint is available for model {weights_name}")
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
model.load_state_dict(state_dict)
model.load_state_dict(state_dict, strict=False)

return model

Expand Down

0 comments on commit d18a54f

Please sign in to comment.