Skip to content

Commit

Permalink
Make generalized_box_iou and box_iou share common code.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Feb 11, 2021
1 parent 059b19b commit fb169a6
Showing 1 changed file with 18 additions and 21 deletions.
39 changes: 18 additions & 21 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,21 @@ def box_area(boxes: Tensor) -> Tensor:

# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
area1 = box_area(boxes1)
area2 = box_area(boxes2)

lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]

wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]

union = area1[:, None] + area2 - inter

return inter, union


def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
"""
Return intersection-over-union (Jaccard index) of boxes.
Expand All @@ -200,16 +215,8 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
"""
area1 = box_area(boxes1)
area2 = box_area(boxes2)

lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]

wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]

iou = inter / (area1[:, None] + area2 - inter)
inter, union = _box_inter_union(boxes1, boxes2)
iou = inter / union
return iou


Expand All @@ -234,17 +241,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()

area1 = box_area(boxes1)
area2 = box_area(boxes2)

lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]

wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]

union = area1[:, None] + area2 - inter

inter, union = _box_inter_union(boxes1, boxes2)
iou = inter / union

lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
Expand Down

0 comments on commit fb169a6

Please sign in to comment.