diff --git a/exps/muti_label/dataset.py b/exps/muti_label/dataset.py index e23db0a..a73f876 100644 --- a/exps/muti_label/dataset.py +++ b/exps/muti_label/dataset.py @@ -22,8 +22,11 @@ def parse_data_info(self, raw_data_info): # print(raw_img_info) # print([cur_cates[ann['category_id']-1] for ann in raw_ann_info]) # to one-hot - category_ids = torch.unique(torch.tensor([ann['category_id'] for ann in raw_ann_info])) - 1 - cate_one_hot = torch.eye(len(cur_cates))[category_ids].sum(dim=0) + if len(raw_ann_info) == 0: + cate_one_hot = torch.zeros(len(cur_cates)) + else: + category_ids = torch.unique(torch.tensor([ann['category_id'] for ann in raw_ann_info])) - 1 + cate_one_hot = torch.eye(len(cur_cates))[category_ids].sum(dim=0) return { "img_path": os.path.join(self.data_root, raw_img_info['file_name']), @@ -45,9 +48,9 @@ def load_data_list(self) -> list[dict]: ann_ids = self.lvis.get_ann_ids(img_ids=[img_id]) raw_ann_info = self.lvis.load_anns(ann_ids) - if len(raw_ann_info) == 0: - # print(f"Image {img_id} has no annotations, skipped.") - continue + # if len(raw_ann_info) == 0: + # # print(f"Image {img_id} has no annotations, skipped.") + # continue parsed_data_info = self.parse_data_info({ 'raw_ann_info': diff --git a/oadp/dp/roi_heads.py b/oadp/dp/roi_heads.py index 3c52b5b..37a62a8 100644 --- a/oadp/dp/roi_heads.py +++ b/oadp/dp/roi_heads.py @@ -191,8 +191,8 @@ def forward( img_path = img_meta['img_path'] if img_path not in self.ram_pred_result: print(f"Image path {img_path} not found in RAM prediction results") - ram_cls_score.append(np.ones((1, cls_score.shape[1]))) - raise ValueError(f"Image path {img_path} not found in RAM prediction results") + ram_cls_score.append(np.ones((cls_score.shape[1] - 1, ))) + # raise ValueError(f"Image path {img_path} not found in RAM prediction results") else: ram_cls_score.append(self.ram_pred_result[img_path]) ram_cls_score = np.concatenate(ram_cls_score, axis=0) @@ -203,8 +203,8 @@ def forward( ram_cls_score = repeat(ram_cls_score, 'c -> b c', b=num_box).sigmoid() ram_cls_score_with_bg = torch.cat([ram_cls_score, background_score], dim=1).to(cls_score.device) - return cls_score * ram_cls_score_with_bg - + return (cls_score.softmax(-1) * ram_cls_score_with_bg).log() + @MODELS.register_module() class RAMEnsembleOADPRoIHead(OADPRoIHead):