diff --git a/slotformer/video_prediction/vp_utils.py b/slotformer/video_prediction/vp_utils.py index 5069b2c..8093520 100644 --- a/slotformer/video_prediction/vp_utils.py +++ b/slotformer/video_prediction/vp_utils.py @@ -229,7 +229,7 @@ def hungarian_miou(gt_mask, pred_mask): N, M = true_oh.shape[-1], pred_oh.shape[-1] # compute all pairwise IoU intersect = (true_oh[:, :, None] * pred_oh[:, None, :]).sum(0) # [N, M] - union = true_oh.sum(0)[:, None] + pred_oh.sum(0)[None, :] # [N, M] + union = (true_oh.sum(0)[:, None] + pred_oh.sum(0)[None, :]) - intersect # [N, M] iou = intersect / (union + 1e-8) # [N, M] iou = iou.detach().cpu().numpy() # find the best match for each gt