-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy path_utils.py
540 lines (436 loc) · 21.6 KB
/
_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
import math
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
class BalancedPositiveNegativeSampler:
"""
This class samples batches, ensuring that they contain a fixed proportion of positives
"""
def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
"""
Args:
batch_size_per_image (int): number of elements to be selected per image
positive_fraction (float): percentage of positive elements per batch
"""
self.batch_size_per_image = batch_size_per_image
self.positive_fraction = positive_fraction
def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
"""
Args:
matched_idxs: list of tensors containing -1, 0 or positive values.
Each tensor corresponds to a specific image.
-1 values are ignored, 0 are considered as negatives and > 0 as
positives.
Returns:
pos_idx (list[tensor])
neg_idx (list[tensor])
Returns two lists of binary masks for each image.
The first list contains the positive elements that were selected,
and the second list the negative example.
"""
pos_idx = []
neg_idx = []
for matched_idxs_per_image in matched_idxs:
positive = torch.where(matched_idxs_per_image >= 1)[0]
negative = torch.where(matched_idxs_per_image == 0)[0]
num_pos = int(self.batch_size_per_image * self.positive_fraction)
# protect against not enough positive examples
num_pos = min(positive.numel(), num_pos)
num_neg = self.batch_size_per_image - num_pos
# protect against not enough negative examples
num_neg = min(negative.numel(), num_neg)
# randomly select positive and negative examples
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
pos_idx_per_image = positive[perm1]
neg_idx_per_image = negative[perm2]
# create binary mask from indices
pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
pos_idx_per_image_mask[pos_idx_per_image] = 1
neg_idx_per_image_mask[neg_idx_per_image] = 1
pos_idx.append(pos_idx_per_image_mask)
neg_idx.append(neg_idx_per_image_mask)
return pos_idx, neg_idx
@torch.jit._script_if_tracing
def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
"""
Encode a set of proposals with respect to some
reference boxes
Args:
reference_boxes (Tensor): reference boxes
proposals (Tensor): boxes to be encoded
weights (Tensor[4]): the weights for ``(x, y, w, h)``
"""
# perform some unpacking to make it JIT-fusion friendly
wx = weights[0]
wy = weights[1]
ww = weights[2]
wh = weights[3]
proposals_x1 = proposals[:, 0].unsqueeze(1)
proposals_y1 = proposals[:, 1].unsqueeze(1)
proposals_x2 = proposals[:, 2].unsqueeze(1)
proposals_y2 = proposals[:, 3].unsqueeze(1)
reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
# implementation starts here
ex_widths = proposals_x2 - proposals_x1
ex_heights = proposals_y2 - proposals_y1
ex_ctr_x = proposals_x1 + 0.5 * ex_widths
ex_ctr_y = proposals_y1 + 0.5 * ex_heights
gt_widths = reference_boxes_x2 - reference_boxes_x1
gt_heights = reference_boxes_y2 - reference_boxes_y1
gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
targets_dw = ww * torch.log(gt_widths / ex_widths)
targets_dh = wh * torch.log(gt_heights / ex_heights)
targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
return targets
class BoxCoder:
"""
This class encodes and decodes a set of bounding boxes into
the representation used for training the regressors.
"""
def __init__(
self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
) -> None:
"""
Args:
weights (4-element tuple)
bbox_xform_clip (float)
"""
self.weights = weights
self.bbox_xform_clip = bbox_xform_clip
def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
boxes_per_image = [len(b) for b in reference_boxes]
reference_boxes = torch.cat(reference_boxes, dim=0)
proposals = torch.cat(proposals, dim=0)
targets = self.encode_single(reference_boxes, proposals)
return targets.split(boxes_per_image, 0)
def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
"""
Encode a set of proposals with respect to some
reference boxes
Args:
reference_boxes (Tensor): reference boxes
proposals (Tensor): boxes to be encoded
"""
dtype = reference_boxes.dtype
device = reference_boxes.device
weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
targets = encode_boxes(reference_boxes, proposals, weights)
return targets
def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
torch._assert(
isinstance(boxes, (list, tuple)),
"This function expects boxes of type list or tuple.",
)
torch._assert(
isinstance(rel_codes, torch.Tensor),
"This function expects rel_codes of type torch.Tensor.",
)
boxes_per_image = [b.size(0) for b in boxes]
concat_boxes = torch.cat(boxes, dim=0)
box_sum = 0
for val in boxes_per_image:
box_sum += val
if box_sum > 0:
rel_codes = rel_codes.reshape(box_sum, -1)
pred_boxes = self.decode_single(rel_codes, concat_boxes)
if box_sum > 0:
pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
return pred_boxes
def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
"""
From a set of original boxes and encoded relative box offsets,
get the decoded boxes.
Args:
rel_codes (Tensor): encoded boxes
boxes (Tensor): reference boxes.
"""
boxes = boxes.to(rel_codes.dtype)
widths = boxes[:, 2] - boxes[:, 0]
heights = boxes[:, 3] - boxes[:, 1]
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights
wx, wy, ww, wh = self.weights
dx = rel_codes[:, 0::4] / wx
dy = rel_codes[:, 1::4] / wy
dw = rel_codes[:, 2::4] / ww
dh = rel_codes[:, 3::4] / wh
# Prevent sending too large values into torch.exp()
dw = torch.clamp(dw, max=self.bbox_xform_clip)
dh = torch.clamp(dh, max=self.bbox_xform_clip)
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
pred_w = torch.exp(dw) * widths[:, None]
pred_h = torch.exp(dh) * heights[:, None]
# Distance from center to box's corner.
c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
pred_boxes1 = pred_ctr_x - c_to_c_w
pred_boxes2 = pred_ctr_y - c_to_c_h
pred_boxes3 = pred_ctr_x + c_to_c_w
pred_boxes4 = pred_ctr_y + c_to_c_h
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
return pred_boxes
class BoxLinearCoder:
"""
The linear box-to-box transform defined in FCOS. The transformation is parameterized
by the distance from the center of (square) src box to 4 edges of the target box.
"""
def __init__(self, normalize_by_size: bool = True) -> None:
"""
Args:
normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
"""
self.normalize_by_size = normalize_by_size
def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
"""
Encode a set of proposals with respect to some reference boxes
Args:
reference_boxes (Tensor): reference boxes
proposals (Tensor): boxes to be encoded
Returns:
Tensor: the encoded relative box offsets that can be used to
decode the boxes.
"""
# get the center of reference_boxes
reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
# get box regression transformation deltas
target_l = reference_boxes_ctr_x - proposals[..., 0]
target_t = reference_boxes_ctr_y - proposals[..., 1]
target_r = proposals[..., 2] - reference_boxes_ctr_x
target_b = proposals[..., 3] - reference_boxes_ctr_y
targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
if self.normalize_by_size:
reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
reference_boxes_size = torch.stack(
(reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
)
targets = targets / reference_boxes_size
return targets
def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
"""
From a set of original boxes and encoded relative box offsets,
get the decoded boxes.
Args:
rel_codes (Tensor): encoded boxes
boxes (Tensor): reference boxes.
Returns:
Tensor: the predicted boxes with the encoded relative box offsets.
.. note::
This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
"""
boxes = boxes.to(dtype=rel_codes.dtype)
ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
if self.normalize_by_size:
boxes_w = boxes[..., 2] - boxes[..., 0]
boxes_h = boxes[..., 3] - boxes[..., 1]
list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
rel_codes = rel_codes * list_box_size
pred_boxes1 = ctr_x - rel_codes[..., 0]
pred_boxes2 = ctr_y - rel_codes[..., 1]
pred_boxes3 = ctr_x + rel_codes[..., 2]
pred_boxes4 = ctr_y + rel_codes[..., 3]
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
return pred_boxes
class Matcher:
"""
This class assigns to each predicted "element" (e.g., a box) a ground-truth
element. Each predicted element will have exactly zero or one matches; each
ground-truth element may be assigned to zero or more predicted elements.
Matching is based on the MxN match_quality_matrix, that characterizes how well
each (ground-truth, predicted)-pair match. For example, if the elements are
boxes, the matrix may contain box IoU overlap values.
The matcher returns a tensor of size N containing the index of the ground-truth
element m that matches to prediction n. If there is no match, a negative value
is returned.
"""
BELOW_LOW_THRESHOLD = -1
BETWEEN_THRESHOLDS = -2
__annotations__ = {
"BELOW_LOW_THRESHOLD": int,
"BETWEEN_THRESHOLDS": int,
}
def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
"""
Args:
high_threshold (float): quality values greater than or equal to
this value are candidate matches.
low_threshold (float): a lower quality threshold used to stratify
matches into three levels:
1) matches >= high_threshold
2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
allow_low_quality_matches (bool): if True, produce additional matches
for predictions that have only low-quality match candidates. See
set_low_quality_matches_ for more details.
"""
self.BELOW_LOW_THRESHOLD = -1
self.BETWEEN_THRESHOLDS = -2
torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
self.high_threshold = high_threshold
self.low_threshold = low_threshold
self.allow_low_quality_matches = allow_low_quality_matches
def __call__(self, match_quality_matrix: Tensor) -> Tensor:
"""
Args:
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
pairwise quality between M ground-truth elements and N predicted elements.
Returns:
matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
[0, M - 1] or a negative value indicating that prediction i could not
be matched.
"""
if match_quality_matrix.numel() == 0:
# empty targets or proposals not supported during training
if match_quality_matrix.shape[0] == 0:
raise ValueError("No ground-truth boxes available for one of the images during training")
else:
raise ValueError("No proposal boxes available for one of the images during training")
# match_quality_matrix is M (gt) x N (predicted)
# Max over gt elements (dim 0) to find best gt candidate for each prediction
matched_vals, matches = match_quality_matrix.max(dim=0)
if self.allow_low_quality_matches:
all_matches = matches.clone()
else:
all_matches = None # type: ignore[assignment]
# Assign candidate matches with low quality to negative (unassigned) values
below_low_threshold = matched_vals < self.low_threshold
between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
matches[between_thresholds] = self.BETWEEN_THRESHOLDS
if self.allow_low_quality_matches:
if all_matches is None:
torch._assert(False, "all_matches should not be None")
else:
self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
return matches
def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
"""
Produce additional matches for predictions that have only low-quality matches.
Specifically, for each ground-truth find the set of predictions that have
maximum overlap with it (including ties); for each prediction in that set, if
it is unmatched, then match it to the ground-truth with which it has the highest
quality value.
"""
# For each gt, find the prediction with which it has the highest quality
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
# Find the highest quality match available, even if it is low, including ties
gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
# Example gt_pred_pairs_of_highest_quality:
# (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]),
# tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
# Each element in the first tensor is a gt index, and each element in second tensor is a prediction index
# Note how gt items 1, 2, 3, and 5 each have two ties
pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
class SSDMatcher(Matcher):
def __init__(self, threshold: float) -> None:
super().__init__(threshold, threshold, allow_low_quality_matches=False)
def __call__(self, match_quality_matrix: Tensor) -> Tensor:
matches = super().__call__(match_quality_matrix)
# For each gt, find the prediction with which it has the highest quality
_, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
matches[highest_quality_pred_foreach_gt] = torch.arange(
highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
)
return matches
def overwrite_eps(model: nn.Module, eps: float) -> None:
"""
This method overwrites the default eps values of all the
FrozenBatchNorm2d layers of the model with the provided value.
This is necessary to address the BC-breaking change introduced
by the bug-fix at pytorch/vision#2933. The overwrite is applied
only when the pretrained weights are loaded to maintain compatibility
with previous versions.
Args:
model (nn.Module): The model on which we perform the overwrite.
eps (float): The new value of eps.
"""
for module in model.modules():
if isinstance(module, FrozenBatchNorm2d):
module.eps = eps
def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
"""
This method retrieves the number of output channels of a specific model.
Args:
model (nn.Module): The model for which we estimate the out_channels.
It should return a single Tensor or an OrderedDict[Tensor].
size (Tuple[int, int]): The size (wxh) of the input.
Returns:
out_channels (List[int]): A list of the output channels of the model.
"""
in_training = model.training
model.eval()
with torch.no_grad():
# Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
device = next(model.parameters()).device
tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
features = model(tmp_img)
if isinstance(features, torch.Tensor):
features = OrderedDict([("0", features)])
out_channels = [x.size(1) for x in features.values()]
if in_training:
model.train()
return out_channels
@torch.jit.unused
def _fake_cast_onnx(v: Tensor) -> int:
return v # type: ignore[return-value]
def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
"""
ONNX spec requires the k-value to be less than or equal to the number of inputs along
provided dim. Certain models use the number of elements along a particular axis instead of K
if K exceeds the number of elements along that axis. Previously, python's min() function was
used to determine whether to use the provided k-value or the specified dim axis value.
However, in cases where the model is being exported in tracing mode, python min() is
static causing the model to be traced incorrectly and eventually fail at the topk node.
In order to avoid this situation, in tracing mode, torch.min() is used instead.
Args:
input (Tensor): The original input tensor.
orig_kval (int): The provided k-value.
axis(int): Axis along which we retrieve the input size.
Returns:
min_kval (int): Appropriately selected k-value.
"""
if not torch.jit.is_tracing():
return min(orig_kval, input.size(axis))
axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
return _fake_cast_onnx(min_kval)
def _box_loss(
type: str,
box_coder: BoxCoder,
anchors_per_image: Tensor,
matched_gt_boxes_per_image: Tensor,
bbox_regression_per_image: Tensor,
cnf: Optional[Dict[str, float]] = None,
) -> Tensor:
torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
if type == "l1":
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
elif type == "smooth_l1":
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
else:
bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
if type == "ciou":
return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
if type == "diou":
return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
# otherwise giou
return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)