diff --git a/scripts/extract_dino_correspondences.py b/scripts/extract_dino_correspondences.py index 3bf5e8a..50dceec 100644 --- a/scripts/extract_dino_correspondences.py +++ b/scripts/extract_dino_correspondences.py @@ -93,11 +93,10 @@ def find_correspondences(extractor: ViTExtractor, image_path1: str, image_path2: bb_cls_attn = (bb_cls_attn1 + bb_cls_attn2) / 2 ranks = bb_cls_attn - for k in range(n_clusters): - for i, (label, rank) in enumerate(zip(kmeans.labels_, ranks)): - if rank > bb_topk_sims[label]: - bb_topk_sims[label] = rank - bb_indices_to_show[label] = i + for i, (label, rank) in enumerate(zip(kmeans.labels_, ranks)): + if rank > bb_topk_sims[label]: + bb_topk_sims[label] = rank + bb_indices_to_show[label] = i # get coordinates to show indices_to_show = torch.nonzero(bbs_mask, as_tuple=False).squeeze(dim=1)[