-
Notifications
You must be signed in to change notification settings - Fork 374
/
generalized_iou_loss.py
46 lines (38 loc) · 1.34 KB
/
generalized_iou_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#! /usr/bin/python
# -*- encoding: utf-8 -*-
import torch
def generalized_iou_loss(gt_bboxes, pr_bboxes, reduction='mean'):
"""
gt_bboxes: tensor (-1, 4) xyxy
pr_bboxes: tensor (-1, 4) xyxy
loss proposed in the paper of giou
"""
gt_area = (gt_bboxes[:, 2]-gt_bboxes[:, 0])*(gt_bboxes[:, 3]-gt_bboxes[:, 1])
pr_area = (pr_bboxes[:, 2]-pr_bboxes[:, 0])*(pr_bboxes[:, 3]-pr_bboxes[:, 1])
# iou
lt = torch.max(gt_bboxes[:, :2], pr_bboxes[:, :2])
rb = torch.min(gt_bboxes[:, 2:], pr_bboxes[:, 2:])
TO_REMOVE = 1
wh = (rb - lt + TO_REMOVE).clamp(min=0)
inter = wh[:, 0] * wh[:, 1]
union = gt_area + pr_area - inter
iou = inter / union
# enclosure
lt = torch.min(gt_bboxes[:, :2], pr_bboxes[:, :2])
rb = torch.max(gt_bboxes[:, 2:], pr_bboxes[:, 2:])
wh = (rb - lt + TO_REMOVE).clamp(min=0)
enclosure = wh[:, 0] * wh[:, 1]
giou = iou - (enclosure-union)/enclosure
loss = 1. - giou
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
elif reduction == 'none':
pass
return loss
if __name__ == '__main__':
gt_bbox = torch.tensor([[1, 2, 3, 4]], dtype=torch.float32)
pr_bbox = torch.tensor([[2, 3, 4, 5]], dtype=torch.float32)
loss = generalized_iou_loss(gt_bbox, pr_bbox, reduction='none')
print(loss)